1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <sstream>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include <jni.h>
22 #include <nativehelper/ScopedLocalRef.h>
23 #include <gtest/gtest.h>
24 
25 static JavaVM* gVm = nullptr;
GetJavaVM()26 JavaVM* GetJavaVM() {
27   return gVm;
28 }
RegisterJavaVm(JNIEnv * env)29 static void RegisterJavaVm(JNIEnv* env) {
30   (void)env->GetJavaVM(&gVm);
31 }
32 
33 namespace {
34 
35     struct {
36         jclass clazz;
37 
38         /** static methods **/
39         jmethodID createTestDescription;
40 
41         /** methods **/
42         jmethodID addChild;
43     } gDescription;
44 
45     struct {
46         jclass clazz;
47 
48         jmethodID fireTestStarted;
49         jmethodID fireTestIgnored;
50         jmethodID fireTestFailure;
51         jmethodID fireTestFinished;
52 
53     } gRunNotifier;
54 
55     struct {
56         jclass clazz;
57         jmethodID ctor;
58     } gAssertionFailure;
59 
60     struct {
61         jclass clazz;
62         jmethodID ctor;
63     } gFailure;
64 
65     jobject gEmptyAnnotationsArray;
66 
67     struct TestNameInfo {
68         std::string nativeName;
69         bool run;
70     };
71 // Maps mangled test names to native test names.
72     std::unordered_map<std::string, TestNameInfo> gNativeTestNames;
73 
74 // Return the full native test name as a Java method name, which does not allow
75 // slashes or dots. Store the original name for later lookup.
registerAndMangleTestName(const std::string & nativeName)76     std::string registerAndMangleTestName(const std::string& nativeName) {
77       std::string mangledName = nativeName;
78       std::replace(mangledName.begin(), mangledName.end(), '.', '_');
79       std::replace(mangledName.begin(), mangledName.end(), '/', '_');
80       gNativeTestNames.insert(std::make_pair(mangledName, TestNameInfo{nativeName, false}));
81       return mangledName;
82     }
83 
84 // Creates org.junit.runner.Description object for a GTest given its name.
createTestDescription(JNIEnv * env,jstring className,const std::string & mangledName)85     jobject createTestDescription(JNIEnv* env, jstring className, const std::string& mangledName) {
86       ScopedLocalRef<jstring> jTestName(env, env->NewStringUTF(mangledName.c_str()));
87       return env->CallStaticObjectMethod(gDescription.clazz, gDescription.createTestDescription,
88                                          className, jTestName.get(), gEmptyAnnotationsArray);
89     }
90 
createTestDescription(JNIEnv * env,jstring className,const char * testCaseName,const char * testName)91     jobject createTestDescription(JNIEnv* env, jstring className, const char* testCaseName, const char* testName) {
92       std::ostringstream nativeNameStream;
93       nativeNameStream << testCaseName << "." << testName;
94       std::string mangledName = registerAndMangleTestName(nativeNameStream.str());
95       return createTestDescription(env, className, mangledName);
96     }
97 
addChild(JNIEnv * env,jobject description,jobject childDescription)98     void addChild(JNIEnv* env, jobject description, jobject childDescription) {
99       env->CallVoidMethod(description, gDescription.addChild, childDescription);
100     }
101 
102 
103     class JUnitNotifyingListener : public ::testing::EmptyTestEventListener {
104     public:
105 
JUnitNotifyingListener(JNIEnv * env,jstring className,jobject runNotifier)106         JUnitNotifyingListener(JNIEnv* env, jstring className, jobject runNotifier)
107                 : mEnv(env)
108                 , mRunNotifier(runNotifier)
109                 , mClassName(className)
110                 , mCurrentTestDescription{env, nullptr}
111         {}
~JUnitNotifyingListener()112         virtual ~JUnitNotifyingListener() {}
113 
OnTestStart(const testing::TestInfo & testInfo)114         virtual void OnTestStart(const testing::TestInfo &testInfo) override {
115           mCurrentTestDescription.reset(
116                   createTestDescription(mEnv, mClassName, testInfo.test_case_name(), testInfo.name()));
117           notify(gRunNotifier.fireTestStarted);
118         }
119 
OnTestPartResult(const testing::TestPartResult & testPartResult)120         virtual void OnTestPartResult(const testing::TestPartResult &testPartResult) override {
121           if (!testPartResult.passed()) {
122             const char* file_name = testPartResult.file_name() != nullptr ? testPartResult.file_name() : "unknown file";
123             mCurrentTestError << "\n" << file_name << ":" << testPartResult.line_number()
124                               << "\n" << testPartResult.message() << "\n";
125           }
126         }
127 
OnTestEnd(const testing::TestInfo &)128         virtual void OnTestEnd(const testing::TestInfo&) override {
129           const std::string error = mCurrentTestError.str();
130 
131           if (!error.empty()) {
132             ScopedLocalRef<jstring> jmessage(mEnv, mEnv->NewStringUTF(error.c_str()));
133             ScopedLocalRef<jobject> jthrowable(mEnv, mEnv->NewObject(gAssertionFailure.clazz,
134                                                                      gAssertionFailure.ctor, jmessage.get()));
135             ScopedLocalRef<jobject> jfailure(mEnv, mEnv->NewObject(gFailure.clazz,
136                                                                    gFailure.ctor, mCurrentTestDescription.get(), jthrowable.get()));
137             mEnv->CallVoidMethod(mRunNotifier, gRunNotifier.fireTestFailure, jfailure.get());
138           }
139 
140           notify(gRunNotifier.fireTestFinished);
141           mCurrentTestDescription.reset();
142           mCurrentTestError.str("");
143           mCurrentTestError.clear();
144         }
145 
reportDisabledTests(const std::vector<std::string> & mangledNames)146         void reportDisabledTests(const std::vector<std::string>& mangledNames) {
147           for (const std::string& mangledName : mangledNames) {
148             mCurrentTestDescription.reset(createTestDescription(mEnv, mClassName, mangledName));
149             notify(gRunNotifier.fireTestIgnored);
150             mCurrentTestDescription.reset();
151           }
152         }
153 
154     private:
notify(jmethodID method)155         void notify(jmethodID method) {
156           mEnv->CallVoidMethod(mRunNotifier, method, mCurrentTestDescription.get());
157         }
158 
159         JNIEnv* mEnv;
160         jobject mRunNotifier;
161         jstring mClassName;
162         ScopedLocalRef<jobject> mCurrentTestDescription;
163         std::ostringstream mCurrentTestError;
164     };
165 
166 }  // namespace
167 
168 extern "C"
169 JNIEXPORT void JNICALL
Java_androidx_test_ext_junitgtest_GtestRunner_initialize(JNIEnv * env,jclass,jstring className,jobject suite)170 Java_androidx_test_ext_junitgtest_GtestRunner_initialize(JNIEnv *env, jclass, jstring className, jobject suite) {
171   RegisterJavaVm(env);
172 
173   // Initialize gtest, removing the default result printer
174   int argc = 1;
175   const char* argv[] = { "gtest_wrapper" };
176   ::testing::InitGoogleTest(&argc, (char**) argv);
177 
178   auto& listeners = ::testing::UnitTest::GetInstance()->listeners();
179   delete listeners.Release(listeners.default_result_printer());
180 
181   gDescription.clazz = (jclass) env->NewGlobalRef(env->FindClass("org/junit/runner/Description"));
182   gDescription.createTestDescription = env->GetStaticMethodID(gDescription.clazz, "createTestDescription",
183                                                               "(Ljava/lang/String;Ljava/lang/String;[Ljava/lang/annotation/Annotation;)Lorg/junit/runner/Description;");
184   gDescription.addChild = env->GetMethodID(gDescription.clazz, "addChild",
185                                            "(Lorg/junit/runner/Description;)V");
186 
187   jclass annotations = env->FindClass("java/lang/annotation/Annotation");
188   gEmptyAnnotationsArray = env->NewGlobalRef(env->NewObjectArray(0, annotations, nullptr));
189   gNativeTestNames.clear();
190 
191   gAssertionFailure.clazz = (jclass) env->NewGlobalRef(env->FindClass("java/lang/AssertionError"));
192   gAssertionFailure.ctor = env->GetMethodID(gAssertionFailure.clazz, "<init>", "(Ljava/lang/Object;)V");
193 
194   gFailure.clazz = (jclass) env->NewGlobalRef(env->FindClass("org/junit/runner/notification/Failure"));
195   gFailure.ctor = env->GetMethodID(gFailure.clazz, "<init>",
196                                    "(Lorg/junit/runner/Description;Ljava/lang/Throwable;)V");
197 
198   gRunNotifier.clazz = (jclass) env->NewGlobalRef(
199           env->FindClass("org/junit/runner/notification/RunNotifier"));
200   gRunNotifier.fireTestStarted = env->GetMethodID(gRunNotifier.clazz, "fireTestStarted",
201                                                   "(Lorg/junit/runner/Description;)V");
202   gRunNotifier.fireTestIgnored = env->GetMethodID(gRunNotifier.clazz, "fireTestIgnored",
203                                                   "(Lorg/junit/runner/Description;)V");
204   gRunNotifier.fireTestFinished = env->GetMethodID(gRunNotifier.clazz, "fireTestFinished",
205                                                    "(Lorg/junit/runner/Description;)V");
206   gRunNotifier.fireTestFailure = env->GetMethodID(gRunNotifier.clazz, "fireTestFailure",
207                                                   "(Lorg/junit/runner/notification/Failure;)V");
208 
209   auto unitTest = ::testing::UnitTest::GetInstance();
210   for (int testCaseIndex = 0; testCaseIndex < unitTest->total_test_case_count(); testCaseIndex++) {
211     auto testCase = unitTest->GetTestCase(testCaseIndex);
212     for (int testIndex = 0; testIndex < testCase->total_test_count(); testIndex++) {
213       auto testInfo = testCase->GetTestInfo(testIndex);
214       ScopedLocalRef<jobject> testDescription(env,
215                                               createTestDescription(env, className, testCase->name(), testInfo->name()));
216       addChild(env, suite, testDescription.get());
217     }
218   }
219 }
220 
221 extern "C"
222 JNIEXPORT void JNICALL
Java_androidx_test_ext_junitgtest_GtestRunner_addTest(JNIEnv * env,jclass,jstring testName)223 Java_androidx_test_ext_junitgtest_GtestRunner_addTest(JNIEnv *env, jclass, jstring testName) {
224   const char* testNameChars = env->GetStringUTFChars(testName, JNI_FALSE);
225   auto found = gNativeTestNames.find(testNameChars);
226   if (found != gNativeTestNames.end()) {
227     found->second.run = true;
228   }
229   env->ReleaseStringUTFChars(testName, testNameChars);
230 }
231 
232 extern "C"
233 JNIEXPORT jboolean JNICALL
Java_androidx_test_ext_junitgtest_GtestRunner_run(JNIEnv * env,jclass,jstring className,jobject notifier)234 Java_androidx_test_ext_junitgtest_GtestRunner_run(JNIEnv *env, jclass, jstring className, jobject notifier) {
235   // Apply the test filter computed in Java-land. The filter is just a list of test names.
236   std::ostringstream filterStream;
237   std::vector<std::string> mangledNamesOfDisabledTests;
238   for (const auto& entry : gNativeTestNames) {
239     // If the test was not selected for running by the Java layer, ignore it completely.
240     if (!entry.second.run) continue;
241     // If the test has DISABLED_ at the beginning of its name, after a slash or after a dot,
242     // report it as ignored (disabled) to the Java layer.
243     if (entry.second.nativeName.find("DISABLED_") == 0 ||
244         entry.second.nativeName.find("/DISABLED_") != std::string::npos ||
245         entry.second.nativeName.find(".DISABLED_") != std::string::npos) {
246       mangledNamesOfDisabledTests.push_back(entry.first);
247       continue;
248     }
249     filterStream << entry.second.nativeName << ":";
250   }
251   std::string filter = filterStream.str();
252   if (filter.empty()) {
253     // If the string we built is empty, we don't want to run any tests, but GTest runs all tests
254     // when an empty filter is passed. Replace an empty filter with a filter that matches nothing.
255     filter = "-*";
256   } else {
257     // Removes the trailing colon.
258     filter.pop_back();
259   }
260   ::testing::GTEST_FLAG(filter) = filter;
261 
262   auto& listeners = ::testing::UnitTest::GetInstance()->listeners();
263   JUnitNotifyingListener junitListener{env, className, notifier};
264   listeners.Append(&junitListener);
265   int success = RUN_ALL_TESTS();
266   listeners.Release(&junitListener);
267   junitListener.reportDisabledTests(mangledNamesOfDisabledTests);
268   return success == 0;
269 }