• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <list>
17 #include <map>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_gen_lib.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/java/src/gen/cc/java_defs.h"
31 #include "tensorflow/java/src/gen/cc/op_generator.h"
32 #include "tensorflow/java/src/gen/cc/op_specs.h"
33 #include "tensorflow/java/src/gen/cc/source_writer.h"
34 
35 namespace tensorflow {
36 namespace java {
37 namespace {
38 
39 constexpr const char kLicense[] =
40     "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
41     "\n"
42     "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
43     "you may not use this file except in compliance with the License.\n"
44     "You may obtain a copy of the License at\n"
45     "\n"
46     "    http://www.apache.org/licenses/LICENSE-2.0\n"
47     "\n"
48     "Unless required by applicable law or agreed to in writing, software\n"
49     "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
50     "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
51     "See the License for the specific language governing permissions and\n"
52     "limitations under the License.\n"
53     "=======================================================================*/"
54     "\n";
55 
56 // There is three different modes to render an op class, depending on the
57 // number and type of outputs it has:
58 //
59 // DEFAULT: This mode does not provide any specialization for the op class, it
60 //          is applied when the operation does not comply with any other mode
61 //
62 // OPERAND: The op class implements the Operand<T> interface, allowing an
63 //          instance to be passed directly in input to another operation
64 //
65 // LIST_OPERAND: The op class implements the Iterable<Operand<T>> interface,
66 //          allowing an instance to be passed directly as a list input to
67 //          another operation
68 //
69 enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND };
70 
AddArgument(const Variable & var,const string & description,Method * method_out,Javadoc * javadoc_out)71 void AddArgument(const Variable& var, const string& description,
72                  Method* method_out, Javadoc* javadoc_out) {
73   method_out->add_argument(var);
74   javadoc_out->add_param_tag(var.name(), description);
75 }
76 
CollectOpDependencies(const OpSpec & op,RenderMode mode,std::list<Type> * out)77 void CollectOpDependencies(const OpSpec& op, RenderMode mode,
78                            std::list<Type>* out) {
79   out->push_back(Type::Class("Operation", "org.tensorflow"));
80   out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
81   out->push_back(Type::Class("Scope", "org.tensorflow.op"));
82   if (mode == OPERAND) {
83     out->push_back(Type::Class("Output", "org.tensorflow"));
84   } else if (mode == LIST_OPERAND) {
85     out->push_back(Type::Interface("Iterator", "java.util"));
86   }
87   // Don't pay attention to duplicate types in the dependency list, they will
88   // be filtered out by the SourceWriter.
89   for (const ArgumentSpec& input : op.inputs()) {
90     out->push_back(input.var().type());
91     if (input.iterable()) {
92       out->push_back(Type::Class("Operands", "org.tensorflow.op"));
93     }
94   }
95   for (const ArgumentSpec& output : op.outputs()) {
96     out->push_back(output.var().type());
97     if (output.iterable()) {
98       out->push_back(Type::Class("Arrays", "java.util"));
99     }
100   }
101   for (const AttributeSpec& attribute : op.attributes()) {
102     out->push_back(attribute.var().type());
103     out->push_back(attribute.jni_type());
104     if (attribute.has_default_value() &&
105         attribute.type().kind() == Type::GENERIC) {
106       out->push_back(Type::ForDataType(attribute.default_value()->type()));
107     }
108   }
109   for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
110     out->push_back(optional_attribute.var().type());
111   }
112 }
113 
WriteSetAttrDirective(const AttributeSpec & attr,bool optional,SourceWriter * writer)114 void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
115                            SourceWriter* writer) {
116   string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
117   if (attr.iterable()) {
118     string array_name = attr.var().name() + "Array";
119     writer->AppendType(attr.jni_type())
120         .Append("[] " + array_name + " = new ")
121         .AppendType(attr.jni_type())
122         .Append("[" + var_name + ".size()];")
123         .EndLine()
124         .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
125         .Append(array_name + "[i] = ");
126     if (attr.type().kind() == Type::GENERIC) {
127       writer->Append("DataType.fromClass(" + var_name + ".get(i));");
128     } else {
129       writer->Append(var_name + ".get(i);");
130     }
131     writer->EndLine()
132         .EndBlock()
133         .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
134         .Append(array_name + ");")
135         .EndLine();
136   } else {
137     writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
138     if (attr.var().type().name() == "Class") {
139       writer->Append("DataType.fromClass(" + var_name + "));");
140     } else {
141       writer->Append(var_name + ");");
142     }
143     writer->EndLine();
144   }
145 }
146 
RenderSecondaryFactoryMethod(const OpSpec & op,const Type & op_class,std::map<string,Type> default_types,SourceWriter * writer)147 void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
148                                   std::map<string, Type> default_types,
149                                   SourceWriter* writer) {
150   // Build the return type for the secondary factory, replacing generic
151   // parameters with their default value if any
152   Type return_type = Type::Class(op_class.name(), op_class.package());
153   for (const Type& parameter : op_class.parameters()) {
154     if (parameter.kind() == Type::GENERIC &&
155         default_types.find(parameter.name()) != default_types.end()) {
156       return_type.add_parameter(default_types.at(parameter.name()));
157     } else {
158       return_type.add_parameter(parameter);
159     }
160   }
161   Method factory = Method::Create("create", return_type);
162   Javadoc factory_doc = Javadoc::Create(
163       "Factory method to create a class wrapping a new " + op_class.name() +
164       " operation using default output types.");
165   Variable scope =
166       Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
167   AddArgument(scope, "current scope", &factory, &factory_doc);
168   std::stringstream factory_statement;
169   factory_statement << "return create(scope";
170   for (const ArgumentSpec& input : op.inputs()) {
171     AddArgument(input.var(), input.description(), &factory, &factory_doc);
172     factory_statement << ", " << input.var().name();
173   }
174   for (const AttributeSpec& attr : op.attributes()) {
175     // Only add attributes that are not types or have no default value to the
176     // signature of the secondary factory
177     factory_statement << ", ";
178     if (attr.type().kind() == Type::GENERIC &&
179         default_types.find(attr.type().name()) != default_types.end()) {
180       factory_statement << default_types.at(attr.type().name()).name()
181                         << ".class";
182     } else {
183       AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
184       factory_statement << attr.var().name();
185     }
186   }
187   if (!op.optional_attributes().empty()) {
188     Variable options_var = Variable::Varargs("options", Type::Class("Options"));
189     AddArgument(options_var, "carries optional attributes values", &factory,
190                 &factory_doc);
191     factory_statement << ", " << options_var.name();
192   }
193   factory_doc.add_tag("return", "a new instance of " + op_class.name());
194 
195   writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
196   writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
197   writer->EndMethod();
198 }
199 
RenderFactoryMethods(const OpSpec & op,const Type & op_class,SourceWriter * writer)200 void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
201                           SourceWriter* writer) {
202   Method factory = Method::Create("create", op_class);
203   Javadoc factory_doc =
204       Javadoc::Create("Factory method to create a class wrapping a new " +
205                       op_class.name() + " operation.");
206   Variable scope =
207       Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
208   AddArgument(scope, "current scope", &factory, &factory_doc);
209   for (const ArgumentSpec& input : op.inputs()) {
210     AddArgument(input.var(), input.description(), &factory, &factory_doc);
211   }
212   std::map<string, Type> default_types;
213   for (const AttributeSpec& attr : op.attributes()) {
214     AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
215     // If this attribute is a type with a default value, save its value
216     // for passing it implicitly in a secondary factory method
217     if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
218       Type default_type = Type::ForDataType(attr.default_value()->type());
219       if (!default_type.wildcard()) {
220         default_types.insert(std::make_pair(attr.type().name(), default_type));
221       }
222     }
223   }
224   if (!op.optional_attributes().empty()) {
225     AddArgument(Variable::Varargs("options", Type::Class("Options")),
226                 "carries optional attributes values", &factory, &factory_doc);
227   }
228   factory_doc.add_tag("return", "a new instance of " + op_class.name());
229 
230   writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
231   writer->Append("OperationBuilder opBuilder = scope.env().opBuilder(\"" +
232                  op.graph_op_name() + "\", scope.makeOpName(\"" +
233                  op_class.name() + "\"));");
234   writer->EndLine();
235   for (const ArgumentSpec& input : op.inputs()) {
236     if (input.iterable()) {
237       writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
238                      input.var().name() + "));");
239       writer->EndLine();
240     } else {
241       writer->Append("opBuilder.addInput(" + input.var().name() +
242                      ".asOutput());");
243       writer->EndLine();
244     }
245   }
246   // Add control dependencies, if any.
247   writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);");
248   writer->EndLine();
249 
250   for (const AttributeSpec& attribute : op.attributes()) {
251     WriteSetAttrDirective(attribute, false, writer);
252   }
253   if (!op.optional_attributes().empty()) {
254     writer->BeginBlock("if (options != null)")
255         .BeginBlock("for (Options opts : options)");
256     for (const AttributeSpec& attribute : op.optional_attributes()) {
257       writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
258       WriteSetAttrDirective(attribute, true, writer);
259       writer->EndBlock();
260     }
261     writer->EndBlock().EndBlock();
262   }
263   writer->Append("return new ")
264       .AppendType(op_class)
265       .Append("(opBuilder.build());")
266       .EndLine();
267   writer->EndMethod();
268 
269   // If this operation has type attributes with a default value, create a
270   // second factory method that infers those values implicitly
271   if (!default_types.empty()) {
272     RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
273   }
274 }
275 
RenderConstructor(const OpSpec & op,const Type & op_class,SourceWriter * writer)276 void RenderConstructor(const OpSpec& op, const Type& op_class,
277                        SourceWriter* writer) {
278   Variable operation =
279       Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
280   Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
281   for (const ArgumentSpec& output : op.outputs()) {
282     if (output.iterable() && !output.type().wildcard()) {
283       constructor.add_annotation(
284           Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
285       break;
286     }
287   }
288   writer->BeginMethod(constructor, PRIVATE)
289       .Append("super(operation);")
290       .EndLine();
291   if (!op.outputs().empty()) {
292     writer->Append("int outputIdx = 0;").EndLine();
293     for (const ArgumentSpec& output : op.outputs()) {
294       if (output.iterable()) {
295         string var_length = output.var().name() + "Length";
296         writer->Append("int " + var_length)
297             .Append(" = operation.outputListLength(\"" + output.op_def_name() +
298                     "\");")
299             .EndLine()
300             .Append(output.var().name() + " = Arrays.asList(");
301         if (!output.type().wildcard()) {
302           writer->Append("(")
303               .AppendType(output.var().type().parameters().front())
304               .Append("[])");
305         }
306         writer->Append("operation.outputList(outputIdx, " + var_length + "));")
307             .EndLine()
308             .Append("outputIdx += " + var_length + ";")
309             .EndLine();
310       } else {
311         writer
312             ->Append(output.var().name() + " = operation.output(outputIdx++);")
313             .EndLine();
314       }
315     }
316   }
317   writer->EndMethod();
318 }
319 
RenderGettersAndSetters(const OpSpec & op,SourceWriter * writer)320 void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
321   for (const AttributeSpec& attr : op.optional_attributes()) {
322     Method setter = Method::Create(attr.var().name(), Type::Class("Options"));
323     Javadoc setter_doc = Javadoc::Create();
324     AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
325     writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc)
326         .Append("return new Options()." + attr.var().name() + "(" +
327                 attr.var().name() + ");")
328         .EndLine()
329         .EndMethod();
330   }
331   for (const ArgumentSpec& output : op.outputs()) {
332     Method getter = Method::Create(output.var().name(), output.var().type());
333     Javadoc getter_doc = Javadoc::Create(output.description());
334     writer->BeginMethod(getter, PUBLIC, &getter_doc)
335         .Append("return " + output.var().name() + ";")
336         .EndLine()
337         .EndMethod();
338   }
339 }
340 
RenderInterfaceImpl(const OpSpec & op,RenderMode mode,SourceWriter * writer)341 void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
342                          SourceWriter* writer) {
343   ArgumentSpec output = op.outputs().front();
344 
345   if (mode == OPERAND) {
346     bool cast2obj = output.type().wildcard();
347     Type return_type =
348         Type::Class("Output", "org.tensorflow")
349             .add_parameter(cast2obj ? Type::Class("Object") : output.type());
350     Method as_output = Method::Create("asOutput", return_type)
351                            .add_annotation(Annotation::Create("Override"));
352     if (cast2obj) {
353       as_output.add_annotation(
354           Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
355     }
356     writer->BeginMethod(as_output, PUBLIC);
357     if (cast2obj) {
358       writer->Append("return (").AppendType(return_type).Append(") ");
359     } else {
360       writer->Append("return ");
361     }
362     writer->Append(output.var().name() + ";").EndLine().EndMethod();
363 
364   } else if (mode == LIST_OPERAND) {
365     Type operand = Type::Interface("Operand", "org.tensorflow");
366     if (output.type().wildcard()) {
367       operand.add_parameter(Type::Class("Object"));
368     } else {
369       operand.add_parameter(output.type());
370     }
371     Type return_type =
372         Type::Interface("Iterator", "java.util").add_parameter(operand);
373     Method iterator =
374         Method::Create("iterator", return_type)
375             .add_annotation(Annotation::Create("Override"))
376             .add_annotation(Annotation::Create("SuppressWarnings")
377                                 .attributes("{\"rawtypes\", \"unchecked\"}"));
378     // cast the output list using a raw List
379     writer->BeginMethod(iterator, PUBLIC)
380         .Append("return (" + return_type.name() + ") ")
381         .Append(output.var().name() + ".iterator();")
382         .EndLine()
383         .EndMethod();
384   }
385 }
386 
RenderOptionsClass(const OpSpec & op,const Type & op_class,SourceWriter * writer)387 void RenderOptionsClass(const OpSpec& op, const Type& op_class,
388                         SourceWriter* writer) {
389   Type options_class = Type::Class("Options");
390   Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " +
391                                         op_class.canonical_name() + "}");
392   writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
393   for (const AttributeSpec& attr : op.optional_attributes()) {
394     Method setter = Method::Create(attr.var().name(), options_class);
395     Javadoc setter_doc = Javadoc::Create();
396     AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
397     writer->BeginMethod(setter, PUBLIC, &setter_doc)
398         .Append("this." + attr.var().name() + " = " + attr.var().name() + ";")
399         .EndLine()
400         .Append("return this;")
401         .EndLine()
402         .EndMethod();
403   }
404   writer->EndLine();
405   for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
406     writer->WriteField(optional_attribute.var(), PRIVATE);
407   }
408   Method constructor = Method::ConstructorFor(options_class);
409   writer->BeginMethod(constructor, PRIVATE).EndMethod();
410   writer->EndType();
411 }
412 
ClassOf(const EndpointSpec & endpoint,const string & base_package)413 inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
414   return Type::Class(
415       endpoint.name(),
416       base_package + "." + absl::AsciiStrToLower(endpoint.package()));
417 }
418 
GenerateOp(const OpSpec & op,const EndpointSpec & endpoint,const string & base_package,const string & output_dir,Env * env)419 void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
420                 const string& base_package, const string& output_dir,
421                 Env* env) {
422   Type op_class(
423       ClassOf(endpoint, base_package)
424           .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
425   Javadoc op_javadoc(endpoint.javadoc());
426 
427   // op interfaces
428   RenderMode mode = DEFAULT;
429   if (op.outputs().size() == 1) {
430     const ArgumentSpec& output = op.outputs().front();
431     Type operand_type(output.type().wildcard() ? Type::Class("Object")
432                                                : output.type());
433     Type operand_inf(Type::Interface("Operand", "org.tensorflow")
434                          .add_parameter(operand_type));
435     if (output.iterable()) {
436       mode = LIST_OPERAND;
437       op_class.add_supertype(Type::IterableOf(operand_inf));
438     } else {
439       mode = OPERAND;
440       op_class.add_supertype(operand_inf);
441     }
442   }
443   // op generic parameters
444   std::set<string> generics;
445   for (const ArgumentSpec& output : op.outputs()) {
446     if (output.type().kind() == Type::GENERIC && !output.type().wildcard() &&
447         generics.find(output.type().name()) == generics.end()) {
448       op_class.add_parameter(output.type());
449       op_javadoc.add_param_tag(
450           "<" + output.type().name() + ">",
451           "data type for {@code " + output.var().name() + "()} output");
452       generics.insert(output.type().name());
453     }
454   }
455   // op annotations
456   if (endpoint.deprecated()) {
457     op_class.add_annotation(Annotation::Create("Deprecated"));
458     string explanation;
459     if (!op.endpoints().front().deprecated()) {
460       explanation =
461           "use {@link " +
462           ClassOf(op.endpoints().front(), base_package).canonical_name() +
463           "} instead";
464     } else {
465       explanation = op.deprecation_explanation();
466     }
467     op_javadoc.add_tag("deprecated", explanation);
468   }
469   if (!op.hidden()) {
470     // expose the op in the Ops Graph API only if it is visible
471     Annotation oper_annot =
472         Annotation::Create("Operator", "org.tensorflow.op.annotation");
473     if (endpoint.package() != kDefaultEndpointPackage) {
474       oper_annot.attributes("group = \"" + endpoint.package() + "\"");
475     }
476     op_class.add_annotation(oper_annot);
477   }
478   // create op class file
479   const string op_dir_name = io::JoinPath(
480       output_dir, str_util::StringReplace(op_class.package(), ".", "/", true));
481   if (!env->FileExists(op_dir_name).ok()) {
482     TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name))
483         << op_dir_name;
484   }
485   const string op_file_name = op_class.name() + ".java";
486   std::unique_ptr<tensorflow::WritableFile> op_file;
487   TF_CHECK_OK(
488       env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file))
489       << op_file_name;
490 
491   // render endpoint source code
492   SourceFileWriter writer(op_file.get());
493   std::list<Type> dependencies;
494   CollectOpDependencies(op, mode, &dependencies);
495   writer.Write(kLicense)
496       .EndLine()
497       .Write("// This class has been generated, DO NOT EDIT!")
498       .EndLine()
499       .EndLine()
500       .BeginType(op_class, PUBLIC | FINAL, &dependencies, &op_javadoc);
501   if (!op.optional_attributes().empty()) {
502     RenderOptionsClass(op, op_class, &writer);
503   }
504   RenderFactoryMethods(op, op_class, &writer);
505   RenderGettersAndSetters(op, &writer);
506   if (mode != DEFAULT) {
507     RenderInterfaceImpl(op, mode, &writer);
508   }
509   writer.EndLine();
510   for (const ArgumentSpec& output : op.outputs()) {
511     writer.WriteField(output.var(), PRIVATE);
512   }
513   RenderConstructor(op, op_class, &writer);
514   writer.EndType();
515 }
516 
CanGenerateOp(const OpDef & op_def,const ApiDef & api_def)517 bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) {
518   if (api_def.visibility() == ApiDef::SKIP) {
519     return false;
520   }
521   for (const auto& attr : op_def.attr()) {
522     if (attr.type() == "func" || attr.type() == "list(func)") {
523       return false;  // TODO(karllessard) add support for function attributes
524     }
525   }
526   return true;
527 }
528 
529 }  // namespace
530 
Run(const OpList & op_list,const string & base_package,const string & output_dir)531 Status OpGenerator::Run(const OpList& op_list, const string& base_package,
532                         const string& output_dir) {
533   ApiDefMap api_map(op_list);
534   if (!api_dirs_.empty()) {
535     // Only load api files that correspond to the requested "op_list"
536     for (const auto& op : op_list.op()) {
537       for (const auto& api_def_dir : api_dirs_) {
538         const std::string api_def_file_pattern =
539             io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
540         if (env_->FileExists(api_def_file_pattern).ok()) {
541           TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern))
542               << api_def_file_pattern;
543         }
544       }
545     }
546   }
547   api_map.UpdateDocs();
548   for (const auto& op_def : op_list.op()) {
549     const ApiDef* api_def = api_map.GetApiDef(op_def.name());
550     if (CanGenerateOp(op_def, *api_def)) {
551       OpSpec op(OpSpec::Create(op_def, *api_def));
552       for (const EndpointSpec& endpoint : op.endpoints()) {
553         GenerateOp(op, endpoint, base_package, output_dir, env_);
554       }
555     }
556   }
557   return Status::OK();
558 }
559 
560 }  // namespace java
561 }  // namespace tensorflow
562