// Copyright 2017 syzkaller project authors. All rights reserved. // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. package prog import ( "fmt" ) type CsumChunkKind int const ( CsumChunkArg CsumChunkKind = iota CsumChunkConst ) type CsumInfo struct { Kind CsumKind Chunks []CsumChunk } type CsumChunk struct { Kind CsumChunkKind Arg Arg // for CsumChunkArg Value uint64 // for CsumChunkConst Size uint64 // for CsumChunkConst } func calcChecksumsCall(c *Call) (map[Arg]CsumInfo, map[Arg]struct{}) { var inetCsumFields, pseudoCsumFields []Arg // Find all csum fields. ForeachArg(c, func(arg Arg, _ *ArgCtx) { if typ, ok := arg.Type().(*CsumType); ok { switch typ.Kind { case CsumInet: inetCsumFields = append(inetCsumFields, arg) case CsumPseudo: pseudoCsumFields = append(pseudoCsumFields, arg) default: panic(fmt.Sprintf("unknown csum kind %v", typ.Kind)) } } }) if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 { return nil, nil } // Build map of each field to its parent struct. parentsMap := make(map[Arg]Arg) ForeachArg(c, func(arg Arg, _ *ArgCtx) { if _, ok := arg.Type().(*StructType); ok { for _, field := range arg.(*GroupArg).Inner { parentsMap[InnerArg(field)] = arg } } }) csumMap := make(map[Arg]CsumInfo) csumUses := make(map[Arg]struct{}) // Calculate generic inet checksums. for _, arg := range inetCsumFields { typ, _ := arg.Type().(*CsumType) csummedArg := findCsummedArg(arg, typ, parentsMap) csumUses[csummedArg] = struct{}{} chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0} csumMap[arg] = CsumInfo{Kind: CsumInet, Chunks: []CsumChunk{chunk}} } // No need to continue if there are no pseudo csum fields. if len(pseudoCsumFields) == 0 { return csumMap, csumUses } // Extract ipv4 or ipv6 source and destination addresses. var ipSrcAddr, ipDstAddr Arg ForeachArg(c, func(arg Arg, _ *ArgCtx) { groupArg, ok := arg.(*GroupArg) if !ok { return } // syz_csum_* structs are used in tests switch groupArg.Type().Name() { case "ipv4_header", "syz_csum_ipv4_header": ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 4) case "ipv6_packet", "syz_csum_ipv6_header": ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 16) } }) if ipSrcAddr == nil || ipDstAddr == nil { panic("no ipv4 nor ipv6 header found") } // Calculate pseudo checksums. for _, arg := range pseudoCsumFields { typ, _ := arg.Type().(*CsumType) csummedArg := findCsummedArg(arg, typ, parentsMap) protocol := uint8(typ.Protocol) var info CsumInfo if ipSrcAddr.Size() == 4 { info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol) } else { info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol) } csumMap[arg] = info } return csumMap, csumUses } func findCsummedArg(arg Arg, typ *CsumType, parentsMap map[Arg]Arg) Arg { if typ.Buf == "parent" { if csummedArg, ok := parentsMap[arg]; ok { return csummedArg } panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name())) } else { for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] { if typ.Buf == parent.Type().Name() { return parent } } } panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf)) } func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo { info := CsumInfo{Kind: CsumInet} info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(protocol))), 2}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(tcpPacket.Size()))), 2}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) return info } func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo { info := CsumInfo{Kind: CsumInet} info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(tcpPacket.Size()))), 4}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(protocol))), 4}) info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) return info } func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) { srcAddr := getFieldByName(arg, "src_ip") dstAddr := getFieldByName(arg, "dst_ip") if srcAddr.Size() != size || dstAddr.Size() != size { panic(fmt.Sprintf("src/dst_ip fields in %v must be %v bytes", arg.Type().Name(), size)) } return srcAddr, dstAddr } func getFieldByName(arg *GroupArg, name string) Arg { for _, field := range arg.Inner { if field.Type().FieldName() == name { return field } } panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name())) }