• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cstdlib>
18 #include <sstream>
19 #include <cstring>
20 #include <cassert>
21 #include <cstdarg>
22 #include <algorithm>
23 
24 #include <jni.h>
25 
26 #include "jvmti.h"
27 
28 #include <slicer/dex_ir.h>
29 #include <slicer/code_ir.h>
30 #include <slicer/dex_ir_builder.h>
31 #include <slicer/dex_utf8.h>
32 #include <slicer/writer.h>
33 #include <slicer/reader.h>
34 #include <slicer/instrumentation.h>
35 
36 using namespace dex;
37 using namespace lir;
38 
39 namespace com_android_dx_mockito_inline {
40     static jvmtiEnv* localJvmtiEnv;
41 
42     static jobject sTransformer;
43 
44     // Converts a class name to a type descriptor
45     // (ex. "java.lang.String" to "Ljava/lang/String;")
46     static std::string
ClassNameToDescriptor(const char * class_name)47     ClassNameToDescriptor(const char* class_name) {
48         std::stringstream ss;
49         ss << "L";
50         for (auto p = class_name; *p != '\0'; ++p) {
51             ss << (*p == '.' ? '/' : *p);
52         }
53         ss << ";";
54         return ss.str();
55     }
56 
57     // Takes the full dex file for class 'classBeingRedefined'
58     // - isolates the dex code for the class out of the dex file
59     // - calls sTransformer.runTransformers on the isolated dex code
60     // - send the transformed code back to the runtime
61     static void
Transform(jvmtiEnv * jvmti_env,JNIEnv * env,jclass classBeingRedefined,jobject loader,const char * name,jobject protectionDomain,jint classDataLen,const unsigned char * classData,jint * newClassDataLen,unsigned char ** newClassData)62     Transform(jvmtiEnv* jvmti_env,
63               JNIEnv* env,
64               jclass classBeingRedefined,
65               jobject loader,
66               const char* name,
67               jobject protectionDomain,
68               jint classDataLen,
69               const unsigned char* classData,
70               jint* newClassDataLen,
71               unsigned char** newClassData) {
72         if (sTransformer != nullptr) {
73             // Even reading the classData array is expensive as the data is only generated when the
74             // memory is touched. Hence call JvmtiAgent#shouldTransform to check if we need to transform
75             // the class.
76             jclass cls = env->GetObjectClass(sTransformer);
77             jmethodID shouldTransformMethod = env->GetMethodID(cls, "shouldTransform",
78                                                                "(Ljava/lang/Class;)Z");
79 
80             jboolean shouldTransform = env->CallBooleanMethod(sTransformer, shouldTransformMethod,
81                                                               classBeingRedefined);
82             if (!shouldTransform) {
83                 return;
84             }
85 
86             // Isolate byte code of class class. This is needed as Android usually gives us more
87             // than the class we need.
88             Reader reader(classData, classDataLen);
89 
90             u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str());
91             reader.CreateClassIr(index);
92             std::shared_ptr<ir::DexFile> ir = reader.GetIr();
93 
94             struct Allocator : public Writer::Allocator {
95                 virtual void* Allocate(size_t size) {return ::malloc(size);}
96                 virtual void Free(void* ptr) {::free(ptr);}
97             };
98 
99             Allocator allocator;
100             Writer writer(ir);
101             size_t isolatedClassLen = 0;
102             std::shared_ptr<jbyte> isolatedClass((jbyte*)writer.CreateImage(&allocator,
103                                                                             &isolatedClassLen));
104 
105             // Create jbyteArray with isolated byte code of class
106             jbyteArray isolatedClassArr = env->NewByteArray(isolatedClassLen);
107             env->SetByteArrayRegion(isolatedClassArr, 0, isolatedClassLen,
108                                     isolatedClass.get());
109 
110             jstring nameStr = env->NewStringUTF(name);
111 
112             // Call JvmtiAgent#runTransformers
113             jmethodID runTransformersMethod = env->GetMethodID(cls, "runTransformers",
114                                                                "(Ljava/lang/ClassLoader;"
115                                                                "Ljava/lang/String;"
116                                                                "Ljava/lang/Class;"
117                                                                "Ljava/security/ProtectionDomain;"
118                                                                "[B)[B");
119 
120             jbyteArray transformedArr = (jbyteArray) env->CallObjectMethod(sTransformer,
121                                                                            runTransformersMethod,
122                                                                            loader, nameStr,
123                                                                            classBeingRedefined,
124                                                                            protectionDomain,
125                                                                            isolatedClassArr);
126 
127             // Set transformed byte code
128             if (!env->ExceptionOccurred() && transformedArr != nullptr) {
129                 *newClassDataLen = env->GetArrayLength(transformedArr);
130 
131                 jbyte* transformed = env->GetByteArrayElements(transformedArr, 0);
132 
133                 jvmti_env->Allocate(*newClassDataLen, newClassData);
134                 std::memcpy(*newClassData, transformed, *newClassDataLen);
135 
136                 env->ReleaseByteArrayElements(transformedArr, transformed, 0);
137             }
138         }
139     }
140 
141     // Add a label before instructionAfter
142     static void
addLabel(CodeIr & c,lir::Instruction * instructionAfter,Label * returnTrueLabel)143     addLabel(CodeIr& c,
144              lir::Instruction* instructionAfter,
145              Label* returnTrueLabel) {
146         c.instructions.InsertBefore(instructionAfter, returnTrueLabel);
147     }
148 
149     // Add a byte code before instructionAfter
150     static void
addInstr(CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,const std::list<Operand * > & operands)151     addInstr(CodeIr& c,
152              lir::Instruction* instructionAfter,
153              Opcode opcode,
154              const std::list<Operand*>& operands) {
155         auto instruction = c.Alloc<Bytecode>();
156 
157         instruction->opcode = opcode;
158 
159         for (auto it = operands.begin(); it != operands.end(); it++) {
160             instruction->operands.push_back(*it);
161         }
162 
163         c.instructions.InsertBefore(instructionAfter, instruction);
164     }
165 
166     // Add a method call byte code before instructionAfter
167     static void
addCall(ir::Builder & b,CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,ir::Type * type,const char * methodName,ir::Type * returnType,const std::vector<ir::Type * > & types,const std::list<int> & regs)168     addCall(ir::Builder& b,
169             CodeIr& c,
170             lir::Instruction* instructionAfter,
171             Opcode opcode,
172             ir::Type* type,
173             const char* methodName,
174             ir::Type* returnType,
175             const std::vector<ir::Type*>& types,
176             const std::list<int>& regs) {
177         auto proto = b.GetProto(returnType, b.GetTypeList(types));
178         auto method = b.GetMethodDecl(b.GetAsciiString(methodName), proto, type);
179 
180         VRegList* param_regs = c.Alloc<VRegList>();
181         for (auto it = regs.begin(); it != regs.end(); it++) {
182             param_regs->registers.push_back(*it);
183         }
184 
185         addInstr(c, instructionAfter, opcode, {param_regs, c.Alloc<Method>(method,
186                                                                            method->orig_index)});
187     }
188 
189     typedef struct {
190         ir::Type* boxedType;
191         ir::Type* scalarType;
192         std::string unboxMethod;
193     } BoxingInfo;
194 
195     // Get boxing / unboxing info for a type
196     static BoxingInfo
getBoxingInfo(ir::Builder & b,char typeCode)197     getBoxingInfo(ir::Builder &b,
198                   char typeCode) {
199         BoxingInfo boxingInfo;
200 
201         if (typeCode != 'L' && typeCode !=  '[') {
202             std::stringstream tmp;
203             tmp << typeCode;
204             boxingInfo.scalarType = b.GetType(tmp.str().c_str());
205         }
206 
207         switch (typeCode) {
208             case 'B':
209                 boxingInfo.boxedType = b.GetType("Ljava/lang/Byte;");
210                 boxingInfo.unboxMethod = "byteValue";
211                 break;
212             case 'S':
213                 boxingInfo.boxedType = b.GetType("Ljava/lang/Short;");
214                 boxingInfo.unboxMethod = "shortValue";
215                 break;
216             case 'I':
217                 boxingInfo.boxedType = b.GetType("Ljava/lang/Integer;");
218                 boxingInfo.unboxMethod = "intValue";
219                 break;
220             case 'C':
221                 boxingInfo.boxedType = b.GetType("Ljava/lang/Character;");
222                 boxingInfo.unboxMethod = "charValue";
223                 break;
224             case 'F':
225                 boxingInfo.boxedType = b.GetType("Ljava/lang/Float;");
226                 boxingInfo.unboxMethod = "floatValue";
227                 break;
228             case 'Z':
229                 boxingInfo.boxedType = b.GetType("Ljava/lang/Boolean;");
230                 boxingInfo.unboxMethod = "booleanValue";
231                 break;
232             case 'J':
233                 boxingInfo.boxedType = b.GetType("Ljava/lang/Long;");
234                 boxingInfo.unboxMethod = "longValue";
235                 break;
236             case 'D':
237                 boxingInfo.boxedType = b.GetType("Ljava/lang/Double;");
238                 boxingInfo.unboxMethod = "doubleValue";
239                 break;
240             default:
241                 // real object
242                 break;
243         }
244 
245         return boxingInfo;
246     }
247 
248     static size_t
getNumParams(ir::EncodedMethod * method)249     getNumParams(ir::EncodedMethod *method) {
250         if (method->decl->prototype->param_types == nullptr) {
251             return 0;
252         }
253 
254         return method->decl->prototype->param_types->types.size();
255     }
256 
257     static bool
canBeTransformed(ir::EncodedMethod * method)258     canBeTransformed(ir::EncodedMethod *method) {
259         std::string type = method->decl->parent->Decl();
260         ir::String* methodName = method->decl->name;
261 
262         return ((method->access_flags & kAccStatic) != 0)
263                && !(((method->access_flags & (kAccPrivate | kAccBridge | kAccNative)) != 0)
264                     || (Utf8Cmp(methodName->c_str(), "<clinit>") == 0)
265                     || (strncmp(type.c_str(), "java.", 5) == 0
266                         && (method->access_flags & (kAccPrivate | kAccPublic | kAccProtected))
267                            == 0));
268     }
269 
270     // Transforms the classes to add the mockito hooks
271     // - equals and hashcode are handled in a special way
272     extern "C" JNIEXPORT jbyteArray JNICALL
Java_com_android_dx_mockito_inline_StaticClassTransformer_nativeRedefine(JNIEnv * env,jobject generator,jstring idStr,jbyteArray originalArr)273     Java_com_android_dx_mockito_inline_StaticClassTransformer_nativeRedefine(JNIEnv* env,
274                                                                            jobject generator,
275                                                                            jstring idStr,
276                                                                            jbyteArray originalArr) {
277         unsigned char* original = (unsigned char*)env->GetByteArrayElements(originalArr, 0);
278 
279         Reader reader(original, env->GetArrayLength(originalArr));
280         reader.CreateClassIr(0);
281         std::shared_ptr<ir::DexFile> dex_ir = reader.GetIr();
282         ir::Builder b(dex_ir);
283 
284         ir::Type* objectT = b.GetType("Ljava/lang/Object;");
285         ir::Type* objectArrayT = b.GetType("[Ljava/lang/Object;");
286         ir::Type* stringT = b.GetType("Ljava/lang/String;");
287         ir::Type* methodT = b.GetType("Ljava/lang/reflect/Method;");
288         ir::Type* callableT = b.GetType("Ljava/util/concurrent/Callable;");
289         ir::Type* dispatcherT = b.GetType("Lcom/android/dx/mockito/inline/MockMethodDispatcher;");
290 
291         // Add id to dex file
292         const char* idNative = env->GetStringUTFChars(idStr, 0);
293         ir::String* id = b.GetAsciiString(idNative);
294         env->ReleaseStringUTFChars(idStr, idNative);
295 
296         for (auto& method : dex_ir->encoded_methods) {
297             if (!canBeTransformed(method.get())) {
298                 continue;
299             }
300             /*
301             static long method_original(int param1, long param2, String param3) {
302                 foo();
303                 return bar();
304             }
305 
306             static long method_transformed(int param1, long param2, String param3) {
307                 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this);
308                 const-string v0, "65463hg34t"
309                 const v1, 0
310                 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher
311                 move-result-object v0
312 
313                 // if (dispatcher == null) {
314                 //    goto original_method;
315                 // }
316                 if-eqz v0, original_method
317 
318                 // Method origin = dispatcher.getOrigin(this, methodDesc);
319                 const-string v1 "fully.qualified.ClassName#original_method(int, long, String)"
320                 const v2, 0
321                 invoke-virtual {v0, v2, v1}, MockMethodDispatcher.getOrigin(Object, String):Method
322                 move-result-object v1
323 
324                 // if (origin == null) {
325                 //     goto original_method;
326                 // }
327                 if-eqz v1, original_method
328 
329                 // Create an array with Objects of all parameters.
330 
331                 //     Object[] arguments = new Object[3]
332                 const v3, 3
333                 new-array v2, v3, Object[]
334 
335                 //     Integer param1Integer = Integer.valueOf(param1)
336                 move-from16 v3, ARG1     # this is necessary as invoke-static cannot deal with high
337                                          # registers and ARG1 might be high
338                 invoke-static {v3}, Integer.valueOf(int):Integer
339                 move-result-object v3
340 
341                 //     arguments[0] = param1Integer
342                 const v4, 0
343                 aput-object v3, v2, v4
344 
345                 //     Long param2Long = Long.valueOf(param2)
346                 move-widefrom16 v3:v4, ARG2.1:ARG2.2 # this is necessary as invoke-static cannot
347                                                      # deal with high registers and ARG2 might be
348                                                      # high
349                 invoke-static {v3, v4}, Long.valueOf(long):Long
350                 move-result-object v3
351 
352                 //     arguments[1] = param2Long
353                 const v4, 1
354                 aput-object v3, v2, v4
355 
356                 //     arguments[2] = param3
357                 const v4, 2
358                 move-objectfrom16 v3, ARG3     # this is necessary as aput-object cannot deal with
359                                                # high registers and ARG3 might be high
360                 aput-object v3, v2, v4
361 
362                 // Callable<?> mocked = dispatcher.handle(methodDesc --as this parameter--,
363                 //                                        origin, arguments);
364                 const-string v3 "fully.qualified.ClassName#original_method(int, long, String)"
365                 invoke-virtual {v0,v3,v1,v2}, MockMethodDispatcher.handle(Object, Method,
366                                                                           Object[]):Callable
367                 move-result-object v0
368 
369                 //  if (mocked != null) {
370                 if-eqz v0, original_method
371 
372                 //      Object ret = mocked.call();
373                 invoke-interface {v0}, Callable.call():Object
374                 move-result-object v0
375 
376                 //      Long retLong = (Long)ret
377                 check-cast v0, Long
378 
379                 //      long retlong = retLong.longValue();
380                 invoke-virtual {v0}, Long.longValue():long
381                 move-result-wide v0:v1
382 
383                 //      return retlong;
384                 return-wide v0:v1
385 
386                 //  }
387 
388             original_method:
389                 // Move all method arguments down so that they match what the original code expects.
390                 // Let's assume three arguments, one int, one long, one String and the and used to
391                 // use 4 registers
392                 move16 v5, v6             # ARG1
393                 move-wide16 v6:v7, v7:v8  # ARG2 (overlapping moves are allowed)
394                 move-object16 v8, v9      # ARG3
395 
396                 // foo();
397                 // return bar();
398                 unmodified original byte code
399             }
400             */
401 
402             CodeIr c(method.get(), dex_ir);
403 
404             // Make sure there are at least 5 local registers to use
405             int originalNumRegisters = method->code->registers - method->code->ins_count;
406             int numAdditionalRegs = std::max(0, 5 - originalNumRegisters);
407             int firstArg = originalNumRegisters + numAdditionalRegs;
408 
409             if (numAdditionalRegs > 0) {
410                 c.ir_method->code->registers += numAdditionalRegs;
411             }
412 
413             lir::Instruction* fi = *(c.instructions.begin());
414 
415             // Add methodDesc to dex file
416             std::stringstream ss;
417             ss << method->decl->parent->Decl() << "#" << method->decl->name->c_str() << "(" ;
418             bool first = true;
419             if (method->decl->prototype->param_types != nullptr) {
420                 for (const auto& type : method->decl->prototype->param_types->types) {
421                     if (first) {
422                         first = false;
423                     } else {
424                         ss << ",";
425                     }
426 
427                     ss << type->Decl().c_str();
428                 }
429             }
430             ss << ")";
431             std::string methodDescStr = ss.str();
432             ir::String* methodDesc = b.GetAsciiString(methodDescStr.c_str());
433 
434             size_t numParams = getNumParams(method.get());
435 
436             Label* originalMethodLabel = c.Alloc<Label>(0);
437             CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel);
438             VReg* v0 = c.Alloc<VReg>(0);
439             VReg* v1 = c.Alloc<VReg>(1);
440             VReg* v2 = c.Alloc<VReg>(2);
441             VReg* v3 = c.Alloc<VReg>(3);
442             VReg* v4 = c.Alloc<VReg>(4);
443 
444             addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)});
445             addInstr(c, fi, OP_CONST, {v1, c.Alloc<Const32>(0)});
446             addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT, {stringT, objectT},
447                     {0, 1});
448             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
449             addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod});
450             addInstr(c, fi, OP_CONST_STRING,
451                      {v1, c.Alloc<String>(methodDesc, methodDesc->orig_index)});
452             addInstr(c, fi, OP_CONST, {v2, c.Alloc<Const32>(0)});
453             addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "getOrigin", methodT,
454                     {objectT, stringT}, {0, 2, 1});
455             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v1});
456             addInstr(c, fi, OP_IF_EQZ, {v1, originalMethod});
457             addInstr(c, fi, OP_CONST, {v3, c.Alloc<Const32>(numParams)});
458             addInstr(c, fi, OP_NEW_ARRAY, {v2, v3, c.Alloc<Type>(objectArrayT,
459                                                                  objectArrayT->orig_index)});
460 
461             if (numParams > 0) {
462                 int argReg = firstArg;
463 
464                 for (int argNum = 0; argNum < numParams; argNum++) {
465                     const auto& type = method->decl->prototype->param_types->types[argNum];
466                     BoxingInfo boxingInfo = getBoxingInfo(b, type->descriptor->c_str()[0]);
467 
468                     switch (type->GetCategory()) {
469                         case ir::Type::Category::Scalar:
470                             addInstr(c, fi, OP_MOVE_FROM16, {v3, c.Alloc<VReg>(argReg)});
471                             addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf",
472                                     boxingInfo.boxedType, {type}, {3});
473                             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3});
474 
475                             argReg++;
476                             break;
477                         case ir::Type::Category::WideScalar: {
478                             VRegPair* v3v4 = c.Alloc<VRegPair>(3);
479                             VRegPair* argRegPair = c.Alloc<VRegPair>(argReg);
480 
481                             addInstr(c, fi, OP_MOVE_WIDE_FROM16, {v3v4, argRegPair});
482                             addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf",
483                                     boxingInfo.boxedType, {type}, {3, 4});
484                             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3});
485 
486                             argReg += 2;
487                             break;
488                         }
489                         case ir::Type::Category::Reference:
490                             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v3, c.Alloc<VReg>(argReg)});
491 
492                             argReg++;
493                             break;
494                         case ir::Type::Category::Void:
495                             assert(false);
496                     }
497 
498                     addInstr(c, fi, OP_CONST, {v4, c.Alloc<Const32>(argNum)});
499                     addInstr(c, fi, OP_APUT_OBJECT, {v3, v2, v4});
500                 }
501             }
502 
503             // NASTY Hack: Push in method name as "mock"
504             addInstr(c, fi, OP_CONST_STRING,
505                      {v3, c.Alloc<String>(methodDesc, methodDesc->orig_index)});
506             addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "handle", callableT,
507                     {objectT, methodT, objectArrayT}, {0, 3, 1, 2});
508             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
509             addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod});
510             addCall(b, c, fi, OP_INVOKE_INTERFACE, callableT, "call", objectT, {}, {0});
511             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
512 
513             ir::Type *returnType = method->decl->prototype->return_type;
514             BoxingInfo boxingInfo = getBoxingInfo(b, returnType->descriptor->c_str()[0]);
515 
516             switch (returnType->GetCategory()) {
517                 case ir::Type::Category::Scalar:
518                     addInstr(c, fi, OP_CHECK_CAST, {v0,
519                                                     c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)});
520                     addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType,
521                             boxingInfo.unboxMethod.c_str(), returnType, {}, {0});
522                     addInstr(c, fi, OP_MOVE_RESULT, {v0});
523                     addInstr(c, fi, OP_RETURN, {v0});
524                     break;
525                 case ir::Type::Category::WideScalar: {
526                     VRegPair* v0v1 = c.Alloc<VRegPair>(0);
527 
528                     addInstr(c, fi, OP_CHECK_CAST, {v0,
529                                                     c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)});
530                     addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType,
531                             boxingInfo.unboxMethod.c_str(), returnType, {}, {0});
532                     addInstr(c, fi, OP_MOVE_RESULT_WIDE, {v0v1});
533                     addInstr(c, fi, OP_RETURN_WIDE, {v0v1});
534                     break;
535                 }
536                 case ir::Type::Category::Reference:
537                     addInstr(c, fi, OP_CHECK_CAST, {v0, c.Alloc<Type>(returnType,
538                                                                       returnType->orig_index)});
539                     addInstr(c, fi, OP_RETURN_OBJECT, {v0});
540                     break;
541                 case ir::Type::Category::Void:
542                     addInstr(c, fi, OP_RETURN_VOID, {});
543                     break;
544             }
545 
546             addLabel(c, fi, originalMethodLabel);
547 
548             if (numParams > 0) {
549                 int argReg = firstArg;
550 
551                 for (int argNum = 0; argNum < numParams; argNum++) {
552                     const auto& type = method->decl->prototype->param_types->types[argNum];
553                     int origReg = argReg - numAdditionalRegs;
554                     switch (type->GetCategory()) {
555                         case ir::Type::Category::Scalar:
556                             addInstr(c, fi, OP_MOVE_16, {c.Alloc<VReg>(origReg),
557                                                          c.Alloc<VReg>(argReg)});
558                             argReg++;
559                             break;
560                         case ir::Type::Category::WideScalar:
561                             addInstr(c, fi, OP_MOVE_WIDE_16,{c.Alloc<VRegPair>(origReg),
562                                                              c.Alloc<VRegPair>(argReg)});
563                             argReg +=2;
564                             break;
565                         case ir::Type::Category::Reference:
566                             addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(origReg),
567                                                                 c.Alloc<VReg>(argReg)});
568                             argReg++;
569                             break;
570                     }
571                 }
572             }
573 
574             c.Assemble();
575         }
576 
577         struct Allocator : public Writer::Allocator {
578             virtual void* Allocate(size_t size) {return ::malloc(size);}
579             virtual void Free(void* ptr) {::free(ptr);}
580         };
581 
582         Allocator allocator;
583         Writer writer(dex_ir);
584         size_t transformedLen = 0;
585         std::shared_ptr<jbyte> transformed((jbyte*)writer.CreateImage(&allocator, &transformedLen));
586 
587         jbyteArray transformedArr = env->NewByteArray(transformedLen);
588         env->SetByteArrayRegion(transformedArr, 0, transformedLen, transformed.get());
589 
590         return transformedArr;
591     }
592 
593     // Initializes the agent
Agent_OnAttach(JavaVM * vm,char * options,void * reserved)594     extern "C" jint Agent_OnAttach(JavaVM* vm,
595                                    char* options,
596                                    void* reserved) {
597         jint jvmError = vm->GetEnv(reinterpret_cast<void**>(&localJvmtiEnv), JVMTI_VERSION_1_2);
598         if (jvmError != JNI_OK) {
599             return jvmError;
600         }
601 
602         jvmtiCapabilities caps;
603         memset(&caps, 0, sizeof(caps));
604         caps.can_retransform_classes = 1;
605 
606         jvmtiError error = localJvmtiEnv->AddCapabilities(&caps);
607         if (error != JVMTI_ERROR_NONE) {
608             return error;
609         }
610 
611         jvmtiEventCallbacks cb;
612         memset(&cb, 0, sizeof(cb));
613         cb.ClassFileLoadHook = Transform;
614 
615         error = localJvmtiEnv->SetEventCallbacks(&cb, sizeof(cb));
616         if (error != JVMTI_ERROR_NONE) {
617             return error;
618         }
619 
620         error = localJvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE,
621                                                         JVMTI_EVENT_CLASS_FILE_LOAD_HOOK, nullptr);
622         if (error != JVMTI_ERROR_NONE) {
623             return error;
624         }
625 
626         return JVMTI_ERROR_NONE;
627     }
628 
629     // Throw runtime exception
throwRuntimeExpection(JNIEnv * env,const char * fmt,...)630     static void throwRuntimeExpection(JNIEnv* env, const char* fmt, ...) {
631         char msgBuf[512];
632 
633         va_list args;
634         va_start (args, fmt);
635         vsnprintf(msgBuf, sizeof(msgBuf), fmt, args);
636         va_end (args);
637 
638         jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
639         env->ThrowNew(exceptionClass, msgBuf);
640     }
641 
642     // Register transformer hook
643     extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRegisterTransformerHook(JNIEnv * env,jobject thiz)644     Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRegisterTransformerHook(JNIEnv* env,
645                                                                                      jobject thiz) {
646         sTransformer = env->NewGlobalRef(thiz);
647     }
648 
649     // Unregister transformer hook
650     extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeUnregisterTransformerHook(JNIEnv * env,jobject thiz)651     Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeUnregisterTransformerHook(JNIEnv* env,
652                                                                                      jobject thiz) {
653         env->DeleteGlobalRef(sTransformer);
654         sTransformer = nullptr;
655     }
656 
657     // Triggers retransformation of classes via this file's Transform method
658     extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRetransformClasses(JNIEnv * env,jobject thiz,jobjectArray classes)659     Java_com_android_dx_mockito_inline_StaticJvmtiAgent_nativeRetransformClasses(JNIEnv* env,
660                                                                              jobject thiz,
661                                                                              jobjectArray classes) {
662         jsize numTransformedClasses = env->GetArrayLength(classes);
663         jclass *transformedClasses = (jclass*) malloc(numTransformedClasses * sizeof(jclass));
664         for (int i = 0; i < numTransformedClasses; i++) {
665             transformedClasses[i] = (jclass) env->NewGlobalRef(env->GetObjectArrayElement(classes, i));
666         }
667 
668         jvmtiError error = localJvmtiEnv->RetransformClasses(numTransformedClasses,
669                                                              transformedClasses);
670 
671         for (int i = 0; i < numTransformedClasses; i++) {
672             env->DeleteGlobalRef(transformedClasses[i]);
673         }
674         free(transformedClasses);
675 
676         if (error != JVMTI_ERROR_NONE) {
677             throwRuntimeExpection(env, "Could not retransform classes: %d", error);
678         }
679     }
680 
681     static jvmtiFrameInfo* frameToInspect;
682     static std::string calledClass;
683 
684     // Takes the full dex file for class 'classBeingRedefined'
685     // - isolates the dex code for the class out of the dex file
686     // - calls sTransformer.runTransformers on the isolated dex code
687     // - send the transformed code back to the runtime
688     static void
InspectClass(jvmtiEnv * jvmtiEnv,JNIEnv * env,jclass classBeingRedefined,jobject loader,const char * name,jobject protectionDomain,jint classDataLen,const unsigned char * classData,jint * newClassDataLen,unsigned char ** newClassData)689     InspectClass(jvmtiEnv* jvmtiEnv,
690                  JNIEnv* env,
691                  jclass classBeingRedefined,
692                  jobject loader,
693                  const char* name,
694                  jobject protectionDomain,
695                  jint classDataLen,
696                  const unsigned char* classData,
697                  jint* newClassDataLen,
698                  unsigned char** newClassData) {
699         calledClass = "none";
700 
701         Reader reader(classData, classDataLen);
702 
703         char *calledMethodName;
704         char *calledMethodSignature;
705         jvmtiError error = jvmtiEnv->GetMethodName(frameToInspect->method, &calledMethodName,
706                                                    &calledMethodSignature, nullptr);
707         if (error != JVMTI_ERROR_NONE) {
708             return;
709         }
710 
711         u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str());
712         reader.CreateClassIr(index);
713         std::shared_ptr<ir::DexFile> class_ir = reader.GetIr();
714 
715         for (auto& method : class_ir->encoded_methods) {
716             if (Utf8Cmp(method->decl->name->c_str(), calledMethodName) == 0
717                 && Utf8Cmp(method->decl->prototype->Signature().c_str(), calledMethodSignature) == 0) {
718                 CodeIr method_ir(method.get(), class_ir);
719 
720                 for (auto instruction : method_ir.instructions) {
721                     Bytecode* bytecode = dynamic_cast<Bytecode*>(instruction);
722                     if (bytecode != nullptr && bytecode->offset == frameToInspect->location) {
723                         Method *method = bytecode->CastOperand<Method>(1);
724                         calledClass = method->ir_method->parent->Decl().c_str();
725 
726                         goto exit;
727                     }
728                 }
729             }
730         }
731 
732         exit:
733         free(calledMethodName);
734         free(calledMethodSignature);
735     }
736 
737 #define GOTO_ON_ERROR(label) \
738     if (error != JVMTI_ERROR_NONE) { \
739         goto label; \
740     }
741 
742 // stack frame of the caller if method was called directly
743 #define DIRECT_CALL_STACK_FRAME (6)
744 
745 // stack frame of the caller if method was called as 'real method'
746 #define REALMETHOD_CALL_STACK_FRAME (23)
747 
748     extern "C" JNIEXPORT jstring JNICALL
Java_com_android_dx_mockito_inline_StaticMockMethodAdvice_nativeGetCalledClassName(JNIEnv * env,jclass klass,jthread currentThread)749     Java_com_android_dx_mockito_inline_StaticMockMethodAdvice_nativeGetCalledClassName(JNIEnv* env,
750                                                                             jclass klass,
751                                                                             jthread currentThread) {
752 
753         JavaVM *vm;
754         jint jvmError = env->GetJavaVM(&vm);
755         if (jvmError != JNI_OK) {
756             return nullptr;
757         }
758 
759         jvmtiEnv *jvmtiEnv;
760         jvmError = vm->GetEnv(reinterpret_cast<void**>(&jvmtiEnv), JVMTI_VERSION_1_2);
761         if (jvmError != JNI_OK) {
762             return nullptr;
763         }
764 
765         jvmtiCapabilities caps;
766         memset(&caps, 0, sizeof(caps));
767         caps.can_retransform_classes = 1;
768 
769         jvmtiError error = jvmtiEnv->AddCapabilities(&caps);
770         GOTO_ON_ERROR(unregister_env_and_exit);
771 
772         jvmtiEventCallbacks cb;
773         memset(&cb, 0, sizeof(cb));
774         cb.ClassFileLoadHook = InspectClass;
775 
776         jvmtiFrameInfo frameInfo[REALMETHOD_CALL_STACK_FRAME + 1];
777         jint numFrames;
778         error = jvmtiEnv->GetStackTrace(nullptr, 0, REALMETHOD_CALL_STACK_FRAME + 1, frameInfo,
779                                         &numFrames);
780         GOTO_ON_ERROR(unregister_env_and_exit);
781 
782         // Method might be called directly or as 'real method' (see
783         // StaticMockMethodAdvice.SuperMethodCall#invoke). Hence the real caller might be in stack
784         // frame DIRECT_CALL_STACK_FRAME for a direct call or REALMETHOD_CALL_STACK_FRAME for a
785         // call through the 'real method' mechanism.
786         int callingFrameNum;
787         if (numFrames < REALMETHOD_CALL_STACK_FRAME) {
788             callingFrameNum = DIRECT_CALL_STACK_FRAME;
789         } else {
790             char *directCallMethodName;
791 
792             jvmtiEnv->GetMethodName(frameInfo[DIRECT_CALL_STACK_FRAME].method,
793                                     &directCallMethodName, nullptr, nullptr);
794             if (strcmp(directCallMethodName, "invoke") == 0) {
795                 callingFrameNum = REALMETHOD_CALL_STACK_FRAME;
796             } else {
797                 callingFrameNum = DIRECT_CALL_STACK_FRAME;
798             }
799         }
800 
801         jclass callingClass;
802         error = jvmtiEnv->GetMethodDeclaringClass(frameInfo[callingFrameNum].method, &callingClass);
803         GOTO_ON_ERROR(unregister_env_and_exit);
804 
805         error = jvmtiEnv->SetEventCallbacks(&cb, sizeof(cb));
806         GOTO_ON_ERROR(unregister_env_and_exit);
807 
808         error = jvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK,
809                                                    currentThread);
810         GOTO_ON_ERROR(unset_cb_and_exit);
811 
812         frameToInspect = &frameInfo[callingFrameNum];
813         error = jvmtiEnv->RetransformClasses(1, &callingClass);
814         GOTO_ON_ERROR(disable_hook_and_exit);
815 
816         disable_hook_and_exit:
817         jvmtiEnv->SetEventNotificationMode(JVMTI_DISABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK,
818                                            currentThread);
819 
820         unset_cb_and_exit:
821         memset(&cb, 0, sizeof(cb));
822         jvmtiEnv->SetEventCallbacks(&cb, sizeof(cb));
823 
824         unregister_env_and_exit:
825         jvmtiEnv->DisposeEnvironment();
826 
827         if (error != JVMTI_ERROR_NONE) {
828             return nullptr;
829         }
830 
831         return env->NewStringUTF(calledClass.c_str());
832     }
833 
834 }  // namespace com_android_dx_mockito_inline
835 
836