1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_builder.h"
17
18 #include <limits>
19 #include <vector>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/lib/strings/scanner.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29
30 using ::tensorflow::strings::Scanner;
31
32 namespace tensorflow {
33
34 namespace {
35
AttrError(StringPiece orig,const string & op_name)36 string AttrError(StringPiece orig, const string& op_name) {
37 return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
38 }
39
ConsumeAttrName(StringPiece * sp,StringPiece * out)40 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) {
41 return Scanner(*sp)
42 .One(Scanner::LETTER)
43 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
44 .StopCapture()
45 .AnySpace()
46 .OneLiteral(":")
47 .AnySpace()
48 .GetResult(sp, out);
49 }
50
ConsumeListPrefix(StringPiece * sp)51 bool ConsumeListPrefix(StringPiece* sp) {
52 return Scanner(*sp)
53 .OneLiteral("list")
54 .AnySpace()
55 .OneLiteral("(")
56 .AnySpace()
57 .GetResult(sp);
58 }
59
ConsumeQuotedString(char quote_ch,StringPiece * sp,StringPiece * out)60 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) {
61 const string quote_str(1, quote_ch);
62 return Scanner(*sp)
63 .OneLiteral(quote_str.c_str())
64 .RestartCapture()
65 .ScanEscapedUntil(quote_ch)
66 .StopCapture()
67 .OneLiteral(quote_str.c_str())
68 .AnySpace()
69 .GetResult(sp, out);
70 }
71
ConsumeAttrType(StringPiece * sp,StringPiece * out)72 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) {
73 return Scanner(*sp)
74 .Many(Scanner::LOWERLETTER_DIGIT)
75 .StopCapture()
76 .AnySpace()
77 .GetResult(sp, out);
78 }
79
ConsumeAttrNumber(StringPiece * sp,int64 * out)80 bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
81 Scanner scan(*sp);
82 StringPiece match;
83 StringPiece remaining;
84
85 scan.AnySpace().RestartCapture();
86 if (scan.Peek() == '-') {
87 scan.OneLiteral("-");
88 }
89 if (!scan.Many(Scanner::DIGIT)
90 .StopCapture()
91 .AnySpace()
92 .GetResult(&remaining, &match)) {
93 return false;
94 }
95 int64 value = 0;
96 if (!strings::safe_strto64(match, &value)) {
97 return false;
98 }
99 *out = value;
100 *sp = remaining;
101 return true;
102 }
103
104 #define VERIFY(expr, ...) \
105 do { \
106 if (!(expr)) { \
107 errors->push_back( \
108 strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
109 return; \
110 } \
111 } while (false)
112
ConsumeCompoundAttrType(StringPiece * sp,StringPiece * out)113 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
114 auto capture_begin = sp->begin();
115 if (str_util::ConsumePrefix(sp, "numbertype") ||
116 str_util::ConsumePrefix(sp, "numerictype") ||
117 str_util::ConsumePrefix(sp, "quantizedtype") ||
118 str_util::ConsumePrefix(sp, "realnumbertype") ||
119 str_util::ConsumePrefix(sp, "realnumberictype")) {
120 *out = StringPiece(capture_begin, sp->begin() - capture_begin);
121 return true;
122 }
123 return false;
124 }
125
ProcessCompoundType(const StringPiece type_string,AttrValue * allowed)126 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
127 if (type_string == "numbertype" || type_string == "numerictype") {
128 for (DataType dt : NumberTypes()) {
129 allowed->mutable_list()->add_type(dt);
130 }
131 } else if (type_string == "quantizedtype") {
132 for (DataType dt : QuantizedTypes()) {
133 allowed->mutable_list()->add_type(dt);
134 }
135 } else if (type_string == "realnumbertype" ||
136 type_string == "realnumerictype") {
137 for (DataType dt : RealNumberTypes()) {
138 allowed->mutable_list()->add_type(dt);
139 }
140 } else {
141 return false;
142 }
143 return true;
144 }
145
FinalizeAttr(StringPiece spec,OpDef * op_def,std::vector<string> * errors)146 void FinalizeAttr(StringPiece spec, OpDef* op_def,
147 std::vector<string>* errors) {
148 OpDef::AttrDef* attr = op_def->add_attr();
149 StringPiece orig(spec);
150
151 // Parse "<name>:" at the beginning.
152 StringPiece tmp_name;
153 VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'");
154 attr->set_name(tmp_name.data(), tmp_name.size());
155
156 // Read "<type>" or "list(<type>)".
157 bool is_list = ConsumeListPrefix(&spec);
158 string type;
159 StringPiece type_string; // Used if type == "type"
160 if (str_util::ConsumePrefix(&spec, "string")) {
161 type = "string";
162 } else if (str_util::ConsumePrefix(&spec, "int")) {
163 type = "int";
164 } else if (str_util::ConsumePrefix(&spec, "float")) {
165 type = "float";
166 } else if (str_util::ConsumePrefix(&spec, "bool")) {
167 type = "bool";
168 } else if (str_util::ConsumePrefix(&spec, "type")) {
169 type = "type";
170 } else if (str_util::ConsumePrefix(&spec, "shape")) {
171 type = "shape";
172 } else if (str_util::ConsumePrefix(&spec, "tensor")) {
173 type = "tensor";
174 } else if (str_util::ConsumePrefix(&spec, "func")) {
175 type = "func";
176 } else if (ConsumeCompoundAttrType(&spec, &type_string)) {
177 type = "type";
178 AttrValue* allowed = attr->mutable_allowed_values();
179 VERIFY(ProcessCompoundType(type_string, allowed),
180 "Expected to see a compound type, saw: ", type_string);
181 } else if (str_util::ConsumePrefix(&spec, "{")) {
182 // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
183 AttrValue* allowed = attr->mutable_allowed_values();
184 str_util::RemoveLeadingWhitespace(&spec);
185 if (str_util::StartsWith(spec, "\"") || str_util::StartsWith(spec, "'")) {
186 type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
187 while (true) {
188 StringPiece escaped_string;
189 VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) ||
190 ConsumeQuotedString('\'', &spec, &escaped_string),
191 "Trouble parsing allowed string at '", spec, "'");
192 string unescaped;
193 string error;
194 VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
195 "Trouble unescaping \"", escaped_string,
196 "\", got error: ", error);
197 allowed->mutable_list()->add_s(unescaped);
198 if (str_util::ConsumePrefix(&spec, ",")) {
199 str_util::RemoveLeadingWhitespace(&spec);
200 if (str_util::ConsumePrefix(&spec, "}"))
201 break; // Allow ending with ", }".
202 } else {
203 VERIFY(str_util::ConsumePrefix(&spec, "}"),
204 "Expected , or } after strings in list, not: '", spec, "'");
205 break;
206 }
207 }
208 } else { // "{ bool, numbertype, string }"
209 type = "type";
210 while (true) {
211 VERIFY(ConsumeAttrType(&spec, &type_string),
212 "Trouble parsing type string at '", spec, "'");
213 if (ProcessCompoundType(type_string, allowed)) {
214 // Processed a compound type.
215 } else {
216 DataType dt;
217 VERIFY(DataTypeFromString(type_string, &dt),
218 "Unrecognized type string '", type_string, "'");
219 allowed->mutable_list()->add_type(dt);
220 }
221 if (str_util::ConsumePrefix(&spec, ",")) {
222 str_util::RemoveLeadingWhitespace(&spec);
223 if (str_util::ConsumePrefix(&spec, "}"))
224 break; // Allow ending with ", }".
225 } else {
226 VERIFY(str_util::ConsumePrefix(&spec, "}"),
227 "Expected , or } after types in list, not: '", spec, "'");
228 break;
229 }
230 }
231 }
232 } else { // if spec.Consume("{")
233 VERIFY(false, "Trouble parsing type string at '", spec, "'");
234 }
235 str_util::RemoveLeadingWhitespace(&spec);
236
237 // Write the type into *attr.
238 if (is_list) {
239 VERIFY(str_util::ConsumePrefix(&spec, ")"),
240 "Expected ) to close 'list(', not: '", spec, "'");
241 str_util::RemoveLeadingWhitespace(&spec);
242 attr->set_type(strings::StrCat("list(", type, ")"));
243 } else {
244 attr->set_type(type);
245 }
246
247 // Read optional minimum constraint at the end.
248 if ((is_list || type == "int") && str_util::ConsumePrefix(&spec, ">=")) {
249 int64 min_limit = -999;
250 VERIFY(ConsumeAttrNumber(&spec, &min_limit),
251 "Could not parse integer lower limit after '>=', found '", spec,
252 "' instead");
253 attr->set_has_minimum(true);
254 attr->set_minimum(min_limit);
255 }
256
257 // Parse default value, if present.
258 if (str_util::ConsumePrefix(&spec, "=")) {
259 str_util::RemoveLeadingWhitespace(&spec);
260 VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
261 "Could not parse default value '", spec, "'");
262 } else {
263 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
264 }
265 }
266
267 #undef VERIFY
268
InOutError(bool is_output,StringPiece orig,const string & op_name)269 string InOutError(bool is_output, StringPiece orig, const string& op_name) {
270 return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
271 "\") for Op ", op_name);
272 }
273
ConsumeInOutName(StringPiece * sp,StringPiece * out)274 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) {
275 return Scanner(*sp)
276 .One(Scanner::LOWERLETTER)
277 .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE)
278 .StopCapture()
279 .AnySpace()
280 .OneLiteral(":")
281 .AnySpace()
282 .GetResult(sp, out);
283 }
284
ConsumeInOutRefOpen(StringPiece * sp)285 bool ConsumeInOutRefOpen(StringPiece* sp) {
286 return Scanner(*sp)
287 .OneLiteral("Ref")
288 .AnySpace()
289 .OneLiteral("(")
290 .AnySpace()
291 .GetResult(sp);
292 }
293
ConsumeInOutRefClose(StringPiece * sp)294 bool ConsumeInOutRefClose(StringPiece* sp) {
295 return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp);
296 }
297
ConsumeInOutNameOrType(StringPiece * sp,StringPiece * out)298 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) {
299 return Scanner(*sp)
300 .One(Scanner::LETTER)
301 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
302 .StopCapture()
303 .AnySpace()
304 .GetResult(sp, out);
305 }
306
ConsumeInOutTimesType(StringPiece * sp,StringPiece * out)307 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) {
308 return Scanner(*sp)
309 .OneLiteral("*")
310 .AnySpace()
311 .RestartCapture()
312 .One(Scanner::LETTER)
313 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
314 .StopCapture()
315 .AnySpace()
316 .GetResult(sp, out);
317 }
318
ConsumeControlOutName(StringPiece * sp,StringPiece * out)319 bool ConsumeControlOutName(StringPiece* sp, StringPiece* out) {
320 return Scanner(*sp)
321 .One(Scanner::LETTER)
322 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
323 .StopCapture()
324 .GetResult(sp, out);
325 }
326
327 #define VERIFY(expr, ...) \
328 do { \
329 if (!(expr)) { \
330 errors->push_back(strings::StrCat( \
331 __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
332 return; \
333 } \
334 } while (false)
335
FinalizeInputOrOutput(StringPiece spec,bool is_output,OpDef * op_def,std::vector<string> * errors)336 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
337 std::vector<string>* errors) {
338 OpDef::ArgDef* arg =
339 is_output ? op_def->add_output_arg() : op_def->add_input_arg();
340
341 StringPiece orig(spec);
342
343 // Parse "<name>:" at the beginning.
344 StringPiece tmp_name;
345 VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'");
346 arg->set_name(tmp_name.data(), tmp_name.size());
347
348 // Detect "Ref(...)".
349 if (ConsumeInOutRefOpen(&spec)) {
350 arg->set_is_ref(true);
351 }
352
353 { // Parse "<name|type>" or "<name>*<name|type>".
354 StringPiece first, second, type_or_attr;
355 VERIFY(ConsumeInOutNameOrType(&spec, &first),
356 "Trouble parsing either a type or an attr name at '", spec, "'");
357 if (ConsumeInOutTimesType(&spec, &second)) {
358 arg->set_number_attr(first.data(), first.size());
359 type_or_attr = second;
360 } else {
361 type_or_attr = first;
362 }
363 DataType dt;
364 if (DataTypeFromString(type_or_attr, &dt)) {
365 arg->set_type(dt);
366 } else {
367 const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
368 VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
369 if (attr->type() == "type") {
370 arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
371 } else {
372 VERIFY(attr->type() == "list(type)", "Reference to attr '",
373 type_or_attr, "' with type ", attr->type(),
374 " that isn't type or list(type)");
375 arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
376 }
377 }
378 }
379
380 // Closing ) for Ref(.
381 if (arg->is_ref()) {
382 VERIFY(ConsumeInOutRefClose(&spec),
383 "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
384 }
385
386 // Should not have anything else.
387 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
388
389 // Int attrs that are the length of an input or output get a default
390 // minimum of 1.
391 if (!arg->number_attr().empty()) {
392 OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
393 if (attr != nullptr && !attr->has_minimum()) {
394 attr->set_has_minimum(true);
395 attr->set_minimum(1);
396 }
397 } else if (!arg->type_list_attr().empty()) {
398 // If an input or output has type specified by a list(type) attr,
399 // it gets a default minimum of 1 as well.
400 OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
401 if (attr != nullptr && attr->type() == "list(type)" &&
402 !attr->has_minimum()) {
403 attr->set_has_minimum(true);
404 attr->set_minimum(1);
405 }
406 }
407
408 // If the arg's dtype is resource we should mark the op as stateful as it
409 // likely touches a resource manager. This deliberately doesn't cover inputs /
410 // outputs which resolve to resource via Attrs as those mostly operate on
411 // resource handles as an opaque type (as opposed to ops which explicitly take
412 // / produce resources).
413 if (arg->type() == DT_RESOURCE) {
414 op_def->set_is_stateful(true);
415 }
416 }
417
418 #undef VERIFY
419
ControlOutError(StringPiece orig,const string & op_name)420 string ControlOutError(StringPiece orig, const string& op_name) {
421 return strings::StrCat(" from ControlOutput(\"", orig, "\") for Op ",
422 op_name);
423 }
424
FinalizeControlOutput(StringPiece name,OpDef * op_def,std::vector<string> * errors)425 void FinalizeControlOutput(StringPiece name, OpDef* op_def,
426 std::vector<string>* errors) {
427 StringPiece orig(name);
428
429 // Parse control output name.
430 StringPiece tmp_name;
431 if (!ConsumeControlOutName(&orig, &tmp_name)) {
432 errors->push_back(strings::StrCat("Trouble parsing 'name:'",
433 ControlOutError(orig, op_def->name())));
434 }
435
436 *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size());
437 }
438
num_leading_spaces(StringPiece s)439 int num_leading_spaces(StringPiece s) {
440 size_t i = 0;
441 while (i < s.size() && s[i] == ' ') {
442 ++i;
443 }
444 return i;
445 }
446
ConsumeDocNameColon(StringPiece * sp,StringPiece * out)447 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) {
448 return Scanner(*sp)
449 .One(Scanner::LETTER)
450 .Any(Scanner::LETTER_DIGIT_UNDERSCORE)
451 .StopCapture()
452 .AnySpace()
453 .OneLiteral(":")
454 .AnySpace()
455 .GetResult(sp, out);
456 }
457
IsDocNameColon(StringPiece s)458 bool IsDocNameColon(StringPiece s) {
459 return ConsumeDocNameColon(&s, nullptr /* out */);
460 }
461
FinalizeDoc(const string & text,OpDef * op_def,std::vector<string> * errors)462 void FinalizeDoc(const string& text, OpDef* op_def,
463 std::vector<string>* errors) {
464 std::vector<string> lines = str_util::Split(text, '\n');
465
466 // Remove trailing spaces.
467 for (string& line : lines) {
468 str_util::StripTrailingWhitespace(&line);
469 }
470
471 // First non-blank line -> summary.
472 int l = 0;
473 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
474 if (static_cast<size_t>(l) < lines.size()) {
475 op_def->set_summary(lines[l]);
476 ++l;
477 }
478 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
479
480 // Lines until we see name: -> description.
481 int start_l = l;
482 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
483 ++l;
484 }
485 int end_l = l;
486 // Trim trailing blank lines from the description.
487 while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
488 string desc = str_util::Join(
489 gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
490 if (!desc.empty()) op_def->set_description(desc);
491
492 // name: description
493 // possibly continued on the next line
494 // if so, we remove the minimum indent
495 StringPiece name;
496 std::vector<StringPiece> description;
497 while (static_cast<size_t>(l) < lines.size()) {
498 description.clear();
499 description.push_back(lines[l]);
500 ConsumeDocNameColon(&description.back(), &name);
501 ++l;
502 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) {
503 description.push_back(lines[l]);
504 ++l;
505 }
506 // Remove any trailing blank lines.
507 while (!description.empty() && description.back().empty()) {
508 description.pop_back();
509 }
510 // Compute the minimum indent of all lines after the first.
511 int min_indent = -1;
512 for (size_t i = 1; i < description.size(); ++i) {
513 if (!description[i].empty()) {
514 int indent = num_leading_spaces(description[i]);
515 if (min_indent < 0 || indent < min_indent) min_indent = indent;
516 }
517 }
518 // Remove min_indent spaces from all lines after the first.
519 for (size_t i = 1; i < description.size(); ++i) {
520 if (!description[i].empty()) description[i].remove_prefix(min_indent);
521 }
522 // Concatenate lines into a single string.
523 const string complete(str_util::Join(description, "\n"));
524
525 // Find name.
526 bool found = false;
527 for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
528 if (op_def->input_arg(i).name() == name) {
529 op_def->mutable_input_arg(i)->set_description(complete);
530 found = true;
531 }
532 }
533 for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
534 if (op_def->output_arg(i).name() == name) {
535 op_def->mutable_output_arg(i)->set_description(complete);
536 found = true;
537 }
538 }
539 for (int i = 0; !found && i < op_def->attr_size(); ++i) {
540 if (op_def->attr(i).name() == name) {
541 op_def->mutable_attr(i)->set_description(complete);
542 found = true;
543 }
544 }
545 if (!found) {
546 errors->push_back(
547 strings::StrCat("No matching input/output/attr for name '", name,
548 "' from Doc() for Op ", op_def->name()));
549 return;
550 }
551 }
552 }
553
554 } // namespace
555
OpDefBuilder(string op_name)556 OpDefBuilder::OpDefBuilder(string op_name) {
557 op_def()->set_name(std::move(op_name));
558 }
559
Attr(string spec)560 OpDefBuilder& OpDefBuilder::Attr(string spec) {
561 attrs_.push_back(std::move(spec));
562 return *this;
563 }
564
Input(string spec)565 OpDefBuilder& OpDefBuilder::Input(string spec) {
566 inputs_.push_back(std::move(spec));
567 return *this;
568 }
569
Output(string spec)570 OpDefBuilder& OpDefBuilder::Output(string spec) {
571 outputs_.push_back(std::move(spec));
572 return *this;
573 }
574
ControlOutput(string name)575 OpDefBuilder& OpDefBuilder::ControlOutput(string name) {
576 control_outputs_.push_back(std::move(name));
577 return *this;
578 }
579
580 #ifndef TF_LEAN_BINARY
Doc(string text)581 OpDefBuilder& OpDefBuilder::Doc(string text) {
582 if (!doc_.empty()) {
583 errors_.push_back(
584 strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
585 } else {
586 doc_ = std::move(text);
587 }
588 return *this;
589 }
590 #endif
591
SetIsCommutative()592 OpDefBuilder& OpDefBuilder::SetIsCommutative() {
593 op_def()->set_is_commutative(true);
594 return *this;
595 }
596
SetIsAggregate()597 OpDefBuilder& OpDefBuilder::SetIsAggregate() {
598 op_def()->set_is_aggregate(true);
599 return *this;
600 }
601
SetIsStateful()602 OpDefBuilder& OpDefBuilder::SetIsStateful() {
603 op_def()->set_is_stateful(true);
604 return *this;
605 }
606
SetAllowsUninitializedInput()607 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
608 op_def()->set_allows_uninitialized_input(true);
609 return *this;
610 }
611
Deprecated(int version,string explanation)612 OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
613 if (op_def()->has_deprecation()) {
614 errors_.push_back(
615 strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
616 } else {
617 OpDeprecation* deprecation = op_def()->mutable_deprecation();
618 deprecation->set_version(version);
619 deprecation->set_explanation(std::move(explanation));
620 }
621 return *this;
622 }
623
SetShapeFn(Status (* fn)(shape_inference::InferenceContext *))624 OpDefBuilder& OpDefBuilder::SetShapeFn(
625 Status (*fn)(shape_inference::InferenceContext*)) {
626 if (op_reg_data_.shape_inference_fn != nullptr) {
627 errors_.push_back(
628 strings::StrCat("SetShapeFn called twice for Op ", op_def()->name()));
629 } else {
630 op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn);
631 }
632 return *this;
633 }
634
Finalize(OpRegistrationData * op_reg_data) const635 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
636 std::vector<string> errors = errors_;
637 *op_reg_data = op_reg_data_;
638
639 OpDef* op_def = &op_reg_data->op_def;
640 for (StringPiece attr : attrs_) {
641 FinalizeAttr(attr, op_def, &errors);
642 }
643 for (StringPiece input : inputs_) {
644 FinalizeInputOrOutput(input, false, op_def, &errors);
645 }
646 for (StringPiece output : outputs_) {
647 FinalizeInputOrOutput(output, true, op_def, &errors);
648 }
649 for (StringPiece control_output : control_outputs_) {
650 FinalizeControlOutput(control_output, op_def, &errors);
651 }
652 FinalizeDoc(doc_, op_def, &errors);
653
654 if (errors.empty()) return Status::OK();
655 return errors::InvalidArgument(str_util::Join(errors, "\n"));
656 }
657
658 } // namespace tensorflow
659