1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Test for parse_flags_from_env.cc
17
18 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
19
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <vector>
23
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/subprocess.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/util/command_line_flags.h"
31
32 namespace xla {
33
34 // Test that XLA flags can be set from the environment.
35 // Failure messages are accompanied by the text in msg[].
TestParseFlagsFromEnv(const char * msg)36 static void TestParseFlagsFromEnv(const char* msg) {
37 // Initialize module under test.
38 int* pargc;
39 std::vector<char*>* pargv;
40 ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv);
41
42 // Check that actual flags can be parsed.
43 bool simple = false;
44 string with_value;
45 string embedded_quotes;
46 string single_quoted;
47 string double_quoted;
48 std::vector<tensorflow::Flag> flag_list = {
49 tensorflow::Flag("simple", &simple, ""),
50 tensorflow::Flag("with_value", &with_value, ""),
51 tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
52 tensorflow::Flag("single_quoted", &single_quoted, ""),
53 tensorflow::Flag("double_quoted", &double_quoted, ""),
54 };
55 bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
56 CHECK_EQ(*pargc, 1) << msg;
57 const std::vector<char*>& argv_second = *pargv;
58 CHECK_NE(argv_second[0], nullptr) << msg;
59 CHECK_EQ(argv_second[1], nullptr) << msg;
60 CHECK(parsed_ok) << msg;
61 CHECK(simple) << msg;
62 CHECK_EQ(with_value, "a_value") << msg;
63 CHECK_EQ(embedded_quotes, "single'double\"") << msg;
64 CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
65 CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
66 }
67
68 // The flags settings to test.
69 static const char kTestFlagString[] =
70 "--simple "
71 "--with_value=a_value "
72 "--embedded_quotes=single'double\" "
73 "--single_quoted='single quoted \\\\ \n \"' "
74 "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
75
76 // Test that the environent variable is parsed correctly.
TEST(ParseFlagsFromEnv,Basic)77 TEST(ParseFlagsFromEnv, Basic) {
78 // Prepare environment.
79 setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
80 TestParseFlagsFromEnv("(flags in environment variable)");
81 }
82
83 // Test that a file named by the environent variable is parsed correctly.
TEST(ParseFlagsFromEnv,File)84 TEST(ParseFlagsFromEnv, File) {
85 // environment variables where tmp dir may be specified.
86 static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
87 static const char kTempDir[] = "/tmp"; // default temp dir if all else fails.
88 const char* tmp_dir = nullptr;
89 for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
90 tmp_dir = getenv(kTempVars[i]);
91 }
92 if (tmp_dir == nullptr) {
93 tmp_dir = kTempDir;
94 }
95 string tmp_file =
96 absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid());
97 FILE* fp = fopen(tmp_file.c_str(), "w");
98 CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
99 for (int i = 0; kTestFlagString[i] != '\0'; i++) {
100 putc(kTestFlagString[i], fp);
101 }
102 fflush(fp);
103 CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
104 fclose(fp);
105 // Prepare environment.
106 setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
107 TestParseFlagsFromEnv("(flags in file)");
108 unlink(tmp_file.c_str());
109 }
110
111 // Name of the test binary.
112 static const char* binary_name;
113
114 // Test that when we use both the environment variable and actual
115 // commend line flags (when the latter is possible), the latter win.
TEST(ParseFlagsFromEnv,EnvAndFlag)116 TEST(ParseFlagsFromEnv, EnvAndFlag) {
117 static struct {
118 const char* env;
119 const char* arg;
120 const char* expected_value;
121 } test[] = {
122 {nullptr, nullptr, "1\n"},
123 {nullptr, "--int_flag=2", "2\n"},
124 {"--int_flag=3", nullptr, "3\n"},
125 {"--int_flag=3", "--int_flag=2", "2\n"}, // flag beats environment
126 };
127 for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
128 if (test[i].env != nullptr) {
129 setenv("TF_XLA_FLAGS", test[i].env, true /*overwrite*/);
130 }
131 tensorflow::SubProcess child;
132 std::vector<string> argv;
133 argv.push_back(binary_name);
134 argv.push_back("--recursing");
135 if (test[i].arg != nullptr) {
136 argv.push_back(test[i].arg);
137 }
138 child.SetProgram(binary_name, argv);
139 child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
140 CHECK(child.Start()) << "test " << i;
141 string stdout_str;
142 int child_status = child.Communicate(nullptr, &stdout_str, nullptr);
143 CHECK_EQ(child_status, 0) << "test " << i;
144 CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
145 }
146 }
147
148 } // namespace xla
149
main(int argc,char * argv[])150 int main(int argc, char* argv[]) {
151 // Save name of binary so that it may invoke itself.
152 xla::binary_name = argv[0];
153 bool recursing = false;
154 xla::int32 int_flag = 1;
155 const std::vector<tensorflow::Flag> flag_list = {
156 tensorflow::Flag("recursing", &recursing,
157 "Whether the binary is being invoked recusively."),
158 tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
159 };
160 xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
161 bool parse_ok =
162 xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
163 if (!parse_ok) {
164 LOG(QFATAL) << "can't parse from environment\n" << usage;
165 }
166 parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
167 if (!parse_ok) {
168 LOG(QFATAL) << usage;
169 }
170 if (recursing) {
171 printf("%d\n", int_flag);
172 exit(0);
173 }
174 testing::InitGoogleTest(&argc, argv);
175 return RUN_ALL_TESTS();
176 }
177