1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/js/ops/ts_op_gen.h"
17 #include <unordered_map>
18
19 #include "tensorflow/core/framework/api_def.pb.h"
20 #include "tensorflow/core/framework/op_def_util.h"
21 #include "tensorflow/core/lib/gtl/map_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/public/version.h"
24
25 namespace tensorflow {
26 namespace {
27
IsListAttr(const OpDef_ArgDef & arg)28 static bool IsListAttr(const OpDef_ArgDef& arg) {
29 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
30 }
31
32 // Struct to hold a combo OpDef and ArgDef for a given Op argument:
33 struct ArgDefs {
ArgDefstensorflow::__anonc3e3999b0111::ArgDefs34 ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg)
35 : op_def_arg(op_def_arg), api_def_arg(api_def_arg) {}
36
37 const OpDef::ArgDef& op_def_arg;
38 const ApiDef::Arg& api_def_arg;
39 };
40
41 // Struct to hold a combo OpDef::AttrDef and ApiDef::Attr for an Op.
42 struct OpAttrs {
OpAttrstensorflow::__anonc3e3999b0111::OpAttrs43 OpAttrs(const OpDef::AttrDef& op_def_attr, const ApiDef::Attr& api_def_attr)
44 : op_def_attr(op_def_attr), api_def_attr(api_def_attr) {}
45
46 const OpDef::AttrDef& op_def_attr;
47 const ApiDef::Attr& api_def_attr;
48 };
49
50 // Helper class to generate TypeScript code for a given OpDef:
51 class GenTypeScriptOp {
52 public:
53 GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def);
54 ~GenTypeScriptOp();
55
56 // Returns the generated code as a string:
57 string Code();
58
59 private:
60 void ProcessArgs();
61 void ProcessAttrs();
62 void AddAttrForArg(const string& attr, int arg_index);
63 string InputForAttr(const OpDef::AttrDef& op_def_attr);
64
65 void AddMethodSignature();
66 void AddOpAttrs();
67 void AddMethodReturnAndClose();
68
69 const OpDef& op_def_;
70 const ApiDef& api_def_;
71
72 // Placeholder string for all generated code:
73 string result_;
74
75 // Holds in-order vector of Op inputs:
76 std::vector<ArgDefs> input_op_args_;
77
78 // Holds in-order vector of Op attributes:
79 std::vector<OpAttrs> op_attrs_;
80
81 // Stores attributes-to-arguments by name:
82 typedef std::unordered_map<string, std::vector<int>> AttrArgIdxMap;
83 AttrArgIdxMap attr_arg_idx_map_;
84
85 // Holds number of outputs:
86 int num_outputs_;
87 };
88
GenTypeScriptOp(const OpDef & op_def,const ApiDef & api_def)89 GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
90 : op_def_(op_def), api_def_(api_def), num_outputs_(0) {}
91
~GenTypeScriptOp()92 GenTypeScriptOp::~GenTypeScriptOp() {}
93
Code()94 string GenTypeScriptOp::Code() {
95 ProcessArgs();
96 ProcessAttrs();
97
98 // Generate exported function for Op:
99 AddMethodSignature();
100 AddOpAttrs();
101 AddMethodReturnAndClose();
102
103 strings::StrAppend(&result_, "\n");
104 return result_;
105 }
106
ProcessArgs()107 void GenTypeScriptOp::ProcessArgs() {
108 for (int i = 0; i < api_def_.arg_order_size(); i++) {
109 auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_);
110 if (op_def_arg == nullptr) {
111 LOG(WARNING) << "Could not find OpDef::ArgDef for "
112 << api_def_.arg_order(i);
113 continue;
114 }
115 auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_);
116 if (api_def_arg == nullptr) {
117 LOG(WARNING) << "Could not find ApiDef::Arg for "
118 << api_def_.arg_order(i);
119 continue;
120 }
121
122 // Map attr names to arg indexes:
123 if (!op_def_arg->type_attr().empty()) {
124 AddAttrForArg(op_def_arg->type_attr(), i);
125 } else if (!op_def_arg->type_list_attr().empty()) {
126 AddAttrForArg(op_def_arg->type_list_attr(), i);
127 }
128 if (!op_def_arg->number_attr().empty()) {
129 AddAttrForArg(op_def_arg->number_attr(), i);
130 }
131
132 input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
133 }
134
135 num_outputs_ = api_def_.out_arg_size();
136 }
137
ProcessAttrs()138 void GenTypeScriptOp::ProcessAttrs() {
139 for (int i = 0; i < op_def_.attr_size(); i++) {
140 op_attrs_.push_back(OpAttrs(op_def_.attr(i), api_def_.attr(i)));
141 }
142 }
143
AddAttrForArg(const string & attr,int arg_index)144 void GenTypeScriptOp::AddAttrForArg(const string& attr, int arg_index) {
145 // Keep track of attributes-to-arguments by name. These will be used for
146 // construction Op attributes that require information about the inputs.
147 auto iter = attr_arg_idx_map_.find(attr);
148 if (iter == attr_arg_idx_map_.end()) {
149 attr_arg_idx_map_.insert(AttrArgIdxMap::value_type(attr, {arg_index}));
150 } else {
151 iter->second.push_back(arg_index);
152 }
153 }
154
InputForAttr(const OpDef::AttrDef & op_def_attr)155 string GenTypeScriptOp::InputForAttr(const OpDef::AttrDef& op_def_attr) {
156 string inputs;
157 auto arg_list = attr_arg_idx_map_.find(op_def_attr.name());
158 if (arg_list != attr_arg_idx_map_.end()) {
159 for (auto iter = arg_list->second.begin(); iter != arg_list->second.end();
160 ++iter) {
161 strings::StrAppend(&inputs, input_op_args_[*iter].op_def_arg.name());
162 }
163 }
164 return inputs;
165 }
166
AddMethodSignature()167 void GenTypeScriptOp::AddMethodSignature() {
168 strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
169 "(");
170
171 bool is_first = true;
172 for (auto& in_arg : input_op_args_) {
173 if (is_first) {
174 is_first = false;
175 } else {
176 strings::StrAppend(&result_, ", ");
177 }
178
179 auto op_def_arg = in_arg.op_def_arg;
180
181 strings::StrAppend(&result_, op_def_arg.name(), ": ");
182 if (IsListAttr(op_def_arg)) {
183 strings::StrAppend(&result_, "tfc.Tensor[]");
184 } else {
185 strings::StrAppend(&result_, "tfc.Tensor");
186 }
187 }
188
189 if (num_outputs_ == 1) {
190 strings::StrAppend(&result_, "): tfc.Tensor {\n");
191 } else {
192 strings::StrAppend(&result_, "): tfc.Tensor[] {\n");
193 }
194 }
195
AddOpAttrs()196 void GenTypeScriptOp::AddOpAttrs() {
197 strings::StrAppend(&result_, " const opAttrs = [\n");
198
199 bool is_first = true;
200 for (auto& attr : op_attrs_) {
201 if (is_first) {
202 is_first = false;
203 } else {
204 strings::StrAppend(&result_, ",\n");
205 }
206
207 // Append 4 spaces to start:
208 strings::StrAppend(&result_, " ");
209
210 if (attr.op_def_attr.type() == "type") {
211 // Type OpAttributes can be generated from a helper function:
212 strings::StrAppend(&result_, "createTensorsTypeOpAttr('",
213 attr.op_def_attr.name(), "', ",
214 InputForAttr(attr.op_def_attr), ")");
215 } else if (attr.op_def_attr.type() == "int") {
216 strings::StrAppend(&result_, "{name: '", attr.op_def_attr.name(), "', ");
217 strings::StrAppend(&result_, "type: nodeBackend().binding.TF_ATTR_INT, ");
218 strings::StrAppend(&result_, "value: ", InputForAttr(attr.op_def_attr),
219 ".length}");
220 }
221 }
222 strings::StrAppend(&result_, "\n ];\n");
223 }
224
AddMethodReturnAndClose()225 void GenTypeScriptOp::AddMethodReturnAndClose() {
226 strings::StrAppend(&result_, " return null;\n}\n");
227 }
228
WriteTSOp(const OpDef & op_def,const ApiDef & api_def,WritableFile * ts)229 void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
230 GenTypeScriptOp ts_op(op_def, api_def);
231 TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code()));
232 }
233
StartFile(WritableFile * ts_file)234 void StartFile(WritableFile* ts_file) {
235 const string header =
236 R"header(/**
237 * @license
238 * Copyright 2018 Google Inc. All Rights Reserved.
239 * Licensed under the Apache License, Version 2.0 (the "License");
240 * you may not use this file except in compliance with the License.
241 * You may obtain a copy of the License at
242 *
243 * http://www.apache.org/licenses/LICENSE-2.0
244 *
245 * Unless required by applicable law or agreed to in writing, software
246 * distributed under the License is distributed on an "AS IS" BASIS,
247 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
248 * See the License for the specific language governing permissions and
249 * limitations under the License.
250 * =============================================================================
251 */
252
253 // This file is MACHINE GENERATED! Do not edit
254
255 import * as tfc from '@tensorflow/tfjs-core';
256 import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
257
258 )header";
259
260 TF_CHECK_OK(ts_file->Append(header));
261 }
262
263 } // namespace
264
WriteTSOps(const OpList & ops,const ApiDefMap & api_def_map,const string & ts_filename)265 void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
266 const string& ts_filename) {
267 Env* env = Env::Default();
268
269 std::unique_ptr<WritableFile> ts_file = nullptr;
270 TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file));
271
272 StartFile(ts_file.get());
273
274 for (const auto& op_def : ops.op()) {
275 // Skip deprecated ops
276 if (op_def.has_deprecation() &&
277 op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
278 continue;
279 }
280
281 const auto* api_def = api_def_map.GetApiDef(op_def.name());
282 if (api_def->visibility() == ApiDef::VISIBLE) {
283 WriteTSOp(op_def, *api_def, ts_file.get());
284 }
285 }
286
287 TF_CHECK_OK(ts_file->Close());
288 }
289
290 } // namespace tensorflow
291