1// Copyright 2015 syzkaller project authors. All rights reserved. 2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4// Package csource generates [almost] equivalent C programs from syzkaller programs. 5package csource 6 7import ( 8 "bytes" 9 "fmt" 10 "regexp" 11 "sort" 12 "strings" 13 14 "github.com/google/syzkaller/prog" 15 "github.com/google/syzkaller/sys/targets" 16) 17 18func Write(p *prog.Prog, opts Options) ([]byte, error) { 19 if err := opts.Check(p.Target.OS); err != nil { 20 return nil, fmt.Errorf("csource: invalid opts: %v", err) 21 } 22 ctx := &context{ 23 p: p, 24 opts: opts, 25 target: p.Target, 26 sysTarget: targets.Get(p.Target.OS, p.Target.Arch), 27 calls: make(map[string]uint64), 28 } 29 30 calls, vars, err := ctx.generateProgCalls(ctx.p, opts.Trace) 31 if err != nil { 32 return nil, err 33 } 34 35 mmapProg := p.Target.GenerateUberMmapProg() 36 mmapCalls, _, err := ctx.generateProgCalls(mmapProg, false) 37 if err != nil { 38 return nil, err 39 } 40 41 for _, c := range append(mmapProg.Calls, p.Calls...) { 42 ctx.calls[c.Meta.CallName] = c.Meta.NR 43 } 44 45 varsBuf := new(bytes.Buffer) 46 if len(vars) != 0 { 47 fmt.Fprintf(varsBuf, "uint64 r[%v] = {", len(vars)) 48 for i, v := range vars { 49 if i != 0 { 50 fmt.Fprintf(varsBuf, ", ") 51 } 52 fmt.Fprintf(varsBuf, "0x%x", v) 53 } 54 fmt.Fprintf(varsBuf, "};\n") 55 } 56 57 sandboxFunc := "loop();" 58 if opts.Sandbox != "" { 59 sandboxFunc = "do_sandbox_" + opts.Sandbox + "();" 60 } 61 replacements := map[string]string{ 62 "PROCS": fmt.Sprint(opts.Procs), 63 "REPEAT_TIMES": fmt.Sprint(opts.RepeatTimes), 64 "NUM_CALLS": fmt.Sprint(len(p.Calls)), 65 "MMAP_DATA": strings.Join(mmapCalls, ""), 66 "SYSCALL_DEFINES": ctx.generateSyscallDefines(), 67 "SANDBOX_FUNC": sandboxFunc, 68 "RESULTS": varsBuf.String(), 69 "SYSCALLS": ctx.generateSyscalls(calls, len(vars) != 0), 70 } 71 if !opts.Threaded && !opts.Repeat && opts.Sandbox == "" { 72 // This inlines syscalls right into main for the simplest case. 73 replacements["SANDBOX_FUNC"] = replacements["SYSCALLS"] 74 replacements["SYSCALLS"] = "unused" 75 } 76 result, err := createCommonHeader(p, mmapProg, replacements, opts) 77 if err != nil { 78 return nil, err 79 } 80 const header = "// autogenerated by syzkaller (https://github.com/google/syzkaller)\n\n" 81 result = append([]byte(header), result...) 82 result = ctx.postProcess(result) 83 return result, nil 84} 85 86type context struct { 87 p *prog.Prog 88 opts Options 89 target *prog.Target 90 sysTarget *targets.Target 91 calls map[string]uint64 // CallName -> NR 92} 93 94func (ctx *context) generateSyscalls(calls []string, hasVars bool) string { 95 opts := ctx.opts 96 buf := new(bytes.Buffer) 97 if !opts.Threaded && !opts.Collide { 98 if hasVars || opts.Trace { 99 fmt.Fprintf(buf, "\tlong res = 0;\n") 100 } 101 if opts.Repro { 102 fmt.Fprintf(buf, "\tif (write(1, \"executing program\\n\", sizeof(\"executing program\\n\") - 1)) {}\n") 103 } 104 if opts.Trace { 105 fmt.Fprintf(buf, "\tprintf(\"### start\\n\");\n") 106 } 107 for _, c := range calls { 108 fmt.Fprintf(buf, "%s", c) 109 } 110 } else { 111 if hasVars || opts.Trace { 112 fmt.Fprintf(buf, "\tlong res;") 113 } 114 fmt.Fprintf(buf, "\tswitch (call) {\n") 115 for i, c := range calls { 116 fmt.Fprintf(buf, "\tcase %v:\n", i) 117 fmt.Fprintf(buf, "%s", strings.Replace(c, "\t", "\t\t", -1)) 118 fmt.Fprintf(buf, "\t\tbreak;\n") 119 } 120 fmt.Fprintf(buf, "\t}\n") 121 } 122 return buf.String() 123} 124 125func (ctx *context) generateSyscallDefines() string { 126 var calls []string 127 for name, nr := range ctx.calls { 128 if !ctx.sysTarget.SyscallNumbers || 129 strings.HasPrefix(name, "syz_") || !ctx.sysTarget.NeedSyscallDefine(nr) { 130 continue 131 } 132 calls = append(calls, name) 133 } 134 sort.Strings(calls) 135 buf := new(bytes.Buffer) 136 prefix := ctx.sysTarget.SyscallPrefix 137 for _, name := range calls { 138 fmt.Fprintf(buf, "#ifndef %v%v\n", prefix, name) 139 fmt.Fprintf(buf, "#define %v%v %v\n", prefix, name, ctx.calls[name]) 140 fmt.Fprintf(buf, "#endif\n") 141 } 142 if ctx.target.OS == "linux" && ctx.target.PtrSize == 4 { 143 // This is a dirty hack. 144 // On 32-bit linux mmap translated to old_mmap syscall which has a different signature. 145 // mmap2 has the right signature. syz-extract translates mmap to mmap2, do the same here. 146 fmt.Fprintf(buf, "#undef __NR_mmap\n") 147 fmt.Fprintf(buf, "#define __NR_mmap __NR_mmap2\n") 148 } 149 return buf.String() 150} 151 152func (ctx *context) generateProgCalls(p *prog.Prog, trace bool) ([]string, []uint64, error) { 153 exec := make([]byte, prog.ExecBufferSize) 154 progSize, err := p.SerializeForExec(exec) 155 if err != nil { 156 return nil, nil, fmt.Errorf("failed to serialize program: %v", err) 157 } 158 decoded, err := ctx.target.DeserializeExec(exec[:progSize]) 159 if err != nil { 160 return nil, nil, err 161 } 162 calls, vars := ctx.generateCalls(decoded, trace) 163 return calls, vars, nil 164} 165 166func (ctx *context) generateCalls(p prog.ExecProg, trace bool) ([]string, []uint64) { 167 var calls []string 168 csumSeq := 0 169 for ci, call := range p.Calls { 170 w := new(bytes.Buffer) 171 // Copyin. 172 for _, copyin := range call.Copyin { 173 ctx.copyin(w, &csumSeq, copyin) 174 } 175 176 if ctx.opts.Fault && ctx.opts.FaultCall == ci { 177 fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/failslab/ignore-gfp-wait\", \"N\");\n") 178 fmt.Fprintf(w, "\twrite_file(\"/sys/kernel/debug/fail_futex/ignore-private\", \"N\");\n") 179 fmt.Fprintf(w, "\tinject_fault(%v);\n", ctx.opts.FaultNth) 180 } 181 // Call itself. 182 callName := call.Meta.CallName 183 resCopyout := call.Index != prog.ExecNoCopyout 184 argCopyout := len(call.Copyout) != 0 185 emitCall := ctx.opts.EnableTun || 186 callName != "syz_emit_ethernet" && 187 callName != "syz_extract_tcp_res" 188 // TODO: if we don't emit the call we must also not emit copyin, copyout and fault injection. 189 // However, simply skipping whole iteration breaks tests due to unused static functions. 190 if emitCall { 191 ctx.emitCall(w, call, ci, resCopyout || argCopyout, trace) 192 } else if trace { 193 fmt.Fprintf(w, "\t(void)res;\n") 194 } 195 196 // Copyout. 197 if resCopyout || argCopyout { 198 ctx.copyout(w, call, resCopyout) 199 } 200 calls = append(calls, w.String()) 201 } 202 return calls, p.Vars 203} 204 205func (ctx *context) emitCall(w *bytes.Buffer, call prog.ExecCall, ci int, haveCopyout, trace bool) { 206 callName := call.Meta.CallName 207 native := ctx.sysTarget.SyscallNumbers && !strings.HasPrefix(callName, "syz_") 208 fmt.Fprintf(w, "\t") 209 if haveCopyout || trace { 210 fmt.Fprintf(w, "res = ") 211 } 212 if native { 213 fmt.Fprintf(w, "syscall(%v%v", ctx.sysTarget.SyscallPrefix, callName) 214 } else if strings.HasPrefix(callName, "syz_") { 215 fmt.Fprintf(w, "%v(", callName) 216 } else { 217 args := strings.Repeat(",long", len(call.Args)) 218 if args != "" { 219 args = args[1:] 220 } 221 fmt.Fprintf(w, "((long(*)(%v))%v)(", args, callName) 222 } 223 for ai, arg := range call.Args { 224 if native || ai > 0 { 225 fmt.Fprintf(w, ", ") 226 } 227 switch arg := arg.(type) { 228 case prog.ExecArgConst: 229 if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { 230 panic("sring format in syscall argument") 231 } 232 fmt.Fprintf(w, "%v", ctx.constArgToStr(arg)) 233 case prog.ExecArgResult: 234 if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { 235 panic("sring format in syscall argument") 236 } 237 val := ctx.resultArgToStr(arg) 238 if native && ctx.target.PtrSize == 4 { 239 // syscall accepts args as ellipsis, resources are uint64 240 // and take 2 slots without the cast, which would be wrong. 241 val = "(long)" + val 242 } 243 fmt.Fprintf(w, "%v", val) 244 default: 245 panic(fmt.Sprintf("unknown arg type: %+v", arg)) 246 } 247 } 248 fmt.Fprintf(w, ");\n") 249 if trace { 250 fmt.Fprintf(w, "\tprintf(\"### call=%v errno=%%u\\n\", res == -1 ? errno : 0);\n", ci) 251 } 252} 253 254func (ctx *context) generateCsumInet(w *bytes.Buffer, addr uint64, arg prog.ExecArgCsum, csumSeq int) { 255 fmt.Fprintf(w, "\tstruct csum_inet csum_%d;\n", csumSeq) 256 fmt.Fprintf(w, "\tcsum_inet_init(&csum_%d);\n", csumSeq) 257 for i, chunk := range arg.Chunks { 258 switch chunk.Kind { 259 case prog.ExecArgCsumChunkData: 260 fmt.Fprintf(w, "\tNONFAILING(csum_inet_update(&csum_%d, (const uint8*)0x%x, %d));\n", 261 csumSeq, chunk.Value, chunk.Size) 262 case prog.ExecArgCsumChunkConst: 263 fmt.Fprintf(w, "\tuint%d csum_%d_chunk_%d = 0x%x;\n", 264 chunk.Size*8, csumSeq, i, chunk.Value) 265 fmt.Fprintf(w, "\tcsum_inet_update(&csum_%d, (const uint8*)&csum_%d_chunk_%d, %d);\n", 266 csumSeq, csumSeq, i, chunk.Size) 267 default: 268 panic(fmt.Sprintf("unknown checksum chunk kind %v", chunk.Kind)) 269 } 270 } 271 fmt.Fprintf(w, "\tNONFAILING(*(uint16*)0x%x = csum_inet_digest(&csum_%d));\n", 272 addr, csumSeq) 273} 274 275func (ctx *context) copyin(w *bytes.Buffer, csumSeq *int, copyin prog.ExecCopyin) { 276 switch arg := copyin.Arg.(type) { 277 case prog.ExecArgConst: 278 if arg.BitfieldOffset == 0 && arg.BitfieldLength == 0 { 279 ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.constArgToStr(arg), arg.Format) 280 } else { 281 if arg.Format != prog.FormatNative && arg.Format != prog.FormatBigEndian { 282 panic("bitfield+string format") 283 } 284 fmt.Fprintf(w, "\tNONFAILING(STORE_BY_BITMASK(uint%v, 0x%x, %v, %v, %v));\n", 285 arg.Size*8, copyin.Addr, ctx.constArgToStr(arg), 286 arg.BitfieldOffset, arg.BitfieldLength) 287 } 288 case prog.ExecArgResult: 289 ctx.copyinVal(w, copyin.Addr, arg.Size, ctx.resultArgToStr(arg), arg.Format) 290 case prog.ExecArgData: 291 fmt.Fprintf(w, "\tNONFAILING(memcpy((void*)0x%x, \"%s\", %v));\n", 292 copyin.Addr, toCString(arg.Data), len(arg.Data)) 293 case prog.ExecArgCsum: 294 switch arg.Kind { 295 case prog.ExecArgCsumInet: 296 *csumSeq++ 297 ctx.generateCsumInet(w, copyin.Addr, arg, *csumSeq) 298 default: 299 panic(fmt.Sprintf("unknown csum kind %v", arg.Kind)) 300 } 301 default: 302 panic(fmt.Sprintf("bad argument type: %+v", arg)) 303 } 304} 305 306func (ctx *context) copyinVal(w *bytes.Buffer, addr, size uint64, val string, bf prog.BinaryFormat) { 307 switch bf { 308 case prog.FormatNative, prog.FormatBigEndian: 309 fmt.Fprintf(w, "\tNONFAILING(*(uint%v*)0x%x = %v);\n", size*8, addr, val) 310 case prog.FormatStrDec: 311 if size != 20 { 312 panic("bad strdec size") 313 } 314 fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%020llu\", (long long)%v));\n", addr, val) 315 case prog.FormatStrHex: 316 if size != 18 { 317 panic("bad strdec size") 318 } 319 fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"0x%%016llx\", (long long)%v));\n", addr, val) 320 case prog.FormatStrOct: 321 if size != 23 { 322 panic("bad strdec size") 323 } 324 fmt.Fprintf(w, "\tNONFAILING(sprintf((char*)0x%x, \"%%023llo\", (long long)%v));\n", addr, val) 325 default: 326 panic("unknown binary format") 327 } 328} 329 330func (ctx *context) copyout(w *bytes.Buffer, call prog.ExecCall, resCopyout bool) { 331 if ctx.sysTarget.OS == "fuchsia" { 332 // On fuchsia we have real system calls that return ZX_OK on success, 333 // and libc calls that are casted to function returning long, 334 // as the result int -1 is returned as 0x00000000ffffffff rather than full -1. 335 if strings.HasPrefix(call.Meta.CallName, "zx_") { 336 fmt.Fprintf(w, "\tif (res == ZX_OK)") 337 } else { 338 fmt.Fprintf(w, "\tif ((int)res != -1)") 339 } 340 } else { 341 fmt.Fprintf(w, "\tif (res != -1)") 342 } 343 copyoutMultiple := len(call.Copyout) > 1 || resCopyout && len(call.Copyout) > 0 344 if copyoutMultiple { 345 fmt.Fprintf(w, " {") 346 } 347 fmt.Fprintf(w, "\n") 348 if resCopyout { 349 fmt.Fprintf(w, "\t\tr[%v] = res;\n", call.Index) 350 } 351 for _, copyout := range call.Copyout { 352 fmt.Fprintf(w, "\t\tNONFAILING(r[%v] = *(uint%v*)0x%x);\n", 353 copyout.Index, copyout.Size*8, copyout.Addr) 354 } 355 if copyoutMultiple { 356 fmt.Fprintf(w, "\t}\n") 357 } 358} 359 360func (ctx *context) constArgToStr(arg prog.ExecArgConst) string { 361 mask := (uint64(1) << (arg.Size * 8)) - 1 362 v := arg.Value & mask 363 val := fmt.Sprintf("%v", v) 364 if v == ^uint64(0)&mask { 365 val = "-1" 366 } else if v >= 10 { 367 val = fmt.Sprintf("0x%x", v) 368 } 369 if ctx.opts.Procs > 1 && arg.PidStride != 0 { 370 val += fmt.Sprintf(" + procid*%v", arg.PidStride) 371 } 372 if arg.Format == prog.FormatBigEndian { 373 val = fmt.Sprintf("htobe%v(%v)", arg.Size*8, val) 374 } 375 return val 376} 377 378func (ctx *context) resultArgToStr(arg prog.ExecArgResult) string { 379 res := fmt.Sprintf("r[%v]", arg.Index) 380 if arg.DivOp != 0 { 381 res = fmt.Sprintf("%v/%v", res, arg.DivOp) 382 } 383 if arg.AddOp != 0 { 384 res = fmt.Sprintf("%v+%v", res, arg.AddOp) 385 } 386 if arg.Format == prog.FormatBigEndian { 387 res = fmt.Sprintf("htobe%v(%v)", arg.Size*8, res) 388 } 389 return res 390} 391 392func (ctx *context) postProcess(result []byte) []byte { 393 // Remove NONFAILING, debug, fail, etc calls. 394 if !ctx.opts.HandleSegv { 395 result = regexp.MustCompile(`\t*NONFAILING\((.*)\);\n`).ReplaceAll(result, []byte("$1;\n")) 396 } 397 result = bytes.Replace(result, []byte("NORETURN"), nil, -1) 398 result = bytes.Replace(result, []byte("PRINTF"), nil, -1) 399 result = bytes.Replace(result, []byte("doexit("), []byte("exit("), -1) 400 result = regexp.MustCompile(`\t*debug\((.*\n)*?.*\);\n`).ReplaceAll(result, nil) 401 result = regexp.MustCompile(`\t*debug_dump_data\((.*\n)*?.*\);\n`).ReplaceAll(result, nil) 402 result = regexp.MustCompile(`\t*exitf\((.*\n)*?.*\);\n`).ReplaceAll(result, []byte("\texit(1);\n")) 403 result = regexp.MustCompile(`\t*fail\((.*\n)*?.*\);\n`).ReplaceAll(result, []byte("\texit(1);\n")) 404 result = regexp.MustCompile(`\t*error\((.*\n)*?.*\);\n`).ReplaceAll(result, []byte("\texit(1);\n")) 405 406 result = ctx.hoistIncludes(result) 407 result = ctx.removeEmptyLines(result) 408 return result 409} 410 411// hoistIncludes moves all includes to the top, removes dups and sorts. 412func (ctx *context) hoistIncludes(result []byte) []byte { 413 includesStart := bytes.Index(result, []byte("#include")) 414 if includesStart == -1 { 415 return result 416 } 417 includes := make(map[string]bool) 418 includeRe := regexp.MustCompile("#include <.*>\n") 419 for _, match := range includeRe.FindAll(result, -1) { 420 includes[string(match)] = true 421 } 422 result = includeRe.ReplaceAll(result, nil) 423 // Linux headers are broken, so we have to move all linux includes to the bottom. 424 var sorted, sortedLinux []string 425 for include := range includes { 426 if strings.Contains(include, "<linux/") { 427 sortedLinux = append(sortedLinux, include) 428 } else { 429 sorted = append(sorted, include) 430 } 431 } 432 sort.Strings(sorted) 433 sort.Strings(sortedLinux) 434 newResult := append([]byte{}, result[:includesStart]...) 435 newResult = append(newResult, strings.Join(sorted, "")...) 436 newResult = append(newResult, '\n') 437 newResult = append(newResult, strings.Join(sortedLinux, "")...) 438 newResult = append(newResult, result[includesStart:]...) 439 return newResult 440} 441 442// removeEmptyLines removes duplicate new lines. 443func (ctx *context) removeEmptyLines(result []byte) []byte { 444 for { 445 newResult := bytes.Replace(result, []byte{'\n', '\n', '\n'}, []byte{'\n', '\n'}, -1) 446 newResult = bytes.Replace(newResult, []byte{'\n', '\n', '\t'}, []byte{'\n', '\t'}, -1) 447 newResult = bytes.Replace(newResult, []byte{'\n', '\n', ' '}, []byte{'\n', ' '}, -1) 448 if len(newResult) == len(result) { 449 return result 450 } 451 result = newResult 452 } 453} 454 455func toCString(data []byte) []byte { 456 if len(data) == 0 { 457 return nil 458 } 459 readable := true 460 for i, v := range data { 461 // Allow 0 only as last byte. 462 if !isReadable(v) && (i != len(data)-1 || v != 0) { 463 readable = false 464 break 465 } 466 } 467 if !readable { 468 buf := new(bytes.Buffer) 469 for _, v := range data { 470 buf.Write([]byte{'\\', 'x', toHex(v >> 4), toHex(v << 4 >> 4)}) 471 } 472 return buf.Bytes() 473 } 474 if data[len(data)-1] == 0 { 475 // Don't serialize last 0, C strings are 0-terminated anyway. 476 data = data[:len(data)-1] 477 } 478 buf := new(bytes.Buffer) 479 for _, v := range data { 480 switch v { 481 case '\t': 482 buf.Write([]byte{'\\', 't'}) 483 case '\r': 484 buf.Write([]byte{'\\', 'r'}) 485 case '\n': 486 buf.Write([]byte{'\\', 'n'}) 487 case '\\': 488 buf.Write([]byte{'\\', '\\'}) 489 case '"': 490 buf.Write([]byte{'\\', '"'}) 491 default: 492 if v < 0x20 || v >= 0x7f { 493 panic("unexpected char during data serialization") 494 } 495 buf.WriteByte(v) 496 } 497 } 498 return buf.Bytes() 499} 500 501func isReadable(v byte) bool { 502 return v >= 0x20 && v < 0x7f || v == '\t' || v == '\r' || v == '\n' 503} 504 505func toHex(v byte) byte { 506 if v < 10 { 507 return '0' + v 508 } 509 return 'a' + v - 10 510} 511