1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 // @exportToFramework:skipFile()
17 package androidx.appsearch.testutil.flags;
18 
19 import static org.junit.Assume.assumeTrue;
20 
21 import androidx.collection.ArrayMap;
22 
23 import java.lang.annotation.Annotation;
24 import java.util.Collection;
25 import java.util.List;
26 import java.util.Map;
27 
28 import org.jspecify.annotations.NonNull;
29 import org.jspecify.annotations.Nullable;
30 import org.junit.rules.TestRule;
31 import org.junit.runner.Description;
32 import org.junit.runners.model.Statement;
33 
34 /**
35  * Shim for real CheckFlagsRule defined in Framework.
36  *
37  * <p>In Jetpack, this shim only handles invocations for {@link RequiresFlagsEnabled} and
38  * {@link RequiresFlagsDisabled}. This rule does two things:
39  * <ul>
40  *     <li>checks that all {@link RequiresFlagsEnabled} and {@link RequiresFlagsDisabled}
41  *     annotations do not conflict.</li>
42  *     <li>skips any test/test class that has a {@link RequiresFlagsDisabled} annotation.</li>
43  * </ul>
44  */
45 public final class CheckFlagsRule implements TestRule {
46     @Override
apply(@onNull Statement base, @Nullable Description description)47     public @NonNull Statement apply(@NonNull Statement base, @Nullable Description description) {
48         return new Statement() {
49             @Override
50             public void evaluate() throws Throwable {
51                 Map<String, Boolean> requiredFlagValues = getMergedFlagValues(description);
52                 checkFlags(requiredFlagValues);
53                 base.evaluate();
54             }
55         };
56     }
57 
58     /**
59      * Checks that the only required flag values specified are from {@link RequiresFlagsEnabled}.
60      * The presence of any flag value specific in {@link RequiresFlagsDisabled} will result in the
61      * test being skipped.
62      */
63     private static void checkFlags(@NonNull Map<String, Boolean> requiredFlagValues) {
64         for (Map.Entry<String, Boolean> required : requiredFlagValues.entrySet()) {
65             final String flag = required.getKey();
66             assumeTrue(String.format("Flag %s required to be enabled, but is disabled", flag),
67                     required.getValue());
68         }
69     }
70 
71     /**
72      * Retrieves the value of all {@link RequiresFlagsEnabled} and {@link RequiresFlagsDisabled} for
73      * both the test class and the test method.
74      *
75      * @throws AssertionError - if the RequiresFlag annotations conflict with each other.
76      * @return a map holding the flag values and whether they are required to be enabled or
77      * disabled.
78      */
79     private static @NonNull Map<String, Boolean> getMergedFlagValues(
80             @NonNull Description description) {
81         final Map<String, Boolean> flagValues = new ArrayMap<>();
82         getFlagValuesFromAnnotations(description.getMethodName(), description.getAnnotations(),
83                 flagValues);
84         Class<?> testClass = description.getTestClass();
85         if (testClass != null) {
86             getFlagValuesFromAnnotations(testClass.getName(), List.of(testClass.getAnnotations()),
87                     flagValues);
88         }
89         return flagValues;
90     }
91 
92     private static void getFlagValuesFromAnnotations(
93             @NonNull String annotationTarget,
94             @NonNull Collection<Annotation> annotations,
95             @NonNull Map<String, Boolean> flagValues) {
96         for (Annotation annotation : annotations) {
97             if (annotation instanceof RequiresFlagsEnabled) {
98                 RequiresFlagsEnabled enabled = (RequiresFlagsEnabled) annotation;
99                 addFlagValues(annotationTarget, enabled.value(), Boolean.TRUE, flagValues);
100             } else if (annotation instanceof RequiresFlagsDisabled) {
101                 RequiresFlagsDisabled disabled = (RequiresFlagsDisabled) annotation;
102                 addFlagValues(annotationTarget, disabled.value(), Boolean.FALSE, flagValues);
103             }
104         }
105     }
106 
107     private static void addFlagValues(@NonNull String annotationTarget, @NonNull String[] flags,
108             @NonNull Boolean value, @NonNull Map<String, Boolean> flagValues) {
109         for (String flagName : flags) {
110             Boolean existingValue = flagValues.get(flagName);
111             if (existingValue == null) {
112                 flagValues.put(flagName, value);
113             } else if (!existingValue.equals(value)) {
114                 throw new AssertionError(
115                         "Flag '" + flagName + "' are required by " + annotationTarget
116                                 + " to be both enabled and disabled.");
117             }
118         }
119     }
120 }
121