• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_builder.h"
17 
18 #include <limits>
19 #include <vector>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/lib/strings/scanner.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 
30 using ::tensorflow::strings::Scanner;
31 
32 namespace tensorflow {
33 
34 namespace {
35 
AttrError(StringPiece orig,const string & op_name)36 string AttrError(StringPiece orig, const string& op_name) {
37   return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
38 }
39 
ConsumeAttrName(StringPiece * sp,StringPiece * out)40 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
41   return Scanner(*sp)
42       .One(Scanner::LETTER)
43       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
44       .StopCapture()
45       .AnySpace()
46       .OneLiteral(":")
47       .AnySpace()
48       .GetResult(sp, out);
49 }
50 
ConsumeListPrefix(StringPiece * sp)51 bool ConsumeListPrefix(StringPiece* sp) {
52   return Scanner(*sp)
53       .OneLiteral("list")
54       .AnySpace()
55       .OneLiteral("(")
56       .AnySpace()
57       .GetResult(sp);
58 }
59 
ConsumeQuotedString(char quote_ch,StringPiece * sp,StringPiece * out)60 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
61   const string quote_str(1, quote_ch);
62   return Scanner(*sp)
63       .OneLiteral(quote_str.c_str())
64       .RestartCapture()
65       .ScanEscapedUntil(quote_ch)
66       .StopCapture()
67       .OneLiteral(quote_str.c_str())
68       .AnySpace()
69       .GetResult(sp, out);
70 }
71 
ConsumeAttrType(StringPiece * sp,StringPiece * out)72 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
73   return Scanner(*sp)
74       .Many(Scanner::LOWERLETTER_DIGIT)
75       .StopCapture()
76       .AnySpace()
77       .GetResult(sp, out);
78 }
79 
ConsumeAttrNumber(StringPiece * sp,int64 * out)80 bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
81   Scanner scan(*sp);
82   StringPiece match;
83   StringPiece remaining;
84 
85   scan.AnySpace().RestartCapture();
86   if (scan.Peek() == '-') {
87     scan.OneLiteral("-");
88   }
89   if (!scan.Many(Scanner::DIGIT)
90            .StopCapture()
91            .AnySpace()
92            .GetResult(&remaining, &match)) {
93     return false;
94   }
95   int64 value = 0;
96   if (!strings::safe_strto64(match, &value)) {
97     return false;
98   }
99   *out = value;
100   *sp = remaining;
101   return true;
102 }
103 
104 #define VERIFY(expr, ...)                                                 \
105   do {                                                                    \
106     if (!(expr)) {                                                        \
107       errors->push_back(                                                  \
108           strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
109       return;                                                             \
110     }                                                                     \
111   } while (false)
112 
ConsumeCompoundAttrType(StringPiece * sp,StringPiece * out)113 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
114   auto capture_begin = sp->begin();
115   if (str_util::ConsumePrefix(sp, "numbertype") ||
116       str_util::ConsumePrefix(sp, "numerictype") ||
117       str_util::ConsumePrefix(sp, "quantizedtype") ||
118       str_util::ConsumePrefix(sp, "realnumbertype") ||
119       str_util::ConsumePrefix(sp, "realnumberictype")) {
120     *out = StringPiece(capture_begin, sp->begin() - capture_begin);
121     return true;
122   }
123   return false;
124 }
125 
ProcessCompoundType(const StringPiece type_string,AttrValue * allowed)126 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
127   if (type_string == "numbertype" || type_string == "numerictype") {
128     for (DataType dt : NumberTypes()) {
129       allowed->mutable_list()->add_type(dt);
130     }
131   } else if (type_string == "quantizedtype") {
132     for (DataType dt : QuantizedTypes()) {
133       allowed->mutable_list()->add_type(dt);
134     }
135   } else if (type_string == "realnumbertype" ||
136              type_string == "realnumerictype") {
137     for (DataType dt : RealNumberTypes()) {
138       allowed->mutable_list()->add_type(dt);
139     }
140   } else {
141     return false;
142   }
143   return true;
144 }
145 
FinalizeAttr(StringPiece spec,OpDef * op_def,std::vector<string> * errors)146 void FinalizeAttr(StringPiece spec, OpDef* op_def,
147                   std::vector<string>* errors) {
148   OpDef::AttrDef* attr = op_def->add_attr();
149   StringPiece orig(spec);
150 
151   // Parse "<name>:" at the beginning.
152   StringPiece tmp_name;
153   VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
154   attr->set_name(tmp_name.data(), tmp_name.size());
155 
156   // Read "<type>" or "list(<type>)".
157   bool is_list = ConsumeListPrefix(&spec);
158   string type;
159   StringPiece type_string;  // Used if type == "type"
160   if (str_util::ConsumePrefix(&spec, "string")) {
161     type = "string";
162   } else if (str_util::ConsumePrefix(&spec, "int")) {
163     type = "int";
164   } else if (str_util::ConsumePrefix(&spec, "float")) {
165     type = "float";
166   } else if (str_util::ConsumePrefix(&spec, "bool")) {
167     type = "bool";
168   } else if (str_util::ConsumePrefix(&spec, "type")) {
169     type = "type";
170   } else if (str_util::ConsumePrefix(&spec, "shape")) {
171     type = "shape";
172   } else if (str_util::ConsumePrefix(&spec, "tensor")) {
173     type = "tensor";
174   } else if (str_util::ConsumePrefix(&spec, "func")) {
175     type = "func";
176   } else if (ConsumeCompoundAttrType(&spec, &type_string)) {
177     type = "type";
178     AttrValue* allowed = attr->mutable_allowed_values();
179     VERIFY(ProcessCompoundType(type_string, allowed),
180            "Expected to see a compound type, saw: ", type_string);
181   } else if (str_util::ConsumePrefix(&spec, "{")) {
182     // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
183     AttrValue* allowed = attr->mutable_allowed_values();
184     str_util::RemoveLeadingWhitespace(&spec);
185     if (str_util::StartsWith(spec, "\"") || str_util::StartsWith(spec, "'")) {
186       type = "string";  // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
187       while (true) {
188         StringPiece escaped_string;
189         VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
190                    ConsumeQuotedString('\'', &spec, &escaped_string),
191                "Trouble parsing allowed string at '", spec, "'");
192         string unescaped;
193         string error;
194         VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
195                "Trouble unescaping \"", escaped_string,
196                "\", got error: ", error);
197         allowed->mutable_list()->add_s(unescaped);
198         if (str_util::ConsumePrefix(&spec, ",")) {
199           str_util::RemoveLeadingWhitespace(&spec);
200           if (str_util::ConsumePrefix(&spec, "}"))
201             break;  // Allow ending with ", }".
202         } else {
203           VERIFY(str_util::ConsumePrefix(&spec, "}"),
204                  "Expected , or } after strings in list, not: '", spec, "'");
205           break;
206         }
207       }
208     } else {  // "{ bool, numbertype, string }"
209       type = "type";
210       while (true) {
211         VERIFY(ConsumeAttrType(&spec, &type_string),
212                "Trouble parsing type string at '", spec, "'");
213         if (ProcessCompoundType(type_string, allowed)) {
214           // Processed a compound type.
215         } else {
216           DataType dt;
217           VERIFY(DataTypeFromString(type_string, &dt),
218                  "Unrecognized type string '", type_string, "'");
219           allowed->mutable_list()->add_type(dt);
220         }
221         if (str_util::ConsumePrefix(&spec, ",")) {
222           str_util::RemoveLeadingWhitespace(&spec);
223           if (str_util::ConsumePrefix(&spec, "}"))
224             break;  // Allow ending with ", }".
225         } else {
226           VERIFY(str_util::ConsumePrefix(&spec, "}"),
227                  "Expected , or } after types in list, not: '", spec, "'");
228           break;
229         }
230       }
231     }
232   } else {  // if spec.Consume("{")
233     VERIFY(false, "Trouble parsing type string at '", spec, "'");
234   }
235   str_util::RemoveLeadingWhitespace(&spec);
236 
237   // Write the type into *attr.
238   if (is_list) {
239     VERIFY(str_util::ConsumePrefix(&spec, ")"),
240            "Expected ) to close 'list(', not: '", spec, "'");
241     str_util::RemoveLeadingWhitespace(&spec);
242     attr->set_type(strings::StrCat("list(", type, ")"));
243   } else {
244     attr->set_type(type);
245   }
246 
247   // Read optional minimum constraint at the end.
248   if ((is_list || type == "int") && str_util::ConsumePrefix(&spec, ">=")) {
249     int64 min_limit = -999;
250     VERIFY(ConsumeAttrNumber(&spec, &min_limit),
251            "Could not parse integer lower limit after '>=', found '", spec,
252            "' instead");
253     attr->set_has_minimum(true);
254     attr->set_minimum(min_limit);
255   }
256 
257   // Parse default value, if present.
258   if (str_util::ConsumePrefix(&spec, "=")) {
259     str_util::RemoveLeadingWhitespace(&spec);
260     VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
261            "Could not parse default value '", spec, "'");
262   } else {
263     VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
264   }
265 }
266 
267 #undef VERIFY
268 
InOutError(bool is_output,StringPiece orig,const string & op_name)269 string InOutError(bool is_output, StringPiece orig, const string& op_name) {
270   return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
271                          "\") for Op ", op_name);
272 }
273 
ConsumeInOutName(StringPiece * sp,StringPiece * out)274 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
275   return Scanner(*sp)
276       .One(Scanner::LOWERLETTER)
277       .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
278       .StopCapture()
279       .AnySpace()
280       .OneLiteral(":")
281       .AnySpace()
282       .GetResult(sp, out);
283 }
284 
ConsumeInOutRefOpen(StringPiece * sp)285 bool ConsumeInOutRefOpen(StringPiece* sp) {
286   return Scanner(*sp)
287       .OneLiteral("Ref")
288       .AnySpace()
289       .OneLiteral("(")
290       .AnySpace()
291       .GetResult(sp);
292 }
293 
ConsumeInOutRefClose(StringPiece * sp)294 bool ConsumeInOutRefClose(StringPiece* sp) {
295   return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
296 }
297 
ConsumeInOutNameOrType(StringPiece * sp,StringPiece * out)298 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
299   return Scanner(*sp)
300       .One(Scanner::LETTER)
301       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
302       .StopCapture()
303       .AnySpace()
304       .GetResult(sp, out);
305 }
306 
ConsumeInOutTimesType(StringPiece * sp,StringPiece * out)307 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
308   return Scanner(*sp)
309       .OneLiteral("*")
310       .AnySpace()
311       .RestartCapture()
312       .One(Scanner::LETTER)
313       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
314       .StopCapture()
315       .AnySpace()
316       .GetResult(sp, out);
317 }
318 
ConsumeControlOutName(StringPiece * sp,StringPiece * out)319 bool ConsumeControlOutName(StringPiece* sp, StringPiece* out) {
320   return Scanner(*sp)
321       .One(Scanner::LETTER)
322       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
323       .StopCapture()
324       .GetResult(sp, out);
325 }
326 
327 #define VERIFY(expr, ...)                                             \
328   do {                                                                \
329     if (!(expr)) {                                                    \
330       errors->push_back(strings::StrCat(                              \
331           __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
332       return;                                                         \
333     }                                                                 \
334   } while (false)
335 
FinalizeInputOrOutput(StringPiece spec,bool is_output,OpDef * op_def,std::vector<string> * errors)336 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
337                            std::vector<string>* errors) {
338   OpDef::ArgDef* arg =
339       is_output ? op_def->add_output_arg() : op_def->add_input_arg();
340 
341   StringPiece orig(spec);
342 
343   // Parse "<name>:" at the beginning.
344   StringPiece tmp_name;
345   VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
346   arg->set_name(tmp_name.data(), tmp_name.size());
347 
348   // Detect "Ref(...)".
349   if (ConsumeInOutRefOpen(&spec)) {
350     arg->set_is_ref(true);
351   }
352 
353   {  // Parse "<name|type>" or "<name>*<name|type>".
354     StringPiece first, second, type_or_attr;
355     VERIFY(ConsumeInOutNameOrType(&spec, &first),
356            "Trouble parsing either a type or an attr name at '", spec, "'");
357     if (ConsumeInOutTimesType(&spec, &second)) {
358       arg->set_number_attr(first.data(), first.size());
359       type_or_attr = second;
360     } else {
361       type_or_attr = first;
362     }
363     DataType dt;
364     if (DataTypeFromString(type_or_attr, &dt)) {
365       arg->set_type(dt);
366     } else {
367       const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
368       VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
369       if (attr->type() == "type") {
370         arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
371       } else {
372         VERIFY(attr->type() == "list(type)", "Reference to attr '",
373                type_or_attr, "' with type ", attr->type(),
374                " that isn't type or list(type)");
375         arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
376       }
377     }
378   }
379 
380   // Closing ) for Ref(.
381   if (arg->is_ref()) {
382     VERIFY(ConsumeInOutRefClose(&spec),
383            "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
384   }
385 
386   // Should not have anything else.
387   VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
388 
389   // Int attrs that are the length of an input or output get a default
390   // minimum of 1.
391   if (!arg->number_attr().empty()) {
392     OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
393     if (attr != nullptr && !attr->has_minimum()) {
394       attr->set_has_minimum(true);
395       attr->set_minimum(1);
396     }
397   } else if (!arg->type_list_attr().empty()) {
398     // If an input or output has type specified by a list(type) attr,
399     // it gets a default minimum of 1 as well.
400     OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
401     if (attr != nullptr && attr->type() == "list(type)" &&
402         !attr->has_minimum()) {
403       attr->set_has_minimum(true);
404       attr->set_minimum(1);
405     }
406   }
407 
408   // If the arg's dtype is resource we should mark the op as stateful as it
409   // likely touches a resource manager. This deliberately doesn't cover inputs /
410   // outputs which resolve to resource via Attrs as those mostly operate on
411   // resource handles as an opaque type (as opposed to ops which explicitly take
412   // / produce resources).
413   if (arg->type() == DT_RESOURCE) {
414     op_def->set_is_stateful(true);
415   }
416 }
417 
418 #undef VERIFY
419 
ControlOutError(StringPiece orig,const string & op_name)420 string ControlOutError(StringPiece orig, const string& op_name) {
421   return strings::StrCat(" from ControlOutput(\"", orig, "\") for Op ",
422                          op_name);
423 }
424 
FinalizeControlOutput(StringPiece name,OpDef * op_def,std::vector<string> * errors)425 void FinalizeControlOutput(StringPiece name, OpDef* op_def,
426                            std::vector<string>* errors) {
427   StringPiece orig(name);
428 
429   // Parse control output name.
430   StringPiece tmp_name;
431   if (!ConsumeControlOutName(&orig, &tmp_name)) {
432     errors->push_back(strings::StrCat("Trouble parsing 'name:'",
433                                       ControlOutError(orig, op_def->name())));
434   }
435 
436   *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size());
437 }
438 
num_leading_spaces(StringPiece s)439 int num_leading_spaces(StringPiece s) {
440   size_t i = 0;
441   while (i < s.size() && s[i] == ' ') {
442     ++i;
443   }
444   return i;
445 }
446 
ConsumeDocNameColon(StringPiece * sp,StringPiece * out)447 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
448   return Scanner(*sp)
449       .One(Scanner::LETTER)
450       .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
451       .StopCapture()
452       .AnySpace()
453       .OneLiteral(":")
454       .AnySpace()
455       .GetResult(sp, out);
456 }
457 
IsDocNameColon(StringPiece s)458 bool IsDocNameColon(StringPiece s) {
459   return ConsumeDocNameColon(&s, nullptr /* out */);
460 }
461 
FinalizeDoc(const string & text,OpDef * op_def,std::vector<string> * errors)462 void FinalizeDoc(const string& text, OpDef* op_def,
463                  std::vector<string>* errors) {
464   std::vector<string> lines = str_util::Split(text, '\n');
465 
466   // Remove trailing spaces.
467   for (string& line : lines) {
468     str_util::StripTrailingWhitespace(&line);
469   }
470 
471   // First non-blank line -> summary.
472   int l = 0;
473   while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
474   if (static_cast<size_t>(l) < lines.size()) {
475     op_def->set_summary(lines[l]);
476     ++l;
477   }
478   while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
479 
480   // Lines until we see name: -> description.
481   int start_l = l;
482   while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
483     ++l;
484   }
485   int end_l = l;
486   // Trim trailing blank lines from the description.
487   while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
488   string desc = str_util::Join(
489       gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
490   if (!desc.empty()) op_def->set_description(desc);
491 
492   // name: description
493   //   possibly continued on the next line
494   //   if so, we remove the minimum indent
495   StringPiece name;
496   std::vector<StringPiece> description;
497   while (static_cast<size_t>(l) < lines.size()) {
498     description.clear();
499     description.push_back(lines[l]);
500     ConsumeDocNameColon(&description.back(), &name);
501     ++l;
502     while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
503       description.push_back(lines[l]);
504       ++l;
505     }
506     // Remove any trailing blank lines.
507     while (!description.empty() && description.back().empty()) {
508       description.pop_back();
509     }
510     // Compute the minimum indent of all lines after the first.
511     int min_indent = -1;
512     for (size_t i = 1; i < description.size(); ++i) {
513       if (!description[i].empty()) {
514         int indent = num_leading_spaces(description[i]);
515         if (min_indent < 0 || indent < min_indent) min_indent = indent;
516       }
517     }
518     // Remove min_indent spaces from all lines after the first.
519     for (size_t i = 1; i < description.size(); ++i) {
520       if (!description[i].empty()) description[i].remove_prefix(min_indent);
521     }
522     // Concatenate lines into a single string.
523     const string complete(str_util::Join(description, "\n"));
524 
525     // Find name.
526     bool found = false;
527     for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
528       if (op_def->input_arg(i).name() == name) {
529         op_def->mutable_input_arg(i)->set_description(complete);
530         found = true;
531       }
532     }
533     for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
534       if (op_def->output_arg(i).name() == name) {
535         op_def->mutable_output_arg(i)->set_description(complete);
536         found = true;
537       }
538     }
539     for (int i = 0; !found && i < op_def->attr_size(); ++i) {
540       if (op_def->attr(i).name() == name) {
541         op_def->mutable_attr(i)->set_description(complete);
542         found = true;
543       }
544     }
545     if (!found) {
546       errors->push_back(
547           strings::StrCat("No matching input/output/attr for name '", name,
548                           "' from Doc() for Op ", op_def->name()));
549       return;
550     }
551   }
552 }
553 
554 }  // namespace
555 
OpDefBuilder(string op_name)556 OpDefBuilder::OpDefBuilder(string op_name) {
557   op_def()->set_name(std::move(op_name));
558 }
559 
Attr(string spec)560 OpDefBuilder& OpDefBuilder::Attr(string spec) {
561   attrs_.push_back(std::move(spec));
562   return *this;
563 }
564 
Input(string spec)565 OpDefBuilder& OpDefBuilder::Input(string spec) {
566   inputs_.push_back(std::move(spec));
567   return *this;
568 }
569 
Output(string spec)570 OpDefBuilder& OpDefBuilder::Output(string spec) {
571   outputs_.push_back(std::move(spec));
572   return *this;
573 }
574 
ControlOutput(string name)575 OpDefBuilder& OpDefBuilder::ControlOutput(string name) {
576   control_outputs_.push_back(std::move(name));
577   return *this;
578 }
579 
580 #ifndef TF_LEAN_BINARY
Doc(string text)581 OpDefBuilder& OpDefBuilder::Doc(string text) {
582   if (!doc_.empty()) {
583     errors_.push_back(
584         strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
585   } else {
586     doc_ = std::move(text);
587   }
588   return *this;
589 }
590 #endif
591 
SetIsCommutative()592 OpDefBuilder& OpDefBuilder::SetIsCommutative() {
593   op_def()->set_is_commutative(true);
594   return *this;
595 }
596 
SetIsAggregate()597 OpDefBuilder& OpDefBuilder::SetIsAggregate() {
598   op_def()->set_is_aggregate(true);
599   return *this;
600 }
601 
SetIsStateful()602 OpDefBuilder& OpDefBuilder::SetIsStateful() {
603   op_def()->set_is_stateful(true);
604   return *this;
605 }
606 
SetAllowsUninitializedInput()607 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
608   op_def()->set_allows_uninitialized_input(true);
609   return *this;
610 }
611 
Deprecated(int version,string explanation)612 OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
613   if (op_def()->has_deprecation()) {
614     errors_.push_back(
615         strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
616   } else {
617     OpDeprecation* deprecation = op_def()->mutable_deprecation();
618     deprecation->set_version(version);
619     deprecation->set_explanation(std::move(explanation));
620   }
621   return *this;
622 }
623 
SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))624 OpDefBuilder& OpDefBuilder::SetShapeFn(
625     Status (*fn)(shape_inference::InferenceContext*)) {
626   if (op_reg_data_.shape_inference_fn != nullptr) {
627     errors_.push_back(
628         strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
629   } else {
630     op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn);
631   }
632   return *this;
633 }
634 
Finalize(OpRegistrationData * op_reg_data) const635 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
636   std::vector<string> errors = errors_;
637   *op_reg_data = op_reg_data_;
638 
639   OpDef* op_def = &op_reg_data->op_def;
640   for (StringPiece attr : attrs_) {
641     FinalizeAttr(attr, op_def, &errors);
642   }
643   for (StringPiece input : inputs_) {
644     FinalizeInputOrOutput(input, false, op_def, &errors);
645   }
646   for (StringPiece output : outputs_) {
647     FinalizeInputOrOutput(output, true, op_def, &errors);
648   }
649   for (StringPiece control_output : control_outputs_) {
650     FinalizeControlOutput(control_output, op_def, &errors);
651   }
652   FinalizeDoc(doc_, op_def, &errors);
653 
654   if (errors.empty()) return Status::OK();
655   return errors::InvalidArgument(str_util::Join(errors, "\n"));
656 }
657 
658 }  // namespace tensorflow
659