1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <list>
17 #include <map>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/core/framework/op_gen_lib.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/java/src/gen/cc/java_defs.h"
31 #include "tensorflow/java/src/gen/cc/op_generator.h"
32 #include "tensorflow/java/src/gen/cc/op_specs.h"
33 #include "tensorflow/java/src/gen/cc/source_writer.h"
34
35 namespace tensorflow {
36 namespace java {
37 namespace {
38
39 constexpr const char kLicense[] =
40 "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
41 "\n"
42 "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
43 "you may not use this file except in compliance with the License.\n"
44 "You may obtain a copy of the License at\n"
45 "\n"
46 " http://www.apache.org/licenses/LICENSE-2.0\n"
47 "\n"
48 "Unless required by applicable law or agreed to in writing, software\n"
49 "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
50 "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
51 "See the License for the specific language governing permissions and\n"
52 "limitations under the License.\n"
53 "=======================================================================*/"
54 "\n";
55
56 // There is three different modes to render an op class, depending on the
57 // number and type of outputs it has:
58 //
59 // DEFAULT: This mode does not provide any specialization for the op class, it
60 // is applied when the operation does not comply with any other mode
61 //
62 // OPERAND: The op class implements the Operand<T> interface, allowing an
63 // instance to be passed directly in input to another operation
64 //
65 // LIST_OPERAND: The op class implements the Iterable<Operand<T>> interface,
66 // allowing an instance to be passed directly as a list input to
67 // another operation
68 //
69 enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND };
70
AddArgument(const Variable & var,const string & description,Method * method_out,Javadoc * javadoc_out)71 void AddArgument(const Variable& var, const string& description,
72 Method* method_out, Javadoc* javadoc_out) {
73 method_out->add_argument(var);
74 javadoc_out->add_param_tag(var.name(), description);
75 }
76
CollectOpDependencies(const OpSpec & op,RenderMode mode,std::list<Type> * out)77 void CollectOpDependencies(const OpSpec& op, RenderMode mode,
78 std::list<Type>* out) {
79 out->push_back(Type::Class("Operation", "org.tensorflow"));
80 out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
81 out->push_back(Type::Class("Scope", "org.tensorflow.op"));
82 if (mode == OPERAND) {
83 out->push_back(Type::Class("Output", "org.tensorflow"));
84 } else if (mode == LIST_OPERAND) {
85 out->push_back(Type::Interface("Iterator", "java.util"));
86 }
87 // Don't pay attention to duplicate types in the dependency list, they will
88 // be filtered out by the SourceWriter.
89 for (const ArgumentSpec& input : op.inputs()) {
90 out->push_back(input.var().type());
91 if (input.iterable()) {
92 out->push_back(Type::Class("Operands", "org.tensorflow.op"));
93 }
94 }
95 for (const ArgumentSpec& output : op.outputs()) {
96 out->push_back(output.var().type());
97 if (output.iterable()) {
98 out->push_back(Type::Class("Arrays", "java.util"));
99 }
100 }
101 for (const AttributeSpec& attribute : op.attributes()) {
102 out->push_back(attribute.var().type());
103 out->push_back(attribute.jni_type());
104 if (attribute.has_default_value() &&
105 attribute.type().kind() == Type::GENERIC) {
106 out->push_back(Type::ForDataType(attribute.default_value()->type()));
107 }
108 }
109 for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
110 out->push_back(optional_attribute.var().type());
111 }
112 }
113
WriteSetAttrDirective(const AttributeSpec & attr,bool optional,SourceWriter * writer)114 void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
115 SourceWriter* writer) {
116 string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
117 if (attr.iterable()) {
118 string array_name = attr.var().name() + "Array";
119 writer->AppendType(attr.jni_type())
120 .Append("[] " + array_name + " = new ")
121 .AppendType(attr.jni_type())
122 .Append("[" + var_name + ".size()];")
123 .EndLine()
124 .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
125 .Append(array_name + "[i] = ");
126 if (attr.type().kind() == Type::GENERIC) {
127 writer->Append("DataType.fromClass(" + var_name + ".get(i));");
128 } else {
129 writer->Append(var_name + ".get(i);");
130 }
131 writer->EndLine()
132 .EndBlock()
133 .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
134 .Append(array_name + ");")
135 .EndLine();
136 } else {
137 writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
138 if (attr.var().type().name() == "Class") {
139 writer->Append("DataType.fromClass(" + var_name + "));");
140 } else {
141 writer->Append(var_name + ");");
142 }
143 writer->EndLine();
144 }
145 }
146
RenderSecondaryFactoryMethod(const OpSpec & op,const Type & op_class,std::map<string,Type> default_types,SourceWriter * writer)147 void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
148 std::map<string, Type> default_types,
149 SourceWriter* writer) {
150 // Build the return type for the secondary factory, replacing generic
151 // parameters with their default value if any
152 Type return_type = Type::Class(op_class.name(), op_class.package());
153 for (const Type& parameter : op_class.parameters()) {
154 if (parameter.kind() == Type::GENERIC &&
155 default_types.find(parameter.name()) != default_types.end()) {
156 return_type.add_parameter(default_types.at(parameter.name()));
157 } else {
158 return_type.add_parameter(parameter);
159 }
160 }
161 Method factory = Method::Create("create", return_type);
162 Javadoc factory_doc = Javadoc::Create(
163 "Factory method to create a class wrapping a new " + op_class.name() +
164 " operation using default output types.");
165 Variable scope =
166 Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
167 AddArgument(scope, "current scope", &factory, &factory_doc);
168 std::stringstream factory_statement;
169 factory_statement << "return create(scope";
170 for (const ArgumentSpec& input : op.inputs()) {
171 AddArgument(input.var(), input.description(), &factory, &factory_doc);
172 factory_statement << ", " << input.var().name();
173 }
174 for (const AttributeSpec& attr : op.attributes()) {
175 // Only add attributes that are not types or have no default value to the
176 // signature of the secondary factory
177 factory_statement << ", ";
178 if (attr.type().kind() == Type::GENERIC &&
179 default_types.find(attr.type().name()) != default_types.end()) {
180 factory_statement << default_types.at(attr.type().name()).name()
181 << ".class";
182 } else {
183 AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
184 factory_statement << attr.var().name();
185 }
186 }
187 if (!op.optional_attributes().empty()) {
188 Variable options_var = Variable::Varargs("options", Type::Class("Options"));
189 AddArgument(options_var, "carries optional attributes values", &factory,
190 &factory_doc);
191 factory_statement << ", " << options_var.name();
192 }
193 factory_doc.add_tag("return", "a new instance of " + op_class.name());
194
195 writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
196 writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
197 writer->EndMethod();
198 }
199
RenderFactoryMethods(const OpSpec & op,const Type & op_class,SourceWriter * writer)200 void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
201 SourceWriter* writer) {
202 Method factory = Method::Create("create", op_class);
203 Javadoc factory_doc =
204 Javadoc::Create("Factory method to create a class wrapping a new " +
205 op_class.name() + " operation.");
206 Variable scope =
207 Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
208 AddArgument(scope, "current scope", &factory, &factory_doc);
209 for (const ArgumentSpec& input : op.inputs()) {
210 AddArgument(input.var(), input.description(), &factory, &factory_doc);
211 }
212 std::map<string, Type> default_types;
213 for (const AttributeSpec& attr : op.attributes()) {
214 AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
215 // If this attribute is a type with a default value, save its value
216 // for passing it implicitly in a secondary factory method
217 if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
218 Type default_type = Type::ForDataType(attr.default_value()->type());
219 if (!default_type.wildcard()) {
220 default_types.insert(std::make_pair(attr.type().name(), default_type));
221 }
222 }
223 }
224 if (!op.optional_attributes().empty()) {
225 AddArgument(Variable::Varargs("options", Type::Class("Options")),
226 "carries optional attributes values", &factory, &factory_doc);
227 }
228 factory_doc.add_tag("return", "a new instance of " + op_class.name());
229
230 writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
231 writer->Append("OperationBuilder opBuilder = scope.env().opBuilder(\"" +
232 op.graph_op_name() + "\", scope.makeOpName(\"" +
233 op_class.name() + "\"));");
234 writer->EndLine();
235 for (const ArgumentSpec& input : op.inputs()) {
236 if (input.iterable()) {
237 writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
238 input.var().name() + "));");
239 writer->EndLine();
240 } else {
241 writer->Append("opBuilder.addInput(" + input.var().name() +
242 ".asOutput());");
243 writer->EndLine();
244 }
245 }
246 // Add control dependencies, if any.
247 writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);");
248 writer->EndLine();
249
250 for (const AttributeSpec& attribute : op.attributes()) {
251 WriteSetAttrDirective(attribute, false, writer);
252 }
253 if (!op.optional_attributes().empty()) {
254 writer->BeginBlock("if (options != null)")
255 .BeginBlock("for (Options opts : options)");
256 for (const AttributeSpec& attribute : op.optional_attributes()) {
257 writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
258 WriteSetAttrDirective(attribute, true, writer);
259 writer->EndBlock();
260 }
261 writer->EndBlock().EndBlock();
262 }
263 writer->Append("return new ")
264 .AppendType(op_class)
265 .Append("(opBuilder.build());")
266 .EndLine();
267 writer->EndMethod();
268
269 // If this operation has type attributes with a default value, create a
270 // second factory method that infers those values implicitly
271 if (!default_types.empty()) {
272 RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
273 }
274 }
275
RenderConstructor(const OpSpec & op,const Type & op_class,SourceWriter * writer)276 void RenderConstructor(const OpSpec& op, const Type& op_class,
277 SourceWriter* writer) {
278 Variable operation =
279 Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
280 Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
281 for (const ArgumentSpec& output : op.outputs()) {
282 if (output.iterable() && !output.type().wildcard()) {
283 constructor.add_annotation(
284 Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
285 break;
286 }
287 }
288 writer->BeginMethod(constructor, PRIVATE)
289 .Append("super(operation);")
290 .EndLine();
291 if (!op.outputs().empty()) {
292 writer->Append("int outputIdx = 0;").EndLine();
293 for (const ArgumentSpec& output : op.outputs()) {
294 if (output.iterable()) {
295 string var_length = output.var().name() + "Length";
296 writer->Append("int " + var_length)
297 .Append(" = operation.outputListLength(\"" + output.op_def_name() +
298 "\");")
299 .EndLine()
300 .Append(output.var().name() + " = Arrays.asList(");
301 if (!output.type().wildcard()) {
302 writer->Append("(")
303 .AppendType(output.var().type().parameters().front())
304 .Append("[])");
305 }
306 writer->Append("operation.outputList(outputIdx, " + var_length + "));")
307 .EndLine()
308 .Append("outputIdx += " + var_length + ";")
309 .EndLine();
310 } else {
311 writer
312 ->Append(output.var().name() + " = operation.output(outputIdx++);")
313 .EndLine();
314 }
315 }
316 }
317 writer->EndMethod();
318 }
319
RenderGettersAndSetters(const OpSpec & op,SourceWriter * writer)320 void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
321 for (const AttributeSpec& attr : op.optional_attributes()) {
322 Method setter = Method::Create(attr.var().name(), Type::Class("Options"));
323 Javadoc setter_doc = Javadoc::Create();
324 AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
325 writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc)
326 .Append("return new Options()." + attr.var().name() + "(" +
327 attr.var().name() + ");")
328 .EndLine()
329 .EndMethod();
330 }
331 for (const ArgumentSpec& output : op.outputs()) {
332 Method getter = Method::Create(output.var().name(), output.var().type());
333 Javadoc getter_doc = Javadoc::Create(output.description());
334 writer->BeginMethod(getter, PUBLIC, &getter_doc)
335 .Append("return " + output.var().name() + ";")
336 .EndLine()
337 .EndMethod();
338 }
339 }
340
RenderInterfaceImpl(const OpSpec & op,RenderMode mode,SourceWriter * writer)341 void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
342 SourceWriter* writer) {
343 ArgumentSpec output = op.outputs().front();
344
345 if (mode == OPERAND) {
346 bool cast2obj = output.type().wildcard();
347 Type return_type =
348 Type::Class("Output", "org.tensorflow")
349 .add_parameter(cast2obj ? Type::Class("Object") : output.type());
350 Method as_output = Method::Create("asOutput", return_type)
351 .add_annotation(Annotation::Create("Override"));
352 if (cast2obj) {
353 as_output.add_annotation(
354 Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
355 }
356 writer->BeginMethod(as_output, PUBLIC);
357 if (cast2obj) {
358 writer->Append("return (").AppendType(return_type).Append(") ");
359 } else {
360 writer->Append("return ");
361 }
362 writer->Append(output.var().name() + ";").EndLine().EndMethod();
363
364 } else if (mode == LIST_OPERAND) {
365 Type operand = Type::Interface("Operand", "org.tensorflow");
366 if (output.type().wildcard()) {
367 operand.add_parameter(Type::Class("Object"));
368 } else {
369 operand.add_parameter(output.type());
370 }
371 Type return_type =
372 Type::Interface("Iterator", "java.util").add_parameter(operand);
373 Method iterator =
374 Method::Create("iterator", return_type)
375 .add_annotation(Annotation::Create("Override"))
376 .add_annotation(Annotation::Create("SuppressWarnings")
377 .attributes("{\"rawtypes\", \"unchecked\"}"));
378 // cast the output list using a raw List
379 writer->BeginMethod(iterator, PUBLIC)
380 .Append("return (" + return_type.name() + ") ")
381 .Append(output.var().name() + ".iterator();")
382 .EndLine()
383 .EndMethod();
384 }
385 }
386
RenderOptionsClass(const OpSpec & op,const Type & op_class,SourceWriter * writer)387 void RenderOptionsClass(const OpSpec& op, const Type& op_class,
388 SourceWriter* writer) {
389 Type options_class = Type::Class("Options");
390 Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " +
391 op_class.canonical_name() + "}");
392 writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
393 for (const AttributeSpec& attr : op.optional_attributes()) {
394 Method setter = Method::Create(attr.var().name(), options_class);
395 Javadoc setter_doc = Javadoc::Create();
396 AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
397 writer->BeginMethod(setter, PUBLIC, &setter_doc)
398 .Append("this." + attr.var().name() + " = " + attr.var().name() + ";")
399 .EndLine()
400 .Append("return this;")
401 .EndLine()
402 .EndMethod();
403 }
404 writer->EndLine();
405 for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
406 writer->WriteField(optional_attribute.var(), PRIVATE);
407 }
408 Method constructor = Method::ConstructorFor(options_class);
409 writer->BeginMethod(constructor, PRIVATE).EndMethod();
410 writer->EndType();
411 }
412
ClassOf(const EndpointSpec & endpoint,const string & base_package)413 inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
414 return Type::Class(
415 endpoint.name(),
416 base_package + "." + absl::AsciiStrToLower(endpoint.package()));
417 }
418
GenerateOp(const OpSpec & op,const EndpointSpec & endpoint,const string & base_package,const string & output_dir,Env * env)419 void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
420 const string& base_package, const string& output_dir,
421 Env* env) {
422 Type op_class(
423 ClassOf(endpoint, base_package)
424 .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
425 Javadoc op_javadoc(endpoint.javadoc());
426
427 // op interfaces
428 RenderMode mode = DEFAULT;
429 if (op.outputs().size() == 1) {
430 const ArgumentSpec& output = op.outputs().front();
431 Type operand_type(output.type().wildcard() ? Type::Class("Object")
432 : output.type());
433 Type operand_inf(Type::Interface("Operand", "org.tensorflow")
434 .add_parameter(operand_type));
435 if (output.iterable()) {
436 mode = LIST_OPERAND;
437 op_class.add_supertype(Type::IterableOf(operand_inf));
438 } else {
439 mode = OPERAND;
440 op_class.add_supertype(operand_inf);
441 }
442 }
443 // op generic parameters
444 std::set<string> generics;
445 for (const ArgumentSpec& output : op.outputs()) {
446 if (output.type().kind() == Type::GENERIC && !output.type().wildcard() &&
447 generics.find(output.type().name()) == generics.end()) {
448 op_class.add_parameter(output.type());
449 op_javadoc.add_param_tag(
450 "<" + output.type().name() + ">",
451 "data type for {@code " + output.var().name() + "()} output");
452 generics.insert(output.type().name());
453 }
454 }
455 // op annotations
456 if (endpoint.deprecated()) {
457 op_class.add_annotation(Annotation::Create("Deprecated"));
458 string explanation;
459 if (!op.endpoints().front().deprecated()) {
460 explanation =
461 "use {@link " +
462 ClassOf(op.endpoints().front(), base_package).canonical_name() +
463 "} instead";
464 } else {
465 explanation = op.deprecation_explanation();
466 }
467 op_javadoc.add_tag("deprecated", explanation);
468 }
469 if (!op.hidden()) {
470 // expose the op in the Ops Graph API only if it is visible
471 Annotation oper_annot =
472 Annotation::Create("Operator", "org.tensorflow.op.annotation");
473 if (endpoint.package() != kDefaultEndpointPackage) {
474 oper_annot.attributes("group = \"" + endpoint.package() + "\"");
475 }
476 op_class.add_annotation(oper_annot);
477 }
478 // create op class file
479 const string op_dir_name = io::JoinPath(
480 output_dir, str_util::StringReplace(op_class.package(), ".", "/", true));
481 if (!env->FileExists(op_dir_name).ok()) {
482 TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name))
483 << op_dir_name;
484 }
485 const string op_file_name = op_class.name() + ".java";
486 std::unique_ptr<tensorflow::WritableFile> op_file;
487 TF_CHECK_OK(
488 env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file))
489 << op_file_name;
490
491 // render endpoint source code
492 SourceFileWriter writer(op_file.get());
493 std::list<Type> dependencies;
494 CollectOpDependencies(op, mode, &dependencies);
495 writer.Write(kLicense)
496 .EndLine()
497 .Write("// This class has been generated, DO NOT EDIT!")
498 .EndLine()
499 .EndLine()
500 .BeginType(op_class, PUBLIC | FINAL, &dependencies, &op_javadoc);
501 if (!op.optional_attributes().empty()) {
502 RenderOptionsClass(op, op_class, &writer);
503 }
504 RenderFactoryMethods(op, op_class, &writer);
505 RenderGettersAndSetters(op, &writer);
506 if (mode != DEFAULT) {
507 RenderInterfaceImpl(op, mode, &writer);
508 }
509 writer.EndLine();
510 for (const ArgumentSpec& output : op.outputs()) {
511 writer.WriteField(output.var(), PRIVATE);
512 }
513 RenderConstructor(op, op_class, &writer);
514 writer.EndType();
515 }
516
CanGenerateOp(const OpDef & op_def,const ApiDef & api_def)517 bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) {
518 if (api_def.visibility() == ApiDef::SKIP) {
519 return false;
520 }
521 for (const auto& attr : op_def.attr()) {
522 if (attr.type() == "func" || attr.type() == "list(func)") {
523 return false; // TODO(karllessard) add support for function attributes
524 }
525 }
526 return true;
527 }
528
529 } // namespace
530
Run(const OpList & op_list,const string & base_package,const string & output_dir)531 Status OpGenerator::Run(const OpList& op_list, const string& base_package,
532 const string& output_dir) {
533 ApiDefMap api_map(op_list);
534 if (!api_dirs_.empty()) {
535 // Only load api files that correspond to the requested "op_list"
536 for (const auto& op : op_list.op()) {
537 for (const auto& api_def_dir : api_dirs_) {
538 const std::string api_def_file_pattern =
539 io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
540 if (env_->FileExists(api_def_file_pattern).ok()) {
541 TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern))
542 << api_def_file_pattern;
543 }
544 }
545 }
546 }
547 api_map.UpdateDocs();
548 for (const auto& op_def : op_list.op()) {
549 const ApiDef* api_def = api_map.GetApiDef(op_def.name());
550 if (CanGenerateOp(op_def, *api_def)) {
551 OpSpec op(OpSpec::Create(op_def, *api_def));
552 for (const EndpointSpec& endpoint : op.endpoints()) {
553 GenerateOp(op, endpoint, base_package, output_dir, env_);
554 }
555 }
556 }
557 return Status::OK();
558 }
559
560 } // namespace java
561 } // namespace tensorflow
562