• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.processor;
17 
18 import com.google.common.base.CaseFormat;
19 import com.google.common.base.Strings;
20 import com.google.common.collect.HashMultimap;
21 import com.google.common.collect.Multimap;
22 import com.squareup.javapoet.ClassName;
23 import com.squareup.javapoet.FieldSpec;
24 import com.squareup.javapoet.JavaFile;
25 import com.squareup.javapoet.MethodSpec;
26 import com.squareup.javapoet.ParameterSpec;
27 import com.squareup.javapoet.TypeName;
28 import com.squareup.javapoet.TypeSpec;
29 import com.squareup.javapoet.TypeVariableName;
30 import java.io.IOException;
31 import java.util.Collection;
32 import java.util.Collections;
33 import java.util.HashMap;
34 import java.util.Map;
35 import java.util.Set;
36 import java.util.regex.Matcher;
37 import java.util.regex.Pattern;
38 import javax.annotation.processing.AbstractProcessor;
39 import javax.annotation.processing.Filer;
40 import javax.annotation.processing.Messager;
41 import javax.annotation.processing.ProcessingEnvironment;
42 import javax.annotation.processing.RoundEnvironment;
43 import javax.lang.model.SourceVersion;
44 import javax.lang.model.element.AnnotationMirror;
45 import javax.lang.model.element.AnnotationValue;
46 import javax.lang.model.element.Element;
47 import javax.lang.model.element.ExecutableElement;
48 import javax.lang.model.element.Modifier;
49 import javax.lang.model.element.TypeElement;
50 import javax.lang.model.element.TypeParameterElement;
51 import javax.lang.model.element.VariableElement;
52 import javax.lang.model.type.TypeMirror;
53 import javax.lang.model.type.TypeVariable;
54 import javax.lang.model.util.ElementFilter;
55 import javax.lang.model.util.Elements;
56 import javax.tools.Diagnostic.Kind;
57 
58 /**
59  * A compile-time Processor that aggregates classes annotated with {@link
60  * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
61  * refer to the {@link org.tensorflow.op.annotation.Operator} annotation for details about the API
62  * generated for each annotated class.
63  *
64  * <p>Note that this processor can only be invoked once, in a single compilation run that includes
65  * all the {@code Operator} annotated source classes. The reason is that the {@code Ops} API is an
66  * "aggregating" API, and annotation processing does not permit modifying an already generated
67  * class.
68  *
69  * @see org.tensorflow.op.annotation.Operator
70  */
71 public final class OperatorProcessor extends AbstractProcessor {
72 
73   @Override
getSupportedSourceVersion()74   public SourceVersion getSupportedSourceVersion() {
75     return SourceVersion.latestSupported();
76   }
77 
78   @Override
init(ProcessingEnvironment processingEnv)79   public synchronized void init(ProcessingEnvironment processingEnv) {
80     super.init(processingEnv);
81     messager = processingEnv.getMessager();
82     filer = processingEnv.getFiler();
83     elements = processingEnv.getElementUtils();
84   }
85 
86   @Override
process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv)87   public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
88     // Nothing needs to be done at the end of all rounds.
89     if (roundEnv.processingOver()) {
90       return false;
91     }
92 
93     // Nothing to look at in this round.
94     if (annotations.size() == 0) {
95       return false;
96     }
97 
98     // We expect to be registered for exactly one annotation.
99     if (annotations.size() != 1) {
100       throw new IllegalStateException(
101           "Unexpected - multiple annotations registered: " + annotations);
102     }
103     TypeElement annotation = annotations.iterator().next();
104     Set<? extends Element> annotated = roundEnv.getElementsAnnotatedWith(annotation);
105 
106     // If there are no annotated elements, claim the annotation but do nothing.
107     if (annotated.size() == 0) {
108       return true;
109     }
110 
111     // This processor has to aggregate all op classes in one round, as it generates a single Ops
112     // API class which cannot be modified once generated. If we find an annotation after we've
113     // generated our code, flag the location of each such class.
114     if (hasRun) {
115       for (Element e : annotated) {
116         error(
117             e,
118             "The Operator processor has already processed @Operator annotated sources\n"
119                 + "and written out an Ops API. It cannot process additional @Operator sources.\n"
120                 + "One reason this can happen is if other annotation processors generate\n"
121                 + "new @Operator source files.");
122       }
123       return true;
124     }
125 
126     // Collect all classes tagged with our annotation.
127     Multimap<String, MethodSpec> groupedMethods = HashMultimap.create();
128     if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) {
129       return true;
130     }
131 
132     // Nothing to do when there are no tagged classes.
133     if (groupedMethods.isEmpty()) {
134       return true;
135     }
136 
137     // Validate operator classes and generate Op API.
138     writeApi(groupedMethods);
139 
140     hasRun = true;
141     return true;
142   }
143 
144   @Override
getSupportedAnnotationTypes()145   public Set<String> getSupportedAnnotationTypes() {
146     return Collections.singleton("org.tensorflow.op.annotation.Operator");
147   }
148 
149   private static final Pattern JAVADOC_TAG_PATTERN =
150       Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
151   private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
152   private static final TypeName T_OPERATOR =
153       ClassName.get("org.tensorflow.op.annotation", "Operator");
154   private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
155   private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
156   private static final TypeName T_STRING = ClassName.get(String.class);
157 
158   private Filer filer;
159   private Messager messager;
160   private Elements elements;
161   private boolean hasRun = false;
162 
error(Element e, String message, Object... args)163   private void error(Element e, String message, Object... args) {
164     if (args != null && args.length > 0) {
165       message = String.format(message, args);
166     }
167     messager.printMessage(Kind.ERROR, message, e);
168   }
169 
write(TypeSpec spec)170   private void write(TypeSpec spec) {
171     try {
172       JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer);
173     } catch (IOException e) {
174       throw new AssertionError(e);
175     }
176   }
177 
writeApi(Multimap<String, MethodSpec> groupedMethods)178   private void writeApi(Multimap<String, MethodSpec> groupedMethods) {
179     Map<String, ClassName> groups = new HashMap<>();
180 
181     // Generate a API class for each group collected other than the default one (= empty string)
182     for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) {
183       if (!entry.getKey().isEmpty()) {
184         TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue());
185         write(groupClass);
186         groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name));
187       }
188     }
189     // Generate the top API class, adding any methods added to the default group
190     TypeSpec topClass = buildTopClass(groups, groupedMethods.get(""));
191     write(topClass);
192   }
193 
collectOpsMethods( RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation)194   private boolean collectOpsMethods(
195       RoundEnvironment roundEnv,
196       Multimap<String, MethodSpec> groupedMethods,
197       TypeElement annotation) {
198     boolean result = true;
199     for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
200       // @Operator can only apply to types, so e must be a TypeElement.
201       if (!(e instanceof TypeElement)) {
202         error(
203             e,
204             "@Operator can only be applied to classes, but this is a %s",
205             e.getKind().toString());
206         result = false;
207         continue;
208       }
209       TypeElement opClass = (TypeElement) e;
210       // Skip deprecated operations for now, as we do not guarantee API stability yet
211       if (opClass.getAnnotation(Deprecated.class) == null) {
212         collectOpMethods(groupedMethods, opClass, annotation);
213       }
214     }
215     return result;
216   }
217 
collectOpMethods( Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation)218   private void collectOpMethods(
219       Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
220     AnnotationMirror am = getAnnotationMirror(opClass, annotation);
221     String groupName = getAnnotationElementValueAsString("group", am);
222     String methodName = getAnnotationElementValueAsString("name", am);
223     ClassName opClassName = ClassName.get(opClass);
224     if (Strings.isNullOrEmpty(methodName)) {
225       methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
226     }
227     // Build a method for each @Operator found in the class path. There should be one method per
228     // operation factory called
229     // "create", which takes in parameter a scope and, optionally, a list of arguments
230     for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
231       if (opMethod.getModifiers().contains(Modifier.STATIC)
232           && opMethod.getSimpleName().contentEquals("create")) {
233         MethodSpec method = buildOpMethod(methodName, opClassName, opMethod);
234         groupedMethods.put(groupName, method);
235       }
236     }
237   }
238 
buildOpMethod( String methodName, ClassName opClassName, ExecutableElement factoryMethod)239   private MethodSpec buildOpMethod(
240       String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
241     MethodSpec.Builder builder =
242         MethodSpec.methodBuilder(methodName)
243             .addModifiers(Modifier.PUBLIC)
244             .returns(TypeName.get(factoryMethod.getReturnType()))
245             .varargs(factoryMethod.isVarArgs())
246             .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
247 
248     for (TypeParameterElement tp : factoryMethod.getTypeParameters()) {
249       TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
250       builder.addTypeVariable(tvn);
251     }
252     for (TypeMirror thrownType : factoryMethod.getThrownTypes()) {
253       builder.addException(TypeName.get(thrownType));
254     }
255     StringBuilder call = new StringBuilder("return $T.create(scope");
256     boolean first = true;
257     for (VariableElement param : factoryMethod.getParameters()) {
258       ParameterSpec p = ParameterSpec.get(param);
259       if (first) {
260         first = false;
261         continue;
262       }
263       call.append(", ");
264       call.append(p.name);
265       builder.addParameter(p);
266     }
267     call.append(")");
268     builder.addStatement(call.toString(), opClassName);
269     return builder.build();
270   }
271 
buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod)272   private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
273     StringBuilder javadoc = new StringBuilder();
274     javadoc
275         .append("Adds an {@link ")
276         .append(opClassName.simpleName())
277         .append("} operation to the graph\n\n");
278 
279     // Add all javadoc tags found in the operator factory method but the first one, which should be
280     // in all cases the
281     // 'scope' parameter that is implicitly passed by this API
282     Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod));
283     boolean firstParam = true;
284 
285     while (tagMatcher.find()) {
286       String tag = tagMatcher.group();
287       if (tag.startsWith("@param") && firstParam) {
288         firstParam = false;
289       } else {
290         javadoc.append(tag).append('\n');
291       }
292     }
293     javadoc.append("@see ").append(opClassName).append("\n");
294 
295     return javadoc.toString();
296   }
297 
buildGroupClass(String group, Collection<MethodSpec> methods)298   private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) {
299     MethodSpec.Builder ctorBuilder =
300         MethodSpec.constructorBuilder()
301             .addParameter(T_SCOPE, "scope")
302             .addStatement("this.scope = scope");
303 
304     TypeSpec.Builder builder =
305         TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
306             .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
307             .addJavadoc(
308                 "An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
309                     + "@see {@link $T}\n",
310                 group,
311                 T_GRAPH,
312                 T_OPS)
313             .addMethods(methods)
314             .addMethod(ctorBuilder.build());
315 
316     builder.addField(
317         FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
318 
319     return builder.build();
320   }
321 
buildTopClass( Map<String, ClassName> groupToClass, Collection<MethodSpec> methods)322   private static TypeSpec buildTopClass(
323       Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
324     MethodSpec.Builder ctorBuilder =
325         MethodSpec.constructorBuilder()
326             .addModifiers(Modifier.PRIVATE)
327             .addParameter(T_SCOPE, "scope")
328             .addStatement("this.scope = scope", T_SCOPE);
329 
330     for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
331       ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue());
332     }
333 
334     TypeSpec.Builder opsBuilder =
335         TypeSpec.classBuilder("Ops")
336             .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
337             .addJavadoc(
338                 "An API for building a {@link $T} with operation wrappers\n<p>\n"
339                     + "Any operation wrapper found in the classpath properly annotated as an"
340                     + "{@link $T @Operator} is exposed\n"
341                     + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
342                     + "try (Graph g = new Graph()) {\n"
343                     + "  Ops ops = Ops.create(g);\n"
344                     + "  // Operations are typed classes with convenience\n"
345                     + "  // builders in Ops.\n"
346                     + "  Constant three = ops.constant(3);\n"
347                     + "  // Single-result operations implement the Operand\n"
348                     + "  // interface, so this works too.\n"
349                     + "  Operand four = ops.constant(4);\n"
350                     + "  // Most builders are found within a group, and accept\n"
351                     + "  // Operand types as operands\n"
352                     + "  Operand nine = ops.math().add(four, ops.constant(5));\n"
353                     + "  // Multi-result operations however offer methods to\n"
354                     + "  // select a particular result for use.\n"
355                     + "  Operand result = \n"
356                     + "      ops.math().add(ops.array().unique(s, a).y(), b);\n"
357                     + "  // Optional attributes\n"
358                     + "  ops.math().matMul(a, b, MatMul.transposeA(true));\n"
359                     + "  // Naming operators\n"
360                     + "  ops.withName(\"foo\").constant(5); // name \"foo\"\n"
361                     + "  // Names can exist in a hierarchy\n"
362                     + "  Ops sub = ops.withSubScope(\"sub\");\n"
363                     + "  sub.withName(\"bar\").constant(4); // \"sub/bar\"\n"
364                     + "}\n"
365                     + "}</pre>\n",
366                 T_GRAPH,
367                 T_OPERATOR)
368             .addMethods(methods)
369             .addMethod(ctorBuilder.build());
370 
371     opsBuilder.addMethod(
372         MethodSpec.methodBuilder("withSubScope")
373             .addModifiers(Modifier.PUBLIC)
374             .addParameter(T_STRING, "childScopeName")
375             .returns(T_OPS)
376             .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
377             .addJavadoc(
378                 "Returns an API that adds operations to the graph with the provided name prefix.\n"
379                     + "\n@see {@link $T#withSubScope(String)}\n",
380                 T_SCOPE)
381             .build());
382 
383     opsBuilder.addMethod(
384         MethodSpec.methodBuilder("withName")
385             .addModifiers(Modifier.PUBLIC)
386             .addParameter(T_STRING, "opName")
387             .returns(T_OPS)
388             .addStatement("return new Ops(scope.withName(opName))")
389             .addJavadoc(
390                 "Returns an API that uses the provided name for an op.\n\n"
391                     + "@see {@link $T#withName(String)}\n",
392                 T_SCOPE)
393             .build());
394 
395     opsBuilder.addField(
396         FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
397 
398     opsBuilder.addMethod(
399         MethodSpec.methodBuilder("scope")
400             .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
401             .returns(T_SCOPE)
402             .addStatement("return scope")
403             .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
404             .build());
405 
406     for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
407       opsBuilder.addField(
408           FieldSpec.builder(entry.getValue(), entry.getKey())
409               .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
410               .build());
411 
412       opsBuilder.addMethod(
413           MethodSpec.methodBuilder(entry.getKey())
414               .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
415               .returns(entry.getValue())
416               .addStatement("return $L", entry.getKey())
417               .addJavadoc(
418                   "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
419               .build());
420     }
421 
422     opsBuilder.addMethod(
423         MethodSpec.methodBuilder("create")
424             .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
425             .addParameter(T_GRAPH, "graph")
426             .returns(T_OPS)
427             .addStatement("return new Ops(new $T(graph))", T_SCOPE)
428             .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
429             .build());
430 
431     return opsBuilder.build();
432   }
433 
getAnnotationMirror(Element element, TypeElement annotation)434   private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) {
435     for (AnnotationMirror am : element.getAnnotationMirrors()) {
436       if (am.getAnnotationType().asElement().equals(annotation)) {
437         return am;
438       }
439     }
440     throw new IllegalArgumentException(
441         "Annotation "
442             + annotation.getSimpleName()
443             + " not present on element "
444             + element.getSimpleName());
445   }
446 
getAnnotationElementValueAsString(String elementName, AnnotationMirror am)447   private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
448     for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry :
449         am.getElementValues().entrySet()) {
450       if (entry.getKey().getSimpleName().contentEquals(elementName)) {
451         return entry.getValue().getValue().toString();
452       }
453     }
454     return "";
455   }
456 }
457