1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Test for parse_flags_from_env.cc
17
18 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/subprocess.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/util/command_line_flags.h"
32
33 namespace xla {
34
35 // Test that XLA flags can be set from the environment.
36 // Failure messages are accompanied by the text in msg[].
TestParseFlagsFromEnv(const char * msg)37 static void TestParseFlagsFromEnv(const char* msg) {
38 // Initialize module under test.
39 int* pargc;
40 std::vector<char*>* pargv;
41 ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv);
42
43 // Check that actual flags can be parsed.
44 bool simple = false;
45 std::string with_value;
46 std::string embedded_quotes;
47 std::string single_quoted;
48 std::string double_quoted;
49 std::vector<tensorflow::Flag> flag_list = {
50 tensorflow::Flag("simple", &simple, ""),
51 tensorflow::Flag("with_value", &with_value, ""),
52 tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
53 tensorflow::Flag("single_quoted", &single_quoted, ""),
54 tensorflow::Flag("double_quoted", &double_quoted, ""),
55 };
56 bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
57 CHECK_EQ(*pargc, 1) << msg;
58 const std::vector<char*>& argv_second = *pargv;
59 CHECK_NE(argv_second[0], nullptr) << msg;
60 CHECK_EQ(argv_second[1], nullptr) << msg;
61 CHECK(parsed_ok) << msg;
62 CHECK(simple) << msg;
63 CHECK_EQ(with_value, "a_value") << msg;
64 CHECK_EQ(embedded_quotes, "single'double\"") << msg;
65 CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
66 CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
67 }
68
69 // The flags settings to test.
70 static const char kTestFlagString[] =
71 "--simple "
72 "--with_value=a_value "
73 "--embedded_quotes=single'double\" "
74 "--single_quoted='single quoted \\\\ \n \"' "
75 "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
76
77 // Test that the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,Basic)78 TEST(ParseFlagsFromEnv, Basic) {
79 // Prepare environment.
80 tensorflow::setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
81 TestParseFlagsFromEnv("(flags in environment variable)");
82 }
83
84 // Test that a file named by the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,File)85 TEST(ParseFlagsFromEnv, File) {
86 // environment variables where tmp dir may be specified.
87 static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
88 static const char kTempDir[] = "/tmp"; // default temp dir if all else fails.
89 const char* tmp_dir = nullptr;
90 for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
91 tmp_dir = getenv(kTempVars[i]);
92 }
93 if (tmp_dir == nullptr) {
94 tmp_dir = kTempDir;
95 }
96 std::string tmp_file =
97 absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid());
98 FILE* fp = fopen(tmp_file.c_str(), "w");
99 CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
100 for (int i = 0; kTestFlagString[i] != '\0'; i++) {
101 putc(kTestFlagString[i], fp);
102 }
103 fflush(fp);
104 CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
105 fclose(fp);
106 // Prepare environment.
107 tensorflow::setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
108 TestParseFlagsFromEnv("(flags in file)");
109 unlink(tmp_file.c_str());
110 }
111
112 // Name of the test binary.
113 static const char* binary_name;
114
115 // Test that when we use both the environment variable and actual
116 // commend line flags (when the latter is possible), the latter win.
TEST(ParseFlagsFromEnv,EnvAndFlag)117 TEST(ParseFlagsFromEnv, EnvAndFlag) {
118 static struct {
119 const char* env;
120 const char* arg;
121 const char* expected_value;
122 } test[] = {
123 {nullptr, nullptr, "1\n"},
124 {nullptr, "--int_flag=2", "2\n"},
125 {"--int_flag=3", nullptr, "3\n"},
126 {"--int_flag=3", "--int_flag=2", "2\n"}, // flag beats environment
127 };
128 for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
129 if (test[i].env == nullptr) {
130 // Might be set from previous tests.
131 tensorflow::unsetenv("TF_XLA_FLAGS");
132 } else {
133 tensorflow::setenv("TF_XLA_FLAGS", test[i].env, /*overwrite=*/true);
134 }
135 tensorflow::SubProcess child;
136 std::vector<std::string> argv;
137 argv.push_back(binary_name);
138 argv.push_back("--recursing");
139 if (test[i].arg != nullptr) {
140 argv.push_back(test[i].arg);
141 }
142 child.SetProgram(binary_name, argv);
143 child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
144 child.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
145 CHECK(child.Start()) << "test " << i;
146 std::string stdout_str;
147 std::string stderr_str;
148 int child_status = child.Communicate(nullptr, &stdout_str, &stderr_str);
149 CHECK_EQ(child_status, 0) << "test " << i << "\nstdout\n"
150 << stdout_str << "\nstderr\n"
151 << stderr_str;
152 // On windows, we get CR characters. Remove them.
153 stdout_str.erase(std::remove(stdout_str.begin(), stdout_str.end(), '\r'),
154 stdout_str.end());
155 CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
156 }
157 }
158
159 } // namespace xla
160
main(int argc,char * argv[])161 int main(int argc, char* argv[]) {
162 // Save name of binary so that it may invoke itself.
163 xla::binary_name = argv[0];
164 bool recursing = false;
165 int32_t int_flag = 1;
166 const std::vector<tensorflow::Flag> flag_list = {
167 tensorflow::Flag("recursing", &recursing,
168 "Whether the binary is being invoked recursively."),
169 tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
170 };
171 std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
172 bool parse_ok =
173 xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
174 if (!parse_ok) {
175 LOG(QFATAL) << "can't parse from environment\n" << usage;
176 }
177 parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
178 if (!parse_ok) {
179 LOG(QFATAL) << usage;
180 }
181 if (recursing) {
182 printf("%d\n", int_flag);
183 exit(0);
184 }
185 testing::InitGoogleTest(&argc, argv);
186 return RUN_ALL_TESTS();
187 }
188