/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include "jvmti.h" #include #include #include #include #include #include #include using namespace dex; using namespace lir; namespace com_android_dx_mockito_inline { static jvmtiEnv* localJvmtiEnv; static jobject sTransformer; // Converts a class name to a type descriptor // (ex. "java.lang.String" to "Ljava/lang/String;") static std::string ClassNameToDescriptor(const char* class_name) { std::stringstream ss; ss << "L"; for (auto p = class_name; *p != '\0'; ++p) { ss << (*p == '.' ? '/' : *p); } ss << ";"; return ss.str(); } // Takes the full dex file for class 'classBeingRedefined' // - isolates the dex code for the class out of the dex file // - calls sTransformer.runTransformers on the isolated dex code // - send the transformed code back to the runtime static void Transform(jvmtiEnv* jvmti_env, JNIEnv* env, jclass classBeingRedefined, jobject loader, const char* name, jobject protectionDomain, jint classDataLen, const unsigned char* classData, jint* newClassDataLen, unsigned char** newClassData) { if (sTransformer != nullptr) { // Even reading the classData array is expensive as the data is only generated when the // memory is touched. Hence call JvmtiAgent#shouldTransform to check if we need to transform // the class. jclass cls = env->GetObjectClass(sTransformer); jmethodID shouldTransformMethod = env->GetMethodID(cls, "shouldTransform", "(Ljava/lang/Class;)Z"); jboolean shouldTransform = env->CallBooleanMethod(sTransformer, shouldTransformMethod, classBeingRedefined); if (!shouldTransform) { return; } // Isolate byte code of class class. This is needed as Android usually gives us more // than the class we need. Reader reader(classData, classDataLen); u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str()); reader.CreateClassIr(index); std::shared_ptr ir = reader.GetIr(); struct Allocator : public Writer::Allocator { virtual void* Allocate(size_t size) {return ::malloc(size);} virtual void Free(void* ptr) {::free(ptr);} }; Allocator allocator; Writer writer(ir); size_t isolatedClassLen = 0; std::shared_ptr isolatedClass((jbyte*)writer.CreateImage(&allocator, &isolatedClassLen)); // Create jbyteArray with isolated byte code of class jbyteArray isolatedClassArr = env->NewByteArray(isolatedClassLen); env->SetByteArrayRegion(isolatedClassArr, 0, isolatedClassLen, isolatedClass.get()); jstring nameStr = env->NewStringUTF(name); // Call JvmtiAgent#runTransformers jmethodID runTransformersMethod = env->GetMethodID(cls, "runTransformers", "(Ljava/lang/ClassLoader;" "Ljava/lang/String;" "Ljava/lang/Class;" "Ljava/security/ProtectionDomain;" "[B)[B"); jbyteArray transformedArr = (jbyteArray) env->CallObjectMethod(sTransformer, runTransformersMethod, loader, nameStr, classBeingRedefined, protectionDomain, isolatedClassArr); // Set transformed byte code if (!env->ExceptionOccurred() && transformedArr != nullptr) { *newClassDataLen = env->GetArrayLength(transformedArr); jbyte* transformed = env->GetByteArrayElements(transformedArr, 0); jvmti_env->Allocate(*newClassDataLen, newClassData); std::memcpy(*newClassData, transformed, *newClassDataLen); env->ReleaseByteArrayElements(transformedArr, transformed, 0); } } } // Add a label before instructionAfter static void addLabel(CodeIr& c, lir::Instruction* instructionAfter, Label* returnTrueLabel) { c.instructions.InsertBefore(instructionAfter, returnTrueLabel); } // Add a byte code before instructionAfter static void addInstr(CodeIr& c, lir::Instruction* instructionAfter, Opcode opcode, const std::list& operands) { auto instruction = c.Alloc(); instruction->opcode = opcode; for (auto it = operands.begin(); it != operands.end(); it++) { instruction->operands.push_back(*it); } c.instructions.InsertBefore(instructionAfter, instruction); } // Add a method call byte code before instructionAfter static void addCall(ir::Builder& b, CodeIr& c, lir::Instruction* instructionAfter, Opcode opcode, ir::Type* type, const char* methodName, ir::Type* returnType, const std::vector& types, const std::list& regs) { auto proto = b.GetProto(returnType, b.GetTypeList(types)); auto method = b.GetMethodDecl(b.GetAsciiString(methodName), proto, type); VRegList* param_regs = c.Alloc(); for (auto it = regs.begin(); it != regs.end(); it++) { param_regs->registers.push_back(*it); } addInstr(c, instructionAfter, opcode, {param_regs, c.Alloc(method, method->orig_index)}); } typedef struct { ir::Type* boxedType; ir::Type* scalarType; std::string unboxMethod; } BoxingInfo; // Get boxing / unboxing info for a type static BoxingInfo getBoxingInfo(ir::Builder &b, char typeCode) { BoxingInfo boxingInfo; if (typeCode != 'L' && typeCode != '[') { std::stringstream tmp; tmp << typeCode; boxingInfo.scalarType = b.GetType(tmp.str().c_str()); } switch (typeCode) { case 'B': boxingInfo.boxedType = b.GetType("Ljava/lang/Byte;"); boxingInfo.unboxMethod = "byteValue"; break; case 'S': boxingInfo.boxedType = b.GetType("Ljava/lang/Short;"); boxingInfo.unboxMethod = "shortValue"; break; case 'I': boxingInfo.boxedType = b.GetType("Ljava/lang/Integer;"); boxingInfo.unboxMethod = "intValue"; break; case 'C': boxingInfo.boxedType = b.GetType("Ljava/lang/Character;"); boxingInfo.unboxMethod = "charValue"; break; case 'F': boxingInfo.boxedType = b.GetType("Ljava/lang/Float;"); boxingInfo.unboxMethod = "floatValue"; break; case 'Z': boxingInfo.boxedType = b.GetType("Ljava/lang/Boolean;"); boxingInfo.unboxMethod = "booleanValue"; break; case 'J': boxingInfo.boxedType = b.GetType("Ljava/lang/Long;"); boxingInfo.unboxMethod = "longValue"; break; case 'D': boxingInfo.boxedType = b.GetType("Ljava/lang/Double;"); boxingInfo.unboxMethod = "doubleValue"; break; default: // real object break; } return boxingInfo; } static size_t getNumParams(ir::EncodedMethod *method) { if (method->decl->prototype->param_types == nullptr) { return 0; } return method->decl->prototype->param_types->types.size(); } static bool canBeTransformed(ir::EncodedMethod *method) { std::string type = method->decl->parent->Decl(); ir::String* methodName = method->decl->name; return ((method->access_flags & kAccStatic) != 0) && !(((method->access_flags & (kAccPrivate | kAccBridge | kAccNative)) != 0) || (Utf8Cmp(methodName->c_str(), "") == 0) || (strncmp(type.c_str(), "java.", 5) == 0 && (method->access_flags & (kAccPrivate | kAccPublic | kAccProtected)) == 0)); } // Transforms the classes to add the mockito hooks // - equals and hashcode are handled in a special way extern "C" JNIEXPORT jbyteArray JNICALL Java_com_android_dx_mockito_inline_StaticClassTransformer_nativeRedefine(JNIEnv* env, jobject generator, jstring idStr, jbyteArray originalArr) { unsigned char* original = (unsigned char*)env->GetByteArrayElements(originalArr, 0); Reader reader(original, env->GetArrayLength(originalArr)); reader.CreateClassIr(0); std::shared_ptr dex_ir = reader.GetIr(); ir::Builder b(dex_ir); ir::Type* objectT = b.GetType("Ljava/lang/Object;"); ir::Type* objectArrayT = b.GetType("[Ljava/lang/Object;"); ir::Type* stringT = b.GetType("Ljava/lang/String;"); ir::Type* methodT = b.GetType("Ljava/lang/reflect/Method;"); ir::Type* callableT = b.GetType("Ljava/util/concurrent/Callable;"); ir::Type* dispatcherT = b.GetType("Lcom/android/dx/mockito/inline/MockMethodDispatcher;"); // Add id to dex file const char* idNative = env->GetStringUTFChars(idStr, 0); ir::String* id = b.GetAsciiString(idNative); env->ReleaseStringUTFChars(idStr, idNative); for (auto& method : dex_ir->encoded_methods) { if (!canBeTransformed(method.get())) { continue; } /* static long method_original(int param1, long param2, String param3) { foo(); return bar(); } static long method_transformed(int param1, long param2, String param3) { // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this); const-string v0, "65463hg34t" const v1, 0 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher move-result-object v0 // if (dispatcher == null) { // goto original_method; // } if-eqz v0, original_method // Method origin = dispatcher.getOrigin(this, methodDesc); const-string v1 "fully.qualified.ClassName#original_method(int, long, String)" const v2, 0 invoke-virtual {v0, v2, v1}, MockMethodDispatcher.getOrigin(Object, String):Method move-result-object v1 // if (origin == null) { // goto original_method; // } if-eqz v1, original_method // Create an array with Objects of all parameters. // Object[] arguments = new Object[3] const v3, 3 new-array v2, v3, Object[] // Integer param1Integer = Integer.valueOf(param1) move-from16 v3, ARG1 # this is necessary as invoke-static cannot deal with high # registers and ARG1 might be high invoke-static {v3}, Integer.valueOf(int):Integer move-result-object v3 // arguments[0] = param1Integer const v4, 0 aput-object v3, v2, v4 // Long param2Long = Long.valueOf(param2) move-widefrom16 v3:v4, ARG2.1:ARG2.2 # this is necessary as invoke-static cannot # deal with high registers and ARG2 might be # high invoke-static {v3, v4}, Long.valueOf(long):Long move-result-object v3 // arguments[1] = param2Long const v4, 1 aput-object v3, v2, v4 // arguments[2] = param3 const v4, 2 move-objectfrom16 v3, ARG3 # this is necessary as aput-object cannot deal with # high registers and ARG3 might be high aput-object v3, v2, v4 // Callable mocked = dispatcher.handle(methodDesc --as this parameter--, // origin, arguments); const-string v3 "fully.qualified.ClassName#original_method(int, long, String)" invoke-virtual {v0,v3,v1,v2}, MockMethodDispatcher.handle(Object, Method, Object[]):Callable move-result-object v0 // if (mocked != null) { if-eqz v0, original_method // Object ret = mocked.call(); invoke-interface {v0}, Callable.call():Object move-result-object v0 // Long retLong = (Long)ret check-cast v0, Long // long retlong = retLong.longValue(); invoke-virtual {v0}, Long.longValue():long move-result-wide v0:v1 // return retlong; return-wide v0:v1 // } original_method: // Move all method arguments down so that they match what the original code expects. // Let's assume three arguments, one int, one long, one String and the and used to // use 4 registers move16 v5, v6 # ARG1 move-wide16 v6:v7, v7:v8 # ARG2 (overlapping moves are allowed) move-object16 v8, v9 # ARG3 // foo(); // return bar(); unmodified original byte code } */ CodeIr c(method.get(), dex_ir); // Make sure there are at least 5 local registers to use int originalNumRegisters = method->code->registers - method->code->ins_count; int numAdditionalRegs = std::max(0, 5 - originalNumRegisters); int firstArg = originalNumRegisters + numAdditionalRegs; if (numAdditionalRegs > 0) { c.ir_method->code->registers += numAdditionalRegs; } lir::Instruction* fi = *(c.instructions.begin()); // Add methodDesc to dex file std::stringstream ss; ss << method->decl->parent->Decl() << "#" << method->decl->name->c_str() << "(" ; bool first = true; if (method->decl->prototype->param_types != nullptr) { for (const auto& type : method->decl->prototype->param_types->types) { if (first) { first = false; } else { ss << ","; } ss << type->Decl().c_str(); } } ss << ")"; std::string methodDescStr = ss.str(); ir::String* methodDesc = b.GetAsciiString(methodDescStr.c_str()); size_t numParams = getNumParams(method.get()); Label* originalMethodLabel = c.Alloc