1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <cstdlib>
18 #include <sstream>
19 #include <cstring>
20 #include <cassert>
21 #include <cstdarg>
22 #include <algorithm>
23
24 #include <jni.h>
25
26 #include "jvmti.h"
27
28 #include <slicer/dex_ir.h>
29 #include <slicer/code_ir.h>
30 #include <slicer/dex_ir_builder.h>
31 #include <slicer/dex_utf8.h>
32 #include <slicer/writer.h>
33 #include <slicer/reader.h>
34 #include <slicer/instrumentation.h>
35
36 using namespace dex;
37 using namespace lir;
38
39 namespace com_android_dx_mockito_inline {
40 static jvmtiEnv* localJvmtiEnv;
41
42 static jobject sTransformer;
43
44 // Converts a class name to a type descriptor
45 // (ex. "java.lang.String" to "Ljava/lang/String;")
46 static std::string
ClassNameToDescriptor(const char * class_name)47 ClassNameToDescriptor(const char* class_name) {
48 std::stringstream ss;
49 ss << "L";
50 for (auto p = class_name; *p != '\0'; ++p) {
51 ss << (*p == '.' ? '/' : *p);
52 }
53 ss << ";";
54 return ss.str();
55 }
56
57 // Takes the full dex file for class 'classBeingRedefined'
58 // - isolates the dex code for the class out of the dex file
59 // - calls sTransformer.runTransformers on the isolated dex code
60 // - send the transformed code back to the runtime
61 static void
Transform(jvmtiEnv * jvmti_env,JNIEnv * env,jclass classBeingRedefined,jobject loader,const char * name,jobject protectionDomain,jint classDataLen,const unsigned char * classData,jint * newClassDataLen,unsigned char ** newClassData)62 Transform(jvmtiEnv* jvmti_env,
63 JNIEnv* env,
64 jclass classBeingRedefined,
65 jobject loader,
66 const char* name,
67 jobject protectionDomain,
68 jint classDataLen,
69 const unsigned char* classData,
70 jint* newClassDataLen,
71 unsigned char** newClassData) {
72 if (sTransformer != NULL) {
73 // Even reading the classData array is expensive as the data is only generated when the
74 // memory is touched. Hence call JvmtiAgent#shouldTransform to check if we need to transform
75 // the class.
76 jclass cls = env->GetObjectClass(sTransformer);
77 jmethodID shouldTransformMethod = env->GetMethodID(cls, "shouldTransform",
78 "(Ljava/lang/Class;)Z");
79
80 jboolean shouldTransform = env->CallBooleanMethod(sTransformer, shouldTransformMethod,
81 classBeingRedefined);
82 if (!shouldTransform) {
83 return;
84 }
85
86 // Isolate byte code of class class. This is needed as Android usually gives us more
87 // than the class we need.
88 Reader reader(classData, classDataLen);
89
90 u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str());
91 reader.CreateClassIr(index);
92 std::shared_ptr<ir::DexFile> ir = reader.GetIr();
93
94 struct Allocator : public Writer::Allocator {
95 virtual void* Allocate(size_t size) {return ::malloc(size);}
96 virtual void Free(void* ptr) {::free(ptr);}
97 };
98
99 Allocator allocator;
100 Writer writer(ir);
101 size_t isolatedClassLen = 0;
102 std::shared_ptr<jbyte> isolatedClass((jbyte*)writer.CreateImage(&allocator,
103 &isolatedClassLen));
104
105 // Create jbyteArray with isolated byte code of class
106 jbyteArray isolatedClassArr = env->NewByteArray(isolatedClassLen);
107 env->SetByteArrayRegion(isolatedClassArr, 0, isolatedClassLen,
108 isolatedClass.get());
109
110 jstring nameStr = env->NewStringUTF(name);
111
112 // Call JvmtiAgent#runTransformers
113 jmethodID runTransformersMethod = env->GetMethodID(cls, "runTransformers",
114 "(Ljava/lang/ClassLoader;"
115 "Ljava/lang/String;"
116 "Ljava/lang/Class;"
117 "Ljava/security/ProtectionDomain;"
118 "[B)[B");
119
120 jbyteArray transformedArr = (jbyteArray) env->CallObjectMethod(sTransformer,
121 runTransformersMethod,
122 loader, nameStr,
123 classBeingRedefined,
124 protectionDomain,
125 isolatedClassArr);
126
127 // Set transformed byte code
128 if (!env->ExceptionOccurred() && transformedArr != NULL) {
129 *newClassDataLen = env->GetArrayLength(transformedArr);
130
131 jbyte* transformed = env->GetByteArrayElements(transformedArr, 0);
132
133 jvmti_env->Allocate(*newClassDataLen, newClassData);
134 std::memcpy(*newClassData, transformed, *newClassDataLen);
135
136 env->ReleaseByteArrayElements(transformedArr, transformed, 0);
137 }
138 }
139 }
140
141 // Add a label before instructionAfter
142 static void
addLabel(CodeIr & c,lir::Instruction * instructionAfter,Label * returnTrueLabel)143 addLabel(CodeIr& c,
144 lir::Instruction* instructionAfter,
145 Label* returnTrueLabel) {
146 c.instructions.InsertBefore(instructionAfter, returnTrueLabel);
147 }
148
149 // Add a byte code before instructionAfter
150 static void
addInstr(CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,const std::list<Operand * > & operands)151 addInstr(CodeIr& c,
152 lir::Instruction* instructionAfter,
153 Opcode opcode,
154 const std::list<Operand*>& operands) {
155 auto instruction = c.Alloc<Bytecode>();
156
157 instruction->opcode = opcode;
158
159 for (auto it = operands.begin(); it != operands.end(); it++) {
160 instruction->operands.push_back(*it);
161 }
162
163 c.instructions.InsertBefore(instructionAfter, instruction);
164 }
165
166 // Add a method call byte code before instructionAfter
167 static void
addCall(ir::Builder & b,CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,ir::Type * type,const char * methodName,ir::Type * returnType,const std::vector<ir::Type * > & types,const std::list<int> & regs)168 addCall(ir::Builder& b,
169 CodeIr& c,
170 lir::Instruction* instructionAfter,
171 Opcode opcode,
172 ir::Type* type,
173 const char* methodName,
174 ir::Type* returnType,
175 const std::vector<ir::Type*>& types,
176 const std::list<int>& regs) {
177 auto proto = b.GetProto(returnType, b.GetTypeList(types));
178 auto method = b.GetMethodDecl(b.GetAsciiString(methodName), proto, type);
179
180 VRegList* param_regs = c.Alloc<VRegList>();
181 for (auto it = regs.begin(); it != regs.end(); it++) {
182 param_regs->registers.push_back(*it);
183 }
184
185 addInstr(c, instructionAfter, opcode, {param_regs, c.Alloc<Method>(method,
186 method->orig_index)});
187 }
188
189 typedef struct {
190 ir::Type* boxedType;
191 ir::Type* scalarType;
192 std::string unboxMethod;
193 } BoxingInfo;
194
195 // Get boxing / unboxing info for a type
196 static BoxingInfo
getBoxingInfo(ir::Builder & b,char typeCode)197 getBoxingInfo(ir::Builder &b,
198 char typeCode) {
199 BoxingInfo boxingInfo;
200
201 if (typeCode != 'L' && typeCode != '[') {
202 std::stringstream tmp;
203 tmp << typeCode;
204 boxingInfo.scalarType = b.GetType(tmp.str().c_str());
205 }
206
207 switch (typeCode) {
208 case 'B':
209 boxingInfo.boxedType = b.GetType("Ljava/lang/Byte;");
210 boxingInfo.unboxMethod = "byteValue";
211 break;
212 case 'S':
213 boxingInfo.boxedType = b.GetType("Ljava/lang/Short;");
214 boxingInfo.unboxMethod = "shortValue";
215 break;
216 case 'I':
217 boxingInfo.boxedType = b.GetType("Ljava/lang/Integer;");
218 boxingInfo.unboxMethod = "intValue";
219 break;
220 case 'C':
221 boxingInfo.boxedType = b.GetType("Ljava/lang/Character;");
222 boxingInfo.unboxMethod = "charValue";
223 break;
224 case 'F':
225 boxingInfo.boxedType = b.GetType("Ljava/lang/Float;");
226 boxingInfo.unboxMethod = "floatValue";
227 break;
228 case 'Z':
229 boxingInfo.boxedType = b.GetType("Ljava/lang/Boolean;");
230 boxingInfo.unboxMethod = "booleanValue";
231 break;
232 case 'J':
233 boxingInfo.boxedType = b.GetType("Ljava/lang/Long;");
234 boxingInfo.unboxMethod = "longValue";
235 break;
236 case 'D':
237 boxingInfo.boxedType = b.GetType("Ljava/lang/Double;");
238 boxingInfo.unboxMethod = "doubleValue";
239 break;
240 default:
241 // real object
242 break;
243 }
244
245 return boxingInfo;
246 }
247
248 static size_t
getNumParams(ir::EncodedMethod * method)249 getNumParams(ir::EncodedMethod *method) {
250 if (method->decl->prototype->param_types == NULL) {
251 return 0;
252 }
253
254 return method->decl->prototype->param_types->types.size();
255 }
256
257 static bool
canBeTransformed(ir::EncodedMethod * method)258 canBeTransformed(ir::EncodedMethod *method) {
259 std::string type = method->decl->parent->Decl();
260 ir::String* methodName = method->decl->name;
261
262 return !(((method->access_flags & (kAccAbstract | kAccPrivate | kAccBridge | kAccNative
263 | kAccStatic)) != 0)
264 || (Utf8Cmp(methodName->c_str(), "<init>") == 0)
265 || (Utf8Cmp(methodName->c_str(), "<clinit>") == 0)
266 || (Utf8Cmp(type.c_str(), "java.lang.Object") == 0
267 && Utf8Cmp(methodName->c_str(), "finalize") == 0
268 && getNumParams(method) == 0)
269 || (strncmp(type.c_str(), "java.", 5) == 0
270 && (method->access_flags & (kAccPrivate | kAccPublic | kAccProtected)) == 0)
271 // getClass is used by MockMethodAdvice.isOverridden
272 || (Utf8Cmp(methodName->c_str(), "getClass") == 0));
273 }
274
275 static bool
isHashCode(ir::EncodedMethod * method)276 isHashCode(ir::EncodedMethod *method) {
277 return Utf8Cmp(method->decl->name->c_str(), "hashCode") == 0
278 && getNumParams(method) == 0;
279 }
280
281 static bool
isEquals(ir::EncodedMethod * method)282 isEquals(ir::EncodedMethod *method) {
283 return Utf8Cmp(method->decl->name->c_str(), "equals") == 0
284 && getNumParams(method) == 1
285 && Utf8Cmp(method->decl->prototype->param_types->types[0]->Decl().c_str(),
286 "java.lang.Object") == 0;
287 }
288
289 // Transforms the classes to add the mockito hooks
290 // - equals and hashcode are handled in a special way
291 extern "C" JNIEXPORT jbyteArray JNICALL
Java_com_android_dx_mockito_inline_ClassTransformer_nativeRedefine(JNIEnv * env,jobject generator,jstring idStr,jbyteArray originalArr)292 Java_com_android_dx_mockito_inline_ClassTransformer_nativeRedefine(JNIEnv* env,
293 jobject generator,
294 jstring idStr,
295 jbyteArray originalArr) {
296 unsigned char* original = (unsigned char*)env->GetByteArrayElements(originalArr, 0);
297
298 Reader reader(original, env->GetArrayLength(originalArr));
299 reader.CreateClassIr(0);
300 std::shared_ptr<ir::DexFile> dex_ir = reader.GetIr();
301 ir::Builder b(dex_ir);
302
303 ir::Type* booleanScalarT = b.GetType("Z");
304 ir::Type* intScalarT = b.GetType("I");
305 ir::Type* objectT = b.GetType("Ljava/lang/Object;");
306 ir::Type* objectArrayT = b.GetType("[Ljava/lang/Object;");
307 ir::Type* stringT = b.GetType("Ljava/lang/String;");
308 ir::Type* methodT = b.GetType("Ljava/lang/reflect/Method;");
309 ir::Type* systemT = b.GetType("Ljava/lang/System;");
310 ir::Type* callableT = b.GetType("Ljava/util/concurrent/Callable;");
311 ir::Type* dispatcherT = b.GetType("Lcom/android/dx/mockito/inline/MockMethodDispatcher;");
312
313 // Add id to dex file
314 const char* idNative = env->GetStringUTFChars(idStr, 0);
315 ir::String* id = b.GetAsciiString(idNative);
316 env->ReleaseStringUTFChars(idStr, idNative);
317
318 for (auto& method : dex_ir->encoded_methods) {
319 if (!canBeTransformed(method.get())) {
320 continue;
321 }
322
323 if (isEquals(method.get())) {
324 /*
325 equals_original(Object other) {
326 T t = foo(other);
327 return bar(t);
328 }
329
330 equals_transformed(params) {
331 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this);
332 const-string v0, "65463hg34t"
333 move-objectfrom16 v1, THIS
334 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher
335 move-result-object v2
336
337 // if (dispatcher == null || ) {
338 // goto original_method;
339 // }
340 if-eqz v2, original_method
341
342 // if (!dispatcher.isMock(this)) {
343 // goto original_method;
344 // }
345 invoke-virtual {v2, v1}, MockMethodDispatcher.isMock(Object):Method
346 move-result v2
347 if-eqz v2, original_method
348
349 // return self == other
350 move-objectfrom16 v0, ARG1
351 if-eq v0, v1, return_true
352
353 const v0, 0
354 return v0
355
356 return true:
357 const v0, 1
358 return v0
359
360 original_method:
361 // Move all method arguments down so that they match what the original code expects.
362 move-object16 v4, v5 # THIS
363 move-object16 v5, v6 # ARG1
364
365 T t = foo(other);
366 return bar(t);
367 }
368 */
369
370 CodeIr c(method.get(), dex_ir);
371
372 // Make sure there are at least 5 local registers to use
373 int originalNumRegisters = method->code->registers - method->code->ins_count;
374 int numAdditionalRegs = std::max(0, 3 - originalNumRegisters);
375 int thisReg = numAdditionalRegs + method->code->registers
376 - method->code->ins_count;
377
378 if (numAdditionalRegs > 0) {
379 c.ir_method->code->registers += numAdditionalRegs;
380 }
381
382 lir::Instruction* fi = *(c.instructions.begin());
383
384 Label* originalMethodLabel = c.Alloc<Label>(0);
385 Label* returnTrueLabel = c.Alloc<Label>(0);
386 CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel);
387 VReg* v0 = c.Alloc<VReg>(0);
388 VReg* v1 = c.Alloc<VReg>(1);
389 VReg* v2 = c.Alloc<VReg>(2);
390 VReg* thiz = c.Alloc<VReg>(thisReg);
391
392 addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)});
393 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v1, thiz});
394 addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT,
395 {stringT, objectT}, {0, 1});
396 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v2});
397 addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
398 addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "isMock", booleanScalarT, {objectT},
399 {2, 1});
400 addInstr(c, fi, OP_MOVE_RESULT, {v2});
401 addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
402 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v0, c.Alloc<VReg>(thisReg + 1)});
403 addInstr(c, fi, OP_IF_EQ, {v0, v1, c.Alloc<CodeLocation>(returnTrueLabel)});
404 addInstr(c, fi, OP_CONST, {v0, c.Alloc<Const32>(0)});
405 addInstr(c, fi, OP_RETURN, {v0});
406 addLabel(c, fi, returnTrueLabel);
407 addInstr(c, fi, OP_CONST, {v0, c.Alloc<Const32>(1)});
408 addInstr(c, fi, OP_RETURN, {v0});
409 addLabel(c, fi, originalMethodLabel);
410 addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs), thiz});
411 addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs + 1),
412 c.Alloc<VReg>(thisReg + 1)});
413
414 c.Assemble();
415 } else if (isHashCode(method.get())) {
416 /*
417 hashCode_original(Object other) {
418 T t = foo(other);
419 return bar(t);
420 }
421
422 hashCode_transformed(params) {
423 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this);
424 const-string v0, "65463hg34t"
425 move-objectfrom16 v1, THIS
426 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher
427 move-result-object v2
428
429 // if (dispatcher == null || ) {
430 // goto original_method;
431 // }
432 if-eqz v2, original_method
433
434 // if (!dispatcher.isMock(this)) {
435 // goto original_method;
436 // }
437 invoke-interface {v2, v1}, MockMethodDispatcher.isMock(Object):Method
438 move-result v2
439 if-eqz v2, original_method
440
441 // return System.identityHashCode(this);
442 invoke-static {v1}, System.identityHashCode(Object):int
443 move-result v2
444 return v2
445
446 original_method:
447 // Move all method arguments down so that they match what the original code expects.
448 move-object16 v4, v5 # THIS
449
450 T t = foo(other);
451 return bar(t);
452 }
453 */
454
455 CodeIr c(method.get(), dex_ir);
456
457 // Make sure there are at least 5 local registers to use
458 int originalNumRegisters = method->code->registers - method->code->ins_count;
459 int numAdditionalRegs = std::max(0, 3 - originalNumRegisters);
460 int thisReg = numAdditionalRegs + method->code->registers - method->code->ins_count;
461
462 if (numAdditionalRegs > 0) {
463 c.ir_method->code->registers += numAdditionalRegs;
464 }
465
466 lir::Instruction* fi = *(c.instructions.begin());
467
468 Label* originalMethodLabel = c.Alloc<Label>(0);
469 CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel);
470 VReg* v0 = c.Alloc<VReg>(0);
471 VReg* v1 = c.Alloc<VReg>(1);
472 VReg* v2 = c.Alloc<VReg>(2);
473 VReg* thiz = c.Alloc<VReg>(thisReg);
474
475 addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)});
476 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v1, thiz});
477 addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT,
478 {stringT, objectT}, {0, 1});
479 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v2});
480 addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
481 addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "isMock", booleanScalarT, {objectT},
482 {2, 1});
483 addInstr(c, fi, OP_MOVE_RESULT, {v2});
484 addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
485 addCall(b, c, fi, OP_INVOKE_STATIC, systemT, "identityHashCode", intScalarT, {objectT},
486 {1});
487 addInstr(c, fi, OP_MOVE_RESULT, {v2});
488 addInstr(c, fi, OP_RETURN, {v2});
489 addLabel(c, fi, originalMethodLabel);
490 addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs), thiz});
491
492 c.Assemble();
493 } else {
494 /*
495 long method_original(int param1, long param2, String param3) {
496 foo();
497 return bar();
498 }
499
500 long method_transformed(int param1, long param2, String param3) {
501 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this);
502 const-string v0, "65463hg34t"
503 move-objectfrom16 v1, THIS # this is necessary as invoke-static cannot deal
504 # with medium or high registers and THIS might not
505 # be low
506 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher
507 move-result-object v0
508
509 // if (dispatcher == null) {
510 // goto original_method;
511 // }
512 if-eqz v0, original_method
513
514 // Method origin = dispatcher.getOrigin(this, methodDesc);
515 const-string v1 "fully.qualified.ClassName#original_method(int, long, String)"
516 move-objectfrom16 v2, THIS # this is necessary as invoke-static cannot deal
517 # with medium or high registers and THIS might not
518 # be low
519 invoke-virtual {v0, v2, v1}, MockMethodDispatcher.getOrigin(Object, String):Method
520 move-result-object v1
521
522 // if (origin == null) {
523 // goto original_method;
524 // }
525 if-eqz v1, original_method
526
527 // Create an array with Objects of all parameters.
528
529 // Object[] arguments = new Object[3]
530 const v3, 3
531 new-array v2, v3, Object[]
532
533 // Integer param1Integer = Integer.valueOf(param1)
534 move-from16 v3, ARG1 # this is necessary as invoke-static cannot deal with high
535 # registers and ARG1 might be high
536 invoke-static {v3}, Integer.valueOf(int):Integer
537 move-result-object v3
538
539 // arguments[0] = param1Integer
540 const v4, 0
541 aput-object v3, v2, v4
542
543 // Long param2Long = Long.valueOf(param2)
544 move-widefrom16 v3:v4, ARG2.1:ARG2.2 # this is necessary as invoke-static cannot
545 # deal with high registers and ARG2 might be
546 # high
547 invoke-static {v3, v4}, Long.valueOf(long):Long
548 move-result-object v3
549
550 // arguments[1] = param2Long
551 const v4, 1
552 aput-object v3, v2, v4
553
554 // arguments[2] = param3
555 const v4, 2
556 move-objectfrom16 v3, ARG3 # this is necessary as aput-object cannot deal with
557 # high registers and ARG3 might be high
558 aput-object v3, v2, v4
559
560 // Callable<?> mocked = dispatcher.handle(this, origin, arguments);
561 move-objectfrom16 v3, THIS # this is necessary as invoke-virtual cannot deal
562 # with medium or high registers and THIS might not
563 # be low
564 invoke-virtual {v0,v3,v1,v2}, MockMethodDispatcher.handle(Object, Method,
565 Object[]):Callable
566 move-result-object v0
567
568 // if (mocked != null) {
569 if-eqz v0, original_method
570
571 // Object ret = mocked.call();
572 invoke-interface {v0}, Callable.call():Object
573 move-result-object v0
574
575 // Long retLong = (Long)ret
576 check-cast v0, Long
577
578 // long retlong = retLong.longValue();
579 invoke-virtual {v0}, Long.longValue():long
580 move-result-wide v0:v1
581
582 // return retlong;
583 return-wide v0:v1
584
585 // }
586
587 original_method:
588 // Move all method arguments down so that they match what the original code expects.
589 // Let's assume three arguments, one int, one long, one String and the and used to
590 // use 4 registers
591 move-object16 v4, v5 # THIS
592 move16 v5, v6 # ARG1
593 move-wide16 v6:v7, v7:v8 # ARG2 (overlapping moves are allowed)
594 move-object16 v8, v9 # ARG3
595
596 // foo();
597 // return bar();
598 unmodified original byte code
599 }
600 */
601
602 CodeIr c(method.get(), dex_ir);
603
604 // Make sure there are at least 5 local registers to use
605 int originalNumRegisters = method->code->registers - method->code->ins_count;
606 int numAdditionalRegs = std::max(0, 5 - originalNumRegisters);
607 int thisReg = originalNumRegisters + numAdditionalRegs;
608
609 if (numAdditionalRegs > 0) {
610 c.ir_method->code->registers += numAdditionalRegs;
611 }
612
613 lir::Instruction* fi = *(c.instructions.begin());
614
615 // Add methodDesc to dex file
616 std::stringstream ss;
617 ss << method->decl->parent->Decl() << "#" << method->decl->name->c_str() << "(" ;
618 bool first = true;
619 if (method->decl->prototype->param_types != NULL) {
620 for (const auto& type : method->decl->prototype->param_types->types) {
621 if (first) {
622 first = false;
623 } else {
624 ss << ",";
625 }
626
627 ss << type->Decl().c_str();
628 }
629 }
630 ss << ")";
631 std::string methodDescStr = ss.str();
632 ir::String* methodDesc = b.GetAsciiString(methodDescStr.c_str());
633
634 size_t numParams = getNumParams(method.get());
635
636 Label* originalMethodLabel = c.Alloc<Label>(0);
637 CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel);
638 VReg* v0 = c.Alloc<VReg>(0);
639 VReg* v1 = c.Alloc<VReg>(1);
640 VReg* v2 = c.Alloc<VReg>(2);
641 VReg* v3 = c.Alloc<VReg>(3);
642 VReg* v4 = c.Alloc<VReg>(4);
643 VReg* thiz = c.Alloc<VReg>(thisReg);
644
645 addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)});
646 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v1, thiz});
647 addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT, {stringT, objectT},
648 {0, 1});
649 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
650 addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod});
651 addInstr(c, fi, OP_CONST_STRING,
652 {v1, c.Alloc<String>(methodDesc, methodDesc->orig_index)});
653 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v2, thiz});
654 addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "getOrigin", methodT,
655 {objectT, stringT}, {0, 2, 1});
656 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v1});
657 addInstr(c, fi, OP_IF_EQZ, {v1, originalMethod});
658 addInstr(c, fi, OP_CONST, {v3, c.Alloc<Const32>(numParams)});
659 addInstr(c, fi, OP_NEW_ARRAY, {v2, v3, c.Alloc<Type>(objectArrayT,
660 objectArrayT->orig_index)});
661
662 if (numParams > 0) {
663 int argReg = thisReg + 1;
664
665 for (int argNum = 0; argNum < numParams; argNum++) {
666 const auto& type = method->decl->prototype->param_types->types[argNum];
667 BoxingInfo boxingInfo = getBoxingInfo(b, type->descriptor->c_str()[0]);
668
669 switch (type->GetCategory()) {
670 case ir::Type::Category::Scalar:
671 addInstr(c, fi, OP_MOVE_FROM16, {v3, c.Alloc<VReg>(argReg)});
672 addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf",
673 boxingInfo.boxedType, {type}, {3});
674 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3});
675
676 argReg++;
677 break;
678 case ir::Type::Category::WideScalar: {
679 VRegPair* v3v4 = c.Alloc<VRegPair>(3);
680 VRegPair* argRegPair = c.Alloc<VRegPair>(argReg);
681
682 addInstr(c, fi, OP_MOVE_WIDE_FROM16, {v3v4, argRegPair});
683 addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf",
684 boxingInfo.boxedType, {type}, {3, 4});
685 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3});
686
687 argReg += 2;
688 break;
689 }
690 case ir::Type::Category::Reference:
691 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v3, c.Alloc<VReg>(argReg)});
692
693 argReg++;
694 break;
695 case ir::Type::Category::Void:
696 assert(false);
697 }
698
699 addInstr(c, fi, OP_CONST, {v4, c.Alloc<Const32>(argNum)});
700 addInstr(c, fi, OP_APUT_OBJECT, {v3, v2, v4});
701 }
702 }
703
704 addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v3, thiz});
705 addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "handle", callableT,
706 {objectT, methodT, objectArrayT}, {0, 3, 1, 2});
707 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
708 addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod});
709 addCall(b, c, fi, OP_INVOKE_INTERFACE, callableT, "call", objectT, {}, {0});
710 addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
711
712 ir::Type *returnType = method->decl->prototype->return_type;
713 BoxingInfo boxingInfo = getBoxingInfo(b, returnType->descriptor->c_str()[0]);
714
715 switch (returnType->GetCategory()) {
716 case ir::Type::Category::Scalar:
717 addInstr(c, fi, OP_CHECK_CAST, {v0,
718 c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)});
719 addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType,
720 boxingInfo.unboxMethod.c_str(), returnType, {}, {0});
721 addInstr(c, fi, OP_MOVE_RESULT, {v0});
722 addInstr(c, fi, OP_RETURN, {v0});
723 break;
724 case ir::Type::Category::WideScalar: {
725 VRegPair* v0v1 = c.Alloc<VRegPair>(0);
726
727 addInstr(c, fi, OP_CHECK_CAST, {v0,
728 c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)});
729 addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType,
730 boxingInfo.unboxMethod.c_str(), returnType, {}, {0});
731 addInstr(c, fi, OP_MOVE_RESULT_WIDE, {v0v1});
732 addInstr(c, fi, OP_RETURN_WIDE, {v0v1});
733 break;
734 }
735 case ir::Type::Category::Reference:
736 addInstr(c, fi, OP_CHECK_CAST, {v0, c.Alloc<Type>(returnType,
737 returnType->orig_index)});
738 addInstr(c, fi, OP_RETURN_OBJECT, {v0});
739 break;
740 case ir::Type::Category::Void:
741 addInstr(c, fi, OP_RETURN_VOID, {});
742 break;
743 }
744
745 addLabel(c, fi, originalMethodLabel);
746 addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs), thiz});
747
748 if (numParams > 0) {
749 int argReg = thisReg + 1;
750
751 for (int argNum = 0; argNum < numParams; argNum++) {
752 const auto& type = method->decl->prototype->param_types->types[argNum];
753 int origReg = argReg - numAdditionalRegs;
754 switch (type->GetCategory()) {
755 case ir::Type::Category::Scalar:
756 addInstr(c, fi, OP_MOVE_16, {c.Alloc<VReg>(origReg),
757 c.Alloc<VReg>(argReg)});
758 argReg++;
759 break;
760 case ir::Type::Category::WideScalar:
761 addInstr(c, fi, OP_MOVE_WIDE_16,{c.Alloc<VRegPair>(origReg),
762 c.Alloc<VRegPair>(argReg)});
763 argReg +=2;
764 break;
765 case ir::Type::Category::Reference:
766 addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(origReg),
767 c.Alloc<VReg>(argReg)});
768 argReg++;
769 break;
770 }
771 }
772 }
773
774 c.Assemble();
775 }
776 }
777
778 struct Allocator : public Writer::Allocator {
779 virtual void* Allocate(size_t size) {return ::malloc(size);}
780 virtual void Free(void* ptr) {::free(ptr);}
781 };
782
783 Allocator allocator;
784 Writer writer(dex_ir);
785 size_t transformedLen = 0;
786 std::shared_ptr<jbyte> transformed((jbyte*)writer.CreateImage(&allocator, &transformedLen));
787
788 jbyteArray transformedArr = env->NewByteArray(transformedLen);
789 env->SetByteArrayRegion(transformedArr, 0, transformedLen, transformed.get());
790
791 return transformedArr;
792 }
793
794 // Initializes the agent
Agent_OnAttach(JavaVM * vm,char * options,void * reserved)795 extern "C" jint Agent_OnAttach(JavaVM* vm,
796 char* options,
797 void* reserved) {
798 jint jvmError = vm->GetEnv(reinterpret_cast<void**>(&localJvmtiEnv), JVMTI_VERSION_1_2);
799 if (jvmError != JNI_OK) {
800 return jvmError;
801 }
802
803 jvmtiCapabilities caps;
804 memset(&caps, 0, sizeof(caps));
805 caps.can_retransform_classes = 1;
806
807 jvmtiError error = localJvmtiEnv->AddCapabilities(&caps);
808 if (error != JVMTI_ERROR_NONE) {
809 return error;
810 }
811
812 jvmtiEventCallbacks cb;
813 memset(&cb, 0, sizeof(cb));
814 cb.ClassFileLoadHook = Transform;
815
816 error = localJvmtiEnv->SetEventCallbacks(&cb, sizeof(cb));
817 if (error != JVMTI_ERROR_NONE) {
818 return error;
819 }
820
821 error = localJvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK,
822 NULL);
823 if (error != JVMTI_ERROR_NONE) {
824 return error;
825 }
826
827 return JVMTI_ERROR_NONE;
828 }
829
830 // Throw runtime exception
throwRuntimeExpection(JNIEnv * env,const char * fmt,...)831 static void throwRuntimeExpection(JNIEnv* env, const char* fmt, ...) {
832 char msgBuf[512];
833
834 va_list args;
835 va_start (args, fmt);
836 vsnprintf(msgBuf, sizeof(msgBuf), fmt, args);
837 va_end (args);
838
839 jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
840 env->ThrowNew(exceptionClass, msgBuf);
841 }
842
843 // Register transformer hook
844 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRegisterTransformerHook(JNIEnv * env,jobject thiz)845 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRegisterTransformerHook(JNIEnv* env,
846 jobject thiz) {
847 sTransformer = env->NewGlobalRef(thiz);
848 }
849
850 // Unregister transformer hook
851 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeUnregisterTransformerHook(JNIEnv * env,jobject thiz)852 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeUnregisterTransformerHook(JNIEnv* env,
853 jobject thiz) {
854 env->DeleteGlobalRef(sTransformer);
855 sTransformer = NULL;
856 }
857
858 // Triggers retransformation of classes via this file's Transform method
859 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRetransformClasses(JNIEnv * env,jobject thiz,jobjectArray classes)860 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRetransformClasses(JNIEnv* env,
861 jobject thiz,
862 jobjectArray classes) {
863 jsize numTransformedClasses = env->GetArrayLength(classes);
864 jclass *transformedClasses = (jclass*) malloc(numTransformedClasses * sizeof(jclass));
865 for (int i = 0; i < numTransformedClasses; i++) {
866 transformedClasses[i] = (jclass) env->NewGlobalRef(env->GetObjectArrayElement(classes, i));
867 }
868
869 jvmtiError error = localJvmtiEnv->RetransformClasses(numTransformedClasses,
870 transformedClasses);
871
872 for (int i = 0; i < numTransformedClasses; i++) {
873 env->DeleteGlobalRef(transformedClasses[i]);
874 }
875 free(transformedClasses);
876
877 if (error != JVMTI_ERROR_NONE) {
878 throwRuntimeExpection(env, "Could not retransform classes: %d", error);
879 }
880 }
881
882 // Adds a jar file to the bootstrap class loader
883 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeAppendToBootstrapClassLoaderSearch(JNIEnv * env,jclass klass,jstring jarFile)884 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeAppendToBootstrapClassLoaderSearch(JNIEnv* env,
885 jclass klass,
886 jstring jarFile) {
887 const char *jarFileNative = env->GetStringUTFChars(jarFile, 0);
888 jvmtiError error = localJvmtiEnv->AddToBootstrapClassLoaderSearch(jarFileNative);
889
890 if (error != JVMTI_ERROR_NONE) {
891 throwRuntimeExpection(env, "Could not add %s to bootstrap class path: %d", jarFileNative,
892 error);
893 }
894
895 env->ReleaseStringUTFChars(jarFile, jarFileNative);
896 }
897 } // namespace com_android_dx_mockito_inline
898
899