<lambda>null1 // Copyright 2021 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package com.code_intelligence.jazzer.instrumentor
16
17 import com.code_intelligence.jazzer.runtime.TraceDataFlowNativeCallbacks
18 import org.objectweb.asm.ClassReader
19 import org.objectweb.asm.ClassWriter
20 import org.objectweb.asm.Opcodes
21 import org.objectweb.asm.tree.AbstractInsnNode
22 import org.objectweb.asm.tree.ClassNode
23 import org.objectweb.asm.tree.InsnList
24 import org.objectweb.asm.tree.InsnNode
25 import org.objectweb.asm.tree.IntInsnNode
26 import org.objectweb.asm.tree.LdcInsnNode
27 import org.objectweb.asm.tree.LookupSwitchInsnNode
28 import org.objectweb.asm.tree.MethodInsnNode
29 import org.objectweb.asm.tree.MethodNode
30 import org.objectweb.asm.tree.TableSwitchInsnNode
31
32 internal class TraceDataFlowInstrumentor(private val types: Set<InstrumentationType>, callbackClass: Class<*> = TraceDataFlowNativeCallbacks::class.java) : Instrumentor {
33
34 private val callbackInternalClassName = callbackClass.name.replace('.', '/')
35 private lateinit var random: DeterministicRandom
36
37 override fun instrument(bytecode: ByteArray): ByteArray {
38 val node = ClassNode()
39 val reader = ClassReader(bytecode)
40 reader.accept(node, 0)
41 random = DeterministicRandom("trace", node.name)
42 for (method in node.methods) {
43 if (shouldInstrument(method)) {
44 addDataFlowInstrumentation(method)
45 }
46 }
47
48 val writer = ClassWriter(ClassWriter.COMPUTE_MAXS)
49 node.accept(writer)
50 return writer.toByteArray()
51 }
52
53 @OptIn(ExperimentalUnsignedTypes::class)
54 private fun addDataFlowInstrumentation(method: MethodNode) {
55 loop@ for (inst in method.instructions.toArray()) {
56 when (inst.opcode) {
57 Opcodes.LCMP -> {
58 if (InstrumentationType.CMP !in types) continue@loop
59 method.instructions.insertBefore(inst, longCmpInstrumentation())
60 method.instructions.remove(inst)
61 }
62 Opcodes.IF_ICMPEQ, Opcodes.IF_ICMPNE,
63 Opcodes.IF_ICMPLT, Opcodes.IF_ICMPLE,
64 Opcodes.IF_ICMPGT, Opcodes.IF_ICMPGE -> {
65 if (InstrumentationType.CMP !in types) continue@loop
66 method.instructions.insertBefore(inst, intCmpInstrumentation())
67 }
68 Opcodes.IFEQ, Opcodes.IFNE,
69 Opcodes.IFLT, Opcodes.IFLE,
70 Opcodes.IFGT, Opcodes.IFGE -> {
71 if (InstrumentationType.CMP !in types) continue@loop
72 // The IF* opcodes are often used to branch based on the result of a compare
73 // instruction for a type other than int. The operands of this compare will
74 // already be reported via the instrumentation above (for non-floating point
75 // numbers) and the follow-up compare does not provide a good signal as all
76 // operands will be in {-1, 0, 1}. Skip instrumentation for it.
77 if (inst.previous?.opcode in listOf(Opcodes.DCMPG, Opcodes.DCMPL, Opcodes.FCMPG, Opcodes.DCMPL) ||
78 (inst.previous as? MethodInsnNode)?.name == "traceCmpLongWrapper"
79 )
80 continue@loop
81 method.instructions.insertBefore(inst, ifInstrumentation())
82 }
83 Opcodes.LOOKUPSWITCH, Opcodes.TABLESWITCH -> {
84 if (InstrumentationType.CMP !in types) continue@loop
85 // Mimic the exclusion logic for small label values in libFuzzer:
86 // https://github.com/llvm-mirror/compiler-rt/blob/69445f095c22aac2388f939bedebf224a6efcdaf/lib/fuzzer/FuzzerTracePC.cpp#L520
87 // Case values are reported to libFuzzer via an array of unsigned long values and thus need to be
88 // sorted by unsigned value.
89 val caseValues = when (inst) {
90 is LookupSwitchInsnNode -> {
91 if (inst.keys.isEmpty() || (0 <= inst.keys.first() && inst.keys.last() < 256))
92 continue@loop
93 inst.keys
94 }
95 is TableSwitchInsnNode -> {
96 if (0 <= inst.min && inst.max < 256)
97 continue@loop
98 (inst.min..inst.max).filter { caseValue ->
99 val index = caseValue - inst.min
100 // Filter out "gap cases".
101 inst.labels[index].label != inst.dflt.label
102 }.toList()
103 }
104 // Not reached.
105 else -> continue@loop
106 }.sortedBy { it.toUInt() }.map { it.toLong() }.toLongArray()
107 method.instructions.insertBefore(inst, switchInstrumentation(caseValues))
108 }
109 Opcodes.IDIV -> {
110 if (InstrumentationType.DIV !in types) continue@loop
111 method.instructions.insertBefore(inst, intDivInstrumentation())
112 }
113 Opcodes.LDIV -> {
114 if (InstrumentationType.DIV !in types) continue@loop
115 method.instructions.insertBefore(inst, longDivInstrumentation())
116 }
117 Opcodes.AALOAD, Opcodes.BALOAD,
118 Opcodes.CALOAD, Opcodes.DALOAD,
119 Opcodes.FALOAD, Opcodes.IALOAD,
120 Opcodes.LALOAD, Opcodes.SALOAD -> {
121 if (InstrumentationType.GEP !in types) continue@loop
122 if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
123 method.instructions.insertBefore(inst, gepLoadInstrumentation())
124 }
125 Opcodes.INVOKEINTERFACE, Opcodes.INVOKESPECIAL, Opcodes.INVOKESTATIC, Opcodes.INVOKEVIRTUAL -> {
126 if (InstrumentationType.GEP !in types) continue@loop
127 if (!isGepLoadMethodInsn(inst as MethodInsnNode)) continue@loop
128 if (!isConstantIntegerPushInsn(inst.previous)) continue@loop
129 method.instructions.insertBefore(inst, gepLoadInstrumentation())
130 }
131 }
132 }
133 }
134
135 private fun InsnList.pushFakePc() {
136 add(LdcInsnNode(random.nextInt(512)))
137 }
138
139 private fun longCmpInstrumentation() = InsnList().apply {
140 pushFakePc()
141 // traceCmpLong returns the result of the comparison as duplicating two longs on the stack
142 // is not possible without local variables.
143 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpLongWrapper", "(JJI)I", false))
144 }
145
146 private fun intCmpInstrumentation() = InsnList().apply {
147 add(InsnNode(Opcodes.DUP2))
148 pushFakePc()
149 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceCmpInt", "(III)V", false))
150 }
151
152 private fun ifInstrumentation() = InsnList().apply {
153 add(InsnNode(Opcodes.DUP))
154 // All if* instructions are compares to the constant 0.
155 add(InsnNode(Opcodes.ICONST_0))
156 add(InsnNode(Opcodes.SWAP))
157 pushFakePc()
158 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceConstCmpInt", "(III)V", false))
159 }
160
161 private fun intDivInstrumentation() = InsnList().apply {
162 add(InsnNode(Opcodes.DUP))
163 pushFakePc()
164 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivInt", "(II)V", false))
165 }
166
167 private fun longDivInstrumentation() = InsnList().apply {
168 add(InsnNode(Opcodes.DUP2))
169 pushFakePc()
170 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceDivLong", "(JI)V", false))
171 }
172
173 private fun switchInstrumentation(caseValues: LongArray) = InsnList().apply {
174 // duplicate {lookup,table}switch key for use as first function argument
175 add(InsnNode(Opcodes.DUP))
176 add(InsnNode(Opcodes.I2L))
177 // Set up array with switch case values. The format libfuzzer expects is created here directly, i.e., the first
178 // two entries are the number of cases and the bit size of values (always 32).
179 add(IntInsnNode(Opcodes.SIPUSH, caseValues.size + 2))
180 add(IntInsnNode(Opcodes.NEWARRAY, Opcodes.T_LONG))
181 // Store number of cases
182 add(InsnNode(Opcodes.DUP))
183 add(IntInsnNode(Opcodes.SIPUSH, 0))
184 add(LdcInsnNode(caseValues.size.toLong()))
185 add(InsnNode(Opcodes.LASTORE))
186 // Store bit size of keys
187 add(InsnNode(Opcodes.DUP))
188 add(IntInsnNode(Opcodes.SIPUSH, 1))
189 add(LdcInsnNode(32.toLong()))
190 add(InsnNode(Opcodes.LASTORE))
191 // Store {lookup,table}switch case values
192 for ((i, caseValue) in caseValues.withIndex()) {
193 add(InsnNode(Opcodes.DUP))
194 add(IntInsnNode(Opcodes.SIPUSH, 2 + i))
195 add(LdcInsnNode(caseValue))
196 add(InsnNode(Opcodes.LASTORE))
197 }
198 pushFakePc()
199 // call the native callback function
200 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceSwitch", "(J[JI)V", false))
201 }
202
203 /**
204 * Returns true if [node] represents an instruction that possibly pushes a valid, non-zero, constant array index
205 * onto the stack.
206 */
207 private fun isConstantIntegerPushInsn(node: AbstractInsnNode?) = node?.opcode in CONSTANT_INTEGER_PUSH_OPCODES
208
209 /**
210 * Returns true if [node] represents a call to a method that performs an indexed lookup into an array-like
211 * structure.
212 */
213 private fun isGepLoadMethodInsn(node: MethodInsnNode): Boolean {
214 if (!node.desc.startsWith("(I)")) return false
215 val returnType = node.desc.removePrefix("(I)")
216 return MethodInfo(node.owner, node.name, returnType) in GEP_LOAD_METHODS
217 }
218
219 private fun gepLoadInstrumentation() = InsnList().apply {
220 // Duplicate the index and convert to long.
221 add(InsnNode(Opcodes.DUP))
222 add(InsnNode(Opcodes.I2L))
223 pushFakePc()
224 add(MethodInsnNode(Opcodes.INVOKESTATIC, callbackInternalClassName, "traceGep", "(JI)V", false))
225 }
226
227 companion object {
228 // Low constants (0, 1) are omitted as they create a lot of noise.
229 val CONSTANT_INTEGER_PUSH_OPCODES = listOf(
230 Opcodes.BIPUSH, Opcodes.SIPUSH,
231 Opcodes.LDC,
232 Opcodes.ICONST_2, Opcodes.ICONST_3, Opcodes.ICONST_4, Opcodes.ICONST_5
233 )
234
235 data class MethodInfo(val internalClassName: String, val name: String, val returnType: String)
236
237 val GEP_LOAD_METHODS = setOf(
238 MethodInfo("java/util/AbstractList", "get", "Ljava/lang/Object;"),
239 MethodInfo("java/util/ArrayList", "get", "Ljava/lang/Object;"),
240 MethodInfo("java/util/List", "get", "Ljava/lang/Object;"),
241 MethodInfo("java/util/Stack", "get", "Ljava/lang/Object;"),
242 MethodInfo("java/util/Vector", "get", "Ljava/lang/Object;"),
243 MethodInfo("java/lang/CharSequence", "charAt", "C"),
244 MethodInfo("java/lang/String", "charAt", "C"),
245 MethodInfo("java/lang/StringBuffer", "charAt", "C"),
246 MethodInfo("java/lang/StringBuilder", "charAt", "C"),
247 MethodInfo("java/lang/String", "codePointAt", "I"),
248 MethodInfo("java/lang/String", "codePointBefore", "I"),
249 MethodInfo("java/nio/ByteBuffer", "get", "B"),
250 MethodInfo("java/nio/ByteBuffer", "getChar", "C"),
251 MethodInfo("java/nio/ByteBuffer", "getDouble", "D"),
252 MethodInfo("java/nio/ByteBuffer", "getFloat", "F"),
253 MethodInfo("java/nio/ByteBuffer", "getInt", "I"),
254 MethodInfo("java/nio/ByteBuffer", "getLong", "J"),
255 MethodInfo("java/nio/ByteBuffer", "getShort", "S"),
256 )
257 }
258 }
259