1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_gen_lib.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/gtl/map_util.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/util/proto/proto_utils.h"
27
28 namespace tensorflow {
29
WordWrap(StringPiece prefix,StringPiece str,int width)30 string WordWrap(StringPiece prefix, StringPiece str, int width) {
31 const string indent_next_line = "\n" + Spaces(prefix.size());
32 width -= prefix.size();
33 string result;
34 strings::StrAppend(&result, prefix);
35
36 while (!str.empty()) {
37 if (static_cast<int>(str.size()) <= width) {
38 // Remaining text fits on one line.
39 strings::StrAppend(&result, str);
40 break;
41 }
42 auto space = str.rfind(' ', width);
43 if (space == StringPiece::npos) {
44 // Rather make a too-long line and break at a space.
45 space = str.find(' ');
46 if (space == StringPiece::npos) {
47 strings::StrAppend(&result, str);
48 break;
49 }
50 }
51 // Breaking at character at position <space>.
52 StringPiece to_append = str.substr(0, space);
53 str.remove_prefix(space + 1);
54 // Remove spaces at break.
55 while (str_util::EndsWith(to_append, " ")) {
56 to_append.remove_suffix(1);
57 }
58 while (str_util::ConsumePrefix(&str, " ")) {
59 }
60
61 // Go on to the next line.
62 strings::StrAppend(&result, to_append);
63 if (!str.empty()) strings::StrAppend(&result, indent_next_line);
64 }
65
66 return result;
67 }
68
ConsumeEquals(StringPiece * description)69 bool ConsumeEquals(StringPiece* description) {
70 if (str_util::ConsumePrefix(description, "=")) {
71 while (str_util::ConsumePrefix(description,
72 " ")) { // Also remove spaces after "=".
73 }
74 return true;
75 }
76 return false;
77 }
78
79 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
80 // Returns whether `split_ch` was found. Afterwards, `*before_split`
81 // contains the maximum prefix of the input `*orig` that doesn't
82 // contain `split_ch`, and `*orig` contains everything after the
83 // first `split_ch`.
SplitAt(char split_ch,StringPiece * orig,StringPiece * before_split)84 static bool SplitAt(char split_ch, StringPiece* orig,
85 StringPiece* before_split) {
86 auto pos = orig->find(split_ch);
87 if (pos == StringPiece::npos) {
88 *before_split = *orig;
89 *orig = StringPiece();
90 return false;
91 } else {
92 *before_split = orig->substr(0, pos);
93 orig->remove_prefix(pos + 1);
94 return true;
95 }
96 }
97
98 // Does this line start with "<spaces><field>:" where "<field>" is
99 // in multi_line_fields? Sets *colon_pos to the position of the colon.
StartsWithFieldName(StringPiece line,const std::vector<string> & multi_line_fields)100 static bool StartsWithFieldName(StringPiece line,
101 const std::vector<string>& multi_line_fields) {
102 StringPiece up_to_colon;
103 if (!SplitAt(':', &line, &up_to_colon)) return false;
104 while (str_util::ConsumePrefix(&up_to_colon, " "))
105 ; // Remove leading spaces.
106 for (const auto& field : multi_line_fields) {
107 if (up_to_colon == field) {
108 return true;
109 }
110 }
111 return false;
112 }
113
ConvertLine(StringPiece line,const std::vector<string> & multi_line_fields,string * ml)114 static bool ConvertLine(StringPiece line,
115 const std::vector<string>& multi_line_fields,
116 string* ml) {
117 // Is this a field we should convert?
118 if (!StartsWithFieldName(line, multi_line_fields)) {
119 return false;
120 }
121 // Has a matching field name, so look for "..." after the colon.
122 StringPiece up_to_colon;
123 StringPiece after_colon = line;
124 SplitAt(':', &after_colon, &up_to_colon);
125 while (str_util::ConsumePrefix(&after_colon, " "))
126 ; // Remove leading spaces.
127 if (!str_util::ConsumePrefix(&after_colon, "\"")) {
128 // We only convert string fields, so don't convert this line.
129 return false;
130 }
131 auto last_quote = after_colon.rfind('\"');
132 if (last_quote == StringPiece::npos) {
133 // Error: we don't see the expected matching quote, abort the conversion.
134 return false;
135 }
136 StringPiece escaped = after_colon.substr(0, last_quote);
137 StringPiece suffix = after_colon.substr(last_quote + 1);
138 // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
139
140 string unescaped;
141 if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
142 // Error unescaping, abort the conversion.
143 return false;
144 }
145 // No more errors possible at this point.
146
147 // Find a string to mark the end that isn't in unescaped.
148 string end = "END";
149 for (int s = 0; unescaped.find(end) != string::npos; ++s) {
150 end = strings::StrCat("END", s);
151 }
152
153 // Actually start writing the converted output.
154 strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
155 if (!suffix.empty()) {
156 // Output suffix, in case there was a trailing comment in the source.
157 strings::StrAppend(ml, suffix);
158 }
159 strings::StrAppend(ml, "\n");
160 return true;
161 }
162
PBTxtToMultiline(StringPiece pbtxt,const std::vector<string> & multi_line_fields)163 string PBTxtToMultiline(StringPiece pbtxt,
164 const std::vector<string>& multi_line_fields) {
165 string ml;
166 // Probably big enough, since the input and output are about the
167 // same size, but just a guess.
168 ml.reserve(pbtxt.size() * (17. / 16));
169 StringPiece line;
170 while (!pbtxt.empty()) {
171 // Split pbtxt into its first line and everything after.
172 SplitAt('\n', &pbtxt, &line);
173 // Convert line or output it unchanged
174 if (!ConvertLine(line, multi_line_fields, &ml)) {
175 strings::StrAppend(&ml, line, "\n");
176 }
177 }
178 return ml;
179 }
180
181 // Given a single line of text `line` with first : at `colon`, determine if
182 // there is an "<<END" expression after the colon and if so return true and set
183 // `*end` to everything after the "<<".
FindMultiline(StringPiece line,size_t colon,string * end)184 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
185 if (colon == StringPiece::npos) return false;
186 line.remove_prefix(colon + 1);
187 while (str_util::ConsumePrefix(&line, " ")) {
188 }
189 if (str_util::ConsumePrefix(&line, "<<")) {
190 *end = string(line);
191 return true;
192 }
193 return false;
194 }
195
PBTxtFromMultiline(StringPiece multiline_pbtxt)196 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
197 string pbtxt;
198 // Probably big enough, since the input and output are about the
199 // same size, but just a guess.
200 pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
201 StringPiece line;
202 while (!multiline_pbtxt.empty()) {
203 // Split multiline_pbtxt into its first line and everything after.
204 if (!SplitAt('\n', &multiline_pbtxt, &line)) {
205 strings::StrAppend(&pbtxt, line);
206 break;
207 }
208
209 string end;
210 auto colon = line.find(':');
211 if (!FindMultiline(line, colon, &end)) {
212 // Normal case: not a multi-line string, just output the line as-is.
213 strings::StrAppend(&pbtxt, line, "\n");
214 continue;
215 }
216
217 // Multi-line case:
218 // something: <<END
219 // xx
220 // yy
221 // END
222 // Should be converted to:
223 // something: "xx\nyy"
224
225 // Output everything up to the colon (" something:").
226 strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
227
228 // Add every line to unescaped until we see the "END" string.
229 string unescaped;
230 bool first = true;
231 while (!multiline_pbtxt.empty()) {
232 SplitAt('\n', &multiline_pbtxt, &line);
233 if (str_util::ConsumePrefix(&line, end)) break;
234 if (first) {
235 first = false;
236 } else {
237 unescaped.push_back('\n');
238 }
239 strings::StrAppend(&unescaped, line);
240 line = StringPiece();
241 }
242
243 // Escape what we extracted and then output it in quotes.
244 strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
245 "\n");
246 }
247 return pbtxt;
248 }
249
StringReplace(const string & from,const string & to,string * s)250 static void StringReplace(const string& from, const string& to, string* s) {
251 // Split *s into pieces delimited by `from`.
252 std::vector<string> split;
253 string::size_type pos = 0;
254 while (pos < s->size()) {
255 auto found = s->find(from, pos);
256 if (found == string::npos) {
257 split.push_back(s->substr(pos));
258 break;
259 } else {
260 split.push_back(s->substr(pos, found - pos));
261 pos = found + from.size();
262 if (pos == s->size()) { // handle case where `from` is at the very end.
263 split.push_back("");
264 }
265 }
266 }
267 // Join the pieces back together with a new delimiter.
268 *s = str_util::Join(split, to.c_str());
269 }
270
RenameInDocs(const string & from,const string & to,ApiDef * api_def)271 static void RenameInDocs(const string& from, const string& to,
272 ApiDef* api_def) {
273 const string from_quoted = strings::StrCat("`", from, "`");
274 const string to_quoted = strings::StrCat("`", to, "`");
275 for (int i = 0; i < api_def->in_arg_size(); ++i) {
276 if (!api_def->in_arg(i).description().empty()) {
277 StringReplace(from_quoted, to_quoted,
278 api_def->mutable_in_arg(i)->mutable_description());
279 }
280 }
281 for (int i = 0; i < api_def->out_arg_size(); ++i) {
282 if (!api_def->out_arg(i).description().empty()) {
283 StringReplace(from_quoted, to_quoted,
284 api_def->mutable_out_arg(i)->mutable_description());
285 }
286 }
287 for (int i = 0; i < api_def->attr_size(); ++i) {
288 if (!api_def->attr(i).description().empty()) {
289 StringReplace(from_quoted, to_quoted,
290 api_def->mutable_attr(i)->mutable_description());
291 }
292 }
293 if (!api_def->summary().empty()) {
294 StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
295 }
296 if (!api_def->description().empty()) {
297 StringReplace(from_quoted, to_quoted, api_def->mutable_description());
298 }
299 }
300
301 namespace {
302
303 // Initializes given ApiDef with data in OpDef.
InitApiDefFromOpDef(const OpDef & op_def,ApiDef * api_def)304 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
305 api_def->set_graph_op_name(op_def.name());
306 api_def->set_visibility(ApiDef::VISIBLE);
307
308 auto* endpoint = api_def->add_endpoint();
309 endpoint->set_name(op_def.name());
310
311 for (const auto& op_in_arg : op_def.input_arg()) {
312 auto* api_in_arg = api_def->add_in_arg();
313 api_in_arg->set_name(op_in_arg.name());
314 api_in_arg->set_rename_to(op_in_arg.name());
315 api_in_arg->set_description(op_in_arg.description());
316
317 *api_def->add_arg_order() = op_in_arg.name();
318 }
319 for (const auto& op_out_arg : op_def.output_arg()) {
320 auto* api_out_arg = api_def->add_out_arg();
321 api_out_arg->set_name(op_out_arg.name());
322 api_out_arg->set_rename_to(op_out_arg.name());
323 api_out_arg->set_description(op_out_arg.description());
324 }
325 for (const auto& op_attr : op_def.attr()) {
326 auto* api_attr = api_def->add_attr();
327 api_attr->set_name(op_attr.name());
328 api_attr->set_rename_to(op_attr.name());
329 if (op_attr.has_default_value()) {
330 *api_attr->mutable_default_value() = op_attr.default_value();
331 }
332 api_attr->set_description(op_attr.description());
333 }
334 api_def->set_summary(op_def.summary());
335 api_def->set_description(op_def.description());
336 }
337
338 // Updates base_arg based on overrides in new_arg.
MergeArg(ApiDef::Arg * base_arg,const ApiDef::Arg & new_arg)339 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
340 if (!new_arg.rename_to().empty()) {
341 base_arg->set_rename_to(new_arg.rename_to());
342 }
343 if (!new_arg.description().empty()) {
344 base_arg->set_description(new_arg.description());
345 }
346 }
347
348 // Updates base_attr based on overrides in new_attr.
MergeAttr(ApiDef::Attr * base_attr,const ApiDef::Attr & new_attr)349 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
350 if (!new_attr.rename_to().empty()) {
351 base_attr->set_rename_to(new_attr.rename_to());
352 }
353 if (new_attr.has_default_value()) {
354 *base_attr->mutable_default_value() = new_attr.default_value();
355 }
356 if (!new_attr.description().empty()) {
357 base_attr->set_description(new_attr.description());
358 }
359 }
360
361 // Updates base_api_def based on overrides in new_api_def.
MergeApiDefs(ApiDef * base_api_def,const ApiDef & new_api_def)362 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
363 // Merge visibility
364 if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
365 base_api_def->set_visibility(new_api_def.visibility());
366 }
367 // Merge endpoints
368 if (new_api_def.endpoint_size() > 0) {
369 base_api_def->clear_endpoint();
370 std::copy(
371 new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
372 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
373 }
374 // Merge args
375 for (const auto& new_arg : new_api_def.in_arg()) {
376 bool found_base_arg = false;
377 for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
378 auto* base_arg = base_api_def->mutable_in_arg(i);
379 if (base_arg->name() == new_arg.name()) {
380 MergeArg(base_arg, new_arg);
381 found_base_arg = true;
382 break;
383 }
384 }
385 if (!found_base_arg) {
386 return errors::FailedPrecondition("Argument ", new_arg.name(),
387 " not defined in base api for ",
388 base_api_def->graph_op_name());
389 }
390 }
391 for (const auto& new_arg : new_api_def.out_arg()) {
392 bool found_base_arg = false;
393 for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
394 auto* base_arg = base_api_def->mutable_out_arg(i);
395 if (base_arg->name() == new_arg.name()) {
396 MergeArg(base_arg, new_arg);
397 found_base_arg = true;
398 break;
399 }
400 }
401 if (!found_base_arg) {
402 return errors::FailedPrecondition("Argument ", new_arg.name(),
403 " not defined in base api for ",
404 base_api_def->graph_op_name());
405 }
406 }
407 // Merge arg order
408 if (new_api_def.arg_order_size() > 0) {
409 // Validate that new arg_order is correct.
410 if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
411 return errors::FailedPrecondition(
412 "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
413 base_api_def->graph_op_name(),
414 ". Expected: ", base_api_def->arg_order_size());
415 }
416 if (!std::is_permutation(new_api_def.arg_order().begin(),
417 new_api_def.arg_order().end(),
418 base_api_def->arg_order().begin())) {
419 return errors::FailedPrecondition(
420 "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
421 " for ", base_api_def->graph_op_name(),
422 ". All elements in arg_order override must match base arg_order: ",
423 str_util::Join(base_api_def->arg_order(), ", "));
424 }
425
426 base_api_def->clear_arg_order();
427 std::copy(
428 new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
429 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
430 }
431 // Merge attributes
432 for (const auto& new_attr : new_api_def.attr()) {
433 bool found_base_attr = false;
434 for (int i = 0; i < base_api_def->attr_size(); ++i) {
435 auto* base_attr = base_api_def->mutable_attr(i);
436 if (base_attr->name() == new_attr.name()) {
437 MergeAttr(base_attr, new_attr);
438 found_base_attr = true;
439 break;
440 }
441 }
442 if (!found_base_attr) {
443 return errors::FailedPrecondition("Attribute ", new_attr.name(),
444 " not defined in base api for ",
445 base_api_def->graph_op_name());
446 }
447 }
448 // Merge summary
449 if (!new_api_def.summary().empty()) {
450 base_api_def->set_summary(new_api_def.summary());
451 }
452 // Merge description
453 auto description = new_api_def.description().empty()
454 ? base_api_def->description()
455 : new_api_def.description();
456
457 if (!new_api_def.description_prefix().empty()) {
458 description =
459 strings::StrCat(new_api_def.description_prefix(), "\n", description);
460 }
461 if (!new_api_def.description_suffix().empty()) {
462 description =
463 strings::StrCat(description, "\n", new_api_def.description_suffix());
464 }
465 base_api_def->set_description(description);
466 return Status::OK();
467 }
468 } // namespace
469
ApiDefMap(const OpList & op_list)470 ApiDefMap::ApiDefMap(const OpList& op_list) {
471 for (const auto& op : op_list.op()) {
472 ApiDef api_def;
473 InitApiDefFromOpDef(op, &api_def);
474 map_[op.name()] = api_def;
475 }
476 }
477
~ApiDefMap()478 ApiDefMap::~ApiDefMap() {}
479
LoadFileList(Env * env,const std::vector<string> & filenames)480 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
481 for (const auto& filename : filenames) {
482 TF_RETURN_IF_ERROR(LoadFile(env, filename));
483 }
484 return Status::OK();
485 }
486
LoadFile(Env * env,const string & filename)487 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
488 if (filename.empty()) return Status::OK();
489 string contents;
490 TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
491 Status status = LoadApiDef(contents);
492 if (!status.ok()) {
493 // Return failed status annotated with filename to aid in debugging.
494 return Status(status.code(),
495 strings::StrCat("Error parsing ApiDef file ", filename, ": ",
496 status.error_message()));
497 }
498 return Status::OK();
499 }
500
LoadApiDef(const string & api_def_file_contents)501 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
502 const string contents = PBTxtFromMultiline(api_def_file_contents);
503 ApiDefs api_defs;
504 TF_RETURN_IF_ERROR(
505 proto_utils::ParseTextFormatFromString(contents, &api_defs));
506 for (const auto& api_def : api_defs.op()) {
507 // Check if the op definition is loaded. If op definition is not
508 // loaded, then we just skip this ApiDef.
509 if (map_.find(api_def.graph_op_name()) != map_.end()) {
510 // Overwrite current api def with data in api_def.
511 TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
512 }
513 }
514 return Status::OK();
515 }
516
UpdateDocs()517 void ApiDefMap::UpdateDocs() {
518 for (auto& name_and_api_def : map_) {
519 auto& api_def = name_and_api_def.second;
520 CHECK_GT(api_def.endpoint_size(), 0);
521 const string canonical_name = api_def.endpoint(0).name();
522 if (api_def.graph_op_name() != canonical_name) {
523 RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
524 }
525 for (const auto& in_arg : api_def.in_arg()) {
526 if (in_arg.name() != in_arg.rename_to()) {
527 RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
528 }
529 }
530 for (const auto& out_arg : api_def.out_arg()) {
531 if (out_arg.name() != out_arg.rename_to()) {
532 RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
533 }
534 }
535 for (const auto& attr : api_def.attr()) {
536 if (attr.name() != attr.rename_to()) {
537 RenameInDocs(attr.name(), attr.rename_to(), &api_def);
538 }
539 }
540 }
541 }
542
GetApiDef(const string & name) const543 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
544 return gtl::FindOrNull(map_, name);
545 }
546 } // namespace tensorflow
547