• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_builder.h"
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/scanner.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 
32 using ::tensorflow::strings::Scanner;
33 
34 namespace tensorflow {
35 
36 namespace {
37 
AttrError(StringPiece orig,const string & op_name)38 string AttrError(StringPiece orig, const string& op_name) {
39   return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
40 }
41 
ConsumeAttrName(StringPiece * sp,StringPiece * out)42 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
43   return Scanner(*sp)
44       .One(Scanner::LETTER)
45       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
46       .StopCapture()
47       .AnySpace()
48       .OneLiteral(":")
49       .AnySpace()
50       .GetResult(sp, out);
51 }
52 
ConsumeListPrefix(StringPiece * sp)53 bool ConsumeListPrefix(StringPiece* sp) {
54   return Scanner(*sp)
55       .OneLiteral("list")
56       .AnySpace()
57       .OneLiteral("(")
58       .AnySpace()
59       .GetResult(sp);
60 }
61 
ConsumeQuotedString(char quote_ch,StringPiece * sp,StringPiece * out)62 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
63   const string quote_str(1, quote_ch);
64   return Scanner(*sp)
65       .OneLiteral(quote_str.c_str())
66       .RestartCapture()
67       .ScanEscapedUntil(quote_ch)
68       .StopCapture()
69       .OneLiteral(quote_str.c_str())
70       .AnySpace()
71       .GetResult(sp, out);
72 }
73 
ConsumeAttrType(StringPiece * sp,StringPiece * out)74 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
75   return Scanner(*sp)
76       .Many(Scanner::LOWERLETTER_DIGIT)
77       .StopCapture()
78       .AnySpace()
79       .GetResult(sp, out);
80 }
81 
ConsumeAttrNumber(StringPiece * sp,int64 * out)82 bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
83   Scanner scan(*sp);
84   StringPiece match;
85   StringPiece remaining;
86 
87   scan.AnySpace().RestartCapture();
88   if (scan.Peek() == '-') {
89     scan.OneLiteral("-");
90   }
91   if (!scan.Many(Scanner::DIGIT)
92            .StopCapture()
93            .AnySpace()
94            .GetResult(&remaining, &match)) {
95     return false;
96   }
97   int64 value = 0;
98   if (!strings::safe_strto64(match, &value)) {
99     return false;
100   }
101   *out = value;
102   *sp = remaining;
103   return true;
104 }
105 
106 #define VERIFY(expr, ...)                                                 \
107   do {                                                                    \
108     if (!(expr)) {                                                        \
109       errors->push_back(                                                  \
110           strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
111       return;                                                             \
112     }                                                                     \
113   } while (false)
114 
ConsumeCompoundAttrType(StringPiece * sp,StringPiece * out)115 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
116   auto capture_begin = sp->begin();
117   if (absl::ConsumePrefix(sp, "numbertype") ||
118       absl::ConsumePrefix(sp, "numerictype") ||
119       absl::ConsumePrefix(sp, "quantizedtype") ||
120       absl::ConsumePrefix(sp, "realnumbertype") ||
121       absl::ConsumePrefix(sp, "realnumberictype")) {
122     *out = StringPiece(capture_begin, sp->begin() - capture_begin);
123     return true;
124   }
125   return false;
126 }
127 
ProcessCompoundType(const StringPiece type_string,AttrValue * allowed)128 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
129   if (type_string == "numbertype" || type_string == "numerictype") {
130     for (DataType dt : NumberTypes()) {
131       allowed->mutable_list()->add_type(dt);
132     }
133   } else if (type_string == "quantizedtype") {
134     for (DataType dt : QuantizedTypes()) {
135       allowed->mutable_list()->add_type(dt);
136     }
137   } else if (type_string == "realnumbertype" ||
138              type_string == "realnumerictype") {
139     for (DataType dt : RealNumberTypes()) {
140       allowed->mutable_list()->add_type(dt);
141     }
142   } else {
143     return false;
144   }
145   return true;
146 }
147 
FinalizeAttr(StringPiece spec,bool allow_attr_type_any,OpDef * op_def,std::vector<string> * errors)148 void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def,
149                   std::vector<string>* errors) {
150   OpDef::AttrDef* attr = op_def->add_attr();
151   StringPiece orig(spec);
152 
153   // Parse "<name>:" at the beginning.
154   StringPiece tmp_name;
155   VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
156   attr->set_name(tmp_name.data(), tmp_name.size());
157 
158   // Read "<type>" or "list(<type>)".
159   bool is_list = ConsumeListPrefix(&spec);
160   string type;
161   StringPiece type_string;  // Used if type == "type"
162   if (absl::ConsumePrefix(&spec, "string")) {
163     type = "string";
164   } else if (absl::ConsumePrefix(&spec, "int")) {
165     type = "int";
166   } else if (absl::ConsumePrefix(&spec, "float")) {
167     type = "float";
168   } else if (absl::ConsumePrefix(&spec, "bool")) {
169     type = "bool";
170   } else if (absl::ConsumePrefix(&spec, "type")) {
171     type = "type";
172   } else if (absl::ConsumePrefix(&spec, "shape")) {
173     type = "shape";
174   } else if (absl::ConsumePrefix(&spec, "tensor")) {
175     type = "tensor";
176   } else if (absl::ConsumePrefix(&spec, "func")) {
177     type = "func";
178   } else if (absl::ConsumePrefix(&spec, "any") && allow_attr_type_any) {
179     type = "any";
180   } else if (ConsumeCompoundAttrType(&spec, &type_string)) {
181     type = "type";
182     AttrValue* allowed = attr->mutable_allowed_values();
183     VERIFY(ProcessCompoundType(type_string, allowed),
184            "Expected to see a compound type, saw: ", type_string);
185   } else if (absl::ConsumePrefix(&spec, "{")) {
186     // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
187     AttrValue* allowed = attr->mutable_allowed_values();
188     str_util::RemoveLeadingWhitespace(&spec);
189     if (absl::StartsWith(spec, "\"") || absl::StartsWith(spec, "'")) {
190       type = "string";  // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
191       while (true) {
192         StringPiece escaped_string;
193         VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
194                    ConsumeQuotedString('\'', &spec, &escaped_string),
195                "Trouble parsing allowed string at '", spec, "'");
196         string unescaped;
197         string error;
198         VERIFY(absl::CUnescape(escaped_string, &unescaped, &error),
199                "Trouble unescaping \"", escaped_string,
200                "\", got error: ", error);
201         allowed->mutable_list()->add_s(unescaped);
202         if (absl::ConsumePrefix(&spec, ",")) {
203           str_util::RemoveLeadingWhitespace(&spec);
204           if (absl::ConsumePrefix(&spec, "}"))
205             break;  // Allow ending with ", }".
206         } else {
207           VERIFY(absl::ConsumePrefix(&spec, "}"),
208                  "Expected , or } after strings in list, not: '", spec, "'");
209           break;
210         }
211       }
212     } else {  // "{ bool, numbertype, string }"
213       type = "type";
214       while (true) {
215         VERIFY(ConsumeAttrType(&spec, &type_string),
216                "Trouble parsing type string at '", spec, "'");
217         if (ProcessCompoundType(type_string, allowed)) {
218           // Processed a compound type.
219         } else {
220           DataType dt;
221           VERIFY(DataTypeFromString(type_string, &dt),
222                  "Unrecognized type string '", type_string, "'");
223           allowed->mutable_list()->add_type(dt);
224         }
225         if (absl::ConsumePrefix(&spec, ",")) {
226           str_util::RemoveLeadingWhitespace(&spec);
227           if (absl::ConsumePrefix(&spec, "}"))
228             break;  // Allow ending with ", }".
229         } else {
230           VERIFY(absl::ConsumePrefix(&spec, "}"),
231                  "Expected , or } after types in list, not: '", spec, "'");
232           break;
233         }
234       }
235     }
236   } else {  // if spec.Consume("{")
237     VERIFY(false, "Trouble parsing type string at '", spec, "'");
238   }
239   str_util::RemoveLeadingWhitespace(&spec);
240 
241   // Write the type into *attr.
242   if (is_list) {
243     VERIFY(absl::ConsumePrefix(&spec, ")"),
244            "Expected ) to close 'list(', not: '", spec, "'");
245     str_util::RemoveLeadingWhitespace(&spec);
246     attr->set_type(strings::StrCat("list(", type, ")"));
247   } else {
248     attr->set_type(type);
249   }
250 
251   // Read optional minimum constraint at the end.
252   if ((is_list || type == "int") && absl::ConsumePrefix(&spec, ">=")) {
253     int64 min_limit = -999;
254     VERIFY(ConsumeAttrNumber(&spec, &min_limit),
255            "Could not parse integer lower limit after '>=', found '", spec,
256            "' instead");
257     attr->set_has_minimum(true);
258     attr->set_minimum(min_limit);
259   }
260 
261   // Parse default value, if present.
262   if (absl::ConsumePrefix(&spec, "=")) {
263     str_util::RemoveLeadingWhitespace(&spec);
264     VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
265            "Could not parse default value '", spec, "'");
266   } else {
267     VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
268   }
269 }
270 
271 #undef VERIFY
272 
InOutError(bool is_output,StringPiece orig,const string & op_name)273 string InOutError(bool is_output, StringPiece orig, const string& op_name) {
274   return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
275                          "\") for Op ", op_name);
276 }
277 
ConsumeInOutName(StringPiece * sp,StringPiece * out)278 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
279   return Scanner(*sp)
280       .One(Scanner::LOWERLETTER)
281       .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
282       .StopCapture()
283       .AnySpace()
284       .OneLiteral(":")
285       .AnySpace()
286       .GetResult(sp, out);
287 }
288 
ConsumeInOutRefOpen(StringPiece * sp)289 bool ConsumeInOutRefOpen(StringPiece* sp) {
290   return Scanner(*sp)
291       .OneLiteral("Ref")
292       .AnySpace()
293       .OneLiteral("(")
294       .AnySpace()
295       .GetResult(sp);
296 }
297 
ConsumeInOutRefClose(StringPiece * sp)298 bool ConsumeInOutRefClose(StringPiece* sp) {
299   return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
300 }
301 
ConsumeInOutNameOrType(StringPiece * sp,StringPiece * out)302 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
303   return Scanner(*sp)
304       .One(Scanner::LETTER)
305       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
306       .StopCapture()
307       .AnySpace()
308       .GetResult(sp, out);
309 }
310 
ConsumeInOutTimesType(StringPiece * sp,StringPiece * out)311 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
312   return Scanner(*sp)
313       .OneLiteral("*")
314       .AnySpace()
315       .RestartCapture()
316       .One(Scanner::LETTER)
317       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
318       .StopCapture()
319       .AnySpace()
320       .GetResult(sp, out);
321 }
322 
ConsumeControlOutName(StringPiece * sp,StringPiece * out)323 bool ConsumeControlOutName(StringPiece* sp, StringPiece* out) {
324   return Scanner(*sp)
325       .One(Scanner::LETTER)
326       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
327       .StopCapture()
328       .GetResult(sp, out);
329 }
330 
331 #define VERIFY(expr, ...)                                             \
332   do {                                                                \
333     if (!(expr)) {                                                    \
334       errors->push_back(strings::StrCat(                              \
335           __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
336       return;                                                         \
337     }                                                                 \
338   } while (false)
339 
FinalizeInputOrOutput(StringPiece spec,bool is_output,OpDef * op_def,std::vector<string> * errors)340 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
341                            std::vector<string>* errors) {
342   OpDef::ArgDef* arg =
343       is_output ? op_def->add_output_arg() : op_def->add_input_arg();
344 
345   StringPiece orig(spec);
346 
347   // Parse "<name>:" at the beginning.
348   StringPiece tmp_name;
349   VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
350   arg->set_name(tmp_name.data(), tmp_name.size());
351 
352   // Detect "Ref(...)".
353   if (ConsumeInOutRefOpen(&spec)) {
354     arg->set_is_ref(true);
355   }
356 
357   {  // Parse "<name|type>" or "<name>*<name|type>".
358     StringPiece first, second, type_or_attr;
359     VERIFY(ConsumeInOutNameOrType(&spec, &first),
360            "Trouble parsing either a type or an attr name at '", spec, "'");
361     if (ConsumeInOutTimesType(&spec, &second)) {
362       arg->set_number_attr(first.data(), first.size());
363       type_or_attr = second;
364     } else {
365       type_or_attr = first;
366     }
367     DataType dt;
368     if (DataTypeFromString(type_or_attr, &dt)) {
369       arg->set_type(dt);
370     } else {
371       const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
372       VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
373       if (attr->type() == "type") {
374         arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
375       } else {
376         VERIFY(attr->type() == "list(type)", "Reference to attr '",
377                type_or_attr, "' with type ", attr->type(),
378                " that isn't type or list(type)");
379         arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
380       }
381     }
382   }
383 
384   // Closing ) for Ref(.
385   if (arg->is_ref()) {
386     VERIFY(ConsumeInOutRefClose(&spec),
387            "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
388   }
389 
390   // Should not have anything else.
391   VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
392 
393   // Int attrs that are the length of an input or output get a default
394   // minimum of 1.
395   if (!arg->number_attr().empty()) {
396     OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
397     if (attr != nullptr && !attr->has_minimum()) {
398       attr->set_has_minimum(true);
399       attr->set_minimum(1);
400     }
401   } else if (!arg->type_list_attr().empty()) {
402     // If an input or output has type specified by a list(type) attr,
403     // it gets a default minimum of 1 as well.
404     OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
405     if (attr != nullptr && attr->type() == "list(type)" &&
406         !attr->has_minimum()) {
407       attr->set_has_minimum(true);
408       attr->set_minimum(1);
409     }
410   }
411 
412   // If the arg's dtype is resource we should mark the op as stateful as it
413   // likely touches a resource manager. This deliberately doesn't cover inputs /
414   // outputs which resolve to resource via Attrs as those mostly operate on
415   // resource handles as an opaque type (as opposed to ops which explicitly take
416   // / produce resources).
417   if (arg->type() == DT_RESOURCE) {
418     op_def->set_is_stateful(true);
419   }
420 }
421 
422 #undef VERIFY
423 
ControlOutError(StringPiece orig,const string & op_name)424 string ControlOutError(StringPiece orig, const string& op_name) {
425   return strings::StrCat(" from ControlOutput(\"", orig, "\") for Op ",
426                          op_name);
427 }
428 
FinalizeControlOutput(StringPiece name,OpDef * op_def,std::vector<string> * errors)429 void FinalizeControlOutput(StringPiece name, OpDef* op_def,
430                            std::vector<string>* errors) {
431   StringPiece orig(name);
432 
433   // Parse control output name.
434   StringPiece tmp_name;
435   if (!ConsumeControlOutName(&orig, &tmp_name)) {
436     errors->push_back(strings::StrCat("Trouble parsing 'name:'",
437                                       ControlOutError(orig, op_def->name())));
438   }
439 
440   *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size());
441 }
442 
num_leading_spaces(StringPiece s)443 int num_leading_spaces(StringPiece s) {
444   size_t i = 0;
445   while (i < s.size() && s[i] == ' ') {
446     ++i;
447   }
448   return i;
449 }
450 
ConsumeDocNameColon(StringPiece * sp,StringPiece * out)451 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
452   return Scanner(*sp)
453       .One(Scanner::LETTER)
454       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
455       .StopCapture()
456       .AnySpace()
457       .OneLiteral(":")
458       .AnySpace()
459       .GetResult(sp, out);
460 }
461 
IsDocNameColon(StringPiece s)462 bool IsDocNameColon(StringPiece s) {
463   return ConsumeDocNameColon(&s, nullptr /* out */);
464 }
465 
FinalizeDoc(const string & text,OpDef * op_def,std::vector<string> * errors)466 void FinalizeDoc(const string& text, OpDef* op_def,
467                  std::vector<string>* errors) {
468   std::vector<string> lines = str_util::Split(text, '\n');
469 
470   // Remove trailing spaces.
471   for (string& line : lines) {
472     absl::StripTrailingAsciiWhitespace(&line);
473   }
474 
475   // First non-blank line -> summary.
476   int l = 0;
477   while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
478   if (static_cast<size_t>(l) < lines.size()) {
479     op_def->set_summary(lines[l]);
480     ++l;
481   }
482   while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
483 
484   // Lines until we see name: -> description.
485   int start_l = l;
486   while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
487     ++l;
488   }
489   int end_l = l;
490   // Trim trailing blank lines from the description.
491   while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
492   string desc = absl::StrJoin(
493       gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
494   if (!desc.empty()) op_def->set_description(desc);
495 
496   // name: description
497   //   possibly continued on the next line
498   //   if so, we remove the minimum indent
499   StringPiece name;
500   std::vector<StringPiece> description;
501   while (static_cast<size_t>(l) < lines.size()) {
502     description.clear();
503     description.push_back(lines[l]);
504     ConsumeDocNameColon(&description.back(), &name);
505     ++l;
506     while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
507       description.push_back(lines[l]);
508       ++l;
509     }
510     // Remove any trailing blank lines.
511     while (!description.empty() && description.back().empty()) {
512       description.pop_back();
513     }
514     // Compute the minimum indent of all lines after the first.
515     int min_indent = -1;
516     for (size_t i = 1; i < description.size(); ++i) {
517       if (!description[i].empty()) {
518         int indent = num_leading_spaces(description[i]);
519         if (min_indent < 0 || indent < min_indent) min_indent = indent;
520       }
521     }
522     // Remove min_indent spaces from all lines after the first.
523     for (size_t i = 1; i < description.size(); ++i) {
524       if (!description[i].empty()) description[i].remove_prefix(min_indent);
525     }
526     // Concatenate lines into a single string.
527     const string complete(absl::StrJoin(description, "\n"));
528 
529     // Find name.
530     bool found = false;
531     for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
532       if (op_def->input_arg(i).name() == name) {
533         op_def->mutable_input_arg(i)->set_description(complete);
534         found = true;
535       }
536     }
537     for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
538       if (op_def->output_arg(i).name() == name) {
539         op_def->mutable_output_arg(i)->set_description(complete);
540         found = true;
541       }
542     }
543     for (int i = 0; !found && i < op_def->attr_size(); ++i) {
544       if (op_def->attr(i).name() == name) {
545         op_def->mutable_attr(i)->set_description(complete);
546         found = true;
547       }
548     }
549     if (!found) {
550       errors->push_back(
551           strings::StrCat("No matching input/output/attr for name '", name,
552                           "' from Doc() for Op ", op_def->name()));
553       return;
554     }
555   }
556 }
557 
558 }  // namespace
559 
OpDefBuilder(string op_name)560 OpDefBuilder::OpDefBuilder(string op_name) {
561   op_def()->set_name(std::move(op_name));
562 }
563 
Attr(string spec)564 OpDefBuilder& OpDefBuilder::Attr(string spec) {
565   attrs_.push_back(std::move(spec));
566   return *this;
567 }
568 
Input(string spec)569 OpDefBuilder& OpDefBuilder::Input(string spec) {
570   inputs_.push_back(std::move(spec));
571   return *this;
572 }
573 
Output(string spec)574 OpDefBuilder& OpDefBuilder::Output(string spec) {
575   outputs_.push_back(std::move(spec));
576   return *this;
577 }
578 
ControlOutput(string name)579 OpDefBuilder& OpDefBuilder::ControlOutput(string name) {
580   control_outputs_.push_back(std::move(name));
581   return *this;
582 }
583 
584 #ifndef TF_LEAN_BINARY
Doc(string text)585 OpDefBuilder& OpDefBuilder::Doc(string text) {
586   if (!doc_.empty()) {
587     errors_.push_back(
588         strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
589   } else {
590     doc_ = std::move(text);
591   }
592   return *this;
593 }
594 #endif
595 
SetIsCommutative()596 OpDefBuilder& OpDefBuilder::SetIsCommutative() {
597   op_def()->set_is_commutative(true);
598   return *this;
599 }
600 
SetIsAggregate()601 OpDefBuilder& OpDefBuilder::SetIsAggregate() {
602   op_def()->set_is_aggregate(true);
603   return *this;
604 }
605 
SetIsStateful()606 OpDefBuilder& OpDefBuilder::SetIsStateful() {
607   op_def()->set_is_stateful(true);
608   return *this;
609 }
610 
SetAllowsUninitializedInput()611 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
612   op_def()->set_allows_uninitialized_input(true);
613   return *this;
614 }
615 
Deprecated(int version,string explanation)616 OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
617   if (op_def()->has_deprecation()) {
618     errors_.push_back(
619         strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
620   } else {
621     OpDeprecation* deprecation = op_def()->mutable_deprecation();
622     deprecation->set_version(version);
623     deprecation->set_explanation(std::move(explanation));
624   }
625   return *this;
626 }
627 
SetShapeFn(OpShapeInferenceFn fn)628 OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) {
629   if (op_reg_data_.shape_inference_fn != nullptr) {
630     errors_.push_back(
631         strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
632   } else {
633     op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn);
634   }
635   return *this;
636 }
637 
AllowAttrTypeAny()638 OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() {
639   allow_attr_type_any_ = true;
640   return *this;
641 }
642 
Finalize(OpRegistrationData * op_reg_data) const643 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
644   std::vector<string> errors = errors_;
645   *op_reg_data = op_reg_data_;
646 
647   OpDef* op_def = &op_reg_data->op_def;
648   for (StringPiece attr : attrs_) {
649     FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors);
650   }
651   for (StringPiece input : inputs_) {
652     FinalizeInputOrOutput(input, false, op_def, &errors);
653   }
654   for (StringPiece output : outputs_) {
655     FinalizeInputOrOutput(output, true, op_def, &errors);
656   }
657   for (StringPiece control_output : control_outputs_) {
658     FinalizeControlOutput(control_output, op_def, &errors);
659   }
660   FinalizeDoc(doc_, op_def, &errors);
661 
662   if (errors.empty()) return Status::OK();
663   return errors::InvalidArgument(absl::StrJoin(errors, "\n"));
664 }
665 
666 }  // namespace tensorflow
667