• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <map>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "re2/re2.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/java/src/gen/cc/op_specs.h"
27 
28 namespace tensorflow {
29 namespace java {
30 namespace {
31 
IsRealNumbers(const AttrValue & values)32 inline bool IsRealNumbers(const AttrValue& values) {
33   if (!values.has_list()) {
34     return RealNumberTypes().Contains(values.type());
35   }
36   for (int i = 0; i < values.list().type_size(); ++i) {
37     if (!RealNumberTypes().Contains(values.list().type(i))) {
38       return false;
39     }
40   }
41   return true;
42 }
43 
44 class TypeResolver {
45  public:
TypeResolver(const OpDef & op_def)46   explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
47 
48   // Returns the class type of an input/output argument
49   //
50   // For example, if the argument's datatype is DT_STRING, this method will
51   // return "java.lang.String", so the argument can become "Operand<String>"
52   // in the Ops API
53   Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
54 
55   // Returns types of an input attribute
56   //
57   // The first element of the pair is the class type of this attribute while
58   // the second is its JNI/primitive type equivalent, required for explicit
59   // unboxing.
60   //
61   // For example, if the attribute is of type "float", this method will return
62   // <java.lang.Float, float>, so the attribute can be used as a "Float" object
63   // in the Ops API and casted to a "float" when passing through the JNI layer.
64   std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
65                                 bool* iterable_out);
66 
67   // Returns true if the type of this attribute has already been resolved
IsAttributeVisited(const string & attr_name)68   bool IsAttributeVisited(const string& attr_name) {
69     return visited_attrs_.find(attr_name) != visited_attrs_.cend();
70   }
71 
72  private:
73   const OpDef op_def_;
74   std::map<std::string, Type> visited_attrs_;
75   char next_generic_letter_ = 'T';
76 
MakeTypePair(const Type & type,const Type & jni_type)77   std::pair<Type, Type> MakeTypePair(const Type& type, const Type& jni_type) {
78     return std::make_pair(type, jni_type);
79   }
MakeTypePair(const Type & type)80   std::pair<Type, Type> MakeTypePair(const Type& type) {
81     return std::make_pair(type, type);
82   }
NextGeneric()83   Type NextGeneric() {
84     char generic_letter = next_generic_letter_++;
85     if (next_generic_letter_ > 'Z') {
86       next_generic_letter_ = 'A';
87     }
88     return Type::Generic(string(1, generic_letter));
89   }
90 };
91 
TypeOf(const OpDef_ArgDef & arg_def,bool * iterable_out)92 Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
93   *iterable_out = false;
94   Type type = Type::Wildcard();
95   if (arg_def.type() != DataType::DT_INVALID) {
96     type = Type::ForDataType(arg_def.type());
97 
98   } else if (!arg_def.type_attr().empty()) {
99     // resolve type from attribute (if already visited, retrieve its type)
100     if (IsAttributeVisited(arg_def.type_attr())) {
101       type = visited_attrs_.at(arg_def.type_attr());
102     } else {
103       for (const auto& attr_def : op_def_.attr()) {
104         if (attr_def.name() == arg_def.type_attr()) {
105           type = TypesOf(attr_def, iterable_out).first;
106           break;
107         }
108       }
109     }
110   } else if (!arg_def.type_list_attr().empty()) {
111     // type is a list of tensors that can be of different data types, so leave
112     // it as a list of wildcards
113     *iterable_out = true;
114     visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
115 
116   } else {
117     LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
118                << "\" in operation \"" << op_def_.name() << "\"";
119   }
120   if (!arg_def.number_attr().empty()) {
121     // when number_attr is set, argument has to be a list of tensors
122     *iterable_out = true;
123     visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
124   }
125   return type;
126 }
127 
TypesOf(const OpDef_AttrDef & attr_def,bool * iterable_out)128 std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
129                                             bool* iterable_out) {
130   std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
131   *iterable_out = false;
132   StringPiece attr_type = attr_def.type();
133   if (absl::ConsumePrefix(&attr_type, "list(")) {
134     attr_type.remove_suffix(1);  // remove closing brace
135     *iterable_out = true;
136   }
137   if (attr_type == "string") {
138     types = MakeTypePair(Type::Class("String"));
139 
140   } else if (attr_type == "int") {
141     types = MakeTypePair(Type::Class("Long"), Type::Long());
142 
143   } else if (attr_type == "float") {
144     types = MakeTypePair(Type::Class("Float"), Type::Float());
145 
146   } else if (attr_type == "bool") {
147     types = MakeTypePair(Type::Class("Boolean"), Type::Boolean());
148 
149   } else if (attr_type == "shape") {
150     types = MakeTypePair(Type::Class("Shape", "org.tensorflow"));
151 
152   } else if (attr_type == "tensor") {
153     types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
154                              .add_parameter(Type::Wildcard()));
155 
156   } else if (attr_type == "type") {
157     Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
158     if (IsRealNumbers(attr_def.allowed_values())) {
159       type.add_supertype(Type::Class("Number"));
160     }
161     types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
162 
163   } else {
164     LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
165                << "\" in operation \"" << op_def_.name() << "\"";
166   }
167   visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
168   return types;
169 }
170 
SnakeToCamelCase(const string & str,bool upper=false)171 string SnakeToCamelCase(const string& str, bool upper = false) {
172   string result;
173   bool cap = upper;
174   for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
175     const char c = *it;
176     if (c == '_') {
177       cap = true;
178     } else if (cap) {
179       result += toupper(c);
180       cap = false;
181     } else {
182       result += c;
183     }
184   }
185   return result;
186 }
187 
FindAndCut(string * input,const RE2 & expr,string * before_match,string * ret_match=nullptr)188 bool FindAndCut(string* input, const RE2& expr, string* before_match,
189                 string* ret_match = nullptr) {
190   string match;
191   if (!RE2::PartialMatch(*input, expr, &match)) return false;
192   *before_match = input->substr(0, input->find(match));
193   *input = input->substr(before_match->size() + match.size());
194   if (ret_match != nullptr) *ret_match = match;
195   return true;
196 }
197 
ParseDocumentation(const string & inp)198 string ParseDocumentation(const string& inp) {
199   std::stringstream javadoc_text;
200 
201   // TODO(karllessard) This is a very minimalist utility method for converting
202   // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
203   // for alternatives to increase the level of support for markups.
204   std::vector<string> markups_subexpr;
205   markups_subexpr.push_back("\n+\\*\\s+");                // lists
206   markups_subexpr.push_back("\n{2,}");                    // paragraphs
207   markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n");  // code blocks
208   markups_subexpr.push_back("`+");           // inlined code and code blocks
209   markups_subexpr.push_back("\\*{1,2}\\b");  // text emphasis
210   markups_subexpr.push_back("\\[");          // hyperlinks
211   const RE2 markup_expr("(" + absl::StrJoin(markups_subexpr, "|") + ")");
212 
213   bool in_list = false;
214   string input = inp;
215   while (true) {
216     string text, markup;
217     if (!FindAndCut(&input, markup_expr, &text, &markup)) {
218       javadoc_text << input;
219       break;  // end of loop
220     }
221     javadoc_text << text;
222     if (absl::StartsWith(markup, "\n")) {
223       javadoc_text << "\n";
224       if (absl::StrContains(markup, "*")) {
225         // new list item
226         javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
227         in_list = true;
228       } else if (in_list) {
229         // end of list
230         javadoc_text << "</li>\n</ul>\n";
231         in_list = false;
232       } else if (!absl::StartsWith(input, "```")) {
233         // new paragraph (not required if a <pre> block follows)
234         javadoc_text << "<p>\n";
235       }
236     } else if (absl::StartsWith(markup, "```")) {
237       // code blocks
238       if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
239         javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
240       } else {
241         javadoc_text << markup;
242       }
243     } else if (absl::StartsWith("(" + markup + ")", "`")) {
244       // inlined code
245       if (FindAndCut(&input, markup, &text)) {
246         javadoc_text << "{@code " << text << "}";
247       } else {
248         javadoc_text << markup;
249       }
250     } else if (markup == "**") {
251       // text emphasis (strong)
252       if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
253         javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
254       } else {
255         javadoc_text << markup;
256       }
257     } else if (markup == "*") {
258       // text emphasis (normal)
259       if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
260         javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
261       } else {
262         javadoc_text << markup;
263       }
264     } else if (absl::StartsWith(markup, "[")) {
265       // hyperlinks
266       string label;
267       string link;
268       if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
269                             &link) &&
270           absl::StartsWith(input, label + link)) {
271         input = input.substr(label.size() + link.size());
272         javadoc_text << "<a href=\"" << link << "\">"
273                      << ParseDocumentation(label) << "</a>";
274       } else {
275         javadoc_text << markup;
276       }
277     } else {
278       // safe fallback
279       javadoc_text << markup;
280     }
281   }
282   return javadoc_text.str();
283 }
284 
CreateInput(const OpDef_ArgDef & input_def,const ApiDef::Arg & input_api_def,TypeResolver * type_resolver)285 ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
286                          const ApiDef::Arg& input_api_def,
287                          TypeResolver* type_resolver) {
288   bool iterable = false;
289   Type type = type_resolver->TypeOf(input_def, &iterable);
290   Type var_type =
291       Type::Interface("Operand", "org.tensorflow").add_parameter(type);
292   if (iterable) {
293     var_type = Type::IterableOf(var_type);
294   }
295   return ArgumentSpec(
296       input_api_def.name(),
297       Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
298       type, ParseDocumentation(input_api_def.description()), iterable);
299 }
300 
CreateAttribute(const OpDef_AttrDef & attr_def,const ApiDef::Attr & attr_api_def,TypeResolver * type_resolver)301 AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
302                               const ApiDef::Attr& attr_api_def,
303                               TypeResolver* type_resolver) {
304   bool iterable = false;
305   std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
306   Type var_type = types.first.kind() == Type::GENERIC
307                       ? Type::ClassOf(types.first)
308                       : types.first;
309   if (iterable) {
310     var_type = Type::ListOf(var_type);
311   }
312   return AttributeSpec(
313       attr_api_def.name(),
314       Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
315       types.first, types.second, ParseDocumentation(attr_api_def.description()),
316       iterable,
317       attr_def.has_default_value() ? &attr_def.default_value() : nullptr);
318 }
319 
CreateOutput(const OpDef_ArgDef & output_def,const ApiDef::Arg & output_api,TypeResolver * type_resolver)320 ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
321                           const ApiDef::Arg& output_api,
322                           TypeResolver* type_resolver) {
323   bool iterable = false;
324   Type type = type_resolver->TypeOf(output_def, &iterable);
325   Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
326   if (iterable) {
327     var_type = Type::ListOf(var_type);
328   }
329   return ArgumentSpec(
330       output_api.name(),
331       Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
332       type, ParseDocumentation(output_api.description()), iterable);
333 }
334 
CreateEndpoint(const OpDef & op_def,const ApiDef & api_def,const ApiDef_Endpoint & endpoint_def)335 EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
336                             const ApiDef_Endpoint& endpoint_def) {
337   std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
338   string package;
339   string name;
340   if (name_tokens.size() > 1) {
341     package = name_tokens.at(0);
342     name = name_tokens.at(1);
343   } else {
344     package = "core";  // generate unclassified ops in the 'core' package
345     name = name_tokens.at(0);
346   }
347   return EndpointSpec(package, name,
348                       Javadoc::Create(ParseDocumentation(api_def.summary()))
349                           .details(ParseDocumentation(api_def.description())));
350 }
351 
352 }  // namespace
353 
Create(const OpDef & op_def,const ApiDef & api_def)354 OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
355   OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
356             op_def.deprecation().explanation());
357   TypeResolver type_resolver(op_def);
358   for (const string& next_input_name : api_def.arg_order()) {
359     for (int i = 0; i < op_def.input_arg().size(); ++i) {
360       if (op_def.input_arg(i).name() == next_input_name) {
361         op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
362                                          &type_resolver));
363         break;
364       }
365     }
366   }
367   for (int i = 0; i < op_def.attr().size(); ++i) {
368     // do not parse attributes already visited, they have probably been inferred
369     // before as an input argument type
370     if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
371       AttributeSpec attr =
372           CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
373       // attributes with a default value are optional
374       if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
375         op.optional_attributes_.push_back(attr);
376       } else {
377         op.attributes_.push_back(attr);
378       }
379     }
380   }
381   for (int i = 0; i < op_def.output_arg().size(); ++i) {
382     op.outputs_.push_back(
383         CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
384   }
385   for (const auto& endpoint_def : api_def.endpoint()) {
386     op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
387   }
388   return op;
389 }
390 
391 }  // namespace java
392 }  // namespace tensorflow
393