1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <map>
17 #include <string>
18 #include <utility>
19 #include <vector>
20
21 #include "re2/re2.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/java/src/gen/cc/op_specs.h"
27
28 namespace tensorflow {
29 namespace java {
30 namespace {
31
IsRealNumbers(const AttrValue & values)32 inline bool IsRealNumbers(const AttrValue& values) {
33 if (!values.has_list()) {
34 return RealNumberTypes().Contains(values.type());
35 }
36 for (int i = 0; i < values.list().type_size(); ++i) {
37 if (!RealNumberTypes().Contains(values.list().type(i))) {
38 return false;
39 }
40 }
41 return true;
42 }
43
44 class TypeResolver {
45 public:
TypeResolver(const OpDef & op_def)46 explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
47
48 // Returns the class type of an input/output argument
49 //
50 // For example, if the argument's datatype is DT_STRING, this method will
51 // return "java.lang.String", so the argument can become "Operand<String>"
52 // in the Ops API
53 Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
54
55 // Returns types of an input attribute
56 //
57 // The first element of the pair is the class type of this attribute while
58 // the second is its JNI/primitive type equivalent, required for explicit
59 // unboxing.
60 //
61 // For example, if the attribute is of type "float", this method will return
62 // <java.lang.Float, float>, so the attribute can be used as a "Float" object
63 // in the Ops API and casted to a "float" when passing through the JNI layer.
64 std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
65 bool* iterable_out);
66
67 // Returns true if the type of this attribute has already been resolved
IsAttributeVisited(const string & attr_name)68 bool IsAttributeVisited(const string& attr_name) {
69 return visited_attrs_.find(attr_name) != visited_attrs_.cend();
70 }
71
72 private:
73 const OpDef op_def_;
74 std::map<std::string, Type> visited_attrs_;
75 char next_generic_letter_ = 'T';
76
MakeTypePair(const Type & type,const Type & jni_type)77 std::pair<Type, Type> MakeTypePair(const Type& type, const Type& jni_type) {
78 return std::make_pair(type, jni_type);
79 }
MakeTypePair(const Type & type)80 std::pair<Type, Type> MakeTypePair(const Type& type) {
81 return std::make_pair(type, type);
82 }
NextGeneric()83 Type NextGeneric() {
84 char generic_letter = next_generic_letter_++;
85 if (next_generic_letter_ > 'Z') {
86 next_generic_letter_ = 'A';
87 }
88 return Type::Generic(string(1, generic_letter));
89 }
90 };
91
TypeOf(const OpDef_ArgDef & arg_def,bool * iterable_out)92 Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
93 *iterable_out = false;
94 Type type = Type::Wildcard();
95 if (arg_def.type() != DataType::DT_INVALID) {
96 type = Type::ForDataType(arg_def.type());
97
98 } else if (!arg_def.type_attr().empty()) {
99 // resolve type from attribute (if already visited, retrieve its type)
100 if (IsAttributeVisited(arg_def.type_attr())) {
101 type = visited_attrs_.at(arg_def.type_attr());
102 } else {
103 for (const auto& attr_def : op_def_.attr()) {
104 if (attr_def.name() == arg_def.type_attr()) {
105 type = TypesOf(attr_def, iterable_out).first;
106 break;
107 }
108 }
109 }
110 } else if (!arg_def.type_list_attr().empty()) {
111 // type is a list of tensors that can be of different data types, so leave
112 // it as a list of wildcards
113 *iterable_out = true;
114 visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
115
116 } else {
117 LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
118 << "\" in operation \"" << op_def_.name() << "\"";
119 }
120 if (!arg_def.number_attr().empty()) {
121 // when number_attr is set, argument has to be a list of tensors
122 *iterable_out = true;
123 visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
124 }
125 return type;
126 }
127
TypesOf(const OpDef_AttrDef & attr_def,bool * iterable_out)128 std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
129 bool* iterable_out) {
130 std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
131 *iterable_out = false;
132 StringPiece attr_type = attr_def.type();
133 if (absl::ConsumePrefix(&attr_type, "list(")) {
134 attr_type.remove_suffix(1); // remove closing brace
135 *iterable_out = true;
136 }
137 if (attr_type == "string") {
138 types = MakeTypePair(Type::Class("String"));
139
140 } else if (attr_type == "int") {
141 types = MakeTypePair(Type::Class("Long"), Type::Long());
142
143 } else if (attr_type == "float") {
144 types = MakeTypePair(Type::Class("Float"), Type::Float());
145
146 } else if (attr_type == "bool") {
147 types = MakeTypePair(Type::Class("Boolean"), Type::Boolean());
148
149 } else if (attr_type == "shape") {
150 types = MakeTypePair(Type::Class("Shape", "org.tensorflow"));
151
152 } else if (attr_type == "tensor") {
153 types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
154 .add_parameter(Type::Wildcard()));
155
156 } else if (attr_type == "type") {
157 Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
158 if (IsRealNumbers(attr_def.allowed_values())) {
159 type.add_supertype(Type::Class("Number"));
160 }
161 types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
162
163 } else {
164 LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
165 << "\" in operation \"" << op_def_.name() << "\"";
166 }
167 visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
168 return types;
169 }
170
SnakeToCamelCase(const string & str,bool upper=false)171 string SnakeToCamelCase(const string& str, bool upper = false) {
172 string result;
173 bool cap = upper;
174 for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
175 const char c = *it;
176 if (c == '_') {
177 cap = true;
178 } else if (cap) {
179 result += toupper(c);
180 cap = false;
181 } else {
182 result += c;
183 }
184 }
185 return result;
186 }
187
FindAndCut(string * input,const RE2 & expr,string * before_match,string * ret_match=nullptr)188 bool FindAndCut(string* input, const RE2& expr, string* before_match,
189 string* ret_match = nullptr) {
190 string match;
191 if (!RE2::PartialMatch(*input, expr, &match)) return false;
192 *before_match = input->substr(0, input->find(match));
193 *input = input->substr(before_match->size() + match.size());
194 if (ret_match != nullptr) *ret_match = match;
195 return true;
196 }
197
ParseDocumentation(const string & inp)198 string ParseDocumentation(const string& inp) {
199 std::stringstream javadoc_text;
200
201 // TODO(karllessard) This is a very minimalist utility method for converting
202 // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
203 // for alternatives to increase the level of support for markups.
204 std::vector<string> markups_subexpr;
205 markups_subexpr.push_back("\n+\\*\\s+"); // lists
206 markups_subexpr.push_back("\n{2,}"); // paragraphs
207 markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks
208 markups_subexpr.push_back("`+"); // inlined code and code blocks
209 markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis
210 markups_subexpr.push_back("\\["); // hyperlinks
211 const RE2 markup_expr("(" + absl::StrJoin(markups_subexpr, "|") + ")");
212
213 bool in_list = false;
214 string input = inp;
215 while (true) {
216 string text, markup;
217 if (!FindAndCut(&input, markup_expr, &text, &markup)) {
218 javadoc_text << input;
219 break; // end of loop
220 }
221 javadoc_text << text;
222 if (absl::StartsWith(markup, "\n")) {
223 javadoc_text << "\n";
224 if (absl::StrContains(markup, "*")) {
225 // new list item
226 javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
227 in_list = true;
228 } else if (in_list) {
229 // end of list
230 javadoc_text << "</li>\n</ul>\n";
231 in_list = false;
232 } else if (!absl::StartsWith(input, "```")) {
233 // new paragraph (not required if a <pre> block follows)
234 javadoc_text << "<p>\n";
235 }
236 } else if (absl::StartsWith(markup, "```")) {
237 // code blocks
238 if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
239 javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
240 } else {
241 javadoc_text << markup;
242 }
243 } else if (absl::StartsWith("(" + markup + ")", "`")) {
244 // inlined code
245 if (FindAndCut(&input, markup, &text)) {
246 javadoc_text << "{@code " << text << "}";
247 } else {
248 javadoc_text << markup;
249 }
250 } else if (markup == "**") {
251 // text emphasis (strong)
252 if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
253 javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
254 } else {
255 javadoc_text << markup;
256 }
257 } else if (markup == "*") {
258 // text emphasis (normal)
259 if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
260 javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
261 } else {
262 javadoc_text << markup;
263 }
264 } else if (absl::StartsWith(markup, "[")) {
265 // hyperlinks
266 string label;
267 string link;
268 if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
269 &link) &&
270 absl::StartsWith(input, label + link)) {
271 input = input.substr(label.size() + link.size());
272 javadoc_text << "<a href=\"" << link << "\">"
273 << ParseDocumentation(label) << "</a>";
274 } else {
275 javadoc_text << markup;
276 }
277 } else {
278 // safe fallback
279 javadoc_text << markup;
280 }
281 }
282 return javadoc_text.str();
283 }
284
CreateInput(const OpDef_ArgDef & input_def,const ApiDef::Arg & input_api_def,TypeResolver * type_resolver)285 ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
286 const ApiDef::Arg& input_api_def,
287 TypeResolver* type_resolver) {
288 bool iterable = false;
289 Type type = type_resolver->TypeOf(input_def, &iterable);
290 Type var_type =
291 Type::Interface("Operand", "org.tensorflow").add_parameter(type);
292 if (iterable) {
293 var_type = Type::IterableOf(var_type);
294 }
295 return ArgumentSpec(
296 input_api_def.name(),
297 Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
298 type, ParseDocumentation(input_api_def.description()), iterable);
299 }
300
CreateAttribute(const OpDef_AttrDef & attr_def,const ApiDef::Attr & attr_api_def,TypeResolver * type_resolver)301 AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
302 const ApiDef::Attr& attr_api_def,
303 TypeResolver* type_resolver) {
304 bool iterable = false;
305 std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
306 Type var_type = types.first.kind() == Type::GENERIC
307 ? Type::ClassOf(types.first)
308 : types.first;
309 if (iterable) {
310 var_type = Type::ListOf(var_type);
311 }
312 return AttributeSpec(
313 attr_api_def.name(),
314 Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
315 types.first, types.second, ParseDocumentation(attr_api_def.description()),
316 iterable,
317 attr_def.has_default_value() ? &attr_def.default_value() : nullptr);
318 }
319
CreateOutput(const OpDef_ArgDef & output_def,const ApiDef::Arg & output_api,TypeResolver * type_resolver)320 ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
321 const ApiDef::Arg& output_api,
322 TypeResolver* type_resolver) {
323 bool iterable = false;
324 Type type = type_resolver->TypeOf(output_def, &iterable);
325 Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
326 if (iterable) {
327 var_type = Type::ListOf(var_type);
328 }
329 return ArgumentSpec(
330 output_api.name(),
331 Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
332 type, ParseDocumentation(output_api.description()), iterable);
333 }
334
CreateEndpoint(const OpDef & op_def,const ApiDef & api_def,const ApiDef_Endpoint & endpoint_def)335 EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
336 const ApiDef_Endpoint& endpoint_def) {
337 std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
338 string package;
339 string name;
340 if (name_tokens.size() > 1) {
341 package = name_tokens.at(0);
342 name = name_tokens.at(1);
343 } else {
344 package = "core"; // generate unclassified ops in the 'core' package
345 name = name_tokens.at(0);
346 }
347 return EndpointSpec(package, name,
348 Javadoc::Create(ParseDocumentation(api_def.summary()))
349 .details(ParseDocumentation(api_def.description())));
350 }
351
352 } // namespace
353
Create(const OpDef & op_def,const ApiDef & api_def)354 OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
355 OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
356 op_def.deprecation().explanation());
357 TypeResolver type_resolver(op_def);
358 for (const string& next_input_name : api_def.arg_order()) {
359 for (int i = 0; i < op_def.input_arg().size(); ++i) {
360 if (op_def.input_arg(i).name() == next_input_name) {
361 op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
362 &type_resolver));
363 break;
364 }
365 }
366 }
367 for (int i = 0; i < op_def.attr().size(); ++i) {
368 // do not parse attributes already visited, they have probably been inferred
369 // before as an input argument type
370 if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
371 AttributeSpec attr =
372 CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
373 // attributes with a default value are optional
374 if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
375 op.optional_attributes_.push_back(attr);
376 } else {
377 op.attributes_.push_back(attr);
378 }
379 }
380 }
381 for (int i = 0; i < op_def.output_arg().size(); ++i) {
382 op.outputs_.push_back(
383 CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
384 }
385 for (const auto& endpoint_def : api_def.endpoint()) {
386 op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
387 }
388 return op;
389 }
390
391 } // namespace java
392 } // namespace tensorflow
393