• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2013 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.auto.factory.processor;
17 
18 import static com.google.auto.common.GeneratedAnnotationSpecs.generatedAnnotationSpec;
19 import static com.squareup.javapoet.MethodSpec.constructorBuilder;
20 import static com.squareup.javapoet.MethodSpec.methodBuilder;
21 import static com.squareup.javapoet.TypeSpec.classBuilder;
22 import static java.util.Objects.requireNonNull;
23 import static java.util.stream.Collectors.joining;
24 import static java.util.stream.Collectors.toList;
25 import static javax.lang.model.element.Modifier.FINAL;
26 import static javax.lang.model.element.Modifier.PRIVATE;
27 import static javax.lang.model.element.Modifier.PUBLIC;
28 import static javax.lang.model.element.Modifier.STATIC;
29 
30 import com.google.auto.common.AnnotationMirrors;
31 import com.google.auto.common.AnnotationValues;
32 import com.google.auto.common.MoreTypes;
33 import com.google.common.collect.ImmutableList;
34 import com.google.common.collect.ImmutableSet;
35 import com.google.common.collect.ImmutableSetMultimap;
36 import com.google.common.collect.Iterables;
37 import com.google.common.collect.Sets;
38 import com.google.common.collect.Streams;
39 import com.squareup.javapoet.AnnotationSpec;
40 import com.squareup.javapoet.ClassName;
41 import com.squareup.javapoet.CodeBlock;
42 import com.squareup.javapoet.JavaFile;
43 import com.squareup.javapoet.MethodSpec;
44 import com.squareup.javapoet.ParameterSpec;
45 import com.squareup.javapoet.ParameterizedTypeName;
46 import com.squareup.javapoet.TypeName;
47 import com.squareup.javapoet.TypeSpec;
48 import com.squareup.javapoet.TypeVariableName;
49 import java.io.IOException;
50 import java.lang.annotation.Target;
51 import java.util.Iterator;
52 import java.util.List;
53 import java.util.Optional;
54 import java.util.stream.Stream;
55 import javax.annotation.processing.Filer;
56 import javax.annotation.processing.ProcessingEnvironment;
57 import javax.inject.Inject;
58 import javax.inject.Provider;
59 import javax.lang.model.SourceVersion;
60 import javax.lang.model.element.AnnotationMirror;
61 import javax.lang.model.element.Element;
62 import javax.lang.model.element.VariableElement;
63 import javax.lang.model.type.DeclaredType;
64 import javax.lang.model.type.TypeKind;
65 import javax.lang.model.type.TypeMirror;
66 import javax.lang.model.type.TypeVariable;
67 import javax.lang.model.util.Elements;
68 
69 final class FactoryWriter {
70 
71   private final Filer filer;
72   private final Elements elements;
73   private final SourceVersion sourceVersion;
74   private final ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated;
75 
FactoryWriter( ProcessingEnvironment processingEnv, ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated)76   FactoryWriter(
77       ProcessingEnvironment processingEnv,
78       ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated) {
79     this.filer = processingEnv.getFiler();
80     this.elements = processingEnv.getElementUtils();
81     this.sourceVersion = processingEnv.getSourceVersion();
82     this.factoriesBeingCreated = factoriesBeingCreated;
83   }
84 
writeFactory(FactoryDescriptor descriptor)85   void writeFactory(FactoryDescriptor descriptor) throws IOException {
86     String factoryName = descriptor.name().className();
87     TypeSpec.Builder factory =
88         classBuilder(factoryName).addOriginatingElement(descriptor.declaration().targetType());
89     generatedAnnotationSpec(
90             elements,
91             sourceVersion,
92             AutoFactoryProcessor.class,
93             "https://github.com/google/auto/tree/master/factory")
94         .ifPresent(factory::addAnnotation);
95     if (!descriptor.allowSubclasses()) {
96       factory.addModifiers(FINAL);
97     }
98     if (descriptor.publicType()) {
99       factory.addModifiers(PUBLIC);
100     }
101 
102     factory.superclass(TypeName.get(descriptor.extendingType()));
103     for (TypeMirror implementingType : descriptor.implementingTypes()) {
104       factory.addSuperinterface(TypeName.get(implementingType));
105     }
106 
107     ImmutableSet<TypeVariableName> factoryTypeVariables = getFactoryTypeVariables(descriptor);
108 
109     addFactoryTypeParameters(factory, factoryTypeVariables);
110     addConstructorAndProviderFields(factory, descriptor);
111     addFactoryMethods(factory, descriptor, factoryTypeVariables);
112     addImplementationMethods(factory, descriptor);
113     addCheckNotNullMethod(factory, descriptor);
114 
115     JavaFile.builder(descriptor.name().packageName(), factory.build())
116         .skipJavaLangImports(true)
117         .build()
118         .writeTo(filer);
119   }
120 
addFactoryTypeParameters( TypeSpec.Builder factory, ImmutableSet<TypeVariableName> typeVariableNames)121   private static void addFactoryTypeParameters(
122       TypeSpec.Builder factory, ImmutableSet<TypeVariableName> typeVariableNames) {
123     factory.addTypeVariables(typeVariableNames);
124   }
125 
addConstructorAndProviderFields( TypeSpec.Builder factory, FactoryDescriptor descriptor)126   private void addConstructorAndProviderFields(
127       TypeSpec.Builder factory, FactoryDescriptor descriptor) {
128     MethodSpec.Builder constructor = constructorBuilder().addAnnotation(Inject.class);
129     if (descriptor.publicType()) {
130       constructor.addModifiers(PUBLIC);
131     }
132     Iterator<ProviderField> providerFields = descriptor.providers().values().iterator();
133     for (int argumentIndex = 1; providerFields.hasNext(); argumentIndex++) {
134       ProviderField provider = providerFields.next();
135       TypeName typeName = resolveTypeName(provider.key().type().get()).box();
136       TypeName providerType = ParameterizedTypeName.get(ClassName.get(Provider.class), typeName);
137       factory.addField(providerType, provider.name(), PRIVATE, FINAL);
138       if (provider.key().qualifier().isPresent()) {
139         // only qualify the constructor parameter
140         providerType = providerType.annotated(AnnotationSpec.get(provider.key().qualifier().get()));
141       }
142       constructor.addParameter(providerType, provider.name());
143       constructor.addStatement("this.$1L = checkNotNull($1L, $2L)", provider.name(), argumentIndex);
144     }
145 
146     factory.addMethod(constructor.build());
147   }
148 
addFactoryMethods( TypeSpec.Builder factory, FactoryDescriptor descriptor, ImmutableSet<TypeVariableName> factoryTypeVariables)149   private void addFactoryMethods(
150       TypeSpec.Builder factory,
151       FactoryDescriptor descriptor,
152       ImmutableSet<TypeVariableName> factoryTypeVariables) {
153     for (FactoryMethodDescriptor methodDescriptor : descriptor.methodDescriptors()) {
154       MethodSpec.Builder method =
155           methodBuilder(methodDescriptor.name())
156               .addTypeVariables(getMethodTypeVariables(methodDescriptor, factoryTypeVariables))
157               .returns(TypeName.get(methodDescriptor.returnType()))
158               .varargs(methodDescriptor.isVarArgs());
159       if (methodDescriptor.overridingMethod()) {
160         method.addAnnotation(Override.class);
161       }
162       if (methodDescriptor.publicMethod()) {
163         method.addModifiers(PUBLIC);
164       }
165       method.addExceptions(
166           methodDescriptor.exceptions().stream().map(TypeName::get).collect(toList()));
167       CodeBlock.Builder args = CodeBlock.builder();
168       method.addParameters(parameters(methodDescriptor.passedParameters()));
169       Iterator<Parameter> parameters = methodDescriptor.creationParameters().iterator();
170       for (int argumentIndex = 1; parameters.hasNext(); argumentIndex++) {
171         Parameter parameter = parameters.next();
172         boolean checkNotNull = !parameter.nullable().isPresent();
173         CodeBlock argument;
174         if (methodDescriptor.passedParameters().contains(parameter)) {
175           argument = CodeBlock.of(parameter.name());
176           if (parameter.isPrimitive()) {
177             checkNotNull = false;
178           }
179         } else {
180           ProviderField provider = requireNonNull(descriptor.providers().get(parameter.key()));
181           argument = CodeBlock.of(provider.name());
182           if (parameter.isProvider()) {
183             // Providers are checked for nullness in the Factory's constructor.
184             checkNotNull = false;
185           } else {
186             argument = CodeBlock.of("$L.get()", argument);
187           }
188         }
189         if (checkNotNull) {
190           argument = CodeBlock.of("checkNotNull($L, $L)", argument, argumentIndex);
191         }
192         args.add(argument);
193         if (parameters.hasNext()) {
194           args.add(", ");
195         }
196       }
197       method.addStatement("return new $T($L)", methodDescriptor.returnType(), args.build());
198       factory.addMethod(method.build());
199     }
200   }
201 
addImplementationMethods(TypeSpec.Builder factory, FactoryDescriptor descriptor)202   private void addImplementationMethods(TypeSpec.Builder factory, FactoryDescriptor descriptor) {
203     for (ImplementationMethodDescriptor methodDescriptor :
204         descriptor.implementationMethodDescriptors()) {
205       MethodSpec.Builder implementationMethod =
206           methodBuilder(methodDescriptor.name())
207               .addAnnotation(Override.class)
208               .returns(TypeName.get(methodDescriptor.returnType()))
209               .varargs(methodDescriptor.isVarArgs());
210       if (methodDescriptor.publicMethod()) {
211         implementationMethod.addModifiers(PUBLIC);
212       }
213       implementationMethod.addExceptions(
214           methodDescriptor.exceptions().stream().map(TypeName::get).collect(toList()));
215       implementationMethod.addParameters(parameters(methodDescriptor.passedParameters()));
216       implementationMethod.addStatement(
217           "return create($L)",
218           methodDescriptor.passedParameters().stream().map(Parameter::name).collect(joining(", ")));
219       factory.addMethod(implementationMethod.build());
220     }
221   }
222 
223   /**
224    * {@link ParameterSpec}s to match {@code parameters}. Note that the type of the {@link
225    * ParameterSpec}s match {@link Parameter#type()} and not {@link Key#type()}.
226    */
parameters(Iterable<Parameter> parameters)227   private ImmutableList<ParameterSpec> parameters(Iterable<Parameter> parameters) {
228     ImmutableList.Builder<ParameterSpec> builder = ImmutableList.builder();
229     for (Parameter parameter : parameters) {
230       TypeName type = resolveTypeName(parameter.type().get());
231       // Remove TYPE_USE annotations, since resolveTypeName will already have included those in
232       // the TypeName it returns.
233       List<AnnotationSpec> annotations =
234           Stream.of(parameter.nullable(), parameter.key().qualifier())
235               .flatMap(Streams::stream)
236               .filter(a -> !isTypeUseAnnotation(a))
237               .map(AnnotationSpec::get)
238               .collect(toList());
239       ParameterSpec parameterSpec =
240           ParameterSpec.builder(type, parameter.name()).addAnnotations(annotations).build();
241       builder.add(parameterSpec);
242     }
243     return builder.build();
244   }
245 
isTypeUseAnnotation(AnnotationMirror mirror)246   private static boolean isTypeUseAnnotation(AnnotationMirror mirror) {
247     Element annotationElement = mirror.getAnnotationType().asElement();
248     // This is basically equivalent to:
249     //    Target target = annotationElement.getAnnotation(Target.class);
250     //    return target != null
251     //        && Arrays.asList(annotationElement.getAnnotation(Target.class)).contains(TYPE_USE);
252     // but that might blow up if the annotation is being compiled at the same time and has an
253     // undefined identifier in its @Target values. The rigmarole below avoids that problem.
254     Optional<AnnotationMirror> maybeTargetMirror =
255         Mirrors.getAnnotationMirror(annotationElement, Target.class);
256     return maybeTargetMirror
257         .map(
258             targetMirror ->
259                 AnnotationValues.getEnums(
260                         AnnotationMirrors.getAnnotationValue(targetMirror, "value"))
261                     .stream()
262                     .map(VariableElement::getSimpleName)
263                     .anyMatch(name -> name.contentEquals("TYPE_USE")))
264         .orElse(false);
265   }
266 
addCheckNotNullMethod( TypeSpec.Builder factory, FactoryDescriptor descriptor)267   private static void addCheckNotNullMethod(
268       TypeSpec.Builder factory, FactoryDescriptor descriptor) {
269     if (shouldGenerateCheckNotNull(descriptor)) {
270       TypeVariableName typeVariable = TypeVariableName.get("T");
271       factory.addMethod(
272           methodBuilder("checkNotNull")
273               .addModifiers(PRIVATE, STATIC)
274               .addTypeVariable(typeVariable)
275               .returns(typeVariable)
276               .addParameter(typeVariable, "reference")
277               .addParameter(TypeName.INT, "argumentIndex")
278               .beginControlFlow("if (reference == null)")
279               .addStatement(
280                   "throw new $T($S + argumentIndex)",
281                   NullPointerException.class,
282                   "@AutoFactory method argument is null but is not marked @Nullable. Argument "
283                       + "index: ")
284               .endControlFlow()
285               .addStatement("return reference")
286               .build());
287     }
288   }
289 
shouldGenerateCheckNotNull(FactoryDescriptor descriptor)290   private static boolean shouldGenerateCheckNotNull(FactoryDescriptor descriptor) {
291     if (!descriptor.providers().isEmpty()) {
292       return true;
293     }
294     for (FactoryMethodDescriptor method : descriptor.methodDescriptors()) {
295       for (Parameter parameter : method.creationParameters()) {
296         if (!parameter.nullable().isPresent() && !parameter.type().get().getKind().isPrimitive()) {
297           return true;
298         }
299       }
300     }
301     return false;
302   }
303 
304   /**
305    * Returns an appropriate {@code TypeName} for the given type. If the type is an
306    * {@code ErrorType}, and if it is a simple-name reference to one of the {@code *Factory}
307    * classes that we are going to generate, then we return its fully-qualified name. In every other
308    * case we just return {@code TypeName.get(type)}. Specifically, if it is an {@code ErrorType}
309    * referencing some other type, or referencing one of the classes we are going to generate but
310    * using its fully-qualified name, then we leave it as-is. JavaPoet treats {@code TypeName.get(t)}
311    * the same for {@code ErrorType} as for {@code DeclaredType}, which means that if this is a name
312    * that will eventually be generated then the code we write that references the type will in fact
313    * compile.
314    *
315    * <p>A simpler alternative would be to defer processing to a later round if we find an
316    * {@code @AutoFactory} class that references undefined types, under the assumption that something
317    * else will generate those types in the meanwhile. However, this would fail if for example
318    * {@code @AutoFactory class Foo} has a constructor parameter of type {@code BarFactory} and
319    * {@code @AutoFactory class Bar} has a constructor parameter of type {@code FooFactory}. We did
320    * in fact find instances of this in Google's source base.
321    *
322    * <p>If the type has type annotations then include those in the returned {@link TypeName}.
323    */
resolveTypeName(TypeMirror type)324   private TypeName resolveTypeName(TypeMirror type) {
325     TypeName typeName = TypeName.get(type);
326     if (type.getKind() == TypeKind.ERROR) {
327       ImmutableSet<PackageAndClass> factoryNames = factoriesBeingCreated.get(type.toString());
328       if (factoryNames.size() == 1) {
329         PackageAndClass packageAndClass = Iterables.getOnlyElement(factoryNames);
330         typeName = ClassName.get(packageAndClass.packageName(), packageAndClass.className());
331       }
332     }
333     return typeName.annotated(
334         type.getAnnotationMirrors().stream().map(AnnotationSpec::get).collect(toList()));
335   }
336 
getFactoryTypeVariables( FactoryDescriptor descriptor)337   private static ImmutableSet<TypeVariableName> getFactoryTypeVariables(
338       FactoryDescriptor descriptor) {
339     ImmutableSet.Builder<TypeVariableName> typeVariables = ImmutableSet.builder();
340     for (ProviderField provider : descriptor.providers().values()) {
341       typeVariables.addAll(getReferencedTypeParameterNames(provider.key().type().get()));
342     }
343     // If a parent type has a type parameter, like FooFactory<T>, then the generated factory needs
344     // to have the same parameter, like FooImplFactory<T> extends FooFactory<T>. This is a little
345     // approximate, at least in the case where there is more than one parent type that has a type
346     // parameter. But that should be pretty rare, so let's keep it simple for now.
347     typeVariables.addAll(typeVariablesFrom(descriptor.extendingType()));
348     for (TypeMirror implementing : descriptor.implementingTypes()) {
349       typeVariables.addAll(typeVariablesFrom(implementing));
350     }
351     return typeVariables.build();
352   }
353 
typeVariablesFrom(TypeMirror type)354   private static List<TypeVariableName> typeVariablesFrom(TypeMirror type) {
355     if (type.getKind().equals(TypeKind.DECLARED)) {
356       DeclaredType declaredType = MoreTypes.asDeclared(type);
357       return declaredType.getTypeArguments().stream()
358           .filter(t -> t.getKind().equals(TypeKind.TYPEVAR))
359           .map(t -> TypeVariableName.get(MoreTypes.asTypeVariable(t)))
360           .collect(toList());
361     }
362     return ImmutableList.of();
363   }
364 
getMethodTypeVariables( FactoryMethodDescriptor methodDescriptor, ImmutableSet<TypeVariableName> factoryTypeVariables)365   private static ImmutableSet<TypeVariableName> getMethodTypeVariables(
366       FactoryMethodDescriptor methodDescriptor,
367       ImmutableSet<TypeVariableName> factoryTypeVariables) {
368     ImmutableSet.Builder<TypeVariableName> typeVariables = ImmutableSet.builder();
369     typeVariables.addAll(getReferencedTypeParameterNames(methodDescriptor.returnType()));
370     for (Parameter parameter : methodDescriptor.passedParameters()) {
371       typeVariables.addAll(getReferencedTypeParameterNames(parameter.type().get()));
372     }
373     return Sets.difference(typeVariables.build(), factoryTypeVariables).immutableCopy();
374   }
375 
getReferencedTypeParameterNames(TypeMirror type)376   private static ImmutableSet<TypeVariableName> getReferencedTypeParameterNames(TypeMirror type) {
377     ImmutableSet.Builder<TypeVariableName> typeVariableNames = ImmutableSet.builder();
378     for (TypeVariable typeVariable : TypeVariables.getReferencedTypeVariables(type)) {
379       typeVariableNames.add(TypeVariableName.get(typeVariable));
380     }
381     return typeVariableNames.build();
382   }
383 }
384