1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.processor; 17 18 import com.google.common.base.CaseFormat; 19 import com.google.common.base.Strings; 20 import com.google.common.collect.HashMultimap; 21 import com.google.common.collect.Multimap; 22 import com.squareup.javapoet.ClassName; 23 import com.squareup.javapoet.FieldSpec; 24 import com.squareup.javapoet.JavaFile; 25 import com.squareup.javapoet.MethodSpec; 26 import com.squareup.javapoet.ParameterSpec; 27 import com.squareup.javapoet.TypeName; 28 import com.squareup.javapoet.TypeSpec; 29 import com.squareup.javapoet.TypeVariableName; 30 import java.io.IOException; 31 import java.util.Collection; 32 import java.util.Collections; 33 import java.util.HashMap; 34 import java.util.Map; 35 import java.util.Set; 36 import java.util.regex.Matcher; 37 import java.util.regex.Pattern; 38 import javax.annotation.processing.AbstractProcessor; 39 import javax.annotation.processing.Filer; 40 import javax.annotation.processing.Messager; 41 import javax.annotation.processing.ProcessingEnvironment; 42 import javax.annotation.processing.RoundEnvironment; 43 import javax.lang.model.SourceVersion; 44 import javax.lang.model.element.AnnotationMirror; 45 import javax.lang.model.element.AnnotationValue; 46 import javax.lang.model.element.Element; 47 import javax.lang.model.element.ExecutableElement; 48 import javax.lang.model.element.Modifier; 49 import javax.lang.model.element.TypeElement; 50 import javax.lang.model.element.TypeParameterElement; 51 import javax.lang.model.element.VariableElement; 52 import javax.lang.model.type.TypeMirror; 53 import javax.lang.model.type.TypeVariable; 54 import javax.lang.model.util.ElementFilter; 55 import javax.lang.model.util.Elements; 56 import javax.tools.Diagnostic.Kind; 57 58 /** 59 * A compile-time Processor that aggregates classes annotated with {@link 60 * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please 61 * refer to the {@link org.tensorflow.op.annotation.Operator} annotation for details about the API 62 * generated for each annotated class. 63 * 64 * <p>Note that this processor can only be invoked once, in a single compilation run that includes 65 * all the {@code Operator} annotated source classes. The reason is that the {@code Ops} API is an 66 * "aggregating" API, and annotation processing does not permit modifying an already generated 67 * class. 68 * 69 * @see org.tensorflow.op.annotation.Operator 70 */ 71 public final class OperatorProcessor extends AbstractProcessor { 72 73 @Override getSupportedSourceVersion()74 public SourceVersion getSupportedSourceVersion() { 75 return SourceVersion.latestSupported(); 76 } 77 78 @Override init(ProcessingEnvironment processingEnv)79 public synchronized void init(ProcessingEnvironment processingEnv) { 80 super.init(processingEnv); 81 messager = processingEnv.getMessager(); 82 filer = processingEnv.getFiler(); 83 elements = processingEnv.getElementUtils(); 84 } 85 86 @Override process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv)87 public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) { 88 // Nothing needs to be done at the end of all rounds. 89 if (roundEnv.processingOver()) { 90 return false; 91 } 92 93 // Nothing to look at in this round. 94 if (annotations.size() == 0) { 95 return false; 96 } 97 98 // We expect to be registered for exactly one annotation. 99 if (annotations.size() != 1) { 100 throw new IllegalStateException( 101 "Unexpected - multiple annotations registered: " + annotations); 102 } 103 TypeElement annotation = annotations.iterator().next(); 104 Set<? extends Element> annotated = roundEnv.getElementsAnnotatedWith(annotation); 105 106 // If there are no annotated elements, claim the annotation but do nothing. 107 if (annotated.size() == 0) { 108 return true; 109 } 110 111 // This processor has to aggregate all op classes in one round, as it generates a single Ops 112 // API class which cannot be modified once generated. If we find an annotation after we've 113 // generated our code, flag the location of each such class. 114 if (hasRun) { 115 for (Element e : annotated) { 116 error( 117 e, 118 "The Operator processor has already processed @Operator annotated sources\n" 119 + "and written out an Ops API. It cannot process additional @Operator sources.\n" 120 + "One reason this can happen is if other annotation processors generate\n" 121 + "new @Operator source files."); 122 } 123 return true; 124 } 125 126 // Collect all classes tagged with our annotation. 127 Multimap<String, MethodSpec> groupedMethods = HashMultimap.create(); 128 if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) { 129 return true; 130 } 131 132 // Nothing to do when there are no tagged classes. 133 if (groupedMethods.isEmpty()) { 134 return true; 135 } 136 137 // Validate operator classes and generate Op API. 138 writeApi(groupedMethods); 139 140 hasRun = true; 141 return true; 142 } 143 144 @Override getSupportedAnnotationTypes()145 public Set<String> getSupportedAnnotationTypes() { 146 return Collections.singleton("org.tensorflow.op.annotation.Operator"); 147 } 148 149 private static final Pattern JAVADOC_TAG_PATTERN = 150 Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); 151 private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); 152 private static final TypeName T_OPERATOR = 153 ClassName.get("org.tensorflow.op.annotation", "Operator"); 154 private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); 155 private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); 156 private static final TypeName T_STRING = ClassName.get(String.class); 157 158 private Filer filer; 159 private Messager messager; 160 private Elements elements; 161 private boolean hasRun = false; 162 error(Element e, String message, Object... args)163 private void error(Element e, String message, Object... args) { 164 if (args != null && args.length > 0) { 165 message = String.format(message, args); 166 } 167 messager.printMessage(Kind.ERROR, message, e); 168 } 169 write(TypeSpec spec)170 private void write(TypeSpec spec) { 171 try { 172 JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); 173 } catch (IOException e) { 174 throw new AssertionError(e); 175 } 176 } 177 writeApi(Multimap<String, MethodSpec> groupedMethods)178 private void writeApi(Multimap<String, MethodSpec> groupedMethods) { 179 Map<String, ClassName> groups = new HashMap<>(); 180 181 // Generate a API class for each group collected other than the default one (= empty string) 182 for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) { 183 if (!entry.getKey().isEmpty()) { 184 TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); 185 write(groupClass); 186 groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name)); 187 } 188 } 189 // Generate the top API class, adding any methods added to the default group 190 TypeSpec topClass = buildTopClass(groups, groupedMethods.get("")); 191 write(topClass); 192 } 193 collectOpsMethods( RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation)194 private boolean collectOpsMethods( 195 RoundEnvironment roundEnv, 196 Multimap<String, MethodSpec> groupedMethods, 197 TypeElement annotation) { 198 boolean result = true; 199 for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { 200 // @Operator can only apply to types, so e must be a TypeElement. 201 if (!(e instanceof TypeElement)) { 202 error( 203 e, 204 "@Operator can only be applied to classes, but this is a %s", 205 e.getKind().toString()); 206 result = false; 207 continue; 208 } 209 TypeElement opClass = (TypeElement) e; 210 // Skip deprecated operations for now, as we do not guarantee API stability yet 211 if (opClass.getAnnotation(Deprecated.class) == null) { 212 collectOpMethods(groupedMethods, opClass, annotation); 213 } 214 } 215 return result; 216 } 217 collectOpMethods( Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation)218 private void collectOpMethods( 219 Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { 220 AnnotationMirror am = getAnnotationMirror(opClass, annotation); 221 String groupName = getAnnotationElementValueAsString("group", am); 222 String methodName = getAnnotationElementValueAsString("name", am); 223 ClassName opClassName = ClassName.get(opClass); 224 if (Strings.isNullOrEmpty(methodName)) { 225 methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); 226 } 227 // Build a method for each @Operator found in the class path. There should be one method per 228 // operation factory called 229 // "create", which takes in parameter a scope and, optionally, a list of arguments 230 for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { 231 if (opMethod.getModifiers().contains(Modifier.STATIC) 232 && opMethod.getSimpleName().contentEquals("create")) { 233 MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); 234 groupedMethods.put(groupName, method); 235 } 236 } 237 } 238 buildOpMethod( String methodName, ClassName opClassName, ExecutableElement factoryMethod)239 private MethodSpec buildOpMethod( 240 String methodName, ClassName opClassName, ExecutableElement factoryMethod) { 241 MethodSpec.Builder builder = 242 MethodSpec.methodBuilder(methodName) 243 .addModifiers(Modifier.PUBLIC) 244 .returns(TypeName.get(factoryMethod.getReturnType())) 245 .varargs(factoryMethod.isVarArgs()) 246 .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); 247 248 for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { 249 TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); 250 builder.addTypeVariable(tvn); 251 } 252 for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { 253 builder.addException(TypeName.get(thrownType)); 254 } 255 StringBuilder call = new StringBuilder("return $T.create(scope"); 256 boolean first = true; 257 for (VariableElement param : factoryMethod.getParameters()) { 258 ParameterSpec p = ParameterSpec.get(param); 259 if (first) { 260 first = false; 261 continue; 262 } 263 call.append(", "); 264 call.append(p.name); 265 builder.addParameter(p); 266 } 267 call.append(")"); 268 builder.addStatement(call.toString(), opClassName); 269 return builder.build(); 270 } 271 buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod)272 private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { 273 StringBuilder javadoc = new StringBuilder(); 274 javadoc 275 .append("Adds an {@link ") 276 .append(opClassName.simpleName()) 277 .append("} operation to the graph\n\n"); 278 279 // Add all javadoc tags found in the operator factory method but the first one, which should be 280 // in all cases the 281 // 'scope' parameter that is implicitly passed by this API 282 Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); 283 boolean firstParam = true; 284 285 while (tagMatcher.find()) { 286 String tag = tagMatcher.group(); 287 if (tag.startsWith("@param") && firstParam) { 288 firstParam = false; 289 } else { 290 javadoc.append(tag).append('\n'); 291 } 292 } 293 javadoc.append("@see ").append(opClassName).append("\n"); 294 295 return javadoc.toString(); 296 } 297 buildGroupClass(String group, Collection<MethodSpec> methods)298 private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) { 299 MethodSpec.Builder ctorBuilder = 300 MethodSpec.constructorBuilder() 301 .addParameter(T_SCOPE, "scope") 302 .addStatement("this.scope = scope"); 303 304 TypeSpec.Builder builder = 305 TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") 306 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 307 .addJavadoc( 308 "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" 309 + "@see {@link $T}\n", 310 group, 311 T_GRAPH, 312 T_OPS) 313 .addMethods(methods) 314 .addMethod(ctorBuilder.build()); 315 316 builder.addField( 317 FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); 318 319 return builder.build(); 320 } 321 buildTopClass( Map<String, ClassName> groupToClass, Collection<MethodSpec> methods)322 private static TypeSpec buildTopClass( 323 Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { 324 MethodSpec.Builder ctorBuilder = 325 MethodSpec.constructorBuilder() 326 .addModifiers(Modifier.PRIVATE) 327 .addParameter(T_SCOPE, "scope") 328 .addStatement("this.scope = scope", T_SCOPE); 329 330 for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { 331 ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); 332 } 333 334 TypeSpec.Builder opsBuilder = 335 TypeSpec.classBuilder("Ops") 336 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 337 .addJavadoc( 338 "An API for building a {@link $T} with operation wrappers\n<p>\n" 339 + "Any operation wrapper found in the classpath properly annotated as an" 340 + "{@link $T @Operator} is exposed\n" 341 + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" 342 + "try (Graph g = new Graph()) {\n" 343 + " Ops ops = Ops.create(g);\n" 344 + " // Operations are typed classes with convenience\n" 345 + " // builders in Ops.\n" 346 + " Constant three = ops.constant(3);\n" 347 + " // Single-result operations implement the Operand\n" 348 + " // interface, so this works too.\n" 349 + " Operand four = ops.constant(4);\n" 350 + " // Most builders are found within a group, and accept\n" 351 + " // Operand types as operands\n" 352 + " Operand nine = ops.math().add(four, ops.constant(5));\n" 353 + " // Multi-result operations however offer methods to\n" 354 + " // select a particular result for use.\n" 355 + " Operand result = \n" 356 + " ops.math().add(ops.array().unique(s, a).y(), b);\n" 357 + " // Optional attributes\n" 358 + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" 359 + " // Naming operators\n" 360 + " ops.withName(\"foo\").constant(5); // name \"foo\"\n" 361 + " // Names can exist in a hierarchy\n" 362 + " Ops sub = ops.withSubScope(\"sub\");\n" 363 + " sub.withName(\"bar\").constant(4); // \"sub/bar\"\n" 364 + "}\n" 365 + "}</pre>\n", 366 T_GRAPH, 367 T_OPERATOR) 368 .addMethods(methods) 369 .addMethod(ctorBuilder.build()); 370 371 opsBuilder.addMethod( 372 MethodSpec.methodBuilder("withSubScope") 373 .addModifiers(Modifier.PUBLIC) 374 .addParameter(T_STRING, "childScopeName") 375 .returns(T_OPS) 376 .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) 377 .addJavadoc( 378 "Returns an API that adds operations to the graph with the provided name prefix.\n" 379 + "\n@see {@link $T#withSubScope(String)}\n", 380 T_SCOPE) 381 .build()); 382 383 opsBuilder.addMethod( 384 MethodSpec.methodBuilder("withName") 385 .addModifiers(Modifier.PUBLIC) 386 .addParameter(T_STRING, "opName") 387 .returns(T_OPS) 388 .addStatement("return new Ops(scope.withName(opName))") 389 .addJavadoc( 390 "Returns an API that uses the provided name for an op.\n\n" 391 + "@see {@link $T#withName(String)}\n", 392 T_SCOPE) 393 .build()); 394 395 opsBuilder.addField( 396 FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); 397 398 opsBuilder.addMethod( 399 MethodSpec.methodBuilder("scope") 400 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 401 .returns(T_SCOPE) 402 .addStatement("return scope") 403 .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) 404 .build()); 405 406 for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { 407 opsBuilder.addField( 408 FieldSpec.builder(entry.getValue(), entry.getKey()) 409 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 410 .build()); 411 412 opsBuilder.addMethod( 413 MethodSpec.methodBuilder(entry.getKey()) 414 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 415 .returns(entry.getValue()) 416 .addStatement("return $L", entry.getKey()) 417 .addJavadoc( 418 "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) 419 .build()); 420 } 421 422 opsBuilder.addMethod( 423 MethodSpec.methodBuilder("create") 424 .addModifiers(Modifier.PUBLIC, Modifier.STATIC) 425 .addParameter(T_GRAPH, "graph") 426 .returns(T_OPS) 427 .addStatement("return new Ops(new $T(graph))", T_SCOPE) 428 .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") 429 .build()); 430 431 return opsBuilder.build(); 432 } 433 getAnnotationMirror(Element element, TypeElement annotation)434 private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) { 435 for (AnnotationMirror am : element.getAnnotationMirrors()) { 436 if (am.getAnnotationType().asElement().equals(annotation)) { 437 return am; 438 } 439 } 440 throw new IllegalArgumentException( 441 "Annotation " 442 + annotation.getSimpleName() 443 + " not present on element " 444 + element.getSimpleName()); 445 } 446 getAnnotationElementValueAsString(String elementName, AnnotationMirror am)447 private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { 448 for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : 449 am.getElementValues().entrySet()) { 450 if (entry.getKey().getSimpleName().contentEquals(elementName)) { 451 return entry.getValue().getValue().toString(); 452 } 453 } 454 return ""; 455 } 456 } 457