• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Test for parse_flags_from_env.cc
17 
18 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
19 
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <vector>
23 
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/subprocess.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/util/command_line_flags.h"
31 
32 namespace xla {
33 
34 // Test that XLA flags can be set from the environment.
35 // Failure messages are accompanied by the text in msg[].
TestParseFlagsFromEnv(const char * msg)36 static void TestParseFlagsFromEnv(const char* msg) {
37   // Initialize module under test.
38   int* pargc;
39   std::vector<char*>* pargv;
40   ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv);
41 
42   // Check that actual flags can be parsed.
43   bool simple = false;
44   string with_value;
45   string embedded_quotes;
46   string single_quoted;
47   string double_quoted;
48   std::vector<tensorflow::Flag> flag_list = {
49       tensorflow::Flag("simple", &simple, ""),
50       tensorflow::Flag("with_value", &with_value, ""),
51       tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
52       tensorflow::Flag("single_quoted", &single_quoted, ""),
53       tensorflow::Flag("double_quoted", &double_quoted, ""),
54   };
55   bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
56   CHECK_EQ(*pargc, 1) << msg;
57   const std::vector<char*>& argv_second = *pargv;
58   CHECK_NE(argv_second[0], nullptr) << msg;
59   CHECK_EQ(argv_second[1], nullptr) << msg;
60   CHECK(parsed_ok) << msg;
61   CHECK(simple) << msg;
62   CHECK_EQ(with_value, "a_value") << msg;
63   CHECK_EQ(embedded_quotes, "single'double\"") << msg;
64   CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
65   CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
66 }
67 
68 // The flags settings to test.
69 static const char kTestFlagString[] =
70     "--simple "
71     "--with_value=a_value "
72     "--embedded_quotes=single'double\" "
73     "--single_quoted='single quoted \\\\ \n \"' "
74     "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
75 
76 // Test that the environent variable is parsed correctly.
TEST(ParseFlagsFromEnv,Basic)77 TEST(ParseFlagsFromEnv, Basic) {
78   // Prepare environment.
79   setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
80   TestParseFlagsFromEnv("(flags in environment variable)");
81 }
82 
83 // Test that a file named by the environent variable is parsed correctly.
TEST(ParseFlagsFromEnv,File)84 TEST(ParseFlagsFromEnv, File) {
85   // environment variables where  tmp dir may be specified.
86   static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
87   static const char kTempDir[] = "/tmp";  // default temp dir if all else fails.
88   const char* tmp_dir = nullptr;
89   for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
90     tmp_dir = getenv(kTempVars[i]);
91   }
92   if (tmp_dir == nullptr) {
93     tmp_dir = kTempDir;
94   }
95   string tmp_file =
96       absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid());
97   FILE* fp = fopen(tmp_file.c_str(), "w");
98   CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
99   for (int i = 0; kTestFlagString[i] != '\0'; i++) {
100     putc(kTestFlagString[i], fp);
101   }
102   fflush(fp);
103   CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
104   fclose(fp);
105   // Prepare environment.
106   setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
107   TestParseFlagsFromEnv("(flags in file)");
108   unlink(tmp_file.c_str());
109 }
110 
111 // Name of the test binary.
112 static const char* binary_name;
113 
114 // Test that when we use both the environment variable and actual
115 // commend line flags (when the latter is possible), the latter win.
TEST(ParseFlagsFromEnv,EnvAndFlag)116 TEST(ParseFlagsFromEnv, EnvAndFlag) {
117   static struct {
118     const char* env;
119     const char* arg;
120     const char* expected_value;
121   } test[] = {
122       {nullptr, nullptr, "1\n"},
123       {nullptr, "--int_flag=2", "2\n"},
124       {"--int_flag=3", nullptr, "3\n"},
125       {"--int_flag=3", "--int_flag=2", "2\n"},  // flag beats environment
126   };
127   for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
128     if (test[i].env != nullptr) {
129       setenv("TF_XLA_FLAGS", test[i].env, true /*overwrite*/);
130     }
131     tensorflow::SubProcess child;
132     std::vector<string> argv;
133     argv.push_back(binary_name);
134     argv.push_back("--recursing");
135     if (test[i].arg != nullptr) {
136       argv.push_back(test[i].arg);
137     }
138     child.SetProgram(binary_name, argv);
139     child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
140     CHECK(child.Start()) << "test " << i;
141     string stdout_str;
142     int child_status = child.Communicate(nullptr, &stdout_str, nullptr);
143     CHECK_EQ(child_status, 0) << "test " << i;
144     CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
145   }
146 }
147 
148 }  // namespace xla
149 
main(int argc,char * argv[])150 int main(int argc, char* argv[]) {
151   // Save name of binary so that it may invoke itself.
152   xla::binary_name = argv[0];
153   bool recursing = false;
154   xla::int32 int_flag = 1;
155   const std::vector<tensorflow::Flag> flag_list = {
156       tensorflow::Flag("recursing", &recursing,
157                        "Whether the binary is being invoked recusively."),
158       tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
159   };
160   xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
161   bool parse_ok =
162       xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list);
163   if (!parse_ok) {
164     LOG(QFATAL) << "can't parse from environment\n" << usage;
165   }
166   parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
167   if (!parse_ok) {
168     LOG(QFATAL) << usage;
169   }
170   if (recursing) {
171     printf("%d\n", int_flag);
172     exit(0);
173   }
174   testing::InitGoogleTest(&argc, argv);
175   return RUN_ALL_TESTS();
176 }
177