1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_builder.h"
17
18 #include <limits>
19 #include <vector>
20
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/lib/strings/scanner.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31
32 using ::tensorflow::strings::Scanner;
33
34 namespace tensorflow {
35
36 namespace {
37
AttrError(StringPiece orig,const string & op_name)38 string AttrError(StringPiece orig, const string& op_name) {
39 return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
40 }
41
ConsumeAttrName(StringPiece * sp,StringPiece * out)42 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
43 return Scanner(*sp)
44 .One(Scanner::LETTER)
45 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
46 .StopCapture()
47 .AnySpace()
48 .OneLiteral(":")
49 .AnySpace()
50 .GetResult(sp, out);
51 }
52
ConsumeListPrefix(StringPiece * sp)53 bool ConsumeListPrefix(StringPiece* sp) {
54 return Scanner(*sp)
55 .OneLiteral("list")
56 .AnySpace()
57 .OneLiteral("(")
58 .AnySpace()
59 .GetResult(sp);
60 }
61
ConsumeQuotedString(char quote_ch,StringPiece * sp,StringPiece * out)62 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
63 const string quote_str(1, quote_ch);
64 return Scanner(*sp)
65 .OneLiteral(quote_str.c_str())
66 .RestartCapture()
67 .ScanEscapedUntil(quote_ch)
68 .StopCapture()
69 .OneLiteral(quote_str.c_str())
70 .AnySpace()
71 .GetResult(sp, out);
72 }
73
ConsumeAttrType(StringPiece * sp,StringPiece * out)74 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
75 return Scanner(*sp)
76 .Many(Scanner::LOWERLETTER_DIGIT)
77 .StopCapture()
78 .AnySpace()
79 .GetResult(sp, out);
80 }
81
ConsumeAttrNumber(StringPiece * sp,int64 * out)82 bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
83 Scanner scan(*sp);
84 StringPiece match;
85 StringPiece remaining;
86
87 scan.AnySpace().RestartCapture();
88 if (scan.Peek() == '-') {
89 scan.OneLiteral("-");
90 }
91 if (!scan.Many(Scanner::DIGIT)
92 .StopCapture()
93 .AnySpace()
94 .GetResult(&remaining, &match)) {
95 return false;
96 }
97 int64 value = 0;
98 if (!strings::safe_strto64(match, &value)) {
99 return false;
100 }
101 *out = value;
102 *sp = remaining;
103 return true;
104 }
105
106 #define VERIFY(expr, ...) \
107 do { \
108 if (!(expr)) { \
109 errors->push_back( \
110 strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
111 return; \
112 } \
113 } while (false)
114
ConsumeCompoundAttrType(StringPiece * sp,StringPiece * out)115 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
116 auto capture_begin = sp->begin();
117 if (absl::ConsumePrefix(sp, "numbertype") ||
118 absl::ConsumePrefix(sp, "numerictype") ||
119 absl::ConsumePrefix(sp, "quantizedtype") ||
120 absl::ConsumePrefix(sp, "realnumbertype") ||
121 absl::ConsumePrefix(sp, "realnumberictype")) {
122 *out = StringPiece(capture_begin, sp->begin() - capture_begin);
123 return true;
124 }
125 return false;
126 }
127
ProcessCompoundType(const StringPiece type_string,AttrValue * allowed)128 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
129 if (type_string == "numbertype" || type_string == "numerictype") {
130 for (DataType dt : NumberTypes()) {
131 allowed->mutable_list()->add_type(dt);
132 }
133 } else if (type_string == "quantizedtype") {
134 for (DataType dt : QuantizedTypes()) {
135 allowed->mutable_list()->add_type(dt);
136 }
137 } else if (type_string == "realnumbertype" ||
138 type_string == "realnumerictype") {
139 for (DataType dt : RealNumberTypes()) {
140 allowed->mutable_list()->add_type(dt);
141 }
142 } else {
143 return false;
144 }
145 return true;
146 }
147
FinalizeAttr(StringPiece spec,bool allow_attr_type_any,OpDef * op_def,std::vector<string> * errors)148 void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def,
149 std::vector<string>* errors) {
150 OpDef::AttrDef* attr = op_def->add_attr();
151 StringPiece orig(spec);
152
153 // Parse "<name>:" at the beginning.
154 StringPiece tmp_name;
155 VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
156 attr->set_name(tmp_name.data(), tmp_name.size());
157
158 // Read "<type>" or "list(<type>)".
159 bool is_list = ConsumeListPrefix(&spec);
160 string type;
161 StringPiece type_string; // Used if type == "type"
162 if (absl::ConsumePrefix(&spec, "string")) {
163 type = "string";
164 } else if (absl::ConsumePrefix(&spec, "int")) {
165 type = "int";
166 } else if (absl::ConsumePrefix(&spec, "float")) {
167 type = "float";
168 } else if (absl::ConsumePrefix(&spec, "bool")) {
169 type = "bool";
170 } else if (absl::ConsumePrefix(&spec, "type")) {
171 type = "type";
172 } else if (absl::ConsumePrefix(&spec, "shape")) {
173 type = "shape";
174 } else if (absl::ConsumePrefix(&spec, "tensor")) {
175 type = "tensor";
176 } else if (absl::ConsumePrefix(&spec, "func")) {
177 type = "func";
178 } else if (absl::ConsumePrefix(&spec, "any") && allow_attr_type_any) {
179 type = "any";
180 } else if (ConsumeCompoundAttrType(&spec, &type_string)) {
181 type = "type";
182 AttrValue* allowed = attr->mutable_allowed_values();
183 VERIFY(ProcessCompoundType(type_string, allowed),
184 "Expected to see a compound type, saw: ", type_string);
185 } else if (absl::ConsumePrefix(&spec, "{")) {
186 // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
187 AttrValue* allowed = attr->mutable_allowed_values();
188 str_util::RemoveLeadingWhitespace(&spec);
189 if (absl::StartsWith(spec, "\"") || absl::StartsWith(spec, "'")) {
190 type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
191 while (true) {
192 StringPiece escaped_string;
193 VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
194 ConsumeQuotedString('\'', &spec, &escaped_string),
195 "Trouble parsing allowed string at '", spec, "'");
196 string unescaped;
197 string error;
198 VERIFY(absl::CUnescape(escaped_string, &unescaped, &error),
199 "Trouble unescaping \"", escaped_string,
200 "\", got error: ", error);
201 allowed->mutable_list()->add_s(unescaped);
202 if (absl::ConsumePrefix(&spec, ",")) {
203 str_util::RemoveLeadingWhitespace(&spec);
204 if (absl::ConsumePrefix(&spec, "}"))
205 break; // Allow ending with ", }".
206 } else {
207 VERIFY(absl::ConsumePrefix(&spec, "}"),
208 "Expected , or } after strings in list, not: '", spec, "'");
209 break;
210 }
211 }
212 } else { // "{ bool, numbertype, string }"
213 type = "type";
214 while (true) {
215 VERIFY(ConsumeAttrType(&spec, &type_string),
216 "Trouble parsing type string at '", spec, "'");
217 if (ProcessCompoundType(type_string, allowed)) {
218 // Processed a compound type.
219 } else {
220 DataType dt;
221 VERIFY(DataTypeFromString(type_string, &dt),
222 "Unrecognized type string '", type_string, "'");
223 allowed->mutable_list()->add_type(dt);
224 }
225 if (absl::ConsumePrefix(&spec, ",")) {
226 str_util::RemoveLeadingWhitespace(&spec);
227 if (absl::ConsumePrefix(&spec, "}"))
228 break; // Allow ending with ", }".
229 } else {
230 VERIFY(absl::ConsumePrefix(&spec, "}"),
231 "Expected , or } after types in list, not: '", spec, "'");
232 break;
233 }
234 }
235 }
236 } else { // if spec.Consume("{")
237 VERIFY(false, "Trouble parsing type string at '", spec, "'");
238 }
239 str_util::RemoveLeadingWhitespace(&spec);
240
241 // Write the type into *attr.
242 if (is_list) {
243 VERIFY(absl::ConsumePrefix(&spec, ")"),
244 "Expected ) to close 'list(', not: '", spec, "'");
245 str_util::RemoveLeadingWhitespace(&spec);
246 attr->set_type(strings::StrCat("list(", type, ")"));
247 } else {
248 attr->set_type(type);
249 }
250
251 // Read optional minimum constraint at the end.
252 if ((is_list || type == "int") && absl::ConsumePrefix(&spec, ">=")) {
253 int64 min_limit = -999;
254 VERIFY(ConsumeAttrNumber(&spec, &min_limit),
255 "Could not parse integer lower limit after '>=', found '", spec,
256 "' instead");
257 attr->set_has_minimum(true);
258 attr->set_minimum(min_limit);
259 }
260
261 // Parse default value, if present.
262 if (absl::ConsumePrefix(&spec, "=")) {
263 str_util::RemoveLeadingWhitespace(&spec);
264 VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
265 "Could not parse default value '", spec, "'");
266 } else {
267 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
268 }
269 }
270
271 #undef VERIFY
272
InOutError(bool is_output,StringPiece orig,const string & op_name)273 string InOutError(bool is_output, StringPiece orig, const string& op_name) {
274 return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
275 "\") for Op ", op_name);
276 }
277
ConsumeInOutName(StringPiece * sp,StringPiece * out)278 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
279 return Scanner(*sp)
280 .One(Scanner::LOWERLETTER)
281 .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
282 .StopCapture()
283 .AnySpace()
284 .OneLiteral(":")
285 .AnySpace()
286 .GetResult(sp, out);
287 }
288
ConsumeInOutRefOpen(StringPiece * sp)289 bool ConsumeInOutRefOpen(StringPiece* sp) {
290 return Scanner(*sp)
291 .OneLiteral("Ref")
292 .AnySpace()
293 .OneLiteral("(")
294 .AnySpace()
295 .GetResult(sp);
296 }
297
ConsumeInOutRefClose(StringPiece * sp)298 bool ConsumeInOutRefClose(StringPiece* sp) {
299 return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
300 }
301
ConsumeInOutNameOrType(StringPiece * sp,StringPiece * out)302 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
303 return Scanner(*sp)
304 .One(Scanner::LETTER)
305 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
306 .StopCapture()
307 .AnySpace()
308 .GetResult(sp, out);
309 }
310
ConsumeInOutTimesType(StringPiece * sp,StringPiece * out)311 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
312 return Scanner(*sp)
313 .OneLiteral("*")
314 .AnySpace()
315 .RestartCapture()
316 .One(Scanner::LETTER)
317 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
318 .StopCapture()
319 .AnySpace()
320 .GetResult(sp, out);
321 }
322
ConsumeControlOutName(StringPiece * sp,StringPiece * out)323 bool ConsumeControlOutName(StringPiece* sp, StringPiece* out) {
324 return Scanner(*sp)
325 .One(Scanner::LETTER)
326 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
327 .StopCapture()
328 .GetResult(sp, out);
329 }
330
331 #define VERIFY(expr, ...) \
332 do { \
333 if (!(expr)) { \
334 errors->push_back(strings::StrCat( \
335 __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
336 return; \
337 } \
338 } while (false)
339
FinalizeInputOrOutput(StringPiece spec,bool is_output,OpDef * op_def,std::vector<string> * errors)340 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
341 std::vector<string>* errors) {
342 OpDef::ArgDef* arg =
343 is_output ? op_def->add_output_arg() : op_def->add_input_arg();
344
345 StringPiece orig(spec);
346
347 // Parse "<name>:" at the beginning.
348 StringPiece tmp_name;
349 VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
350 arg->set_name(tmp_name.data(), tmp_name.size());
351
352 // Detect "Ref(...)".
353 if (ConsumeInOutRefOpen(&spec)) {
354 arg->set_is_ref(true);
355 }
356
357 { // Parse "<name|type>" or "<name>*<name|type>".
358 StringPiece first, second, type_or_attr;
359 VERIFY(ConsumeInOutNameOrType(&spec, &first),
360 "Trouble parsing either a type or an attr name at '", spec, "'");
361 if (ConsumeInOutTimesType(&spec, &second)) {
362 arg->set_number_attr(first.data(), first.size());
363 type_or_attr = second;
364 } else {
365 type_or_attr = first;
366 }
367 DataType dt;
368 if (DataTypeFromString(type_or_attr, &dt)) {
369 arg->set_type(dt);
370 } else {
371 const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
372 VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
373 if (attr->type() == "type") {
374 arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
375 } else {
376 VERIFY(attr->type() == "list(type)", "Reference to attr '",
377 type_or_attr, "' with type ", attr->type(),
378 " that isn't type or list(type)");
379 arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
380 }
381 }
382 }
383
384 // Closing ) for Ref(.
385 if (arg->is_ref()) {
386 VERIFY(ConsumeInOutRefClose(&spec),
387 "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
388 }
389
390 // Should not have anything else.
391 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
392
393 // Int attrs that are the length of an input or output get a default
394 // minimum of 1.
395 if (!arg->number_attr().empty()) {
396 OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
397 if (attr != nullptr && !attr->has_minimum()) {
398 attr->set_has_minimum(true);
399 attr->set_minimum(1);
400 }
401 } else if (!arg->type_list_attr().empty()) {
402 // If an input or output has type specified by a list(type) attr,
403 // it gets a default minimum of 1 as well.
404 OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
405 if (attr != nullptr && attr->type() == "list(type)" &&
406 !attr->has_minimum()) {
407 attr->set_has_minimum(true);
408 attr->set_minimum(1);
409 }
410 }
411
412 // If the arg's dtype is resource we should mark the op as stateful as it
413 // likely touches a resource manager. This deliberately doesn't cover inputs /
414 // outputs which resolve to resource via Attrs as those mostly operate on
415 // resource handles as an opaque type (as opposed to ops which explicitly take
416 // / produce resources).
417 if (arg->type() == DT_RESOURCE) {
418 op_def->set_is_stateful(true);
419 }
420 }
421
422 #undef VERIFY
423
ControlOutError(StringPiece orig,const string & op_name)424 string ControlOutError(StringPiece orig, const string& op_name) {
425 return strings::StrCat(" from ControlOutput(\"", orig, "\") for Op ",
426 op_name);
427 }
428
FinalizeControlOutput(StringPiece name,OpDef * op_def,std::vector<string> * errors)429 void FinalizeControlOutput(StringPiece name, OpDef* op_def,
430 std::vector<string>* errors) {
431 StringPiece orig(name);
432
433 // Parse control output name.
434 StringPiece tmp_name;
435 if (!ConsumeControlOutName(&orig, &tmp_name)) {
436 errors->push_back(strings::StrCat("Trouble parsing 'name:'",
437 ControlOutError(orig, op_def->name())));
438 }
439
440 *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size());
441 }
442
num_leading_spaces(StringPiece s)443 int num_leading_spaces(StringPiece s) {
444 size_t i = 0;
445 while (i < s.size() && s[i] == ' ') {
446 ++i;
447 }
448 return i;
449 }
450
ConsumeDocNameColon(StringPiece * sp,StringPiece * out)451 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
452 return Scanner(*sp)
453 .One(Scanner::LETTER)
454 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
455 .StopCapture()
456 .AnySpace()
457 .OneLiteral(":")
458 .AnySpace()
459 .GetResult(sp, out);
460 }
461
IsDocNameColon(StringPiece s)462 bool IsDocNameColon(StringPiece s) {
463 return ConsumeDocNameColon(&s, nullptr /* out */);
464 }
465
FinalizeDoc(const string & text,OpDef * op_def,std::vector<string> * errors)466 void FinalizeDoc(const string& text, OpDef* op_def,
467 std::vector<string>* errors) {
468 std::vector<string> lines = str_util::Split(text, '\n');
469
470 // Remove trailing spaces.
471 for (string& line : lines) {
472 absl::StripTrailingAsciiWhitespace(&line);
473 }
474
475 // First non-blank line -> summary.
476 int l = 0;
477 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
478 if (static_cast<size_t>(l) < lines.size()) {
479 op_def->set_summary(lines[l]);
480 ++l;
481 }
482 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
483
484 // Lines until we see name: -> description.
485 int start_l = l;
486 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
487 ++l;
488 }
489 int end_l = l;
490 // Trim trailing blank lines from the description.
491 while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
492 string desc = absl::StrJoin(
493 gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
494 if (!desc.empty()) op_def->set_description(desc);
495
496 // name: description
497 // possibly continued on the next line
498 // if so, we remove the minimum indent
499 StringPiece name;
500 std::vector<StringPiece> description;
501 while (static_cast<size_t>(l) < lines.size()) {
502 description.clear();
503 description.push_back(lines[l]);
504 ConsumeDocNameColon(&description.back(), &name);
505 ++l;
506 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
507 description.push_back(lines[l]);
508 ++l;
509 }
510 // Remove any trailing blank lines.
511 while (!description.empty() && description.back().empty()) {
512 description.pop_back();
513 }
514 // Compute the minimum indent of all lines after the first.
515 int min_indent = -1;
516 for (size_t i = 1; i < description.size(); ++i) {
517 if (!description[i].empty()) {
518 int indent = num_leading_spaces(description[i]);
519 if (min_indent < 0 || indent < min_indent) min_indent = indent;
520 }
521 }
522 // Remove min_indent spaces from all lines after the first.
523 for (size_t i = 1; i < description.size(); ++i) {
524 if (!description[i].empty()) description[i].remove_prefix(min_indent);
525 }
526 // Concatenate lines into a single string.
527 const string complete(absl::StrJoin(description, "\n"));
528
529 // Find name.
530 bool found = false;
531 for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
532 if (op_def->input_arg(i).name() == name) {
533 op_def->mutable_input_arg(i)->set_description(complete);
534 found = true;
535 }
536 }
537 for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
538 if (op_def->output_arg(i).name() == name) {
539 op_def->mutable_output_arg(i)->set_description(complete);
540 found = true;
541 }
542 }
543 for (int i = 0; !found && i < op_def->attr_size(); ++i) {
544 if (op_def->attr(i).name() == name) {
545 op_def->mutable_attr(i)->set_description(complete);
546 found = true;
547 }
548 }
549 if (!found) {
550 errors->push_back(
551 strings::StrCat("No matching input/output/attr for name '", name,
552 "' from Doc() for Op ", op_def->name()));
553 return;
554 }
555 }
556 }
557
558 } // namespace
559
OpDefBuilder(string op_name)560 OpDefBuilder::OpDefBuilder(string op_name) {
561 op_def()->set_name(std::move(op_name));
562 }
563
Attr(string spec)564 OpDefBuilder& OpDefBuilder::Attr(string spec) {
565 attrs_.push_back(std::move(spec));
566 return *this;
567 }
568
Input(string spec)569 OpDefBuilder& OpDefBuilder::Input(string spec) {
570 inputs_.push_back(std::move(spec));
571 return *this;
572 }
573
Output(string spec)574 OpDefBuilder& OpDefBuilder::Output(string spec) {
575 outputs_.push_back(std::move(spec));
576 return *this;
577 }
578
ControlOutput(string name)579 OpDefBuilder& OpDefBuilder::ControlOutput(string name) {
580 control_outputs_.push_back(std::move(name));
581 return *this;
582 }
583
584 #ifndef TF_LEAN_BINARY
Doc(string text)585 OpDefBuilder& OpDefBuilder::Doc(string text) {
586 if (!doc_.empty()) {
587 errors_.push_back(
588 strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
589 } else {
590 doc_ = std::move(text);
591 }
592 return *this;
593 }
594 #endif
595
SetIsCommutative()596 OpDefBuilder& OpDefBuilder::SetIsCommutative() {
597 op_def()->set_is_commutative(true);
598 return *this;
599 }
600
SetIsAggregate()601 OpDefBuilder& OpDefBuilder::SetIsAggregate() {
602 op_def()->set_is_aggregate(true);
603 return *this;
604 }
605
SetIsStateful()606 OpDefBuilder& OpDefBuilder::SetIsStateful() {
607 op_def()->set_is_stateful(true);
608 return *this;
609 }
610
SetAllowsUninitializedInput()611 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
612 op_def()->set_allows_uninitialized_input(true);
613 return *this;
614 }
615
Deprecated(int version,string explanation)616 OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
617 if (op_def()->has_deprecation()) {
618 errors_.push_back(
619 strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
620 } else {
621 OpDeprecation* deprecation = op_def()->mutable_deprecation();
622 deprecation->set_version(version);
623 deprecation->set_explanation(std::move(explanation));
624 }
625 return *this;
626 }
627
SetShapeFn(OpShapeInferenceFn fn)628 OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) {
629 if (op_reg_data_.shape_inference_fn != nullptr) {
630 errors_.push_back(
631 strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
632 } else {
633 op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn);
634 }
635 return *this;
636 }
637
AllowAttrTypeAny()638 OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() {
639 allow_attr_type_any_ = true;
640 return *this;
641 }
642
Finalize(OpRegistrationData * op_reg_data) const643 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
644 std::vector<string> errors = errors_;
645 *op_reg_data = op_reg_data_;
646
647 OpDef* op_def = &op_reg_data->op_def;
648 for (StringPiece attr : attrs_) {
649 FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors);
650 }
651 for (StringPiece input : inputs_) {
652 FinalizeInputOrOutput(input, false, op_def, &errors);
653 }
654 for (StringPiece output : outputs_) {
655 FinalizeInputOrOutput(output, true, op_def, &errors);
656 }
657 for (StringPiece control_output : control_outputs_) {
658 FinalizeControlOutput(control_output, op_def, &errors);
659 }
660 FinalizeDoc(doc_, op_def, &errors);
661
662 if (errors.empty()) return Status::OK();
663 return errors::InvalidArgument(absl::StrJoin(errors, "\n"));
664 }
665
666 } // namespace tensorflow
667