• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Test for parse_flags_from_env.cc
17 
18 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
19 
20 #include <stdio.h>
21 #include <stdlib.h>
22 
23 #include <vector>
24 
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/subprocess.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/util/command_line_flags.h"
32 
33 namespace xla {
34 
35 // Test that XLA flags can be set from the environment.
36 // Failure messages are accompanied by the text in msg[].
TestParseFlagsFromEnv(const char * msg)37 static void TestParseFlagsFromEnv(const char* msg) {
38   // Initialize module under test.
39   int* pargc;
40   std::vector<char*>* pargv;
41   ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv);
42 
43   // Check that actual flags can be parsed.
44   bool simple = false;
45   std::string with_value;
46   std::string embedded_quotes;
47   std::string single_quoted;
48   std::string double_quoted;
49   std::vector<tensorflow::Flag> flag_list = {
50       tensorflow::Flag("simple", &simple, ""),
51       tensorflow::Flag("with_value", &with_value, ""),
52       tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
53       tensorflow::Flag("single_quoted", &single_quoted, ""),
54       tensorflow::Flag("double_quoted", &double_quoted, ""),
55   };
56   bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
57   CHECK_EQ(*pargc, 1) << msg;
58   const std::vector<char*>& argv_second = *pargv;
59   CHECK_NE(argv_second[0], nullptr) << msg;
60   CHECK_EQ(argv_second[1], nullptr) << msg;
61   CHECK(parsed_ok) << msg;
62   CHECK(simple) << msg;
63   CHECK_EQ(with_value, "a_value") << msg;
64   CHECK_EQ(embedded_quotes, "single'double\"") << msg;
65   CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
66   CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
67 }
68 
69 // The flags settings to test.
70 static const char kTestFlagString[] =
71     "--simple "
72     "--with_value=a_value "
73     "--embedded_quotes=single'double\" "
74     "--single_quoted='single quoted \\\\ \n \"' "
75     "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
76 
77 // Test that the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,Basic)78 TEST(ParseFlagsFromEnv, Basic) {
79   // Prepare environment.
80   tensorflow::setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
81   TestParseFlagsFromEnv("(flags in environment variable)");
82 }
83 
84 // Test that a file named by the environment variable is parsed correctly.
TEST(ParseFlagsFromEnv,File)85 TEST(ParseFlagsFromEnv, File) {
86   // environment variables where  tmp dir may be specified.
87   static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
88   static const char kTempDir[] = "/tmp";  // default temp dir if all else fails.
89   const char* tmp_dir = nullptr;
90   for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
91     tmp_dir = getenv(kTempVars[i]);
92   }
93   if (tmp_dir == nullptr) {
94     tmp_dir = kTempDir;
95   }
96   std::string tmp_file =
97       absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid());
98   FILE* fp = fopen(tmp_file.c_str(), "w");
99   CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
100   for (int i = 0; kTestFlagString[i] != '\0'; i++) {
101     putc(kTestFlagString[i], fp);
102   }
103   fflush(fp);
104   CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
105   fclose(fp);
106   // Prepare environment.
107   tensorflow::setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
108   TestParseFlagsFromEnv("(flags in file)");
109   unlink(tmp_file.c_str());
110 }
111 
112 // Name of the test binary.
113 static const char* binary_name;
114 
115 // Test that when we use both the environment variable and actual
116 // commend line flags (when the latter is possible), the latter win.
TEST(ParseFlagsFromEnv,EnvAndFlag)117 TEST(ParseFlagsFromEnv, EnvAndFlag) {
118   static struct {
119     const char* env;
120     const char* arg;
121     const char* expected_value;
122   } test[] = {
123       {nullptr, nullptr, "1\n"},
124       {nullptr, "--int_flag=2", "2\n"},
125       {"--int_flag=3", nullptr, "3\n"},
126       {"--int_flag=3", "--int_flag=2", "2\n"},  // flag beats environment
127   };
128   for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
129     if (test[i].env == nullptr) {
130       // Might be set from previous tests.
131       tensorflow::unsetenv("TF_XLA_FLAGS");
132     } else {
133       tensorflow::setenv("TF_XLA_FLAGS", test[i].env, /*overwrite=*/true);
134     }
135     tensorflow::SubProcess child;
136     std::vector<std::string> argv;
137     argv.push_back(binary_name);
138     argv.push_back("--recursing");
139     if (test[i].arg != nullptr) {
140       argv.push_back(test[i].arg);
141     }
142     child.SetProgram(binary_name, argv);
143     child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
144     child.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE);
145     CHECK(child.Start()) << "test " << i;
146     std::string stdout_str;
147     std::string stderr_str;
148     int child_status = child.Communicate(nullptr, &stdout_str, &stderr_str);
149     CHECK_EQ(child_status, 0) << "test " << i << "\nstdout\n"
150                               << stdout_str << "\nstderr\n"
151                               << stderr_str;
152     // On windows, we get CR characters. Remove them.
153     stdout_str.erase(std::remove(stdout_str.begin(), stdout_str.end(), '\r'),
154                      stdout_str.end());
155     CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
156   }
157 }
158 
159 }  // namespace xla
160 
main(int argc,char * argv[])161 int main(int argc, char* argv[]) {
162   // Save name of binary so that it may invoke itself.
163   xla::binary_name = argv[0];
164   bool recursing = false;
165   int32_t int_flag = 1;
166   const std::vector<tensorflow::Flag> flag_list = {
167       tensorflow::Flag("recursing", &recursing,
168                        "Whether the binary is being invoked recursively."),
169       tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
170   };
171   std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
172   bool parse_ok =
173       xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
174   if (!parse_ok) {
175     LOG(QFATAL) << "can't parse from environment\n" << usage;
176   }
177   parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
178   if (!parse_ok) {
179     LOG(QFATAL) << usage;
180   }
181   if (recursing) {
182     printf("%d\n", int_flag);
183     exit(0);
184   }
185   testing::InitGoogleTest(&argc, argv);
186   return RUN_ALL_TESTS();
187 }
188