1 package org.junit.runners.model; 2 3 import static java.lang.reflect.Modifier.isStatic; 4 import static org.junit.internal.MethodSorter.NAME_ASCENDING; 5 6 import java.lang.annotation.Annotation; 7 import java.lang.reflect.Constructor; 8 import java.lang.reflect.Field; 9 import java.lang.reflect.Method; 10 import java.lang.reflect.Modifier; 11 import java.util.ArrayList; 12 import java.util.Arrays; 13 import java.util.Collections; 14 import java.util.Comparator; 15 import java.util.LinkedHashMap; 16 import java.util.LinkedHashSet; 17 import java.util.List; 18 import java.util.Map; 19 import java.util.Set; 20 21 import org.junit.Assert; 22 import org.junit.Before; 23 import org.junit.BeforeClass; 24 import org.junit.internal.MethodSorter; 25 26 /** 27 * Wraps a class to be run, providing method validation and annotation searching 28 * 29 * @since 4.5 30 */ 31 public class TestClass implements Annotatable { 32 private static final FieldComparator FIELD_COMPARATOR = new FieldComparator(); 33 private static final MethodComparator METHOD_COMPARATOR = new MethodComparator(); 34 35 private final Class<?> clazz; 36 private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations; 37 private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations; 38 39 /** 40 * Creates a {@code TestClass} wrapping {@code clazz}. Each time this 41 * constructor executes, the class is scanned for annotations, which can be 42 * an expensive process (we hope in future JDK's it will not be.) Therefore, 43 * try to share instances of {@code TestClass} where possible. 44 */ TestClass(Class<?> clazz)45 public TestClass(Class<?> clazz) { 46 this.clazz = clazz; 47 if (clazz != null && clazz.getConstructors().length > 1) { 48 throw new IllegalArgumentException( 49 "Test class can only have one constructor"); 50 } 51 52 Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations = 53 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>(); 54 Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations = 55 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>(); 56 57 scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations); 58 59 this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations); 60 this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations); 61 } 62 scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations)63 protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) { 64 for (Class<?> eachClass : getSuperClasses(clazz)) { 65 for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) { 66 addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations); 67 } 68 // ensuring fields are sorted to make sure that entries are inserted 69 // and read from fieldForAnnotations in a deterministic order 70 for (Field eachField : getSortedDeclaredFields(eachClass)) { 71 addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations); 72 } 73 } 74 } 75 getSortedDeclaredFields(Class<?> clazz)76 private static Field[] getSortedDeclaredFields(Class<?> clazz) { 77 Field[] declaredFields = clazz.getDeclaredFields(); 78 Arrays.sort(declaredFields, FIELD_COMPARATOR); 79 return declaredFields; 80 } 81 addToAnnotationLists(T member, Map<Class<? extends Annotation>, List<T>> map)82 protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member, 83 Map<Class<? extends Annotation>, List<T>> map) { 84 for (Annotation each : member.getAnnotations()) { 85 Class<? extends Annotation> type = each.annotationType(); 86 List<T> members = getAnnotatedMembers(map, type, true); 87 T memberToAdd = member.handlePossibleBridgeMethod(members); 88 if (memberToAdd == null) { 89 return; 90 } 91 if (runsTopToBottom(type)) { 92 members.add(0, memberToAdd); 93 } else { 94 members.add(memberToAdd); 95 } 96 } 97 } 98 99 private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>> makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source)100 makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) { 101 Map<Class<? extends Annotation>, List<T>> copy = 102 new LinkedHashMap<Class<? extends Annotation>, List<T>>(); 103 for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) { 104 copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue())); 105 } 106 return Collections.unmodifiableMap(copy); 107 } 108 109 /** 110 * Returns, efficiently, all the non-overridden methods in this class and 111 * its superclasses that are annotated}. 112 * 113 * @since 4.12 114 */ getAnnotatedMethods()115 public List<FrameworkMethod> getAnnotatedMethods() { 116 List<FrameworkMethod> methods = collectValues(methodsForAnnotations); 117 Collections.sort(methods, METHOD_COMPARATOR); 118 return methods; 119 } 120 121 /** 122 * Returns, efficiently, all the non-overridden methods in this class and 123 * its superclasses that are annotated with {@code annotationClass}. 124 */ getAnnotatedMethods( Class<? extends Annotation> annotationClass)125 public List<FrameworkMethod> getAnnotatedMethods( 126 Class<? extends Annotation> annotationClass) { 127 return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false)); 128 } 129 130 /** 131 * Returns, efficiently, all the non-overridden fields in this class and its 132 * superclasses that are annotated. 133 * 134 * @since 4.12 135 */ getAnnotatedFields()136 public List<FrameworkField> getAnnotatedFields() { 137 return collectValues(fieldsForAnnotations); 138 } 139 140 /** 141 * Returns, efficiently, all the non-overridden fields in this class and its 142 * superclasses that are annotated with {@code annotationClass}. 143 */ getAnnotatedFields( Class<? extends Annotation> annotationClass)144 public List<FrameworkField> getAnnotatedFields( 145 Class<? extends Annotation> annotationClass) { 146 return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false)); 147 } 148 collectValues(Map<?, List<T>> map)149 private <T> List<T> collectValues(Map<?, List<T>> map) { 150 Set<T> values = new LinkedHashSet<T>(); 151 for (List<T> additionalValues : map.values()) { 152 values.addAll(additionalValues); 153 } 154 return new ArrayList<T>(values); 155 } 156 getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map, Class<? extends Annotation> type, boolean fillIfAbsent)157 private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map, 158 Class<? extends Annotation> type, boolean fillIfAbsent) { 159 if (!map.containsKey(type) && fillIfAbsent) { 160 map.put(type, new ArrayList<T>()); 161 } 162 List<T> members = map.get(type); 163 return members == null ? Collections.<T>emptyList() : members; 164 } 165 runsTopToBottom(Class<? extends Annotation> annotation)166 private static boolean runsTopToBottom(Class<? extends Annotation> annotation) { 167 return annotation.equals(Before.class) 168 || annotation.equals(BeforeClass.class); 169 } 170 getSuperClasses(Class<?> testClass)171 private static List<Class<?>> getSuperClasses(Class<?> testClass) { 172 List<Class<?>> results = new ArrayList<Class<?>>(); 173 Class<?> current = testClass; 174 while (current != null) { 175 results.add(current); 176 current = current.getSuperclass(); 177 } 178 return results; 179 } 180 181 /** 182 * Returns the underlying Java class. 183 */ getJavaClass()184 public Class<?> getJavaClass() { 185 return clazz; 186 } 187 188 /** 189 * Returns the class's name. 190 */ getName()191 public String getName() { 192 if (clazz == null) { 193 return "null"; 194 } 195 return clazz.getName(); 196 } 197 198 /** 199 * Returns the only public constructor in the class, or throws an {@code 200 * AssertionError} if there are more or less than one. 201 */ 202 getOnlyConstructor()203 public Constructor<?> getOnlyConstructor() { 204 Constructor<?>[] constructors = clazz.getConstructors(); 205 Assert.assertEquals(1, constructors.length); 206 return constructors[0]; 207 } 208 209 /** 210 * Returns the annotations on this class 211 */ getAnnotations()212 public Annotation[] getAnnotations() { 213 if (clazz == null) { 214 return new Annotation[0]; 215 } 216 return clazz.getAnnotations(); 217 } 218 getAnnotation(Class<T> annotationType)219 public <T extends Annotation> T getAnnotation(Class<T> annotationType) { 220 if (clazz == null) { 221 return null; 222 } 223 return clazz.getAnnotation(annotationType); 224 } 225 getAnnotatedFieldValues(Object test, Class<? extends Annotation> annotationClass, Class<T> valueClass)226 public <T> List<T> getAnnotatedFieldValues(Object test, 227 Class<? extends Annotation> annotationClass, Class<T> valueClass) { 228 final List<T> results = new ArrayList<T>(); 229 collectAnnotatedFieldValues(test, annotationClass, valueClass, 230 new MemberValueConsumer<T>() { 231 public void accept(FrameworkMember<?> member, T value) { 232 results.add(value); 233 } 234 }); 235 return results; 236 } 237 238 /** 239 * Finds the fields annotated with the specified annotation and having the specified type, 240 * retrieves the values and passes those to the specified consumer. 241 * 242 * @since 4.13 243 */ collectAnnotatedFieldValues(Object test, Class<? extends Annotation> annotationClass, Class<T> valueClass, MemberValueConsumer<T> consumer)244 public <T> void collectAnnotatedFieldValues(Object test, 245 Class<? extends Annotation> annotationClass, Class<T> valueClass, 246 MemberValueConsumer<T> consumer) { 247 for (FrameworkField each : getAnnotatedFields(annotationClass)) { 248 try { 249 Object fieldValue = each.get(test); 250 if (valueClass.isInstance(fieldValue)) { 251 consumer.accept(each, valueClass.cast(fieldValue)); 252 } 253 } catch (IllegalAccessException e) { 254 throw new RuntimeException( 255 "How did getFields return a field we couldn't access?", e); 256 } 257 } 258 } 259 getAnnotatedMethodValues(Object test, Class<? extends Annotation> annotationClass, Class<T> valueClass)260 public <T> List<T> getAnnotatedMethodValues(Object test, 261 Class<? extends Annotation> annotationClass, Class<T> valueClass) { 262 final List<T> results = new ArrayList<T>(); 263 collectAnnotatedMethodValues(test, annotationClass, valueClass, 264 new MemberValueConsumer<T>() { 265 public void accept(FrameworkMember<?> member, T value) { 266 results.add(value); 267 } 268 }); 269 return results; 270 } 271 272 /** 273 * Finds the methods annotated with the specified annotation and returning the specified type, 274 * invokes it and pass the return value to the specified consumer. 275 * 276 * @since 4.13 277 */ collectAnnotatedMethodValues(Object test, Class<? extends Annotation> annotationClass, Class<T> valueClass, MemberValueConsumer<T> consumer)278 public <T> void collectAnnotatedMethodValues(Object test, 279 Class<? extends Annotation> annotationClass, Class<T> valueClass, 280 MemberValueConsumer<T> consumer) { 281 for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) { 282 try { 283 /* 284 * A method annotated with @Rule may return a @TestRule or a @MethodRule, 285 * we cannot call the method to check whether the return type matches our 286 * expectation i.e. subclass of valueClass. If we do that then the method 287 * will be invoked twice and we do not want to do that. So we first check 288 * whether return type matches our expectation and only then call the method 289 * to fetch the MethodRule 290 */ 291 if (valueClass.isAssignableFrom(each.getReturnType())) { 292 Object fieldValue = each.invokeExplosively(test); 293 consumer.accept(each, valueClass.cast(fieldValue)); 294 } 295 } catch (Throwable e) { 296 throw new RuntimeException( 297 "Exception in " + each.getName(), e); 298 } 299 } 300 } 301 isPublic()302 public boolean isPublic() { 303 return Modifier.isPublic(clazz.getModifiers()); 304 } 305 isANonStaticInnerClass()306 public boolean isANonStaticInnerClass() { 307 return clazz.isMemberClass() && !isStatic(clazz.getModifiers()); 308 } 309 310 @Override hashCode()311 public int hashCode() { 312 return (clazz == null) ? 0 : clazz.hashCode(); 313 } 314 315 @Override equals(Object obj)316 public boolean equals(Object obj) { 317 if (this == obj) { 318 return true; 319 } 320 if (obj == null) { 321 return false; 322 } 323 if (getClass() != obj.getClass()) { 324 return false; 325 } 326 TestClass other = (TestClass) obj; 327 return clazz == other.clazz; 328 } 329 330 /** 331 * Compares two fields by its name. 332 */ 333 private static class FieldComparator implements Comparator<Field> { compare(Field left, Field right)334 public int compare(Field left, Field right) { 335 return left.getName().compareTo(right.getName()); 336 } 337 } 338 339 /** 340 * Compares two methods by its name. 341 */ 342 private static class MethodComparator implements 343 Comparator<FrameworkMethod> { compare(FrameworkMethod left, FrameworkMethod right)344 public int compare(FrameworkMethod left, FrameworkMethod right) { 345 return NAME_ASCENDING.compare(left.getMethod(), right.getMethod()); 346 } 347 } 348 } 349