1// Copyright 2017 syzkaller project authors. All rights reserved. 2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4package prog 5 6import ( 7 "fmt" 8) 9 10type CsumChunkKind int 11 12const ( 13 CsumChunkArg CsumChunkKind = iota 14 CsumChunkConst 15) 16 17type CsumInfo struct { 18 Kind CsumKind 19 Chunks []CsumChunk 20} 21 22type CsumChunk struct { 23 Kind CsumChunkKind 24 Arg Arg // for CsumChunkArg 25 Value uint64 // for CsumChunkConst 26 Size uint64 // for CsumChunkConst 27} 28 29func calcChecksumsCall(c *Call) (map[Arg]CsumInfo, map[Arg]struct{}) { 30 var inetCsumFields, pseudoCsumFields []Arg 31 32 // Find all csum fields. 33 ForeachArg(c, func(arg Arg, _ *ArgCtx) { 34 if typ, ok := arg.Type().(*CsumType); ok { 35 switch typ.Kind { 36 case CsumInet: 37 inetCsumFields = append(inetCsumFields, arg) 38 case CsumPseudo: 39 pseudoCsumFields = append(pseudoCsumFields, arg) 40 default: 41 panic(fmt.Sprintf("unknown csum kind %v", typ.Kind)) 42 } 43 } 44 }) 45 46 if len(inetCsumFields) == 0 && len(pseudoCsumFields) == 0 { 47 return nil, nil 48 } 49 50 // Build map of each field to its parent struct. 51 parentsMap := make(map[Arg]Arg) 52 ForeachArg(c, func(arg Arg, _ *ArgCtx) { 53 if _, ok := arg.Type().(*StructType); ok { 54 for _, field := range arg.(*GroupArg).Inner { 55 parentsMap[InnerArg(field)] = arg 56 } 57 } 58 }) 59 60 csumMap := make(map[Arg]CsumInfo) 61 csumUses := make(map[Arg]struct{}) 62 63 // Calculate generic inet checksums. 64 for _, arg := range inetCsumFields { 65 typ, _ := arg.Type().(*CsumType) 66 csummedArg := findCsummedArg(arg, typ, parentsMap) 67 csumUses[csummedArg] = struct{}{} 68 chunk := CsumChunk{CsumChunkArg, csummedArg, 0, 0} 69 csumMap[arg] = CsumInfo{Kind: CsumInet, Chunks: []CsumChunk{chunk}} 70 } 71 72 // No need to continue if there are no pseudo csum fields. 73 if len(pseudoCsumFields) == 0 { 74 return csumMap, csumUses 75 } 76 77 // Extract ipv4 or ipv6 source and destination addresses. 78 var ipSrcAddr, ipDstAddr Arg 79 ForeachArg(c, func(arg Arg, _ *ArgCtx) { 80 groupArg, ok := arg.(*GroupArg) 81 if !ok { 82 return 83 } 84 // syz_csum_* structs are used in tests 85 switch groupArg.Type().Name() { 86 case "ipv4_header", "syz_csum_ipv4_header": 87 ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 4) 88 case "ipv6_packet", "syz_csum_ipv6_header": 89 ipSrcAddr, ipDstAddr = extractHeaderParams(groupArg, 16) 90 } 91 }) 92 if ipSrcAddr == nil || ipDstAddr == nil { 93 panic("no ipv4 nor ipv6 header found") 94 } 95 96 // Calculate pseudo checksums. 97 for _, arg := range pseudoCsumFields { 98 typ, _ := arg.Type().(*CsumType) 99 csummedArg := findCsummedArg(arg, typ, parentsMap) 100 protocol := uint8(typ.Protocol) 101 var info CsumInfo 102 if ipSrcAddr.Size() == 4 { 103 info = composePseudoCsumIPv4(csummedArg, ipSrcAddr, ipDstAddr, protocol) 104 } else { 105 info = composePseudoCsumIPv6(csummedArg, ipSrcAddr, ipDstAddr, protocol) 106 } 107 csumMap[arg] = info 108 } 109 110 return csumMap, csumUses 111} 112 113func findCsummedArg(arg Arg, typ *CsumType, parentsMap map[Arg]Arg) Arg { 114 if typ.Buf == "parent" { 115 if csummedArg, ok := parentsMap[arg]; ok { 116 return csummedArg 117 } 118 panic(fmt.Sprintf("parent for %v is not in parents map", typ.Name())) 119 } else { 120 for parent := parentsMap[arg]; parent != nil; parent = parentsMap[parent] { 121 if typ.Buf == parent.Type().Name() { 122 return parent 123 } 124 } 125 } 126 panic(fmt.Sprintf("csum field '%v' references non existent field '%v'", typ.FieldName(), typ.Buf)) 127} 128 129func composePseudoCsumIPv4(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo { 130 info := CsumInfo{Kind: CsumInet} 131 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) 132 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) 133 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(protocol))), 2}) 134 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap16(uint16(tcpPacket.Size()))), 2}) 135 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) 136 return info 137} 138 139func composePseudoCsumIPv6(tcpPacket, srcAddr, dstAddr Arg, protocol uint8) CsumInfo { 140 info := CsumInfo{Kind: CsumInet} 141 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, srcAddr, 0, 0}) 142 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, dstAddr, 0, 0}) 143 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(tcpPacket.Size()))), 4}) 144 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkConst, nil, uint64(swap32(uint32(protocol))), 4}) 145 info.Chunks = append(info.Chunks, CsumChunk{CsumChunkArg, tcpPacket, 0, 0}) 146 return info 147} 148 149func extractHeaderParams(arg *GroupArg, size uint64) (Arg, Arg) { 150 srcAddr := getFieldByName(arg, "src_ip") 151 dstAddr := getFieldByName(arg, "dst_ip") 152 if srcAddr.Size() != size || dstAddr.Size() != size { 153 panic(fmt.Sprintf("src/dst_ip fields in %v must be %v bytes", arg.Type().Name(), size)) 154 } 155 return srcAddr, dstAddr 156} 157 158func getFieldByName(arg *GroupArg, name string) Arg { 159 for _, field := range arg.Inner { 160 if field.Type().FieldName() == name { 161 return field 162 } 163 } 164 panic(fmt.Sprintf("failed to find %v field in %v", name, arg.Type().Name())) 165} 166