• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2013 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.auto.factory.processor;
17 
18 import static com.google.auto.common.GeneratedAnnotationSpecs.generatedAnnotationSpec;
19 import static com.google.auto.common.MoreStreams.toImmutableList;
20 import static com.squareup.javapoet.MethodSpec.constructorBuilder;
21 import static com.squareup.javapoet.MethodSpec.methodBuilder;
22 import static com.squareup.javapoet.TypeSpec.classBuilder;
23 import static java.util.Objects.requireNonNull;
24 import static java.util.stream.Collectors.joining;
25 import static java.util.stream.Collectors.toList;
26 import static javax.lang.model.element.Modifier.FINAL;
27 import static javax.lang.model.element.Modifier.PRIVATE;
28 import static javax.lang.model.element.Modifier.PUBLIC;
29 import static javax.lang.model.element.Modifier.STATIC;
30 
31 import com.google.auto.common.MoreTypes;
32 import com.google.common.collect.ImmutableCollection;
33 import com.google.common.collect.ImmutableList;
34 import com.google.common.collect.ImmutableSet;
35 import com.google.common.collect.ImmutableSetMultimap;
36 import com.google.common.collect.Iterables;
37 import com.google.common.collect.Sets;
38 import com.squareup.javapoet.AnnotationSpec;
39 import com.squareup.javapoet.ClassName;
40 import com.squareup.javapoet.CodeBlock;
41 import com.squareup.javapoet.JavaFile;
42 import com.squareup.javapoet.MethodSpec;
43 import com.squareup.javapoet.ParameterSpec;
44 import com.squareup.javapoet.ParameterizedTypeName;
45 import com.squareup.javapoet.TypeName;
46 import com.squareup.javapoet.TypeSpec;
47 import com.squareup.javapoet.TypeVariableName;
48 import java.io.IOException;
49 import java.util.List;
50 import javax.annotation.processing.Filer;
51 import javax.annotation.processing.ProcessingEnvironment;
52 import javax.inject.Inject;
53 import javax.inject.Provider;
54 import javax.lang.model.SourceVersion;
55 import javax.lang.model.type.DeclaredType;
56 import javax.lang.model.type.TypeKind;
57 import javax.lang.model.type.TypeMirror;
58 import javax.lang.model.type.TypeVariable;
59 import javax.lang.model.util.Elements;
60 
61 final class FactoryWriter {
62 
63   private final Filer filer;
64   private final Elements elements;
65   private final SourceVersion sourceVersion;
66   private final ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated;
67 
FactoryWriter( ProcessingEnvironment processingEnv, ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated)68   FactoryWriter(
69       ProcessingEnvironment processingEnv,
70       ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated) {
71     this.filer = processingEnv.getFiler();
72     this.elements = processingEnv.getElementUtils();
73     this.sourceVersion = processingEnv.getSourceVersion();
74     this.factoriesBeingCreated = factoriesBeingCreated;
75   }
76 
writeFactory(FactoryDescriptor descriptor)77   void writeFactory(FactoryDescriptor descriptor) throws IOException {
78     String factoryName = descriptor.name().className();
79     TypeSpec.Builder factory =
80         classBuilder(factoryName).addOriginatingElement(descriptor.declaration().targetType());
81     generatedAnnotationSpec(
82             elements,
83             sourceVersion,
84             AutoFactoryProcessor.class,
85             "https://github.com/google/auto/tree/main/factory")
86         .ifPresent(factory::addAnnotation);
87     descriptor.annotations().forEach(a -> factory.addAnnotation(AnnotationSpec.get(a)));
88     if (!descriptor.allowSubclasses()) {
89       factory.addModifiers(FINAL);
90     }
91     if (descriptor.publicType()) {
92       factory.addModifiers(PUBLIC);
93     }
94 
95     factory.superclass(TypeName.get(descriptor.extendingType()));
96     for (TypeMirror implementingType : descriptor.implementingTypes()) {
97       factory.addSuperinterface(TypeName.get(implementingType));
98     }
99 
100     ImmutableSet<TypeVariableName> factoryTypeVariables = getFactoryTypeVariables(descriptor);
101 
102     addFactoryTypeParameters(factory, factoryTypeVariables);
103     addConstructorAndProviderFields(factory, descriptor);
104     addFactoryMethods(factory, descriptor, factoryTypeVariables);
105     addImplementationMethods(factory, descriptor);
106     addCheckNotNullMethod(factory, descriptor);
107 
108     JavaFile.builder(descriptor.name().packageName(), factory.build())
109         .skipJavaLangImports(true)
110         .build()
111         .writeTo(filer);
112   }
113 
addFactoryTypeParameters( TypeSpec.Builder factory, ImmutableSet<TypeVariableName> typeVariableNames)114   private static void addFactoryTypeParameters(
115       TypeSpec.Builder factory, ImmutableSet<TypeVariableName> typeVariableNames) {
116     factory.addTypeVariables(typeVariableNames);
117   }
118 
addConstructorAndProviderFields( TypeSpec.Builder factory, FactoryDescriptor descriptor)119   private void addConstructorAndProviderFields(
120       TypeSpec.Builder factory, FactoryDescriptor descriptor) {
121     MethodSpec.Builder constructor = constructorBuilder().addAnnotation(Inject.class);
122     if (descriptor.publicType()) {
123       constructor.addModifiers(PUBLIC);
124     }
125     ImmutableCollection<ProviderField> providerFields = descriptor.providers().values();
126     int argumentNumber = 0;
127     for (ProviderField provider : providerFields) {
128       ++argumentNumber;
129       TypeName typeName = resolveTypeName(provider.key().type().get()).box();
130       TypeName providerType = ParameterizedTypeName.get(ClassName.get(Provider.class), typeName);
131       factory.addField(providerType, provider.name(), PRIVATE, FINAL);
132       if (provider.key().qualifier().isPresent()) {
133         // only qualify the constructor parameter
134         providerType = providerType.annotated(AnnotationSpec.get(provider.key().qualifier().get()));
135       }
136       constructor.addParameter(providerType, provider.name());
137       constructor.addStatement(
138           "this.$1L = checkNotNull($1L, $2L, $3L)",
139           provider.name(),
140           argumentNumber,
141           providerFields.size());
142     }
143 
144     factory.addMethod(constructor.build());
145   }
146 
addFactoryMethods( TypeSpec.Builder factory, FactoryDescriptor descriptor, ImmutableSet<TypeVariableName> factoryTypeVariables)147   private void addFactoryMethods(
148       TypeSpec.Builder factory,
149       FactoryDescriptor descriptor,
150       ImmutableSet<TypeVariableName> factoryTypeVariables) {
151     for (FactoryMethodDescriptor methodDescriptor : descriptor.methodDescriptors()) {
152       MethodSpec.Builder method =
153           methodBuilder(methodDescriptor.name())
154               .addTypeVariables(getMethodTypeVariables(methodDescriptor, factoryTypeVariables))
155               .returns(TypeName.get(methodDescriptor.returnType()))
156               .varargs(methodDescriptor.isVarArgs());
157       if (methodDescriptor.overridingMethod()) {
158         method.addAnnotation(Override.class);
159       }
160       if (methodDescriptor.publicMethod()) {
161         method.addModifiers(PUBLIC);
162       }
163       method.addExceptions(
164           methodDescriptor.exceptions().stream().map(TypeName::get).collect(toList()));
165       CodeBlock.Builder args = CodeBlock.builder();
166       method.addParameters(parameters(methodDescriptor.passedParameters()));
167       ImmutableSet<Parameter> parameters = methodDescriptor.creationParameters();
168       int argumentNumber = 0;
169       String sep = "";
170       for (Parameter parameter : parameters) {
171         ++argumentNumber;
172         args.add(sep);
173         sep = ", ";
174         boolean checkNotNull = !parameter.nullable().isPresent();
175         CodeBlock argument;
176         if (methodDescriptor.passedParameters().contains(parameter)) {
177           argument = CodeBlock.of(parameter.name());
178           if (parameter.isPrimitive()) {
179             checkNotNull = false;
180           }
181         } else {
182           ProviderField provider = requireNonNull(descriptor.providers().get(parameter.key()));
183           argument = CodeBlock.of(provider.name());
184           if (parameter.isProvider()) {
185             // Providers are checked for nullness in the Factory's constructor.
186             checkNotNull = false;
187           } else {
188             argument = CodeBlock.of("$L.get()", argument);
189           }
190         }
191         if (checkNotNull) {
192           argument =
193               CodeBlock.of("checkNotNull($L, $L, $L)", argument, argumentNumber, parameters.size());
194         }
195         args.add(argument);
196       }
197       method.addStatement("return new $T($L)", methodDescriptor.returnType(), args.build());
198       factory.addMethod(method.build());
199     }
200   }
201 
addImplementationMethods(TypeSpec.Builder factory, FactoryDescriptor descriptor)202   private void addImplementationMethods(TypeSpec.Builder factory, FactoryDescriptor descriptor) {
203     for (ImplementationMethodDescriptor methodDescriptor :
204         descriptor.implementationMethodDescriptors()) {
205       MethodSpec.Builder implementationMethod =
206           methodBuilder(methodDescriptor.name())
207               .addAnnotation(Override.class)
208               .returns(TypeName.get(methodDescriptor.returnType()))
209               .varargs(methodDescriptor.isVarArgs());
210       if (methodDescriptor.publicMethod()) {
211         implementationMethod.addModifiers(PUBLIC);
212       }
213       implementationMethod.addExceptions(
214           methodDescriptor.exceptions().stream().map(TypeName::get).collect(toList()));
215       implementationMethod.addParameters(parameters(methodDescriptor.passedParameters()));
216       implementationMethod.addStatement(
217           "return create($L)",
218           methodDescriptor.passedParameters().stream().map(Parameter::name).collect(joining(", ")));
219       factory.addMethod(implementationMethod.build());
220     }
221   }
222 
223   /**
224    * {@link ParameterSpec}s to match {@code parameters}. Note that the type of the {@link
225    * ParameterSpec}s match {@link Parameter#type()} and not {@link Key#type()}.
226    */
parameters(Iterable<Parameter> parameters)227   private ImmutableList<ParameterSpec> parameters(Iterable<Parameter> parameters) {
228     ImmutableList.Builder<ParameterSpec> builder = ImmutableList.builder();
229     for (Parameter parameter : parameters) {
230       TypeName type = resolveTypeName(parameter.type().get());
231       ImmutableList<AnnotationSpec> annotations =
232           parameter.annotations().stream().map(AnnotationSpec::get).collect(toImmutableList());
233       ParameterSpec parameterSpec =
234           ParameterSpec.builder(type, parameter.name()).addAnnotations(annotations).build();
235       builder.add(parameterSpec);
236     }
237     return builder.build();
238   }
239 
addCheckNotNullMethod( TypeSpec.Builder factory, FactoryDescriptor descriptor)240   private static void addCheckNotNullMethod(
241       TypeSpec.Builder factory, FactoryDescriptor descriptor) {
242     if (shouldGenerateCheckNotNull(descriptor)) {
243       TypeVariableName typeVariable = TypeVariableName.get("T");
244       factory.addMethod(
245           methodBuilder("checkNotNull")
246               .addModifiers(PRIVATE, STATIC)
247               .addTypeVariable(typeVariable)
248               .returns(typeVariable)
249               .addParameter(typeVariable, "reference")
250               .addParameter(TypeName.INT, "argumentNumber")
251               .addParameter(TypeName.INT, "argumentCount")
252               .beginControlFlow("if (reference == null)")
253               .addStatement(
254                   "throw new $T($S + argumentNumber + $S + argumentCount)",
255                   NullPointerException.class,
256                   "@AutoFactory method argument is null but is not marked @Nullable. Argument ",
257                   " of ")
258               .endControlFlow()
259               .addStatement("return reference")
260               .build());
261     }
262   }
263 
shouldGenerateCheckNotNull(FactoryDescriptor descriptor)264   private static boolean shouldGenerateCheckNotNull(FactoryDescriptor descriptor) {
265     if (!descriptor.providers().isEmpty()) {
266       return true;
267     }
268     for (FactoryMethodDescriptor method : descriptor.methodDescriptors()) {
269       for (Parameter parameter : method.creationParameters()) {
270         if (!parameter.nullable().isPresent() && !parameter.type().get().getKind().isPrimitive()) {
271           return true;
272         }
273       }
274     }
275     return false;
276   }
277 
278   /**
279    * Returns an appropriate {@code TypeName} for the given type. If the type is an {@code
280    * ErrorType}, and if it is a simple-name reference to one of the {@code *Factory} classes that we
281    * are going to generate, then we return its fully-qualified name. In every other case we just
282    * return {@code TypeName.get(type)}. Specifically, if it is an {@code ErrorType} referencing some
283    * other type, or referencing one of the classes we are going to generate but using its
284    * fully-qualified name, then we leave it as-is. JavaPoet treats {@code TypeName.get(t)} the same
285    * for {@code ErrorType} as for {@code DeclaredType}, which means that if this is a name that will
286    * eventually be generated then the code we write that references the type will in fact compile.
287    *
288    * <p>A simpler alternative would be to defer processing to a later round if we find an
289    * {@code @AutoFactory} class that references undefined types, under the assumption that something
290    * else will generate those types in the meanwhile. However, this would fail if for example
291    * {@code @AutoFactory class Foo} has a constructor parameter of type {@code BarFactory} and
292    * {@code @AutoFactory class Bar} has a constructor parameter of type {@code FooFactory}. We did
293    * in fact find instances of this in Google's source base.
294    *
295    * <p>If the type has type annotations then include those in the returned {@link TypeName}.
296    */
resolveTypeName(TypeMirror type)297   private TypeName resolveTypeName(TypeMirror type) {
298     TypeName typeName = TypeName.get(type);
299     if (type.getKind() == TypeKind.ERROR) {
300       ImmutableSet<PackageAndClass> factoryNames = factoriesBeingCreated.get(type.toString());
301       if (factoryNames.size() == 1) {
302         PackageAndClass packageAndClass = Iterables.getOnlyElement(factoryNames);
303         typeName = ClassName.get(packageAndClass.packageName(), packageAndClass.className());
304       }
305     }
306     return typeName.annotated(
307         type.getAnnotationMirrors().stream().map(AnnotationSpec::get).collect(toList()));
308   }
309 
getFactoryTypeVariables( FactoryDescriptor descriptor)310   private static ImmutableSet<TypeVariableName> getFactoryTypeVariables(
311       FactoryDescriptor descriptor) {
312     ImmutableSet.Builder<TypeVariableName> typeVariables = ImmutableSet.builder();
313     for (ProviderField provider : descriptor.providers().values()) {
314       typeVariables.addAll(getReferencedTypeParameterNames(provider.key().type().get()));
315     }
316     // If a parent type has a type parameter, like FooFactory<T>, then the generated factory needs
317     // to have the same parameter, like FooImplFactory<T> extends FooFactory<T>. This is a little
318     // approximate, at least in the case where there is more than one parent type that has a type
319     // parameter. But that should be pretty rare, so let's keep it simple for now.
320     typeVariables.addAll(typeVariablesFrom(descriptor.extendingType()));
321     for (TypeMirror implementing : descriptor.implementingTypes()) {
322       typeVariables.addAll(typeVariablesFrom(implementing));
323     }
324     return typeVariables.build();
325   }
326 
typeVariablesFrom(TypeMirror type)327   private static List<TypeVariableName> typeVariablesFrom(TypeMirror type) {
328     if (type.getKind().equals(TypeKind.DECLARED)) {
329       DeclaredType declaredType = MoreTypes.asDeclared(type);
330       return declaredType.getTypeArguments().stream()
331           .filter(t -> t.getKind().equals(TypeKind.TYPEVAR))
332           .map(t -> TypeVariableName.get(MoreTypes.asTypeVariable(t)))
333           .collect(toList());
334     }
335     return ImmutableList.of();
336   }
337 
getMethodTypeVariables( FactoryMethodDescriptor methodDescriptor, ImmutableSet<TypeVariableName> factoryTypeVariables)338   private static ImmutableSet<TypeVariableName> getMethodTypeVariables(
339       FactoryMethodDescriptor methodDescriptor,
340       ImmutableSet<TypeVariableName> factoryTypeVariables) {
341     ImmutableSet.Builder<TypeVariableName> typeVariables = ImmutableSet.builder();
342     typeVariables.addAll(getReferencedTypeParameterNames(methodDescriptor.returnType()));
343     for (Parameter parameter : methodDescriptor.passedParameters()) {
344       typeVariables.addAll(getReferencedTypeParameterNames(parameter.type().get()));
345     }
346     return Sets.difference(typeVariables.build(), factoryTypeVariables).immutableCopy();
347   }
348 
getReferencedTypeParameterNames(TypeMirror type)349   private static ImmutableSet<TypeVariableName> getReferencedTypeParameterNames(TypeMirror type) {
350     ImmutableSet.Builder<TypeVariableName> typeVariableNames = ImmutableSet.builder();
351     for (TypeVariable typeVariable : TypeVariables.getReferencedTypeVariables(type)) {
352       typeVariableNames.add(TypeVariableName.get(typeVariable));
353     }
354     return typeVariableNames.build();
355   }
356 }
357