• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/js/ops/ts_op_gen.h"
17 #include <unordered_map>
18 
19 #include "tensorflow/core/framework/api_def.pb.h"
20 #include "tensorflow/core/framework/op_def_util.h"
21 #include "tensorflow/core/lib/gtl/map_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/public/version.h"
24 
25 namespace tensorflow {
26 namespace {
27 
IsListAttr(const OpDef_ArgDef & arg)28 static bool IsListAttr(const OpDef_ArgDef& arg) {
29   return !arg.type_list_attr().empty() || !arg.number_attr().empty();
30 }
31 
32 // Struct to hold a combo OpDef and ArgDef for a given Op argument:
33 struct ArgDefs {
ArgDefstensorflow::__anonc3e3999b0111::ArgDefs34   ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg)
35       : op_def_arg(op_def_arg), api_def_arg(api_def_arg) {}
36 
37   const OpDef::ArgDef& op_def_arg;
38   const ApiDef::Arg& api_def_arg;
39 };
40 
41 // Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op.
42 struct OpAttrs {
OpAttrstensorflow::__anonc3e3999b0111::OpAttrs43   OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr)
44       : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {}
45 
46   const OpDef::AttrDef& op_def_attr;
47   const ApiDef::Attr& api_def_attr;
48 };
49 
50 // Helper class to generate TypeScript code for a given OpDef:
51 class GenTypeScriptOp {
52  public:
53   GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def);
54   ~GenTypeScriptOp();
55 
56   // Returns the generated code as a string:
57   string Code();
58 
59  private:
60   void ProcessArgs();
61   void ProcessAttrs();
62   void AddAttrForArg(const string& attr, int arg_index);
63   string InputForAttr(const OpDef::AttrDef& op_def_attr);
64 
65   void AddMethodSignature();
66   void AddOpAttrs();
67   void AddMethodReturnAndClose();
68 
69   const OpDef& op_def_;
70   const ApiDef& api_def_;
71 
72   // Placeholder string for all generated code:
73   string result_;
74 
75   // Holds in-order vector of Op inputs:
76   std::vector<ArgDefs> input_op_args_;
77 
78   // Holds in-order vector of Op attributes:
79   std::vector<OpAttrs> op_attrs_;
80 
81   // Stores attributes-to-arguments by name:
82   typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
83   AttrArgIdxMap attr_arg_idx_map_;
84 
85   // Holds number of outputs:
86   int num_outputs_;
87 };
88 
GenTypeScriptOp(const OpDef & op_def,const ApiDef & api_def)89 GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
90     : op_def_(op_def), api_def_(api_def), num_outputs_(0) {}
91 
~GenTypeScriptOp()92 GenTypeScriptOp::~GenTypeScriptOp() {}
93 
Code()94 string GenTypeScriptOp::Code() {
95   ProcessArgs();
96   ProcessAttrs();
97 
98   // Generate exported function for Op:
99   AddMethodSignature();
100   AddOpAttrs();
101   AddMethodReturnAndClose();
102 
103   strings::StrAppend(&result_, "\n");
104   return result_;
105 }
106 
ProcessArgs()107 void GenTypeScriptOp::ProcessArgs() {
108   for (int i = 0; i < api_def_.arg_order_size(); i++) {
109     auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_);
110     if (op_def_arg == nullptr) {
111       LOG(WARNING) << "Could not find OpDef::ArgDef for "
112                    << api_def_.arg_order(i);
113       continue;
114     }
115     auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_);
116     if (api_def_arg == nullptr) {
117       LOG(WARNING) << "Could not find ApiDef::Arg for "
118                    << api_def_.arg_order(i);
119       continue;
120     }
121 
122     // Map attr names to arg indexes:
123     if (!op_def_arg->type_attr().empty()) {
124       AddAttrForArg(op_def_arg->type_attr(), i);
125     } else if (!op_def_arg->type_list_attr().empty()) {
126       AddAttrForArg(op_def_arg->type_list_attr(), i);
127     }
128     if (!op_def_arg->number_attr().empty()) {
129       AddAttrForArg(op_def_arg->number_attr(), i);
130     }
131 
132     input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
133   }
134 
135   num_outputs_ = api_def_.out_arg_size();
136 }
137 
ProcessAttrs()138 void GenTypeScriptOp::ProcessAttrs() {
139   for (int i = 0; i < op_def_.attr_size(); i++) {
140     op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i)));
141   }
142 }
143 
AddAttrForArg(const string & attr,int arg_index)144 void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
145   // Keep track of attributes-to-arguments by name. These will be used for
146   // construction Op attributes that require information about the inputs.
147   auto iter = attr_arg_idx_map_.find(attr);
148   if (iter == attr_arg_idx_map_.end()) {
149     attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index}));
150   } else {
151     iter->second.push_back(arg_index);
152   }
153 }
154 
InputForAttr(const OpDef::AttrDef & op_def_attr)155 string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
156   string inputs;
157   auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
158   if (arg_list != attr_arg_idx_map_.end()) {
159     for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
160          ++iter) {
161       strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name());
162     }
163   }
164   return inputs;
165 }
166 
AddMethodSignature()167 void GenTypeScriptOp::AddMethodSignature() {
168   strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
169                      "(");
170 
171   bool is_first = true;
172   for (auto& in_arg : input_op_args_) {
173     if (is_first) {
174       is_first = false;
175     } else {
176       strings::StrAppend(&result_, ", ");
177     }
178 
179     auto op_def_arg = in_arg.op_def_arg;
180 
181     strings::StrAppend(&result_, op_def_arg.name(), ": ");
182     if (IsListAttr(op_def_arg)) {
183       strings::StrAppend(&result_, "tfc.Tensor[]");
184     } else {
185       strings::StrAppend(&result_, "tfc.Tensor");
186     }
187   }
188 
189   if (num_outputs_ == 1) {
190     strings::StrAppend(&result_, "): tfc.Tensor {\n");
191   } else {
192     strings::StrAppend(&result_, "): tfc.Tensor[] {\n");
193   }
194 }
195 
AddOpAttrs()196 void GenTypeScriptOp::AddOpAttrs() {
197   strings::StrAppend(&result_, "  const opAttrs = [\n");
198 
199   bool is_first = true;
200   for (auto& attr : op_attrs_) {
201     if (is_first) {
202       is_first = false;
203     } else {
204       strings::StrAppend(&result_, ",\n");
205     }
206 
207     // Append 4 spaces to start:
208     strings::StrAppend(&result_, "    ");
209 
210     if (attr.op_def_attr.type() == "type") {
211       // Type OpAttributes can be generated from a helper function:
212       strings::StrAppend(&result_, "createTensorsTypeOpAttr('",
213                          attr.op_def_attr.name(), "', ",
214                          InputForAttr(attr.op_def_attr), ")");
215     } else if (attr.op_def_attr.type() == "int") {
216       strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', ");
217       strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, ");
218       strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr),
219                          ".length}");
220     }
221   }
222   strings::StrAppend(&result_, "\n  ];\n");
223 }
224 
AddMethodReturnAndClose()225 void GenTypeScriptOp::AddMethodReturnAndClose() {
226   strings::StrAppend(&result_, "  return null;\n}\n");
227 }
228 
WriteTSOp(const OpDef & op_def,const ApiDef & api_def,WritableFile * ts)229 void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
230   GenTypeScriptOp ts_op(op_def, api_def);
231   TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code()));
232 }
233 
StartFile(WritableFile * ts_file)234 void StartFile(WritableFile* ts_file) {
235   const string header =
236       R"header(/**
237  * @license
238  * Copyright 2018 Google Inc. All Rights Reserved.
239  * Licensed under the Apache License, Version 2.0 (the "License");
240  * you may not use this file except in compliance with the License.
241  * You may obtain a copy of the License at
242  *
243  * http://www.apache.org/licenses/LICENSE-2.0
244  *
245  * Unless required by applicable law or agreed to in writing, software
246  * distributed under the License is distributed on an "AS IS" BASIS,
247  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
248  * See the License for the specific language governing permissions and
249  * limitations under the License.
250  * =============================================================================
251  */
252 
253 // This file is MACHINE GENERATED! Do not edit
254 
255 import * as tfc from '@tensorflow/tfjs-core';
256 import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
257 
258 )header";
259 
260   TF_CHECK_OK(ts_file->Append(header));
261 }
262 
263 }  // namespace
264 
WriteTSOps(const OpList & ops,const ApiDefMap & api_def_map,const string & ts_filename)265 void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
266                 const string& ts_filename) {
267   Env* env = Env::Default();
268 
269   std::unique_ptr<WritableFile> ts_file = nullptr;
270   TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file));
271 
272   StartFile(ts_file.get());
273 
274   for (const auto& op_def : ops.op()) {
275     // Skip deprecated ops
276     if (op_def.has_deprecation() &&
277         op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
278       continue;
279     }
280 
281     const auto* api_def = api_def_map.GetApiDef(op_def.name());
282     if (api_def->visibility() == ApiDef::VISIBLE) {
283       WriteTSOp(op_def, *api_def, ts_file.get());
284     }
285   }
286 
287   TF_CHECK_OK(ts_file->Close());
288 }
289 
290 }  // namespace tensorflow
291