1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <cstdlib> 18 #include <sstream> 19 #include <cstring> 20 #include <cassert> 21 #include <cstdarg> 22 #include <algorithm> 23 24 #include <jni.h> 25 26 #include "jvmti.h" 27 28 #include <slicer/dex_ir.h> 29 #include <slicer/code_ir.h> 30 #include <slicer/dex_ir_builder.h> 31 #include <slicer/dex_utf8.h> 32 #include <slicer/writer.h> 33 #include <slicer/reader.h> 34 #include <slicer/instrumentation.h> 35 36 using namespace dex; 37 using namespace lir; 38 39 namespace com_android_dx_mockito_inline { 40 static jvmtiEnv* localJvmtiEnv; 41 42 static jobject sTransformer; 43 44 // Converts a class name to a type descriptor 45 // (ex. "java.lang.String" to "Ljava/lang/String;") 46 static std::string ClassNameToDescriptor(const char * class_name)47 ClassNameToDescriptor(const char* class_name) { 48 std::stringstream ss; 49 ss << "L"; 50 for (auto p = class_name; *p != '\0'; ++p) { 51 ss << (*p == '.' ? '/' : *p); 52 } 53 ss << ";"; 54 return ss.str(); 55 } 56 57 // Takes the full dex file for class 'classBeingRedefined' 58 // - isolates the dex code for the class out of the dex file 59 // - calls sTransformer.runTransformers on the isolated dex code 60 // - send the transformed code back to the runtime 61 static void Transform(jvmtiEnv * jvmti_env,JNIEnv * env,jclass classBeingRedefined,jobject loader,const char * name,jobject protectionDomain,jint classDataLen,const unsigned char * classData,jint * newClassDataLen,unsigned char ** newClassData)62 Transform(jvmtiEnv* jvmti_env, 63 JNIEnv* env, 64 jclass classBeingRedefined, 65 jobject loader, 66 const char* name, 67 jobject protectionDomain, 68 jint classDataLen, 69 const unsigned char* classData, 70 jint* newClassDataLen, 71 unsigned char** newClassData) { 72 if (sTransformer != nullptr) { 73 // Even reading the classData array is expensive as the data is only generated when the 74 // memory is touched. Hence call JvmtiAgent#shouldTransform to check if we need to transform 75 // the class. 76 jclass cls = env->GetObjectClass(sTransformer); 77 jmethodID shouldTransformMethod = env->GetMethodID(cls, "shouldTransform", 78 "(Ljava/lang/Class;)Z"); 79 80 jboolean shouldTransform = env->CallBooleanMethod(sTransformer, shouldTransformMethod, 81 classBeingRedefined); 82 if (!shouldTransform) { 83 return; 84 } 85 86 // Isolate byte code of class class. This is needed as Android usually gives us more 87 // than the class we need. 88 Reader reader(classData, classDataLen); 89 90 u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str()); 91 reader.CreateClassIr(index); 92 std::shared_ptr<ir::DexFile> ir = reader.GetIr(); 93 94 struct Allocator : public Writer::Allocator { 95 virtual void* Allocate(size_t size) {return ::malloc(size);} 96 virtual void Free(void* ptr) {::free(ptr);} 97 }; 98 99 Allocator allocator; 100 Writer writer(ir); 101 size_t isolatedClassLen = 0; 102 std::shared_ptr<jbyte> isolatedClass((jbyte*)writer.CreateImage(&allocator, 103 &isolatedClassLen)); 104 105 // Create jbyteArray with isolated byte code of class 106 jbyteArray isolatedClassArr = env->NewByteArray(isolatedClassLen); 107 env->SetByteArrayRegion(isolatedClassArr, 0, isolatedClassLen, 108 isolatedClass.get()); 109 110 jstring nameStr = env->NewStringUTF(name); 111 112 // Call JvmtiAgent#runTransformers 113 jmethodID runTransformersMethod = env->GetMethodID(cls, "runTransformers", 114 "(Ljava/lang/ClassLoader;" 115 "Ljava/lang/String;" 116 "Ljava/lang/Class;" 117 "Ljava/security/ProtectionDomain;" 118 "[B)[B"); 119 120 jbyteArray transformedArr = (jbyteArray) env->CallObjectMethod(sTransformer, 121 runTransformersMethod, 122 loader, nameStr, 123 classBeingRedefined, 124 protectionDomain, 125 isolatedClassArr); 126 127 // Set transformed byte code 128 if (!env->ExceptionOccurred() && transformedArr != nullptr) { 129 *newClassDataLen = env->GetArrayLength(transformedArr); 130 131 jbyte* transformed = env->GetByteArrayElements(transformedArr, 0); 132 133 jvmti_env->Allocate(*newClassDataLen, newClassData); 134 std::memcpy(*newClassData, transformed, *newClassDataLen); 135 136 env->ReleaseByteArrayElements(transformedArr, transformed, 0); 137 } 138 } 139 } 140 141 // Add a label before instructionAfter 142 static void addLabel(CodeIr & c,lir::Instruction * instructionAfter,Label * returnTrueLabel)143 addLabel(CodeIr& c, 144 lir::Instruction* instructionAfter, 145 Label* returnTrueLabel) { 146 c.instructions.InsertBefore(instructionAfter, returnTrueLabel); 147 } 148 149 // Add a byte code before instructionAfter 150 static void addInstr(CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,const std::list<Operand * > & operands)151 addInstr(CodeIr& c, 152 lir::Instruction* instructionAfter, 153 Opcode opcode, 154 const std::list<Operand*>& operands) { 155 auto instruction = c.Alloc<Bytecode>(); 156 157 instruction->opcode = opcode; 158 159 for (auto it = operands.begin(); it != operands.end(); it++) { 160 instruction->operands.push_back(*it); 161 } 162 163 c.instructions.InsertBefore(instructionAfter, instruction); 164 } 165 166 // Add a method call byte code before instructionAfter 167 static void addCall(ir::Builder & b,CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,ir::Type * type,const char * methodName,ir::Type * returnType,const std::vector<ir::Type * > & types,const std::list<int> & regs)168 addCall(ir::Builder& b, 169 CodeIr& c, 170 lir::Instruction* instructionAfter, 171 Opcode opcode, 172 ir::Type* type, 173 const char* methodName, 174 ir::Type* returnType, 175 const std::vector<ir::Type*>& types, 176 const std::list<int>& regs) { 177 auto proto = b.GetProto(returnType, b.GetTypeList(types)); 178 auto method = b.GetMethodDecl(b.GetAsciiString(methodName), proto, type); 179 180 VRegList* param_regs = c.Alloc<VRegList>(); 181 for (auto it = regs.begin(); it != regs.end(); it++) { 182 param_regs->registers.push_back(*it); 183 } 184 185 addInstr(c, instructionAfter, opcode, {param_regs, c.Alloc<Method>(method, 186 method->orig_index)}); 187 } 188 189 typedef struct { 190 ir::Type* boxedType; 191 ir::Type* scalarType; 192 std::string unboxMethod; 193 } BoxingInfo; 194 195 // Get boxing / unboxing info for a type 196 static BoxingInfo getBoxingInfo(ir::Builder & b,char typeCode)197 getBoxingInfo(ir::Builder &b, 198 char typeCode) { 199 BoxingInfo boxingInfo; 200 201 if (typeCode != 'L' && typeCode != '[') { 202 std::stringstream tmp; 203 tmp << typeCode; 204 boxingInfo.scalarType = b.GetType(tmp.str().c_str()); 205 } 206 207 switch (typeCode) { 208 case 'B': 209 boxingInfo.boxedType = b.GetType("Ljava/lang/Byte;"); 210 boxingInfo.unboxMethod = "byteValue"; 211 break; 212 case 'S': 213 boxingInfo.boxedType = b.GetType("Ljava/lang/Short;"); 214 boxingInfo.unboxMethod = "shortValue"; 215 break; 216 case 'I': 217 boxingInfo.boxedType = b.GetType("Ljava/lang/Integer;"); 218 boxingInfo.unboxMethod = "intValue"; 219 break; 220 case 'C': 221 boxingInfo.boxedType = b.GetType("Ljava/lang/Character;"); 222 boxingInfo.unboxMethod = "charValue"; 223 break; 224 case 'F': 225 boxingInfo.boxedType = b.GetType("Ljava/lang/Float;"); 226 boxingInfo.unboxMethod = "floatValue"; 227 break; 228 case 'Z': 229 boxingInfo.boxedType = b.GetType("Ljava/lang/Boolean;"); 230 boxingInfo.unboxMethod = "booleanValue"; 231 break; 232 case 'J': 233 boxingInfo.boxedType = b.GetType("Ljava/lang/Long;"); 234 boxingInfo.unboxMethod = "longValue"; 235 break; 236 case 'D': 237 boxingInfo.boxedType = b.GetType("Ljava/lang/Double;"); 238 boxingInfo.unboxMethod = "doubleValue"; 239 break; 240 default: 241 // real object 242 break; 243 } 244 245 return boxingInfo; 246 } 247 248 static size_t getNumParams(ir::EncodedMethod * method)249 getNumParams(ir::EncodedMethod *method) { 250 if (method->decl->prototype->param_types == nullptr) { 251 return 0; 252 } 253 254 return method->decl->prototype->param_types->types.size(); 255 } 256 257 static bool canBeTransformed(ir::EncodedMethod * method)258 canBeTransformed(ir::EncodedMethod *method) { 259 std::string type = method->decl->parent->Decl(); 260 ir::String* methodName = method->decl->name; 261 262 return ((method->access_flags & kAccStatic) != 0) 263 && !(((method->access_flags & (kAccPrivate | kAccBridge | kAccNative)) != 0) 264 || (Utf8Cmp(methodName->c_str(), "<clinit>") == 0) 265 || (strncmp(type.c_str(), "java.", 5) == 0 266 && (method->access_flags & (kAccPrivate | kAccPublic | kAccProtected)) 267 == 0)); 268 } 269 270 // Transforms the classes to add the mockito hooks 271 // - equals and hashcode are handled in a special way 272 extern "C" JNIEXPORT jbyteArray JNICALL Java_com_android_dx_mockito_inline_StaticClassTransformer_nativeRedefine(JNIEnv * env,jobject generator,jstring idStr,jbyteArray originalArr)273 Java_com_android_dx_mockito_inline_StaticClassTransformer_nativeRedefine(JNIEnv* env, 274 jobject generator, 275 jstring idStr, 276 jbyteArray originalArr) { 277 unsigned char* original = (unsigned char*)env->GetByteArrayElements(originalArr, 0); 278 279 Reader reader(original, env->GetArrayLength(originalArr)); 280 reader.CreateClassIr(0); 281 std::shared_ptr<ir::DexFile> dex_ir = reader.GetIr(); 282 ir::Builder b(dex_ir); 283 284 ir::Type* objectT = b.GetType("Ljava/lang/Object;"); 285 ir::Type* objectArrayT = b.GetType("[Ljava/lang/Object;"); 286 ir::Type* stringT = b.GetType("Ljava/lang/String;"); 287 ir::Type* methodT = b.GetType("Ljava/lang/reflect/Method;"); 288 ir::Type* callableT = b.GetType("Ljava/util/concurrent/Callable;"); 289 ir::Type* dispatcherT = b.GetType("Lcom/android/dx/mockito/inline/MockMethodDispatcher;"); 290 291 // Add id to dex file 292 const char* idNative = env->GetStringUTFChars(idStr, 0); 293 ir::String* id = b.GetAsciiString(idNative); 294 env->ReleaseStringUTFChars(idStr, idNative); 295 296 for (auto& method : dex_ir->encoded_methods) { 297 if (!canBeTransformed(method.get())) { 298 continue; 299 } 300 /* 301 static long method_original(int param1, long param2, String param3) { 302 foo(); 303 return bar(); 304 } 305 306 static long method_transformed(int param1, long param2, String param3) { 307 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this); 308 const-string v0, "65463hg34t" 309 const v1, 0 310 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher 311 move-result-object v0 312 313 // if (dispatcher == null) { 314 // goto original_method; 315 // } 316 if-eqz v0, original_method 317 318 // Method origin = dispatcher.getOrigin(this, methodDesc); 319 const-string v1 "fully.qualified.ClassName#original_method(int, long, String)" 320 const v2, 0 321 invoke-virtual {v0, v2, v1}, MockMethodDispatcher.getOrigin(Object, String):Method 322 move-result-object v1 323 324 // if (origin == null) { 325 // goto original_method; 326 // } 327 if-eqz v1, original_method 328 329 // Create an array with Objects of all parameters. 330 331 // Object[] arguments = new Object[3] 332 const v3, 3 333 new-array v2, v3, Object[] 334 335 // Integer param1Integer = Integer.valueOf(param1) 336 move-from16 v3, ARG1 # this is necessary as invoke-static cannot deal with high 337 # registers and ARG1 might be high 338 invoke-static {v3}, Integer.valueOf(int):Integer 339 move-result-object v3 340 341 // arguments[0] = param1Integer 342 const v4, 0 343 aput-object v3, v2, v4 344 345 // Long param2Long = Long.valueOf(param2) 346 move-widefrom16 v3:v4, ARG2.1:ARG2.2 # this is necessary as invoke-static cannot 347 # deal with high registers and ARG2 might be 348 # high 349 invoke-static {v3, v4}, Long.valueOf(long):Long 350 move-result-object v3 351 352 // arguments[1] = param2Long 353 const v4, 1 354 aput-object v3, v2, v4 355 356 // arguments[2] = param3 357 const v4, 2 358 move-objectfrom16 v3, ARG3 # this is necessary as aput-object cannot deal with 359 # high registers and ARG3 might be high 360 aput-object v3, v2, v4 361 362 // Callable<?> mocked = dispatcher.handle(methodDesc --as this parameter--, 363 // origin, arguments); 364 const-string v3 "fully.qualified.ClassName#original_method(int, long, String)" 365 invoke-virtual {v0,v3,v1,v2}, MockMethodDispatcher.handle(Object, Method, 366 Object[]):Callable 367 move-result-object v0 368 369 // if (mocked != null) { 370 if-eqz v0, original_method 371 372 // Object ret = mocked.call(); 373 invoke-interface {v0}, Callable.call():Object 374 move-result-object v0 375 376 // Long retLong = (Long)ret 377 check-cast v0, Long 378 379 // long retlong = retLong.longValue(); 380 invoke-virtual {v0}, Long.longValue():long 381 move-result-wide v0:v1 382 383 // return retlong; 384 return-wide v0:v1 385 386 // } 387 388 original_method: 389 // Move all method arguments down so that they match what the original code expects. 390 // Let's assume three arguments, one int, one long, one String and the and used to 391 // use 4 registers 392 move16 v5, v6 # ARG1 393 move-wide16 v6:v7, v7:v8 # ARG2 (overlapping moves are allowed) 394 move-object16 v8, v9 # ARG3 395 396 // foo(); 397 // return bar(); 398 unmodified original byte code 399 } 400 */ 401 402 CodeIr c(method.get(), dex_ir); 403 404 // Make sure there are at least 5 local registers to use 405 int originalNumRegisters = method->code->registers - method->code->ins_count; 406 int numAdditionalRegs = std::max(0, 5 - originalNumRegisters); 407 int firstArg = originalNumRegisters + numAdditionalRegs; 408 409 if (numAdditionalRegs > 0) { 410 c.ir_method->code->registers += numAdditionalRegs; 411 } 412 413 lir::Instruction* fi = *(c.instructions.begin()); 414 415 // Add methodDesc to dex file 416 std::stringstream ss; 417 ss << method->decl->parent->Decl() << "#" << method->decl->name->c_str() << "(" ; 418 bool first = true; 419 if (method->decl->prototype->param_types != nullptr) { 420 for (const auto& type : method->decl->prototype->param_types->types) { 421 if (first) { 422 first = false; 423 } else { 424 ss << ","; 425 } 426 427 ss << type->Decl().c_str(); 428 } 429 } 430 ss << ")"; 431 std::string methodDescStr = ss.str(); 432 ir::String* methodDesc = b.GetAsciiString(methodDescStr.c_str()); 433 434 size_t numParams = getNumParams(method.get()); 435 436 Label* originalMethodLabel = c.Alloc<Label>(0); 437 CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel); 438 VReg* v0 = c.Alloc<VReg>(0); 439 VReg* v1 = c.Alloc<VReg>(1); 440 VReg* v2 = c.Alloc<VReg>(2); 441 VReg* v3 = c.Alloc<VReg>(3); 442 VReg* v4 = c.Alloc<VReg>(4); 443 444 addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)}); 445 addInstr(c, fi, OP_CONST, {v1, c.Alloc<Const32>(0)}); 446 addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT, {stringT, objectT}, 447 {0, 1}); 448 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0}); 449 addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod}); 450 addInstr(c, fi, OP_CONST_STRING, 451 {v1, c.Alloc<String>(methodDesc, methodDesc->orig_index)}); 452 addInstr(c, fi, OP_CONST, {v2, c.Alloc<Const32>(0)}); 453 addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "getOrigin", methodT, 454 {objectT, stringT}, {0, 2, 1}); 455 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v1}); 456 addInstr(c, fi, OP_IF_EQZ, {v1, originalMethod}); 457 addInstr(c, fi, OP_CONST, {v3, c.Alloc<Const32>(numParams)}); 458 addInstr(c, fi, OP_NEW_ARRAY, {v2, v3, c.Alloc<Type>(objectArrayT, 459 objectArrayT->orig_index)}); 460 461 if (numParams > 0) { 462 int argReg = firstArg; 463 464 for (int argNum = 0; argNum < numParams; argNum++) { 465 const auto& type = method->decl->prototype->param_types->types[argNum]; 466 BoxingInfo boxingInfo = getBoxingInfo(b, type->descriptor->c_str()[0]); 467 468 switch (type->GetCategory()) { 469 case ir::Type::Category::Scalar: 470 addInstr(c, fi, OP_MOVE_FROM16, {v3, c.Alloc<VReg>(argReg)}); 471 addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf", 472 boxingInfo.boxedType, {type}, {3}); 473 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3}); 474 475 argReg++; 476 break; 477 case ir::Type::Category::WideScalar: { 478 VRegPair* v3v4 = c.Alloc<VRegPair>(3); 479 VRegPair* argRegPair = c.Alloc<VRegPair>(argReg); 480 481 addInstr(c, fi, OP_MOVE_WIDE_FROM16, {v3v4, argRegPair}); 482 addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf", 483 boxingInfo.boxedType, {type}, {3, 4}); 484 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3}); 485 486 argReg += 2; 487 break; 488 } 489 case ir::Type::Category::Reference: 490 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v3, c.Alloc<VReg>(argReg)}); 491 492 argReg++; 493 break; 494 case ir::Type::Category::Void: 495 assert(false); 496 } 497 498 addInstr(c, fi, OP_CONST, {v4, c.Alloc<Const32>(argNum)}); 499 addInstr(c, fi, OP_APUT_OBJECT, {v3, v2, v4}); 500 } 501 } 502 503 // NASTY Hack: Push in method name as "mock" 504 addInstr(c, fi, OP_CONST_STRING, 505 {v3, c.Alloc<String>(methodDesc, methodDesc->orig_index)}); 506 addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "handle", callableT, 507 {objectT, methodT, objectArrayT}, {0, 3, 1, 2}); 508 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0}); 509 addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod}); 510 addCall(b, c, fi, OP_INVOKE_INTERFACE, callableT, "call", objectT, {}, {0}); 511 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0}); 512 513 ir::Type *returnType = method->decl->prototype->return_type; 514 BoxingInfo boxingInfo = getBoxingInfo(b, returnType->descriptor->c_str()[0]); 515 516 switch (returnType->GetCategory()) { 517 case ir::Type::Category::Scalar: 518 addInstr(c, fi, OP_CHECK_CAST, {v0, 519 c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)}); 520 addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType, 521 boxingInfo.unboxMethod.c_str(), returnType, {}, {0}); 522 addInstr(c, fi, OP_MOVE_RESULT, {v0}); 523 addInstr(c, fi, OP_RETURN, {v0}); 524 break; 525 case ir::Type::Category::WideScalar: { 526 VRegPair* v0v1 = c.Alloc<VRegPair>(0); 527 528 addInstr(c, fi, OP_CHECK_CAST, {v0, 529 c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)}); 530 addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType, 531 boxingInfo.unboxMethod.c_str(), returnType, {}, {0}); 532 addInstr(c, fi, OP_MOVE_RESULT_WIDE, {v0v1}); 533 addInstr(c, fi, OP_RETURN_WIDE, {v0v1}); 534 break; 535 } 536 case ir::Type::Category::Reference: 537 addInstr(c, fi, OP_CHECK_CAST, {v0, c.Alloc<Type>(returnType, 538 returnType->orig_index)}); 539 addInstr(c, fi, OP_RETURN_OBJECT, {v0}); 540 break; 541 case ir::Type::Category::Void: 542 addInstr(c, fi, OP_RETURN_VOID, {}); 543 break; 544 } 545 546 addLabel(c, fi, originalMethodLabel); 547 548 if (numParams > 0) { 549 int argReg = firstArg; 550 551 for (int argNum = 0; argNum < numParams; argNum++) { 552 const auto& type = method->decl->prototype->param_types->types[argNum]; 553 int origReg = argReg - numAdditionalRegs; 554 switch (type->GetCategory()) { 555 case ir::Type::Category::Scalar: 556 addInstr(c, fi, OP_MOVE_16, {c.Alloc<VReg>(origReg), 557 c.Alloc<VReg>(argReg)}); 558 argReg++; 559 break; 560 case ir::Type::Category::WideScalar: 561 addInstr(c, fi, OP_MOVE_WIDE_16,{c.Alloc<VRegPair>(origReg), 562 c.Alloc<VRegPair>(argReg)}); 563 argReg +=2; 564 break; 565 case ir::Type::Category::Reference: 566 addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(origReg), 567 c.Alloc<VReg>(argReg)}); 568 argReg++; 569 break; 570 } 571 } 572 } 573 574 c.Assemble(); 575 } 576 577 struct Allocator : public Writer::Allocator { 578 virtual void* Allocate(size_t size) {return ::malloc(size);} 579 virtual void Free(void* ptr) {::free(ptr);} 580 }; 581 582 Allocator allocator; 583 Writer writer(dex_ir); 584 size_t transformedLen = 0; 585 std::shared_ptr<jbyte> transformed((jbyte*)writer.CreateImage(&allocator, &transformedLen)); 586 587 jbyteArray transformedArr = env->NewByteArray(transformedLen); 588 env->SetByteArrayRegion(transformedArr, 0, transformedLen, transformed.get()); 589 590 return transformedArr; 591 } 592 593 // Initializes the agent Agent_OnAttach(JavaVM * vm,char * options,void * reserved)594 extern "C" jint Agent_OnAttach(JavaVM* vm, 595 char* options, 596 void* reserved) { 597 jint jvmError = vm->GetEnv(reinterpret_cast<void**>(&localJvmtiEnv), JVMTI_VERSION_1_2); 598 if (jvmError != JNI_OK) { 599 return jvmError; 600 } 601 602 jvmtiCapabilities caps; 603 memset(&caps, 0, sizeof(caps)); 604 caps.can_retransform_classes = 1; 605 606 jvmtiError error = localJvmtiEnv->AddCapabilities(&caps); 607 if (error != JVMTI_ERROR_NONE) { 608 return error; 609 } 610 611 jvmtiEventCallbacks cb; 612 memset(&cb, 0, sizeof(cb)); 613 cb.ClassFileLoadHook = Transform; 614 615 error = localJvmtiEnv->SetEventCallbacks(&cb, sizeof(cb)); 616 if (error != JVMTI_ERROR_NONE) { 617 return error; 618 } 619 620 error = localJvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE, 621 JVMTI_EVENT_CLASS_FILE_LOAD_HOOK, nullptr); 622 if (error != JVMTI_ERROR_NONE) { 623 return error; 624 } 625 626 return JVMTI_ERROR_NONE; 627 } 628 629 // Throw runtime exception throwRuntimeExpection(JNIEnv * env,const char * fmt,...)630 static void throwRuntimeExpection(JNIEnv* env, const char* fmt, ...) { 631 char msgBuf[512]; 632 633 va_list args; 634 va_start (args, fmt); 635 vsnprintf(msgBuf, sizeof(msgBuf), fmt, args); 636 va_end (args); 637 638 jclass exceptionClass = env->FindClass("java/lang/RuntimeException"); 639 env->ThrowNew(exceptionClass, msgBuf); 640 } 641 642 // Register transformer hook 643 extern "C" JNIEXPORT void JNICALL Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRegisterTransformerHook(JNIEnv * env,jobject thiz)644 Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRegisterTransformerHook(JNIEnv* env, 645 jobject thiz) { 646 sTransformer = env->NewGlobalRef(thiz); 647 } 648 649 // Unregister transformer hook 650 extern "C" JNIEXPORT void JNICALL Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeUnregisterTransformerHook(JNIEnv * env,jobject thiz)651 Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeUnregisterTransformerHook(JNIEnv* env, 652 jobject thiz) { 653 env->DeleteGlobalRef(sTransformer); 654 sTransformer = nullptr; 655 } 656 657 // Triggers retransformation of classes via this file's Transform method 658 extern "C" JNIEXPORT void JNICALL Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRetransformClasses(JNIEnv * env,jobject thiz,jobjectArray classes)659 Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRetransformClasses(JNIEnv* env, 660 jobject thiz, 661 jobjectArray classes) { 662 jsize numTransformedClasses = env->GetArrayLength(classes); 663 jclass *transformedClasses = (jclass*) malloc(numTransformedClasses * sizeof(jclass)); 664 for (int i = 0; i < numTransformedClasses; i++) { 665 transformedClasses[i] = (jclass) env->NewGlobalRef(env->GetObjectArrayElement(classes, i)); 666 } 667 668 jvmtiError error = localJvmtiEnv->RetransformClasses(numTransformedClasses, 669 transformedClasses); 670 671 for (int i = 0; i < numTransformedClasses; i++) { 672 env->DeleteGlobalRef(transformedClasses[i]); 673 } 674 free(transformedClasses); 675 676 if (error != JVMTI_ERROR_NONE) { 677 throwRuntimeExpection(env, "Could not retransform classes: %d", error); 678 } 679 } 680 681 static jvmtiFrameInfo* frameToInspect; 682 static std::string calledClass; 683 684 // Takes the full dex file for class 'classBeingRedefined' 685 // - isolates the dex code for the class out of the dex file 686 // - calls sTransformer.runTransformers on the isolated dex code 687 // - send the transformed code back to the runtime 688 static void InspectClass(jvmtiEnv * jvmtiEnv,JNIEnv * env,jclass classBeingRedefined,jobject loader,const char * name,jobject protectionDomain,jint classDataLen,const unsigned char * classData,jint * newClassDataLen,unsigned char ** newClassData)689 InspectClass(jvmtiEnv* jvmtiEnv, 690 JNIEnv* env, 691 jclass classBeingRedefined, 692 jobject loader, 693 const char* name, 694 jobject protectionDomain, 695 jint classDataLen, 696 const unsigned char* classData, 697 jint* newClassDataLen, 698 unsigned char** newClassData) { 699 calledClass = "none"; 700 701 Reader reader(classData, classDataLen); 702 703 char *calledMethodName; 704 char *calledMethodSignature; 705 jvmtiError error = jvmtiEnv->GetMethodName(frameToInspect->method, &calledMethodName, 706 &calledMethodSignature, nullptr); 707 if (error != JVMTI_ERROR_NONE) { 708 return; 709 } 710 711 u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str()); 712 reader.CreateClassIr(index); 713 std::shared_ptr<ir::DexFile> class_ir = reader.GetIr(); 714 715 for (auto& method : class_ir->encoded_methods) { 716 if (Utf8Cmp(method->decl->name->c_str(), calledMethodName) == 0 717 && Utf8Cmp(method->decl->prototype->Signature().c_str(), calledMethodSignature) == 0) { 718 CodeIr method_ir(method.get(), class_ir); 719 720 for (auto instruction : method_ir.instructions) { 721 Bytecode* bytecode = dynamic_cast<Bytecode*>(instruction); 722 if (bytecode != nullptr && bytecode->offset == frameToInspect->location) { 723 Method *method = bytecode->CastOperand<Method>(1); 724 calledClass = method->ir_method->parent->Decl().c_str(); 725 726 goto exit; 727 } 728 } 729 } 730 } 731 732 exit: 733 free(calledMethodName); 734 free(calledMethodSignature); 735 } 736 737 #define GOTO_ON_ERROR(label) \ 738 if (error != JVMTI_ERROR_NONE) { \ 739 goto label; \ 740 } 741 742 // stack frame of the caller if method was called directly 743 #define DIRECT_CALL_STACK_FRAME (6) 744 745 // stack frame of the caller if method was called as 'real method' 746 #define REALMETHOD_CALL_STACK_FRAME (23) 747 748 extern "C" JNIEXPORT jstring JNICALL Java_com_android_dx_mockito_inline_StaticMockMethodAdvice_nativeGetCalledClassName(JNIEnv * env,jclass klass,jthread currentThread)749 Java_com_android_dx_mockito_inline_StaticMockMethodAdvice_nativeGetCalledClassName(JNIEnv* env, 750 jclass klass, 751 jthread currentThread) { 752 753 JavaVM *vm; 754 jint jvmError = env->GetJavaVM(&vm); 755 if (jvmError != JNI_OK) { 756 return nullptr; 757 } 758 759 jvmtiEnv *jvmtiEnv; 760 jvmError = vm->GetEnv(reinterpret_cast<void**>(&jvmtiEnv), JVMTI_VERSION_1_2); 761 if (jvmError != JNI_OK) { 762 return nullptr; 763 } 764 765 jvmtiCapabilities caps; 766 memset(&caps, 0, sizeof(caps)); 767 caps.can_retransform_classes = 1; 768 769 jvmtiError error = jvmtiEnv->AddCapabilities(&caps); 770 GOTO_ON_ERROR(unregister_env_and_exit); 771 772 jvmtiEventCallbacks cb; 773 memset(&cb, 0, sizeof(cb)); 774 cb.ClassFileLoadHook = InspectClass; 775 776 jvmtiFrameInfo frameInfo[REALMETHOD_CALL_STACK_FRAME + 1]; 777 jint numFrames; 778 error = jvmtiEnv->GetStackTrace(nullptr, 0, REALMETHOD_CALL_STACK_FRAME + 1, frameInfo, 779 &numFrames); 780 GOTO_ON_ERROR(unregister_env_and_exit); 781 782 // Method might be called directly or as 'real method' (see 783 // StaticMockMethodAdvice.SuperMethodCall#invoke). Hence the real caller might be in stack 784 // frame DIRECT_CALL_STACK_FRAME for a direct call or REALMETHOD_CALL_STACK_FRAME for a 785 // call through the 'real method' mechanism. 786 int callingFrameNum; 787 if (numFrames < REALMETHOD_CALL_STACK_FRAME) { 788 callingFrameNum = DIRECT_CALL_STACK_FRAME; 789 } else { 790 char *directCallMethodName; 791 792 jvmtiEnv->GetMethodName(frameInfo[DIRECT_CALL_STACK_FRAME].method, 793 &directCallMethodName, nullptr, nullptr); 794 if (strcmp(directCallMethodName, "invoke") == 0) { 795 callingFrameNum = REALMETHOD_CALL_STACK_FRAME; 796 } else { 797 callingFrameNum = DIRECT_CALL_STACK_FRAME; 798 } 799 } 800 801 jclass callingClass; 802 error = jvmtiEnv->GetMethodDeclaringClass(frameInfo[callingFrameNum].method, &callingClass); 803 GOTO_ON_ERROR(unregister_env_and_exit); 804 805 error = jvmtiEnv->SetEventCallbacks(&cb, sizeof(cb)); 806 GOTO_ON_ERROR(unregister_env_and_exit); 807 808 error = jvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK, 809 currentThread); 810 GOTO_ON_ERROR(unset_cb_and_exit); 811 812 frameToInspect = &frameInfo[callingFrameNum]; 813 error = jvmtiEnv->RetransformClasses(1, &callingClass); 814 GOTO_ON_ERROR(disable_hook_and_exit); 815 816 disable_hook_and_exit: 817 jvmtiEnv->SetEventNotificationMode(JVMTI_DISABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK, 818 currentThread); 819 820 unset_cb_and_exit: 821 memset(&cb, 0, sizeof(cb)); 822 jvmtiEnv->SetEventCallbacks(&cb, sizeof(cb)); 823 824 unregister_env_and_exit: 825 jvmtiEnv->DisposeEnvironment(); 826 827 if (error != JVMTI_ERROR_NONE) { 828 return nullptr; 829 } 830 831 return env->NewStringUTF(calledClass.c_str()); 832 } 833 834 } // namespace com_android_dx_mockito_inline 835 836