1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_gen_lib.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/protobuf.h"
28 #include "tensorflow/core/util/proto/proto_utils.h"
29
30 namespace tensorflow {
31
WordWrap(StringPiece prefix,StringPiece str,int width)32 string WordWrap(StringPiece prefix, StringPiece str, int width) {
33 const string indent_next_line = "\n" + Spaces(prefix.size());
34 width -= prefix.size();
35 string result;
36 strings::StrAppend(&result, prefix);
37
38 while (!str.empty()) {
39 if (static_cast<int>(str.size()) <= width) {
40 // Remaining text fits on one line.
41 strings::StrAppend(&result, str);
42 break;
43 }
44 auto space = str.rfind(' ', width);
45 if (space == StringPiece::npos) {
46 // Rather make a too-long line and break at a space.
47 space = str.find(' ');
48 if (space == StringPiece::npos) {
49 strings::StrAppend(&result, str);
50 break;
51 }
52 }
53 // Breaking at character at position <space>.
54 StringPiece to_append = str.substr(0, space);
55 str.remove_prefix(space + 1);
56 // Remove spaces at break.
57 while (str_util::EndsWith(to_append, " ")) {
58 to_append.remove_suffix(1);
59 }
60 while (absl::ConsumePrefix(&str, " ")) {
61 }
62
63 // Go on to the next line.
64 strings::StrAppend(&result, to_append);
65 if (!str.empty()) strings::StrAppend(&result, indent_next_line);
66 }
67
68 return result;
69 }
70
ConsumeEquals(StringPiece * description)71 bool ConsumeEquals(StringPiece* description) {
72 if (absl::ConsumePrefix(description, "=")) {
73 while (absl::ConsumePrefix(description,
74 " ")) { // Also remove spaces after "=".
75 }
76 return true;
77 }
78 return false;
79 }
80
81 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
82 // Returns whether `split_ch` was found. Afterwards, `*before_split`
83 // contains the maximum prefix of the input `*orig` that doesn't
84 // contain `split_ch`, and `*orig` contains everything after the
85 // first `split_ch`.
SplitAt(char split_ch,StringPiece * orig,StringPiece * before_split)86 static bool SplitAt(char split_ch, StringPiece* orig,
87 StringPiece* before_split) {
88 auto pos = orig->find(split_ch);
89 if (pos == StringPiece::npos) {
90 *before_split = *orig;
91 *orig = StringPiece();
92 return false;
93 } else {
94 *before_split = orig->substr(0, pos);
95 orig->remove_prefix(pos + 1);
96 return true;
97 }
98 }
99
100 // Does this line start with "<spaces><field>:" where "<field>" is
101 // in multi_line_fields? Sets *colon_pos to the position of the colon.
StartsWithFieldName(StringPiece line,const std::vector<string> & multi_line_fields)102 static bool StartsWithFieldName(StringPiece line,
103 const std::vector<string>& multi_line_fields) {
104 StringPiece up_to_colon;
105 if (!SplitAt(':', &line, &up_to_colon)) return false;
106 while (absl::ConsumePrefix(&up_to_colon, " "))
107 ; // Remove leading spaces.
108 for (const auto& field : multi_line_fields) {
109 if (up_to_colon == field) {
110 return true;
111 }
112 }
113 return false;
114 }
115
ConvertLine(StringPiece line,const std::vector<string> & multi_line_fields,string * ml)116 static bool ConvertLine(StringPiece line,
117 const std::vector<string>& multi_line_fields,
118 string* ml) {
119 // Is this a field we should convert?
120 if (!StartsWithFieldName(line, multi_line_fields)) {
121 return false;
122 }
123 // Has a matching field name, so look for "..." after the colon.
124 StringPiece up_to_colon;
125 StringPiece after_colon = line;
126 SplitAt(':', &after_colon, &up_to_colon);
127 while (absl::ConsumePrefix(&after_colon, " "))
128 ; // Remove leading spaces.
129 if (!absl::ConsumePrefix(&after_colon, "\"")) {
130 // We only convert string fields, so don't convert this line.
131 return false;
132 }
133 auto last_quote = after_colon.rfind('\"');
134 if (last_quote == StringPiece::npos) {
135 // Error: we don't see the expected matching quote, abort the conversion.
136 return false;
137 }
138 StringPiece escaped = after_colon.substr(0, last_quote);
139 StringPiece suffix = after_colon.substr(last_quote + 1);
140 // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
141
142 string unescaped;
143 if (!absl::CUnescape(escaped, &unescaped, nullptr)) {
144 // Error unescaping, abort the conversion.
145 return false;
146 }
147 // No more errors possible at this point.
148
149 // Find a string to mark the end that isn't in unescaped.
150 string end = "END";
151 for (int s = 0; unescaped.find(end) != string::npos; ++s) {
152 end = strings::StrCat("END", s);
153 }
154
155 // Actually start writing the converted output.
156 strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
157 if (!suffix.empty()) {
158 // Output suffix, in case there was a trailing comment in the source.
159 strings::StrAppend(ml, suffix);
160 }
161 strings::StrAppend(ml, "\n");
162 return true;
163 }
164
PBTxtToMultiline(StringPiece pbtxt,const std::vector<string> & multi_line_fields)165 string PBTxtToMultiline(StringPiece pbtxt,
166 const std::vector<string>& multi_line_fields) {
167 string ml;
168 // Probably big enough, since the input and output are about the
169 // same size, but just a guess.
170 ml.reserve(pbtxt.size() * (17. / 16));
171 StringPiece line;
172 while (!pbtxt.empty()) {
173 // Split pbtxt into its first line and everything after.
174 SplitAt('\n', &pbtxt, &line);
175 // Convert line or output it unchanged
176 if (!ConvertLine(line, multi_line_fields, &ml)) {
177 strings::StrAppend(&ml, line, "\n");
178 }
179 }
180 return ml;
181 }
182
183 // Given a single line of text `line` with first : at `colon`, determine if
184 // there is an "<<END" expression after the colon and if so return true and set
185 // `*end` to everything after the "<<".
FindMultiline(StringPiece line,size_t colon,string * end)186 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
187 if (colon == StringPiece::npos) return false;
188 line.remove_prefix(colon + 1);
189 while (absl::ConsumePrefix(&line, " ")) {
190 }
191 if (absl::ConsumePrefix(&line, "<<")) {
192 *end = string(line);
193 return true;
194 }
195 return false;
196 }
197
PBTxtFromMultiline(StringPiece multiline_pbtxt)198 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
199 string pbtxt;
200 // Probably big enough, since the input and output are about the
201 // same size, but just a guess.
202 pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
203 StringPiece line;
204 while (!multiline_pbtxt.empty()) {
205 // Split multiline_pbtxt into its first line and everything after.
206 if (!SplitAt('\n', &multiline_pbtxt, &line)) {
207 strings::StrAppend(&pbtxt, line);
208 break;
209 }
210
211 string end;
212 auto colon = line.find(':');
213 if (!FindMultiline(line, colon, &end)) {
214 // Normal case: not a multi-line string, just output the line as-is.
215 strings::StrAppend(&pbtxt, line, "\n");
216 continue;
217 }
218
219 // Multi-line case:
220 // something: <<END
221 // xx
222 // yy
223 // END
224 // Should be converted to:
225 // something: "xx\nyy"
226
227 // Output everything up to the colon (" something:").
228 strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
229
230 // Add every line to unescaped until we see the "END" string.
231 string unescaped;
232 bool first = true;
233 while (!multiline_pbtxt.empty()) {
234 SplitAt('\n', &multiline_pbtxt, &line);
235 if (absl::ConsumePrefix(&line, end)) break;
236 if (first) {
237 first = false;
238 } else {
239 unescaped.push_back('\n');
240 }
241 strings::StrAppend(&unescaped, line);
242 line = StringPiece();
243 }
244
245 // Escape what we extracted and then output it in quotes.
246 strings::StrAppend(&pbtxt, " \"", absl::CEscape(unescaped), "\"", line,
247 "\n");
248 }
249 return pbtxt;
250 }
251
StringReplace(const string & from,const string & to,string * s)252 static void StringReplace(const string& from, const string& to, string* s) {
253 // Split *s into pieces delimited by `from`.
254 std::vector<string> split;
255 string::size_type pos = 0;
256 while (pos < s->size()) {
257 auto found = s->find(from, pos);
258 if (found == string::npos) {
259 split.push_back(s->substr(pos));
260 break;
261 } else {
262 split.push_back(s->substr(pos, found - pos));
263 pos = found + from.size();
264 if (pos == s->size()) { // handle case where `from` is at the very end.
265 split.push_back("");
266 }
267 }
268 }
269 // Join the pieces back together with a new delimiter.
270 *s = absl::StrJoin(split, to);
271 }
272
RenameInDocs(const string & from,const string & to,ApiDef * api_def)273 static void RenameInDocs(const string& from, const string& to,
274 ApiDef* api_def) {
275 const string from_quoted = strings::StrCat("`", from, "`");
276 const string to_quoted = strings::StrCat("`", to, "`");
277 for (int i = 0; i < api_def->in_arg_size(); ++i) {
278 if (!api_def->in_arg(i).description().empty()) {
279 StringReplace(from_quoted, to_quoted,
280 api_def->mutable_in_arg(i)->mutable_description());
281 }
282 }
283 for (int i = 0; i < api_def->out_arg_size(); ++i) {
284 if (!api_def->out_arg(i).description().empty()) {
285 StringReplace(from_quoted, to_quoted,
286 api_def->mutable_out_arg(i)->mutable_description());
287 }
288 }
289 for (int i = 0; i < api_def->attr_size(); ++i) {
290 if (!api_def->attr(i).description().empty()) {
291 StringReplace(from_quoted, to_quoted,
292 api_def->mutable_attr(i)->mutable_description());
293 }
294 }
295 if (!api_def->summary().empty()) {
296 StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
297 }
298 if (!api_def->description().empty()) {
299 StringReplace(from_quoted, to_quoted, api_def->mutable_description());
300 }
301 }
302
303 namespace {
304
305 // Initializes given ApiDef with data in OpDef.
InitApiDefFromOpDef(const OpDef & op_def,ApiDef * api_def)306 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
307 api_def->set_graph_op_name(op_def.name());
308 api_def->set_visibility(ApiDef::VISIBLE);
309
310 auto* endpoint = api_def->add_endpoint();
311 endpoint->set_name(op_def.name());
312
313 for (const auto& op_in_arg : op_def.input_arg()) {
314 auto* api_in_arg = api_def->add_in_arg();
315 api_in_arg->set_name(op_in_arg.name());
316 api_in_arg->set_rename_to(op_in_arg.name());
317 api_in_arg->set_description(op_in_arg.description());
318
319 *api_def->add_arg_order() = op_in_arg.name();
320 }
321 for (const auto& op_out_arg : op_def.output_arg()) {
322 auto* api_out_arg = api_def->add_out_arg();
323 api_out_arg->set_name(op_out_arg.name());
324 api_out_arg->set_rename_to(op_out_arg.name());
325 api_out_arg->set_description(op_out_arg.description());
326 }
327 for (const auto& op_attr : op_def.attr()) {
328 auto* api_attr = api_def->add_attr();
329 api_attr->set_name(op_attr.name());
330 api_attr->set_rename_to(op_attr.name());
331 if (op_attr.has_default_value()) {
332 *api_attr->mutable_default_value() = op_attr.default_value();
333 }
334 api_attr->set_description(op_attr.description());
335 }
336 api_def->set_summary(op_def.summary());
337 api_def->set_description(op_def.description());
338 }
339
340 // Updates base_arg based on overrides in new_arg.
MergeArg(ApiDef::Arg * base_arg,const ApiDef::Arg & new_arg)341 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
342 if (!new_arg.rename_to().empty()) {
343 base_arg->set_rename_to(new_arg.rename_to());
344 }
345 if (!new_arg.description().empty()) {
346 base_arg->set_description(new_arg.description());
347 }
348 }
349
350 // Updates base_attr based on overrides in new_attr.
MergeAttr(ApiDef::Attr * base_attr,const ApiDef::Attr & new_attr)351 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
352 if (!new_attr.rename_to().empty()) {
353 base_attr->set_rename_to(new_attr.rename_to());
354 }
355 if (new_attr.has_default_value()) {
356 *base_attr->mutable_default_value() = new_attr.default_value();
357 }
358 if (!new_attr.description().empty()) {
359 base_attr->set_description(new_attr.description());
360 }
361 }
362
363 // Updates base_api_def based on overrides in new_api_def.
MergeApiDefs(ApiDef * base_api_def,const ApiDef & new_api_def)364 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
365 // Merge visibility
366 if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
367 base_api_def->set_visibility(new_api_def.visibility());
368 }
369 // Merge endpoints
370 if (new_api_def.endpoint_size() > 0) {
371 base_api_def->clear_endpoint();
372 std::copy(
373 new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
374 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
375 }
376 // Merge args
377 for (const auto& new_arg : new_api_def.in_arg()) {
378 bool found_base_arg = false;
379 for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
380 auto* base_arg = base_api_def->mutable_in_arg(i);
381 if (base_arg->name() == new_arg.name()) {
382 MergeArg(base_arg, new_arg);
383 found_base_arg = true;
384 break;
385 }
386 }
387 if (!found_base_arg) {
388 return errors::FailedPrecondition("Argument ", new_arg.name(),
389 " not defined in base api for ",
390 base_api_def->graph_op_name());
391 }
392 }
393 for (const auto& new_arg : new_api_def.out_arg()) {
394 bool found_base_arg = false;
395 for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
396 auto* base_arg = base_api_def->mutable_out_arg(i);
397 if (base_arg->name() == new_arg.name()) {
398 MergeArg(base_arg, new_arg);
399 found_base_arg = true;
400 break;
401 }
402 }
403 if (!found_base_arg) {
404 return errors::FailedPrecondition("Argument ", new_arg.name(),
405 " not defined in base api for ",
406 base_api_def->graph_op_name());
407 }
408 }
409 // Merge arg order
410 if (new_api_def.arg_order_size() > 0) {
411 // Validate that new arg_order is correct.
412 if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
413 return errors::FailedPrecondition(
414 "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
415 base_api_def->graph_op_name(),
416 ". Expected: ", base_api_def->arg_order_size());
417 }
418 if (!std::is_permutation(new_api_def.arg_order().begin(),
419 new_api_def.arg_order().end(),
420 base_api_def->arg_order().begin())) {
421 return errors::FailedPrecondition(
422 "Invalid arg_order: ", absl::StrJoin(new_api_def.arg_order(), ", "),
423 " for ", base_api_def->graph_op_name(),
424 ". All elements in arg_order override must match base arg_order: ",
425 absl::StrJoin(base_api_def->arg_order(), ", "));
426 }
427
428 base_api_def->clear_arg_order();
429 std::copy(
430 new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
431 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
432 }
433 // Merge attributes
434 for (const auto& new_attr : new_api_def.attr()) {
435 bool found_base_attr = false;
436 for (int i = 0; i < base_api_def->attr_size(); ++i) {
437 auto* base_attr = base_api_def->mutable_attr(i);
438 if (base_attr->name() == new_attr.name()) {
439 MergeAttr(base_attr, new_attr);
440 found_base_attr = true;
441 break;
442 }
443 }
444 if (!found_base_attr) {
445 return errors::FailedPrecondition("Attribute ", new_attr.name(),
446 " not defined in base api for ",
447 base_api_def->graph_op_name());
448 }
449 }
450 // Merge summary
451 if (!new_api_def.summary().empty()) {
452 base_api_def->set_summary(new_api_def.summary());
453 }
454 // Merge description
455 auto description = new_api_def.description().empty()
456 ? base_api_def->description()
457 : new_api_def.description();
458
459 if (!new_api_def.description_prefix().empty()) {
460 description =
461 strings::StrCat(new_api_def.description_prefix(), "\n", description);
462 }
463 if (!new_api_def.description_suffix().empty()) {
464 description =
465 strings::StrCat(description, "\n", new_api_def.description_suffix());
466 }
467 base_api_def->set_description(description);
468 return Status::OK();
469 }
470 } // namespace
471
ApiDefMap(const OpList & op_list)472 ApiDefMap::ApiDefMap(const OpList& op_list) {
473 for (const auto& op : op_list.op()) {
474 ApiDef api_def;
475 InitApiDefFromOpDef(op, &api_def);
476 map_[op.name()] = api_def;
477 }
478 }
479
~ApiDefMap()480 ApiDefMap::~ApiDefMap() {}
481
LoadFileList(Env * env,const std::vector<string> & filenames)482 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
483 for (const auto& filename : filenames) {
484 TF_RETURN_IF_ERROR(LoadFile(env, filename));
485 }
486 return Status::OK();
487 }
488
LoadFile(Env * env,const string & filename)489 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
490 if (filename.empty()) return Status::OK();
491 string contents;
492 TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
493 Status status = LoadApiDef(contents);
494 if (!status.ok()) {
495 // Return failed status annotated with filename to aid in debugging.
496 return Status(status.code(),
497 strings::StrCat("Error parsing ApiDef file ", filename, ": ",
498 status.error_message()));
499 }
500 return Status::OK();
501 }
502
LoadApiDef(const string & api_def_file_contents)503 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
504 const string contents = PBTxtFromMultiline(api_def_file_contents);
505 ApiDefs api_defs;
506 TF_RETURN_IF_ERROR(
507 proto_utils::ParseTextFormatFromString(contents, &api_defs));
508 for (const auto& api_def : api_defs.op()) {
509 // Check if the op definition is loaded. If op definition is not
510 // loaded, then we just skip this ApiDef.
511 if (map_.find(api_def.graph_op_name()) != map_.end()) {
512 // Overwrite current api def with data in api_def.
513 TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
514 }
515 }
516 return Status::OK();
517 }
518
UpdateDocs()519 void ApiDefMap::UpdateDocs() {
520 for (auto& name_and_api_def : map_) {
521 auto& api_def = name_and_api_def.second;
522 CHECK_GT(api_def.endpoint_size(), 0);
523 const string canonical_name = api_def.endpoint(0).name();
524 if (api_def.graph_op_name() != canonical_name) {
525 RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
526 }
527 for (const auto& in_arg : api_def.in_arg()) {
528 if (in_arg.name() != in_arg.rename_to()) {
529 RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
530 }
531 }
532 for (const auto& out_arg : api_def.out_arg()) {
533 if (out_arg.name() != out_arg.rename_to()) {
534 RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
535 }
536 }
537 for (const auto& attr : api_def.attr()) {
538 if (attr.name() != attr.rename_to()) {
539 RenameInDocs(attr.name(), attr.rename_to(), &api_def);
540 }
541 }
542 }
543 }
544
GetApiDef(const string & name) const545 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
546 return gtl::FindOrNull(map_, name);
547 }
548 } // namespace tensorflow
549