1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/framework/cc_op_gen.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20 #include <vector>
21
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/api_def.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/op_def_util.h"
27 #include "tensorflow/core/framework/op_gen_lib.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/lib/hash/hash.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/types.h"
38 #include "tensorflow/core/public/version.h"
39
40 namespace tensorflow {
41 namespace {
42
43 const int kRightMargin = 79;
44
45 // Converts:
46 // bazel-out/.../(bin|genfiles)/(external/YYY/)?XX
47 // to: XX.
GetPath(const string & dot_h_fname)48 string GetPath(const string& dot_h_fname) {
49 auto pos = dot_h_fname.find("/bin/");
50 string result = dot_h_fname;
51 if (pos != string::npos) {
52 // - 1 account for the terminating null character (\0) in "/genfiles/".
53 result = dot_h_fname.substr(pos + sizeof("/bin/") - 1);
54 } else {
55 pos = dot_h_fname.find("/genfiles/");
56 if (pos != string::npos) {
57 result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
58 }
59 }
60 if (result.size() > sizeof("external/") &&
61 result.compare(0, sizeof("external/") - 1, "external/") == 0) {
62 result = result.substr(sizeof("external/") - 1);
63 pos = result.find('/');
64 if (pos != string::npos) {
65 result = result.substr(pos + 1);
66 }
67 }
68 return result;
69 }
70
71 // Converts: some/path/to/file.xx
72 // to: file
73 // (note that suffix is removed)
GetFilename(const string & path)74 string GetFilename(const string& path) {
75 size_t slash_pos = path.rfind('/');
76 if (slash_pos == path.npos) slash_pos = -1;
77 size_t dot_pos = path.rfind('.');
78 return path.substr(slash_pos + 1, dot_pos - (slash_pos + 1));
79 }
80
81 // Converts:
82 // cc/ops/gen_foo_ops.h
83 // to:
84 // CC_OPS_GEN_FOO_OPS_H_
ToGuard(const string & path)85 string ToGuard(const string& path) {
86 string guard;
87 guard.reserve(path.size() + 1); // + 1 -> trailing _
88 for (const char c : path) {
89 if (c >= 'A' && c <= 'Z') {
90 guard += c;
91 } else if (c >= 'a' && c <= 'z') {
92 guard += c + 'A' - 'a';
93 } else {
94 guard += '_';
95 }
96 }
97 guard += '_';
98 return guard;
99 }
100
101 // Converts: some_name_xyz
102 // to: Some Name Xyz
ToTitle(const string & name)103 string ToTitle(const string& name) {
104 string title = name;
105 for (int i = 0; i < title.size(); ++i) {
106 if (title[i] == '_') title[i] = ' ';
107 }
108 str_util::TitlecaseString(&title, " ");
109 return title;
110 }
111
112 // Change: Into:
113 // ABC /// ABC
114 // ///
115 // DEF /// DEF
MakeComment(StringPiece text,StringPiece indent)116 string MakeComment(StringPiece text, StringPiece indent) {
117 string ret;
118 while (!text.empty()) {
119 int last_non_space = -1;
120 int newline;
121 for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
122 if (text[newline] == '\n') break;
123 if (text[newline] != ' ') last_non_space = newline;
124 }
125 if (last_non_space == -1) {
126 strings::StrAppend(&ret, indent, "///\n");
127 } else {
128 strings::StrAppend(&ret, indent, "/// ",
129 text.substr(0, last_non_space + 1), "\n");
130 }
131 text.remove_prefix(newline + 1);
132 }
133 return ret;
134 }
135
PrintString(const string & str)136 string PrintString(const string& str) {
137 return strings::StrCat("\"", absl::CEscape(str), "\"");
138 }
139
PrintTensorShape(const TensorShapeProto & shape_proto)140 string PrintTensorShape(const TensorShapeProto& shape_proto) {
141 PartialTensorShape shape(shape_proto);
142 if (shape.IsIdenticalTo(PartialTensorShape())) {
143 return "::tensorflow::PartialTensorShape() /* unknown */";
144 }
145 string ret = "{";
146 for (int d = 0; d < shape.dims(); ++d) {
147 if (d > 0) strings::StrAppend(&ret, ", ");
148 strings::StrAppend(&ret, shape.dim_size(d));
149 }
150 strings::StrAppend(&ret, "}");
151 return ret;
152 }
153
154 template <typename T>
PrintArray(int64 num_elts,const T * array)155 string PrintArray(int64 num_elts, const T* array) {
156 string ret;
157 for (int64 i = 0; i < num_elts; ++i) {
158 if (i > 0) strings::StrAppend(&ret, ", ");
159 strings::StrAppend(&ret, array[i]);
160 }
161 return ret;
162 }
163
PrintTensor(const TensorProto & tensor_proto)164 string PrintTensor(const TensorProto& tensor_proto) {
165 Tensor t(tensor_proto.dtype());
166 CHECK(t.FromProto(tensor_proto));
167 const int64 num_elts = t.NumElements();
168 switch (t.dtype()) {
169 case DT_FLOAT:
170 return PrintArray(num_elts, t.flat<float>().data());
171 case DT_DOUBLE:
172 return PrintArray(num_elts, t.flat<double>().data());
173 case DT_INT32:
174 return PrintArray(num_elts, t.flat<int32>().data());
175 case DT_UINT8:
176 case DT_QUINT8:
177 return PrintArray(num_elts, t.flat<uint8>().data());
178 case DT_UINT16:
179 case DT_QUINT16:
180 return PrintArray(num_elts, t.flat<uint16>().data());
181 case DT_INT16:
182 case DT_QINT16:
183 return PrintArray(num_elts, t.flat<int16>().data());
184 case DT_INT8:
185 case DT_QINT8:
186 return PrintArray(num_elts, t.flat<int8>().data());
187 case DT_INT64:
188 return PrintArray(num_elts, t.flat<int64>().data());
189 case DT_BOOL:
190 return PrintArray(num_elts, t.flat<bool>().data());
191 case DT_STRING: {
192 string ret;
193 for (int64 i = 0; i < num_elts; ++i) {
194 if (i > 0) strings::StrAppend(&ret, " ");
195 strings::StrAppend(&ret, absl::CEscape(t.flat<tstring>()(i)));
196 }
197 return ret;
198 }
199 default: {
200 LOG(FATAL) << "Not handling type " << DataType_Name(t.dtype());
201 return string();
202 }
203 }
204 }
205
PrintTensorProto(const TensorProto & proto)206 string PrintTensorProto(const TensorProto& proto) {
207 return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ",
208 PrintTensorShape(proto.tensor_shape()),
209 ").AsTensorProto()");
210 }
211
PrintAttrValue(const string & op,const AttrValue & attr_value)212 string PrintAttrValue(const string& op, const AttrValue& attr_value) {
213 switch (attr_value.value_case()) {
214 case AttrValue::kS:
215 return PrintString(attr_value.s());
216 case AttrValue::kI:
217 return strings::StrCat(attr_value.i());
218 case AttrValue::kF: {
219 const float f = attr_value.f();
220 return strings::StrCat(attr_value.f(), floorf(f) == f ? ".0" : "", "f");
221 }
222 case AttrValue::kB:
223 return attr_value.b() ? "true" : "false";
224 case AttrValue::kType:
225 return DataType_Name(attr_value.type());
226 case AttrValue::kShape:
227 return PrintTensorShape(attr_value.shape());
228 case AttrValue::kTensor:
229 return PrintTensorProto(attr_value.tensor());
230 case AttrValue::kList: {
231 string ret = "{";
232 if (attr_value.list().s_size() > 0) {
233 for (int i = 0; i < attr_value.list().s_size(); ++i) {
234 if (i > 0) strings::StrAppend(&ret, ", ");
235 strings::StrAppend(&ret, PrintString(attr_value.list().s(i)));
236 }
237 } else if (attr_value.list().i_size() > 0) {
238 for (int i = 0; i < attr_value.list().i_size(); ++i) {
239 if (i > 0) strings::StrAppend(&ret, ", ");
240 strings::StrAppend(&ret, attr_value.list().i(i));
241 }
242 } else if (attr_value.list().f_size() > 0) {
243 for (int i = 0; i < attr_value.list().f_size(); ++i) {
244 if (i > 0) strings::StrAppend(&ret, ", ");
245 const float f = attr_value.list().f(i);
246 strings::StrAppend(&ret, f, floorf(f) == f ? ".0" : "", "f");
247 }
248 } else if (attr_value.list().b_size() > 0) {
249 for (int i = 0; i < attr_value.list().b_size(); ++i) {
250 if (i > 0) strings::StrAppend(&ret, ", ");
251 strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
252 }
253 } else if (attr_value.list().type_size() > 0) {
254 for (int i = 0; i < attr_value.list().type_size(); ++i) {
255 if (i > 0) strings::StrAppend(&ret, ", ");
256 strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i)));
257 }
258 } else if (attr_value.list().shape_size() > 0) {
259 for (int i = 0; i < attr_value.list().shape_size(); ++i) {
260 if (i > 0) strings::StrAppend(&ret, ", ");
261 strings::StrAppend(&ret,
262 PrintTensorShape(attr_value.list().shape(i)));
263 }
264 } else if (attr_value.list().tensor_size() > 0) {
265 for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
266 if (i > 0) strings::StrAppend(&ret, ", ");
267 strings::StrAppend(&ret,
268 PrintTensorProto(attr_value.list().tensor(i)));
269 }
270 }
271 strings::StrAppend(&ret, "}");
272 return ret;
273 }
274 default:
275 LOG(FATAL) << "Unsupported Attr type: " << op << " "
276 << attr_value.value_case();
277 }
278 return "<Unknown AttrValue type>"; // Prevent missing return warning
279 }
280
IsEmptyList(const AttrValue::ListValue & list)281 bool IsEmptyList(const AttrValue::ListValue& list) {
282 return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 &&
283 list.b_size() == 0 && list.type_size() == 0 &&
284 list.shape_size() == 0 && list.tensor_size() == 0;
285 }
286
ToCamelCase(const string & str)287 string ToCamelCase(const string& str) {
288 string result;
289 const char joiner = '_';
290 size_t i = 0;
291 bool cap = true;
292 while (i < str.size()) {
293 const char c = str[i++];
294 if (c == '>') {
295 cap = true;
296 } else if (c == joiner) {
297 cap = true;
298 } else if (cap) {
299 result += toupper(c);
300 cap = false;
301 } else {
302 result += c;
303 }
304 }
305 return result;
306 }
307
SeparateNamespaces(const string & str)308 string SeparateNamespaces(const string& str) {
309 string result;
310 const char joiner = '_';
311 size_t i = 0;
312 while (i < str.size()) {
313 const char c = str[i++];
314 if (c == '>') {
315 result += joiner;
316 } else {
317 result += c;
318 }
319 }
320 return result;
321 }
322
323 // Returns a <string, bool> pair. The string is the C++ type name to be used for
324 // attr_type when defining an object of that type. The bool is a flag to
325 // indicate whether to treat the type as const when accepting the C++ type as an
326 // argument to a function.
AttrTypeName(StringPiece attr_type)327 std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
328 static const auto* attr_type_map =
329 new std::unordered_map<StringPiece, std::pair<const char*, bool>,
330 StringPieceHasher>{
331 {"string", {"StringPiece", false}},
332 {"list(string)", {"gtl::ArraySlice<::tensorflow::tstring>", true}},
333 {"int", {"int64", false}},
334 {"list(int)", {"gtl::ArraySlice<int>", true}},
335 {"float", {"float", false}},
336 {"list(float)", {"gtl::ArraySlice<float>", true}},
337 {"bool", {"bool", false}},
338 {"list(bool)", {"gtl::ArraySlice<bool>", true}},
339 {"type", {"DataType", false}},
340 {"list(type)", {"DataTypeSlice", true}},
341 {"shape", {"PartialTensorShape", false}},
342 {"list(shape)", {"gtl::ArraySlice<PartialTensorShape>", true}},
343 {"tensor", {"TensorProto", true}},
344 {"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}},
345 {"func", {"NameAttrList", true}},
346 {"list(func)", {"gtl::ArraySlice<NameAttrList>", true}},
347 };
348
349 auto entry = attr_type_map->find(attr_type);
350 if (entry == attr_type_map->end()) {
351 LOG(FATAL) << "Unsupported Attr type: " << attr_type;
352 return {"", false};
353 }
354 return entry->second;
355 }
356
ListElementTypeName(StringPiece attr_type)357 const char* ListElementTypeName(StringPiece attr_type) {
358 static const auto* attr_list_type_map =
359 new std::unordered_map<StringPiece, const char*, StringPieceHasher>{
360 {"list(string)", "string"},
361 {"list(int)", "int"},
362 {"list(float)", "float"},
363 {"list(bool)", "bool"},
364 {"list(type)", "DataType"},
365 {"list(shape)", "PartialTensorShape"},
366 {"list(tensor)", "TensorProto"},
367 };
368
369 auto entry = attr_list_type_map->find(attr_type);
370 if (entry == attr_list_type_map->end()) {
371 LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type;
372 return "";
373 }
374 return entry->second;
375 }
376
IsCPPKeyword(StringPiece name)377 bool IsCPPKeyword(StringPiece name) {
378 static const std::unordered_set<StringPiece, StringPieceHasher>
379 // Keywords obtained from http://en.cppreference.com/w/cpp/keyword
380 kCPPReserved{
381 "alignas",
382 "alignof",
383 "and",
384 "and_eq",
385 "asm",
386 "atomic_cancel",
387 "atomic_commit",
388 "atomic_noexcept",
389 "auto",
390 "bitand",
391 "bitor",
392 "bool",
393 "break",
394 "case",
395 "catch",
396 "char",
397 "char16_t",
398 "char32_t",
399 "class",
400 "compl",
401 "concept",
402 "const",
403 "const_cast",
404 "constexpr",
405 "continue",
406 "decltype",
407 "default",
408 "delete",
409 "do",
410 "double",
411 "dynamic_cast",
412 "else",
413 "enum",
414 "explicit",
415 "export",
416 "extern",
417 "false",
418 "final",
419 "float",
420 "for",
421 "friend",
422 "goto",
423 "if",
424 "import",
425 "inline",
426 "int",
427 "long",
428 "module",
429 "mutable",
430 "namespace",
431 "new",
432 "noexcept",
433 "not",
434 "not_eq",
435 "nullptr",
436 "operator",
437 "or",
438 "or_eq",
439 "override",
440 "private",
441 "protected",
442 "public",
443 "register",
444 "reinterpret_cast",
445 "requires",
446 "return",
447 "short",
448 "signed",
449 "sizeof",
450 "static",
451 "static_assert",
452 "static_cast",
453 "struct",
454 "switch",
455 "synchronized",
456 "template",
457 "this",
458 "thread_local",
459 "throw",
460 "true",
461 "try",
462 "typedef",
463 "typeid",
464 "typename",
465 "union",
466 "unsigned",
467 "using",
468 "virtual",
469 "void",
470 "volatile",
471 "wchar_t",
472 "while",
473 "xor",
474 "xor_eq",
475
476 // The following are not C++ keywords, but names of local variables
477 // and parameters used in the op constructor. Treating them as
478 // keywords, so that other parameter names don't conflict with these.
479 "builder",
480 "node",
481 "ret",
482 "scope",
483 "unique_name",
484 };
485 return kCPPReserved.count(name) > 0;
486 }
487
AvoidCPPKeywords(StringPiece name)488 string AvoidCPPKeywords(StringPiece name) {
489 if (IsCPPKeyword(name)) {
490 return strings::StrCat(name, "_");
491 }
492 return string(name);
493 }
494
InferArgAttributes(const OpDef::ArgDef & arg,std::unordered_map<string,string> * inferred_attrs)495 void InferArgAttributes(const OpDef::ArgDef& arg,
496 std::unordered_map<string, string>* inferred_attrs) {
497 if (!arg.type_attr().empty()) {
498 gtl::InsertIfNotPresent(inferred_attrs, arg.type_attr(), arg.name());
499 } else if (!arg.type_list_attr().empty()) {
500 gtl::InsertIfNotPresent(inferred_attrs, arg.type_list_attr(), arg.name());
501 }
502 if (!arg.number_attr().empty()) {
503 gtl::InsertIfNotPresent(inferred_attrs, arg.number_attr(), arg.name());
504 }
505 }
506
InferOpAttributes(const OpDef & op_def,std::unordered_map<string,string> * inferred_input_attrs)507 void InferOpAttributes(
508 const OpDef& op_def,
509 std::unordered_map<string, string>* inferred_input_attrs) {
510 for (int i = 0; i < op_def.input_arg_size(); ++i) {
511 const auto& arg(op_def.input_arg(i));
512 InferArgAttributes(arg, inferred_input_attrs);
513 }
514 }
515
ArgIsList(const OpDef::ArgDef & arg)516 bool ArgIsList(const OpDef::ArgDef& arg) {
517 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
518 }
519
HasOptionalAttrs(const ApiDef & api_def,const std::unordered_map<string,string> & inferred_input_attrs)520 bool HasOptionalAttrs(
521 const ApiDef& api_def,
522 const std::unordered_map<string, string>& inferred_input_attrs) {
523 for (int i = 0; i < api_def.attr_size(); ++i) {
524 const auto& attr(api_def.attr(i));
525 if ((inferred_input_attrs.find(attr.name()) ==
526 inferred_input_attrs.end()) &&
527 attr.has_default_value()) {
528 return true;
529 }
530 }
531 return false;
532 }
533
534 struct OpInfo {
535 // graph_op_def: The OpDef used by the runtime, has the names that
536 // must be used when calling NodeBuilder.
537 // interface_op_def: The OpDef used in the interface in the generated
538 // code, with possibly overridden names and defaults.
539 explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
540 const std::vector<string>& aliases);
541 string GetOpAttrStruct() const;
542 string GetConstructorDecl(StringPiece op_name_prefix,
543 bool include_attr) const;
544 void WriteClassDecl(WritableFile* h) const;
545 void GetOutput(string* out) const;
546 string GetConstructorBody() const;
547 void WriteClassDef(WritableFile* cc) const;
548
549 string op_name;
550 std::vector<string> arg_types;
551 std::vector<string> arg_names;
552 std::vector<string> output_types;
553 std::vector<string> output_names;
554 std::vector<bool> is_list_output;
555 bool has_optional_attrs;
556 string comment;
557
558 const OpDef& graph_op_def;
559 const ApiDef& api_def;
560 const std::vector<string>& aliases;
561 // Map from type attribute to corresponding original argument name.
562 std::unordered_map<string, string> inferred_input_attrs;
563 };
564
OpInfo(const OpDef & graph_op_def,const ApiDef & api_def,const std::vector<string> & aliases)565 OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
566 const std::vector<string>& aliases)
567 : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
568 op_name = SeparateNamespaces(api_def.endpoint(0).name());
569 InferOpAttributes(graph_op_def, &inferred_input_attrs);
570 has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
571 arg_types.push_back("const ::tensorflow::Scope&");
572 arg_names.push_back("scope");
573
574 if (graph_op_def.has_deprecation()) {
575 if (!api_def.summary().empty()) {
576 comment = strings::StrCat(api_def.summary(), "\n");
577 }
578 strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
579 graph_op_def.deprecation().version(), ":\n",
580 graph_op_def.deprecation().explanation(), ".\n");
581 } else if (api_def.summary().empty()) {
582 comment = "TODO: add doc.\n";
583 } else {
584 comment = strings::StrCat(api_def.summary(), "\n");
585 }
586 if (!api_def.description().empty()) {
587 strings::StrAppend(&comment, "\n", api_def.description(), "\n");
588 }
589 strings::StrAppend(&comment, "\nArgs:\n* scope: A Scope object\n");
590
591 // Process inputs
592 for (int i = 0; i < api_def.arg_order_size(); ++i) {
593 const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
594 const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
595 arg_types.push_back(strings::StrCat(
596 "::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
597 arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
598
599 // TODO(keveman): Include input type information.
600 StringPiece description = api_def_arg.description();
601 if (!description.empty()) {
602 ConsumeEquals(&description);
603 strings::StrAppend(&comment, "* ",
604 AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
605 api_def_arg.description(), "\n");
606 }
607 }
608
609 // Process attrs
610 string required_attrs_comment;
611 string optional_attrs_comment;
612 for (int i = 0; i < graph_op_def.attr_size(); ++i) {
613 // ApiDef attributes must be in the same order as in OpDef since
614 // we initialize ApiDef based on OpDef.
615 const auto& attr(graph_op_def.attr(i));
616 const auto& api_def_attr(api_def.attr(i));
617 CHECK_EQ(attr.name(), api_def_attr.name());
618 // Skip inferred arguments
619 if (inferred_input_attrs.count(attr.name()) > 0) continue;
620
621 const auto entry = AttrTypeName(attr.type());
622 const auto attr_type_name = entry.first;
623 const bool use_const = entry.second;
624 string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
625
626 string attr_comment;
627 if (!api_def_attr.description().empty()) {
628 // TODO(keveman): Word wrap and indent this, to handle multi-line
629 // descriptions.
630 strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
631 api_def_attr.description(), "\n");
632 }
633 if (api_def_attr.has_default_value()) {
634 strings::StrAppend(&optional_attrs_comment, attr_comment);
635 } else {
636 strings::StrAppend(&required_attrs_comment, attr_comment);
637 arg_types.push_back(strings::StrCat(
638 use_const ? "const " : "", attr_type_name, use_const ? "&" : ""));
639 arg_names.push_back(attr_name);
640 }
641 }
642
643 strings::StrAppend(&comment, required_attrs_comment);
644
645 if (!optional_attrs_comment.empty()) {
646 strings::StrAppend(&comment, "\nOptional attributes (see `Attrs`):\n");
647 strings::StrAppend(&comment, optional_attrs_comment);
648 }
649
650 // Process outputs
651 for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
652 // ApiDef arguments must be in the same order as in OpDef since
653 // we initialize ApiDef based on OpDef.
654 const auto& arg = graph_op_def.output_arg(i);
655 const auto& api_def_arg(api_def.out_arg(i));
656 CHECK_EQ(arg.name(), api_def_arg.name());
657
658 bool is_list = ArgIsList(arg);
659 output_types.push_back(
660 strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
661 output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
662 is_list_output.push_back(is_list);
663 }
664
665 strings::StrAppend(&comment, "\nReturns:\n");
666 if (graph_op_def.output_arg_size() == 0) { // No outputs.
667 strings::StrAppend(&comment, "* the created `Operation`\n");
668 } else if (graph_op_def.output_arg_size() == 1) { // One output
669 if (is_list_output[0]) {
670 strings::StrAppend(&comment, "* `OutputList`: ");
671 } else {
672 strings::StrAppend(&comment, "* `Output`: ");
673 }
674 if (api_def.out_arg(0).description().empty()) {
675 strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
676 " tensor.\n");
677 } else {
678 // TODO(josh11b): Word wrap this.
679 strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
680 }
681 } else { // Multiple outputs.
682 for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
683 if (is_list_output[i]) {
684 strings::StrAppend(&comment, "* `OutputList`");
685 } else {
686 strings::StrAppend(&comment, "* `Output`");
687 }
688 strings::StrAppend(&comment, " ", output_names[i]);
689 if (api_def.out_arg(i).description().empty()) {
690 strings::StrAppend(&comment, "\n");
691 } else {
692 // TODO(josh11b): Word wrap this.
693 strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
694 "\n");
695 }
696 }
697 }
698
699 if (!aliases.empty()) {
700 strings::StrAppend(&comment, "\nAliases:\n");
701 for (const auto& alias : aliases) {
702 strings::StrAppend(&comment, "* ", alias, "\n");
703 }
704 }
705 comment = MakeComment(comment, "");
706 }
707
GetOpAttrStruct() const708 string OpInfo::GetOpAttrStruct() const {
709 string struct_fields;
710 string setters;
711 string defaults_static_storage;
712
713 for (int i = 0; i < graph_op_def.attr_size(); ++i) {
714 const auto& attr(graph_op_def.attr(i));
715 const auto& api_def_attr(api_def.attr(i));
716 // If attr will be inferred or it doesn't have a default value, don't
717 // add it to the struct.
718 if ((inferred_input_attrs.find(attr.name()) !=
719 inferred_input_attrs.end()) ||
720 !api_def_attr.has_default_value()) {
721 continue;
722 }
723 const auto entry = AttrTypeName(attr.type());
724 const auto attr_type_name = entry.first;
725 const bool use_const = entry.second;
726 const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
727 const string suffix =
728 (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
729 const string attr_func_def =
730 strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "",
731 attr_type_name, use_const ? "&" : "");
732
733 string attr_comment;
734 if (!api_def_attr.description().empty()) {
735 strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
736 }
737 strings::StrAppend(&attr_comment, "Defaults to ",
738 SummarizeAttrValue(api_def_attr.default_value()), "\n");
739 attr_comment = MakeComment(attr_comment, " ");
740
741 strings::StrAppend(&setters, attr_comment);
742 strings::StrAppend(&setters, " TF_MUST_USE_RESULT Attrs ", attr_func_def,
743 " x) {\n");
744 strings::StrAppend(&setters, " Attrs ret = *this;\n");
745 strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(),
746 "_ = x;\n");
747 strings::StrAppend(&setters, " return ret;\n }\n\n");
748
749 string field_initiliazer;
750 auto& default_value = api_def_attr.default_value();
751 if (default_value.value_case() == AttrValue::kList &&
752 !IsEmptyList(default_value.list())) {
753 // Non-empty lists need static storage for their defaults. Define a
754 // function with static local variable that stores the array.
755 strings::StrAppend(&defaults_static_storage, " static ",
756 attr_type_name, " Default_", api_def_attr.rename_to(),
757 "() {\n");
758 strings::StrAppend(
759 &defaults_static_storage, " static const ",
760 ListElementTypeName(attr.type()), " kStorage[] = ",
761 PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
762 ";\n");
763 strings::StrAppend(&defaults_static_storage, " return ",
764 attr_type_name, "(kStorage);\n }\n");
765 // Set the field_initializer to call the defined function.
766 strings::StrAppend(&field_initiliazer, "Default_",
767 api_def_attr.rename_to(), "()");
768 } else {
769 field_initiliazer =
770 PrintAttrValue(graph_op_def.name(), api_def_attr.default_value());
771 }
772 strings::StrAppend(&struct_fields, " ", attr_type_name, " ",
773 api_def_attr.rename_to(), "_ = ", field_initiliazer,
774 ";\n");
775 }
776
777 if (struct_fields.empty()) {
778 return "";
779 }
780
781 string attrs_comment =
782 strings::StrCat("Optional attribute setters for ", op_name, "\n");
783 string struct_decl = MakeComment(attrs_comment, " ");
784 strings::StrAppend(&struct_decl, " struct Attrs {\n");
785 strings::StrAppend(&struct_decl, setters, struct_fields);
786 if (!defaults_static_storage.empty()) {
787 strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage);
788 }
789 strings::StrAppend(&struct_decl, " };\n");
790
791 return struct_decl;
792 }
793
GetConstructorDecl(StringPiece op_name_prefix,bool include_attr) const794 string OpInfo::GetConstructorDecl(StringPiece op_name_prefix,
795 bool include_attr) const {
796 const string prefix = strings::StrCat(op_name_prefix, op_name, "(");
797 string c_decl;
798 for (int i = 0; i < arg_types.size(); ++i) {
799 if (i > 0) strings::StrAppend(&c_decl, ", ");
800 strings::StrAppend(&c_decl, arg_types[i], " ", arg_names[i]);
801 }
802 if (include_attr && has_optional_attrs) {
803 strings::StrAppend(&c_decl, ", const ", op_name, "::Attrs& attrs");
804 }
805 strings::StrAppend(&c_decl, ")");
806 return WordWrap(prefix, c_decl, kRightMargin);
807 }
808
WriteClassDecl(WritableFile * h) const809 void OpInfo::WriteClassDecl(WritableFile* h) const {
810 string class_decl = comment;
811 strings::StrAppend(&class_decl, "class ", op_name, " {\n");
812 strings::StrAppend(&class_decl, " public:\n");
813 if (has_optional_attrs) {
814 strings::StrAppend(&class_decl, GetOpAttrStruct());
815 }
816 strings::StrAppend(&class_decl, " ",
817 GetConstructorDecl("", /* include_attr */ false), ";\n");
818 if (has_optional_attrs) {
819 strings::StrAppend(&class_decl, " ",
820 GetConstructorDecl("", /* include_attr */ true), ";\n");
821 }
822 if (output_types.empty()) {
823 // Allow casting this class to Operation.
824 strings::StrAppend(&class_decl,
825 " operator ::tensorflow::Operation() const { "
826 "return operation; }\n");
827 } else if (output_types.size() == 1) {
828 if (is_list_output[0]) {
829 // Write the subscript operator, allowing out[i] for the list-typed
830 // output.
831 strings::StrAppend(&class_decl,
832 " ::tensorflow::Output operator[](size_t index) "
833 "const { return ",
834 output_names[0], "[index]; }\n\n");
835
836 } else {
837 // Write type cast functions, allowing casting this class to Input and
838 // Output.
839 strings::StrAppend(&class_decl,
840 " operator ::tensorflow::Output() const { return ",
841 output_names[0], "; }\n");
842 strings::StrAppend(&class_decl,
843 " operator ::tensorflow::Input() const { return ",
844 output_names[0], "; }\n");
845 // Write node() to get the Node* directly.
846 strings::StrAppend(&class_decl,
847 " ::tensorflow::Node* node() const { return ",
848 output_names[0], ".node(); }\n");
849 }
850 }
851 // Add the static functions to set optional attrs
852 if (has_optional_attrs) {
853 strings::StrAppend(&class_decl, "\n");
854 for (int i = 0; i < graph_op_def.attr_size(); ++i) {
855 const auto& attr(graph_op_def.attr(i));
856 const auto& api_def_attr(api_def.attr(i));
857 if ((inferred_input_attrs.find(attr.name()) !=
858 inferred_input_attrs.end()) ||
859 !api_def_attr.has_default_value()) {
860 continue;
861 }
862 const auto entry = AttrTypeName(attr.type());
863 const auto attr_type_name = entry.first;
864 const bool use_const = entry.second;
865 const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
866 const string suffix =
867 (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
868 const string attr_func_def = strings::StrCat(
869 camel_case_name, suffix, "(", use_const ? "const " : "",
870 attr_type_name, use_const ? "&" : "");
871 strings::StrAppend(&class_decl, " static Attrs ", attr_func_def,
872 " x) {\n");
873 strings::StrAppend(&class_decl, " return Attrs().", camel_case_name,
874 suffix, "(x);\n");
875 strings::StrAppend(&class_decl, " }\n");
876 }
877 }
878
879 strings::StrAppend(&class_decl, "\n Operation operation;\n");
880 for (int i = 0; i < output_types.size(); ++i) {
881 strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
882 ";\n");
883 }
884
885 strings::StrAppend(&class_decl, "};\n");
886 if (!aliases.empty()) {
887 for (const auto& alias : aliases) {
888 strings::StrAppend(&class_decl, "typedef ", op_name, " ", alias, ";\n");
889 }
890 }
891 strings::StrAppend(&class_decl, "\n");
892 TF_CHECK_OK(h->Append(class_decl));
893 }
894
GetOutput(string * out) const895 void OpInfo::GetOutput(string* out) const {
896 const string scope_str = arg_names[0];
897 string return_on_error =
898 strings::StrCat("if (!", scope_str, ".ok()) return;");
899
900 strings::StrAppend(out, " this->operation = Operation(ret);\n");
901
902 // No outputs.
903 if (graph_op_def.output_arg_size() == 0) {
904 strings::StrAppend(out, " return;\n");
905 return;
906 }
907 if (graph_op_def.output_arg_size() == 1) {
908 // One output, no need for NameRangeMap
909 if (is_list_output[0]) {
910 strings::StrAppend(out,
911 " for (int32 i = 0; i < ret->num_outputs(); ++i)\n");
912 strings::StrAppend(out, " this->", output_names[0],
913 ".push_back(Output(ret, i));\n");
914 } else {
915 strings::StrAppend(out, " this->", output_names[0],
916 " = Output(ret, 0);\n");
917 }
918 return;
919 }
920 strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n");
921 strings::StrAppend(out,
922 " ::tensorflow::Status _status_ = "
923 "::tensorflow::NameRangesForNode(*ret, ret->op_def(), "
924 "nullptr, &_outputs_range);\n");
925 strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str,
926 ".UpdateStatus(_status_);\n", " return;\n");
927 strings::StrAppend(out, " }\n\n");
928
929 for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
930 const string arg_range = strings::StrCat(
931 "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
932 if (is_list_output[i]) {
933 strings::StrAppend(out, " for (int32 i = ", arg_range, ".first; i < ",
934 arg_range, ".second; ++i)\n");
935 strings::StrAppend(out, " this->", output_names[i],
936 ".push_back(Output(ret, i));\n");
937 } else {
938 strings::StrAppend(out, " this->", output_names[i], " = Output(ret, ",
939 arg_range, ".first);\n");
940 }
941 }
942 }
943
GetConstructorBody() const944 string OpInfo::GetConstructorBody() const {
945 const string scope_str = arg_names[0];
946
947 string body;
948 string return_on_error =
949 strings::StrCat("if (!", scope_str, ".ok()) return;");
950
951 strings::StrAppend(&body, " ", return_on_error, "\n");
952
953 for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
954 const auto& arg(graph_op_def.input_arg(i));
955 const auto& api_def_arg(api_def.in_arg(i));
956 strings::StrAppend(
957 &body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
958 ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
959 AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
960 strings::StrAppend(&body, " ", return_on_error, "\n");
961 }
962
963 strings::StrAppend(&body, " ::tensorflow::Node* ret;\n");
964 strings::StrAppend(&body, " const auto unique_name = ", scope_str,
965 ".GetUniqueNameForOp(\"", op_name, "\");\n");
966 strings::StrAppend(
967 &body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
968 graph_op_def.name(), "\")\n");
969 const string spaces = " ";
970 for (int i = 0; i < api_def.in_arg_size(); ++i) {
971 const auto& arg(api_def.in_arg(i));
972 strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
973 }
974 for (int i = 0; i < api_def.attr_size(); ++i) {
975 const auto& graph_attr(graph_op_def.attr(i));
976 const auto& api_def_attr(api_def.attr(i));
977 if (inferred_input_attrs.find(api_def_attr.name()) !=
978 inferred_input_attrs.end()) {
979 continue;
980 }
981 const string attr_name =
982 api_def_attr.has_default_value()
983 ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
984 : AvoidCPPKeywords(api_def_attr.rename_to());
985 strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
986 attr_name, ")\n");
987 }
988 strings::StrAppend(&body, " ;\n");
989 strings::StrAppend(&body, " ", scope_str, ".UpdateBuilder(&builder);\n");
990 strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
991 scope_str, ".graph(), &ret));\n");
992 strings::StrAppend(&body, " ", return_on_error, "\n");
993 strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
994 ".DoShapeInference(ret));\n");
995
996 GetOutput(&body);
997 return body;
998 }
999
WriteClassDef(WritableFile * cc) const1000 void OpInfo::WriteClassDef(WritableFile* cc) const {
1001 string class_def;
1002 strings::StrAppend(&class_def,
1003 GetConstructorDecl(strings::StrCat(op_name, "::"),
1004 /* include_attr */ true),
1005 " {\n");
1006 strings::StrAppend(&class_def, GetConstructorBody());
1007 strings::StrAppend(&class_def, "}\n\n");
1008
1009 if (has_optional_attrs) {
1010 strings::StrAppend(&class_def,
1011 GetConstructorDecl(strings::StrCat(op_name, "::"),
1012 /* include_attr */ false));
1013 strings::StrAppend(&class_def, "\n : ", op_name, "(");
1014 int i = 0;
1015 for (; i < arg_names.size(); ++i) {
1016 if (i > 0) strings::StrAppend(&class_def, ", ");
1017 strings::StrAppend(&class_def, arg_names[i]);
1018 }
1019 if (i > 0) strings::StrAppend(&class_def, ", ");
1020 strings::StrAppend(&class_def, op_name, "::Attrs()");
1021 strings::StrAppend(&class_def, ") {}\n\n");
1022 }
1023 TF_CHECK_OK(cc->Append(class_def));
1024 }
1025
WriteCCOp(const OpDef & graph_op_def,const ApiDef & api_def,const std::vector<string> & aliases,WritableFile * h,WritableFile * cc)1026 void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
1027 const std::vector<string>& aliases, WritableFile* h,
1028 WritableFile* cc) {
1029 OpInfo op_info(graph_op_def, api_def, aliases);
1030
1031 op_info.WriteClassDecl(h);
1032 op_info.WriteClassDef(cc);
1033 }
1034
StartFiles(bool internal,const string & dot_h_fname,WritableFile * h,WritableFile * cc,string * op_header_guard)1035 void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h,
1036 WritableFile* cc, string* op_header_guard) {
1037 const string header =
1038 R"header(// This file is MACHINE GENERATED! Do not edit.
1039
1040 #include "tensorflow/cc/framework/ops.h"
1041 #include "tensorflow/cc/framework/scope.h"
1042 #include "tensorflow/core/framework/tensor.h"
1043 #include "tensorflow/core/framework/tensor_shape.h"
1044 #include "tensorflow/core/framework/types.h"
1045 #include "tensorflow/core/lib/gtl/array_slice.h"
1046 )header";
1047
1048 // TODO(keveman): Make namespaces configurable.
1049 const string namespace_begin = internal ? R"namespace(
1050 namespace tensorflow {
1051 namespace ops {
1052 namespace internal {
1053 // NOTE: This namespace has internal TensorFlow details that
1054 // are not part of TensorFlow's public API.
1055
1056 )namespace"
1057 : R"namespace(
1058 namespace tensorflow {
1059 namespace ops {
1060
1061 )namespace";
1062
1063 const string op_header = GetPath(dot_h_fname);
1064 *op_header_guard = ToGuard(op_header);
1065 const string cc_header = strings::StrCat(
1066 R"include(// This file is MACHINE GENERATED! Do not edit.
1067
1068
1069 #include "tensorflow/cc/ops/const_op.h"
1070 )include",
1071 "#include \"", op_header, "\"\n", namespace_begin);
1072
1073 const string filename = GetFilename(dot_h_fname);
1074 const string doxygen = strings::StrCat("/// @defgroup ", filename, " ",
1075 ToTitle(filename), "\n", "/// @{\n\n");
1076
1077 TF_CHECK_OK(h->Append(
1078 strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
1079 "#ifndef ",
1080 *op_header_guard,
1081 "\n"
1082 "#define ",
1083 *op_header_guard, "\n\n")));
1084 TF_CHECK_OK(h->Append(header));
1085 TF_CHECK_OK(h->Append(namespace_begin));
1086 TF_CHECK_OK(h->Append(doxygen));
1087 TF_CHECK_OK(cc->Append(cc_header));
1088 }
1089
FinishFiles(bool internal,WritableFile * h,WritableFile * cc,const string & op_header_guard)1090 void FinishFiles(bool internal, WritableFile* h, WritableFile* cc,
1091 const string& op_header_guard) {
1092 const string footer = internal ? R"footer(} // namespace internal
1093 } // namespace ops
1094 } // namespace tensorflow
1095 )footer"
1096 :
1097 R"footer(/// @}
1098
1099 } // namespace ops
1100 } // namespace tensorflow
1101 )footer";
1102
1103 TF_CHECK_OK(h->Append(footer));
1104 TF_CHECK_OK(
1105 h->Append(strings::StrCat("\n#endif ", "// ", op_header_guard, "\n")));
1106 TF_CHECK_OK(cc->Append(footer));
1107
1108 TF_CHECK_OK(cc->Close());
1109 TF_CHECK_OK(h->Close());
1110 }
1111
MakeInternal(const string & fname)1112 string MakeInternal(const string& fname) {
1113 auto dot_pos = fname.rfind('.');
1114 if (dot_pos == string::npos) {
1115 return strings::StrCat(fname, "_internal");
1116 } else {
1117 return strings::StrCat(fname.substr(0, dot_pos), "_internal",
1118 fname.substr(dot_pos));
1119 }
1120 }
1121
1122 } // namespace
1123
WriteCCOps(const OpList & ops,const ApiDefMap & api_def_map,const string & dot_h_fname,const string & dot_cc_fname)1124 void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
1125 const string& dot_h_fname, const string& dot_cc_fname) {
1126 Env* env = Env::Default();
1127
1128 // Write the initial boilerplate to the .h and .cc files.
1129 std::unique_ptr<WritableFile> h = nullptr;
1130 std::unique_ptr<WritableFile> cc = nullptr;
1131 TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
1132 TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
1133 string op_header_guard;
1134 StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard);
1135
1136 // Create the internal versions of these files for the hidden ops.
1137 std::unique_ptr<WritableFile> internal_h = nullptr;
1138 std::unique_ptr<WritableFile> internal_cc = nullptr;
1139 const string internal_dot_h_fname = MakeInternal(dot_h_fname);
1140 TF_CHECK_OK(env->NewWritableFile(internal_dot_h_fname, &internal_h));
1141 TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc));
1142 string internal_op_header_guard;
1143 StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(),
1144 internal_cc.get(), &internal_op_header_guard);
1145
1146 for (const auto& graph_op_def : ops.op()) {
1147 // Skip deprecated ops.
1148 // TODO(josh11b): If needed, can put them into a "deprecated" namespace
1149 // instead of skipping.
1150 if (graph_op_def.has_deprecation() &&
1151 graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
1152 continue;
1153 }
1154
1155 // We use a hand-written wrapper for "Const", since the generated
1156 // code depends on it.
1157 if (graph_op_def.name() == "Const") continue;
1158
1159 const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
1160
1161 std::vector<string> aliases;
1162 if (api_def->visibility() == ApiDef::SKIP) continue;
1163 // First endpoint is canonical, the rest are aliases.
1164 for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
1165 ++endpoint_i) {
1166 aliases.push_back(api_def->endpoint(endpoint_i).name());
1167 }
1168 if (api_def->visibility() == ApiDef::HIDDEN) {
1169 // Write hidden ops to _internal.h and _internal.cc.
1170 WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
1171 internal_cc.get());
1172 continue;
1173 }
1174 // This isn't a hidden op, write it to the main files.
1175 WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
1176 }
1177
1178 FinishFiles(false, h.get(), cc.get(), op_header_guard);
1179 FinishFiles(true /* internal */, internal_h.get(), internal_cc.get(),
1180 internal_op_header_guard);
1181 }
1182
1183 } // namespace tensorflow
1184