• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/framework/python_op_gen.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/framework/op_gen_lib.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/io/inputbuffer.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/lib/strings/scanner.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/init_main.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace tensorflow {
36 namespace {
37 
ReadOpListFromFile(const string & filename,std::vector<string> * op_list)38 Status ReadOpListFromFile(const string& filename,
39                           std::vector<string>* op_list) {
40   std::unique_ptr<RandomAccessFile> file;
41   TF_CHECK_OK(Env::Default()->NewRandomAccessFile(filename, &file));
42   std::unique_ptr<io::InputBuffer> input_buffer(
43       new io::InputBuffer(file.get(), 256 << 10));
44   string line_contents;
45   Status s = input_buffer->ReadLine(&line_contents);
46   while (s.ok()) {
47     // The parser assumes that the op name is the first string on each
48     // line with no preceding whitespace, and ignores lines that do
49     // not start with an op name as a comment.
50     strings::Scanner scanner{StringPiece(line_contents)};
51     StringPiece op_name;
52     if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT)
53             .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
54             .GetResult(nullptr, &op_name)) {
55       op_list->emplace_back(op_name);
56     }
57     s = input_buffer->ReadLine(&line_contents);
58   }
59   if (!errors::IsOutOfRange(s)) return s;
60   return Status::OK();
61 }
62 
63 // The argument parsing is deliberately simplistic to support our only
64 // known use cases:
65 //
66 // 1. Read all op names from a file.
67 // 2. Read all op names from the arg as a comma-delimited list.
68 //
69 // Expected command-line argument syntax:
70 // ARG ::= '@' FILENAME
71 //       |  OP_NAME [',' OP_NAME]*
72 //       |  ''
ParseOpListCommandLine(const char * arg,std::vector<string> * op_list)73 Status ParseOpListCommandLine(const char* arg, std::vector<string>* op_list) {
74   std::vector<string> op_names = str_util::Split(arg, ',');
75   if (op_names.size() == 1 && op_names[0].empty()) {
76     return Status::OK();
77   } else if (op_names.size() == 1 && op_names[0].substr(0, 1) == "@") {
78     const string filename = op_names[0].substr(1);
79     return tensorflow::ReadOpListFromFile(filename, op_list);
80   } else {
81     *op_list = std::move(op_names);
82   }
83   return Status::OK();
84 }
85 
86 // Use the name of the current executable to infer the C++ source file
87 // where the REGISTER_OP() call for the operator can be found.
88 // Returns the name of the file.
89 // Returns an empty string if the current executable's name does not
90 // follow a known pattern.
InferSourceFileName(const char * argv_zero)91 string InferSourceFileName(const char* argv_zero) {
92   StringPiece command_str = io::Basename(argv_zero);
93 
94   // For built-in ops, the Bazel build creates a separate executable
95   // with the name gen_<op type>_ops_py_wrappers_cc containing the
96   // operators defined in <op type>_ops.cc
97   const char* kExecPrefix = "gen_";
98   const char* kExecSuffix = "_py_wrappers_cc";
99   if (str_util::ConsumePrefix(&command_str, kExecPrefix) &&
100       str_util::EndsWith(command_str, kExecSuffix)) {
101     command_str.remove_suffix(strlen(kExecSuffix));
102     return strings::StrCat(command_str, ".cc");
103   } else {
104     return string("");
105   }
106 }
107 
PrintAllPythonOps(const std::vector<string> & op_list,const std::vector<string> & api_def_dirs,const string & source_file_name,bool require_shapes,bool op_list_is_whitelist)108 void PrintAllPythonOps(const std::vector<string>& op_list,
109                        const std::vector<string>& api_def_dirs,
110                        const string& source_file_name, bool require_shapes,
111                        bool op_list_is_whitelist) {
112   OpList ops;
113   OpRegistry::Global()->Export(false, &ops);
114 
115   ApiDefMap api_def_map(ops);
116   if (!api_def_dirs.empty()) {
117     Env* env = Env::Default();
118 
119     for (const auto& api_def_dir : api_def_dirs) {
120       std::vector<string> api_files;
121       TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
122                                         &api_files));
123       TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
124     }
125     api_def_map.UpdateDocs();
126   }
127 
128   if (op_list_is_whitelist) {
129     std::unordered_set<string> whitelist(op_list.begin(), op_list.end());
130     OpList pruned_ops;
131     for (const auto& op_def : ops.op()) {
132       if (whitelist.find(op_def.name()) != whitelist.end()) {
133         *pruned_ops.mutable_op()->Add() = op_def;
134       }
135     }
136     PrintPythonOps(pruned_ops, api_def_map, {}, require_shapes,
137                    source_file_name);
138   } else {
139     PrintPythonOps(ops, api_def_map, op_list, require_shapes, source_file_name);
140   }
141 }
142 
143 }  // namespace
144 }  // namespace tensorflow
145 
main(int argc,char * argv[])146 int main(int argc, char* argv[]) {
147   tensorflow::port::InitMain(argv[0], &argc, &argv);
148 
149   tensorflow::string source_file_name =
150       tensorflow::InferSourceFileName(argv[0]);
151 
152   // Usage:
153   //   gen_main api_def_dir1,api_def_dir2,...
154   //       [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
155   if (argc < 3) {
156     return -1;
157   }
158   std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
159       argv[1], ",", tensorflow::str_util::SkipEmpty());
160 
161   if (argc == 3) {
162     tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
163                                   tensorflow::string(argv[2]) == "1",
164                                   false /* op_list_is_whitelist */);
165   } else if (argc == 4) {
166     std::vector<tensorflow::string> hidden_ops;
167     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
168     tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
169                                   tensorflow::string(argv[3]) == "1",
170                                   false /* op_list_is_whitelist */);
171   } else if (argc == 5) {
172     std::vector<tensorflow::string> op_list;
173     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
174     tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
175                                   tensorflow::string(argv[3]) == "1",
176                                   tensorflow::string(argv[4]) == "1");
177   } else {
178     return -1;
179   }
180   return 0;
181 }
182