• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download

<lambda>null1 // Copyright 2021 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 package com.code_intelligence.jazzer.instrumentor
16 
17 import com.code_intelligence.jazzer.runtime.TraceDataFlowNativeCallbacks
18 import org.objectweb.asm.ClassReader
19 import org.objectweb.asm.ClassWriter
20 import org.objectweb.asm.Opcodes
21 import org.objectweb.asm.tree.AbstractInsnNode
22 import org.objectweb.asm.tree.ClassNode
23 import org.objectweb.asm.tree.InsnList
24 import org.objectweb.asm.tree.InsnNode
25 import org.objectweb.asm.tree.IntInsnNode
26 import org.objectweb.asm.tree.LdcInsnNode
27 import org.objectweb.asm.tree.LookupSwitchInsnNode
28 import org.objectweb.asm.tree.MethodInsnNode
29 import org.objectweb.asm.tree.MethodNode
30 import org.objectweb.asm.tree.TableSwitchInsnNode
31 
32 internal class TraceDataFlowInstrumentor(private val types: Set<InstrumentationType>, callbackClass: Class<*> = TraceDataFlowNativeCallbacks::class.java) : Instrumentor {
33 
34     private val callbackInternalClassName = callbackClass.name.replace('.', '/')
35     private lateinit var random: DeterministicRandom
36 
37     override fun instrument(bytecode: ByteArray): ByteArray {
38         val node = ClassNode()
39         val reader = ClassReader(bytecode)
40         reader.accept(node, 0)
41         random = DeterministicRandom("trace", node.name)
42         for (method in node.methods) {
43             if (shouldInstrument(method)) {
44                 addDataFlowInstrumentation(method)
45             }
46         }
47 
48         val writer = ClassWriter(ClassWriter.COMPUTE_MAXS)
49         node.accept(writer)
50         return writer.toByteArray()
51     }
52 
53     @OptIn(ExperimentalUnsignedTypes::class)
54     private fun addDataFlowInstrumentation(method: MethodNode) {
55         loop@ for (inst in method.instructions.toArray()) {
56             when (inst.opcode) {
57                 Opcodes.LCMP -> {
58                     if (InstrumentationType.CMP !in types) continue@loop
59                     method.instructions.insertBefore(inst, longCmpInstrumentation())
60                     method.instructions.remove(inst)
61                 }
62                 Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE,
63                 Opcodes.IF_ICMPLT, Opcodes.IF_ICMPLE,
64                 Opcodes.IF_ICMPGT, Opcodes.IF_ICMPGE -> {
65                     if (InstrumentationType.CMP !in types) continue@loop
66                     method.instructions.insertBefore(inst, intCmpInstrumentation())
67                 }
68                 Opcodes.IFEQ, Opcodes.IFNE,
69                 Opcodes.IFLT, Opcodes.IFLE,
70                 Opcodes.IFGT, Opcodes.IFGE -> {
71                     if (InstrumentationType.CMP !in types) continue@loop
72                     // The IF* opcodes are often used to branch based on the result of a compare
73                     // instruction for a type other than int. The operands of this compare will
74                     // already be reported via the instrumentation above (for non-floating point
75                     // numbers) and the follow-up compare does not provide a good signal as all
76                     // operands will be in {-1, 0, 1}. Skip instrumentation for it.
77                     if (inst.previous?.opcode in listOf(Opcodes.DCMPG, Opcodes.DCMPL, Opcodes.FCMPG, Opcodes.DCMPL) ||
78                         (inst.previous as? MethodInsnNode)?.name == "traceCmpLongWrapper"
79                     )
80                         continue@loop
81                     method.instructions.insertBefore(inst, ifInstrumentation())
82                 }
83                 Opcodes.LOOKUPSWITCH, Opcodes.TABLESWITCH -> {
84                     if (InstrumentationType.CMP !in types) continue@loop
85                     // Mimic the exclusion logic for small label values in libFuzzer:
86                     // https://github.com/llvm-mirror/compiler-rt/blob/69445f095c22aac2388f939bedebf224a6efcdaf/lib/fuzzer/FuzzerTracePC.cpp#L520
87                     // Case values are reported to libFuzzer via an array of unsigned long values and thus need to be
88                     // sorted by unsigned value.
89                     val caseValues = when (inst) {
90                         is LookupSwitchInsnNode -> {
91                             if (inst.keys.isEmpty() || (0 <= inst.keys.first() && inst.keys.last() < 256))
92                                 continue@loop
93                             inst.keys
94                         }
95                         is TableSwitchInsnNode -> {
96                             if (0 <= inst.min && inst.max < 256)
97                                 continue@loop
98                             (inst.min..inst.max).filter { caseValue ->
99                                 val index = caseValue - inst.min
100                                 // Filter out "gap cases".
101                                 inst.labels[index].label != inst.dflt.label
102                             }.toList()
103                         }
104                         // Not reached.
105                         else -> continue@loop
106                     }.sortedBy { it.toUInt() }.map { it.toLong() }.toLongArray()
107                     method.instructions.insertBefore(inst, switchInstrumentation(caseValues))
108                 }
109                 Opcodes.IDIV -> {
110                     if (InstrumentationType.DIV !in types) continue@loop
111                     method.instructions.insertBefore(inst, intDivInstrumentation())
112                 }
113                 Opcodes.LDIV -> {
114                     if (InstrumentationType.DIV !in types) continue@loop
115                     method.instructions.insertBefore(inst, longDivInstrumentation())
116                 }
117                 Opcodes.AALOAD, Opcodes.BALOAD,
118                 Opcodes.CALOAD, Opcodes.DALOAD,
119                 Opcodes.FALOAD, Opcodes.IALOAD,
120                 Opcodes.LALOAD, Opcodes.SALOAD -> {
121                     if (InstrumentationType.GEP !in types) continue@loop
122                     if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
123                     method.instructions.insertBefore(inst, gepLoadInstrumentation())
124                 }
125                 Opcodes.INVOKEINTERFACE, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEVIRTUAL -> {
126                     if (InstrumentationType.GEP !in types) continue@loop
127                     if (!isGepLoadMethodInsn(inst as MethodInsnNode)) continue@loop
128                     if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
129                     method.instructions.insertBefore(inst, gepLoadInstrumentation())
130                 }
131             }
132         }
133     }
134 
135     private fun InsnList.pushFakePc() {
136         add(LdcInsnNode(random.nextInt(512)))
137     }
138 
139     private fun longCmpInstrumentation() = InsnList().apply {
140         pushFakePc()
141         // traceCmpLong returns the result of the comparison as duplicating two longs on the stack
142         // is not possible without local variables.
143         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpLongWrapper", "(JJI)I", false))
144     }
145 
146     private fun intCmpInstrumentation() = InsnList().apply {
147         add(InsnNode(Opcodes.DUP2))
148         pushFakePc()
149         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpInt", "(III)V", false))
150     }
151 
152     private fun ifInstrumentation() = InsnList().apply {
153         add(InsnNode(Opcodes.DUP))
154         // All if* instructions are compares to the constant 0.
155         add(InsnNode(Opcodes.ICONST_0))
156         add(InsnNode(Opcodes.SWAP))
157         pushFakePc()
158         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceConstCmpInt", "(III)V", false))
159     }
160 
161     private fun intDivInstrumentation() = InsnList().apply {
162         add(InsnNode(Opcodes.DUP))
163         pushFakePc()
164         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivInt", "(II)V", false))
165     }
166 
167     private fun longDivInstrumentation() = InsnList().apply {
168         add(InsnNode(Opcodes.DUP2))
169         pushFakePc()
170         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivLong", "(JI)V", false))
171     }
172 
173     private fun switchInstrumentation(caseValues: LongArray) = InsnList().apply {
174         // duplicate {lookup,table}switch key for use as first function argument
175         add(InsnNode(Opcodes.DUP))
176         add(InsnNode(Opcodes.I2L))
177         // Set up array with switch case values. The format libfuzzer expects is created here directly, i.e., the first
178         // two entries are the number of cases and the bit size of values (always 32).
179         add(IntInsnNode(Opcodes.SIPUSH, caseValues.size + 2))
180         add(IntInsnNode(Opcodes.NEWARRAY, Opcodes.T_LONG))
181         // Store number of cases
182         add(InsnNode(Opcodes.DUP))
183         add(IntInsnNode(Opcodes.SIPUSH, 0))
184         add(LdcInsnNode(caseValues.size.toLong()))
185         add(InsnNode(Opcodes.LASTORE))
186         // Store bit size of keys
187         add(InsnNode(Opcodes.DUP))
188         add(IntInsnNode(Opcodes.SIPUSH, 1))
189         add(LdcInsnNode(32.toLong()))
190         add(InsnNode(Opcodes.LASTORE))
191         // Store {lookup,table}switch case values
192         for ((i, caseValue) in caseValues.withIndex()) {
193             add(InsnNode(Opcodes.DUP))
194             add(IntInsnNode(Opcodes.SIPUSH, 2 + i))
195             add(LdcInsnNode(caseValue))
196             add(InsnNode(Opcodes.LASTORE))
197         }
198         pushFakePc()
199         // call the native callback function
200         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceSwitch", "(J[JI)V", false))
201     }
202 
203     /**
204      * Returns true if [node] represents an instruction that possibly pushes a valid, non-zero, constant array index
205      * onto the stack.
206      */
207     private fun isConstantIntegerPushInsn(node: AbstractInsnNode?) = node?.opcode in CONSTANT_INTEGER_PUSH_OPCODES
208 
209     /**
210      * Returns true if [node] represents a call to a method that performs an indexed lookup into an array-like
211      * structure.
212      */
213     private fun isGepLoadMethodInsn(node: MethodInsnNode): Boolean {
214         if (!node.desc.startsWith("(I)")) return false
215         val returnType = node.desc.removePrefix("(I)")
216         return MethodInfo(node.owner, node.name, returnType) in GEP_LOAD_METHODS
217     }
218 
219     private fun gepLoadInstrumentation() = InsnList().apply {
220         // Duplicate the index and convert to long.
221         add(InsnNode(Opcodes.DUP))
222         add(InsnNode(Opcodes.I2L))
223         pushFakePc()
224         add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceGep", "(JI)V", false))
225     }
226 
227     companion object {
228         // Low constants (0, 1) are omitted as they create a lot of noise.
229         val CONSTANT_INTEGER_PUSH_OPCODES = listOf(
230             Opcodes.BIPUSH, Opcodes.SIPUSH,
231             Opcodes.LDC,
232             Opcodes.ICONST_2, Opcodes.ICONST_3, Opcodes.ICONST_4, Opcodes.ICONST_5
233         )
234 
235         data class MethodInfo(val internalClassName: String, val name: String, val returnType: String)
236 
237         val GEP_LOAD_METHODS = setOf(
238             MethodInfo("java/util/AbstractList", "get", "Ljava/lang/Object;"),
239             MethodInfo("java/util/ArrayList", "get", "Ljava/lang/Object;"),
240             MethodInfo("java/util/List", "get", "Ljava/lang/Object;"),
241             MethodInfo("java/util/Stack", "get", "Ljava/lang/Object;"),
242             MethodInfo("java/util/Vector", "get", "Ljava/lang/Object;"),
243             MethodInfo("java/lang/CharSequence", "charAt", "C"),
244             MethodInfo("java/lang/String", "charAt", "C"),
245             MethodInfo("java/lang/StringBuffer", "charAt", "C"),
246             MethodInfo("java/lang/StringBuilder", "charAt", "C"),
247             MethodInfo("java/lang/String", "codePointAt", "I"),
248             MethodInfo("java/lang/String", "codePointBefore", "I"),
249             MethodInfo("java/nio/ByteBuffer", "get", "B"),
250             MethodInfo("java/nio/ByteBuffer", "getChar", "C"),
251             MethodInfo("java/nio/ByteBuffer", "getDouble", "D"),
252             MethodInfo("java/nio/ByteBuffer", "getFloat", "F"),
253             MethodInfo("java/nio/ByteBuffer", "getInt", "I"),
254             MethodInfo("java/nio/ByteBuffer", "getLong", "J"),
255             MethodInfo("java/nio/ByteBuffer", "getShort", "S"),
256         )
257     }
258 }
259