/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include "jvmti.h" #include #include #include #include #include #include #include using namespace dex; using namespace lir; namespace com_android_dx_mockito_inline { static jvmtiEnv* localJvmtiEnv; static jobject sTransformer; // Converts a class name to a type descriptor // (ex. "java.lang.String" to "Ljava/lang/String;") static std::string ClassNameToDescriptor(const char* class_name) { std::stringstream ss; ss << "L"; for (auto p = class_name; *p != '\0'; ++p) { ss << (*p == '.' ? '/' : *p); } ss << ";"; return ss.str(); } // Takes the full dex file for class 'classBeingRedefined' // - isolates the dex code for the class out of the dex file // - calls sTransformer.runTransformers on the isolated dex code // - send the transformed code back to the runtime static void Transform(jvmtiEnv* jvmti_env, JNIEnv* env, jclass classBeingRedefined, jobject loader, const char* name, jobject protectionDomain, jint classDataLen, const unsigned char* classData, jint* newClassDataLen, unsigned char** newClassData) { if (sTransformer != NULL) { // Even reading the classData array is expensive as the data is only generated when the // memory is touched. Hence call JvmtiAgent#shouldTransform to check if we need to transform // the class. jclass cls = env->GetObjectClass(sTransformer); jmethodID shouldTransformMethod = env->GetMethodID(cls, "shouldTransform", "(Ljava/lang/Class;)Z"); jboolean shouldTransform = env->CallBooleanMethod(sTransformer, shouldTransformMethod, classBeingRedefined); if (!shouldTransform) { return; } // Isolate byte code of class class. This is needed as Android usually gives us more // than the class we need. Reader reader(classData, classDataLen); u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str()); reader.CreateClassIr(index); std::shared_ptr ir = reader.GetIr(); struct Allocator : public Writer::Allocator { virtual void* Allocate(size_t size) {return ::malloc(size);} virtual void Free(void* ptr) {::free(ptr);} }; Allocator allocator; Writer writer(ir); size_t isolatedClassLen = 0; std::shared_ptr isolatedClass((jbyte*)writer.CreateImage(&allocator, &isolatedClassLen)); // Create jbyteArray with isolated byte code of class jbyteArray isolatedClassArr = env->NewByteArray(isolatedClassLen); env->SetByteArrayRegion(isolatedClassArr, 0, isolatedClassLen, isolatedClass.get()); jstring nameStr = env->NewStringUTF(name); // Call JvmtiAgent#runTransformers jmethodID runTransformersMethod = env->GetMethodID(cls, "runTransformers", "(Ljava/lang/ClassLoader;" "Ljava/lang/String;" "Ljava/lang/Class;" "Ljava/security/ProtectionDomain;" "[B)[B"); jbyteArray transformedArr = (jbyteArray) env->CallObjectMethod(sTransformer, runTransformersMethod, loader, nameStr, classBeingRedefined, protectionDomain, isolatedClassArr); // Set transformed byte code if (!env->ExceptionOccurred() && transformedArr != NULL) { *newClassDataLen = env->GetArrayLength(transformedArr); jbyte* transformed = env->GetByteArrayElements(transformedArr, 0); jvmti_env->Allocate(*newClassDataLen, newClassData); std::memcpy(*newClassData, transformed, *newClassDataLen); env->ReleaseByteArrayElements(transformedArr, transformed, 0); } } } // Add a label before instructionAfter static void addLabel(CodeIr& c, lir::Instruction* instructionAfter, Label* returnTrueLabel) { c.instructions.InsertBefore(instructionAfter, returnTrueLabel); } // Add a byte code before instructionAfter static void addInstr(CodeIr& c, lir::Instruction* instructionAfter, Opcode opcode, const std::list& operands) { auto instruction = c.Alloc(); instruction->opcode = opcode; for (auto it = operands.begin(); it != operands.end(); it++) { instruction->operands.push_back(*it); } c.instructions.InsertBefore(instructionAfter, instruction); } // Add a method call byte code before instructionAfter static void addCall(ir::Builder& b, CodeIr& c, lir::Instruction* instructionAfter, Opcode opcode, ir::Type* type, const char* methodName, ir::Type* returnType, const std::vector& types, const std::list& regs) { auto proto = b.GetProto(returnType, b.GetTypeList(types)); auto method = b.GetMethodDecl(b.GetAsciiString(methodName), proto, type); VRegList* param_regs = c.Alloc(); for (auto it = regs.begin(); it != regs.end(); it++) { param_regs->registers.push_back(*it); } addInstr(c, instructionAfter, opcode, {param_regs, c.Alloc(method, method->orig_index)}); } typedef struct { ir::Type* boxedType; ir::Type* scalarType; std::string unboxMethod; } BoxingInfo; // Get boxing / unboxing info for a type static BoxingInfo getBoxingInfo(ir::Builder &b, char typeCode) { BoxingInfo boxingInfo; if (typeCode != 'L' && typeCode != '[') { std::stringstream tmp; tmp << typeCode; boxingInfo.scalarType = b.GetType(tmp.str().c_str()); } switch (typeCode) { case 'B': boxingInfo.boxedType = b.GetType("Ljava/lang/Byte;"); boxingInfo.unboxMethod = "byteValue"; break; case 'S': boxingInfo.boxedType = b.GetType("Ljava/lang/Short;"); boxingInfo.unboxMethod = "shortValue"; break; case 'I': boxingInfo.boxedType = b.GetType("Ljava/lang/Integer;"); boxingInfo.unboxMethod = "intValue"; break; case 'C': boxingInfo.boxedType = b.GetType("Ljava/lang/Character;"); boxingInfo.unboxMethod = "charValue"; break; case 'F': boxingInfo.boxedType = b.GetType("Ljava/lang/Float;"); boxingInfo.unboxMethod = "floatValue"; break; case 'Z': boxingInfo.boxedType = b.GetType("Ljava/lang/Boolean;"); boxingInfo.unboxMethod = "booleanValue"; break; case 'J': boxingInfo.boxedType = b.GetType("Ljava/lang/Long;"); boxingInfo.unboxMethod = "longValue"; break; case 'D': boxingInfo.boxedType = b.GetType("Ljava/lang/Double;"); boxingInfo.unboxMethod = "doubleValue"; break; default: // real object break; } return boxingInfo; } static size_t getNumParams(ir::EncodedMethod *method) { if (method->decl->prototype->param_types == NULL) { return 0; } return method->decl->prototype->param_types->types.size(); } static bool canBeTransformed(ir::EncodedMethod *method) { std::string type = method->decl->parent->Decl(); ir::String* methodName = method->decl->name; return !(((method->access_flags & (kAccAbstract | kAccPrivate | kAccBridge | kAccNative | kAccStatic)) != 0) || (Utf8Cmp(methodName->c_str(), "") == 0) || (Utf8Cmp(methodName->c_str(), "") == 0) || (Utf8Cmp(type.c_str(), "java.lang.Object") == 0 && Utf8Cmp(methodName->c_str(), "finalize") == 0 && getNumParams(method) == 0) || (strncmp(type.c_str(), "java.", 5) == 0 && (method->access_flags & (kAccPrivate | kAccPublic | kAccProtected)) == 0) // getClass is used by MockMethodAdvice.isOverridden || (Utf8Cmp(methodName->c_str(), "getClass") == 0)); } static bool isHashCode(ir::EncodedMethod *method) { return Utf8Cmp(method->decl->name->c_str(), "hashCode") == 0 && getNumParams(method) == 0; } static bool isEquals(ir::EncodedMethod *method) { return Utf8Cmp(method->decl->name->c_str(), "equals") == 0 && getNumParams(method) == 1 && Utf8Cmp(method->decl->prototype->param_types->types[0]->Decl().c_str(), "java.lang.Object") == 0; } // Transforms the classes to add the mockito hooks // - equals and hashcode are handled in a special way extern "C" JNIEXPORT jbyteArray JNICALL Java_com_android_dx_mockito_inline_ClassTransformer_nativeRedefine(JNIEnv* env, jobject generator, jstring idStr, jbyteArray originalArr) { unsigned char* original = (unsigned char*)env->GetByteArrayElements(originalArr, 0); Reader reader(original, env->GetArrayLength(originalArr)); reader.CreateClassIr(0); std::shared_ptr dex_ir = reader.GetIr(); ir::Builder b(dex_ir); ir::Type* booleanScalarT = b.GetType("Z"); ir::Type* intScalarT = b.GetType("I"); ir::Type* objectT = b.GetType("Ljava/lang/Object;"); ir::Type* objectArrayT = b.GetType("[Ljava/lang/Object;"); ir::Type* stringT = b.GetType("Ljava/lang/String;"); ir::Type* methodT = b.GetType("Ljava/lang/reflect/Method;"); ir::Type* systemT = b.GetType("Ljava/lang/System;"); ir::Type* callableT = b.GetType("Ljava/util/concurrent/Callable;"); ir::Type* dispatcherT = b.GetType("Lcom/android/dx/mockito/inline/MockMethodDispatcher;"); // Add id to dex file const char* idNative = env->GetStringUTFChars(idStr, 0); ir::String* id = b.GetAsciiString(idNative); env->ReleaseStringUTFChars(idStr, idNative); for (auto& method : dex_ir->encoded_methods) { if (!canBeTransformed(method.get())) { continue; } if (isEquals(method.get())) { /* equals_original(Object other) { T t = foo(other); return bar(t); } equals_transformed(params) { // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this); const-string v0, "65463hg34t" move-objectfrom16 v1, THIS invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher move-result-object v2 // if (dispatcher == null || ) { // goto original_method; // } if-eqz v2, original_method // if (!dispatcher.isMock(this)) { // goto original_method; // } invoke-virtual {v2, v1}, MockMethodDispatcher.isMock(Object):Method move-result v2 if-eqz v2, original_method // return self == other move-objectfrom16 v0, ARG1 if-eq v0, v1, return_true const v0, 0 return v0 return true: const v0, 1 return v0 original_method: // Move all method arguments down so that they match what the original code expects. move-object16 v4, v5 # THIS move-object16 v5, v6 # ARG1 T t = foo(other); return bar(t); } */ CodeIr c(method.get(), dex_ir); // Make sure there are at least 5 local registers to use int originalNumRegisters = method->code->registers - method->code->ins_count; int numAdditionalRegs = std::max(0, 3 - originalNumRegisters); int thisReg = numAdditionalRegs + method->code->registers - method->code->ins_count; if (numAdditionalRegs > 0) { c.ir_method->code->registers += numAdditionalRegs; } lir::Instruction* fi = *(c.instructions.begin()); Label* originalMethodLabel = c.Alloc