• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2015 syzkaller project authors. All rights reserved.
2// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
3
4// Conservative resource-related analysis of programs.
5// The analysis figures out what files descriptors are [potentially] opened
6// at a particular point in program, what pages are [potentially] mapped,
7// what files were already referenced in calls, etc.
8
9package prog
10
11import (
12	"fmt"
13)
14
15type state struct {
16	target    *Target
17	ct        *ChoiceTable
18	files     map[string]bool
19	resources map[string][]*ResultArg
20	strings   map[string]bool
21	ma        *memAlloc
22	va        *vmaAlloc
23}
24
25// analyze analyzes the program p up to but not including call c.
26func analyze(ct *ChoiceTable, p *Prog, c *Call) *state {
27	s := newState(p.Target, ct)
28	resources := true
29	for _, c1 := range p.Calls {
30		if c1 == c {
31			resources = false
32		}
33		s.analyzeImpl(c1, resources)
34	}
35	return s
36}
37
38func newState(target *Target, ct *ChoiceTable) *state {
39	s := &state{
40		target:    target,
41		ct:        ct,
42		files:     make(map[string]bool),
43		resources: make(map[string][]*ResultArg),
44		strings:   make(map[string]bool),
45		ma:        newMemAlloc(target.NumPages * target.PageSize),
46		va:        newVmaAlloc(target.NumPages),
47	}
48	return s
49}
50
51func (s *state) analyze(c *Call) {
52	s.analyzeImpl(c, true)
53}
54
55func (s *state) analyzeImpl(c *Call, resources bool) {
56	ForeachArg(c, func(arg Arg, _ *ArgCtx) {
57		switch a := arg.(type) {
58		case *PointerArg:
59			switch {
60			case a.IsNull():
61			case a.VmaSize != 0:
62				s.va.noteAlloc(a.Address/s.target.PageSize, a.VmaSize/s.target.PageSize)
63			default:
64				s.ma.noteAlloc(a.Address, a.Res.Size())
65			}
66		}
67		switch typ := arg.Type().(type) {
68		case *ResourceType:
69			a := arg.(*ResultArg)
70			if resources && typ.Dir() != DirIn {
71				s.resources[typ.Desc.Name] = append(s.resources[typ.Desc.Name], a)
72				// TODO: negative PIDs and add them as well (that's process groups).
73			}
74		case *BufferType:
75			a := arg.(*DataArg)
76			if typ.Dir() != DirOut && len(a.Data()) != 0 {
77				val := string(a.Data())
78				// Remove trailing zero padding.
79				for len(val) >= 2 && val[len(val)-1] == 0 && val[len(val)-2] == 0 {
80					val = val[:len(val)-1]
81				}
82				switch typ.Kind {
83				case BufferString:
84					s.strings[val] = true
85				case BufferFilename:
86					if len(val) < 3 {
87						// This is not our file, probalby one of specialFiles.
88						return
89					}
90					if val[len(val)-1] == 0 {
91						val = val[:len(val)-1]
92					}
93					s.files[val] = true
94				}
95			}
96		}
97	})
98}
99
100type ArgCtx struct {
101	Parent *[]Arg      // GroupArg.Inner (for structs) or Call.Args containing this arg
102	Base   *PointerArg // pointer to the base of the heap object containing this arg
103	Offset uint64      // offset of this arg from the base
104	Stop   bool        // if set by the callback, subargs of this arg are not visited
105}
106
107func ForeachSubArg(arg Arg, f func(Arg, *ArgCtx)) {
108	foreachArgImpl(arg, ArgCtx{}, f)
109}
110
111func ForeachArg(c *Call, f func(Arg, *ArgCtx)) {
112	ctx := ArgCtx{}
113	if c.Ret != nil {
114		foreachArgImpl(c.Ret, ctx, f)
115	}
116	ctx.Parent = &c.Args
117	for _, arg := range c.Args {
118		foreachArgImpl(arg, ctx, f)
119	}
120}
121
122func foreachArgImpl(arg Arg, ctx ArgCtx, f func(Arg, *ArgCtx)) {
123	f(arg, &ctx)
124	if ctx.Stop {
125		return
126	}
127	switch a := arg.(type) {
128	case *GroupArg:
129		if _, ok := a.Type().(*StructType); ok {
130			ctx.Parent = &a.Inner
131		}
132		var totalSize uint64
133		for _, arg1 := range a.Inner {
134			foreachArgImpl(arg1, ctx, f)
135			if !arg1.Type().BitfieldMiddle() {
136				size := arg1.Size()
137				ctx.Offset += size
138				totalSize += size
139			}
140		}
141		claimedSize := a.Size()
142		varlen := a.Type().Varlen()
143		if varlen && totalSize > claimedSize || !varlen && totalSize != claimedSize {
144			panic(fmt.Sprintf("bad group arg size %v, should be <= %v for %#v type %#v",
145				totalSize, claimedSize, a, a.Type()))
146		}
147	case *PointerArg:
148		if a.Res != nil {
149			ctx.Base = a
150			ctx.Offset = 0
151			foreachArgImpl(a.Res, ctx, f)
152		}
153	case *UnionArg:
154		foreachArgImpl(a.Option, ctx, f)
155	}
156}
157
158func RequiredFeatures(p *Prog) (bitmasks, csums bool) {
159	for _, c := range p.Calls {
160		ForeachArg(c, func(arg Arg, _ *ArgCtx) {
161			if a, ok := arg.(*ConstArg); ok {
162				if a.Type().BitfieldOffset() != 0 || a.Type().BitfieldLength() != 0 {
163					bitmasks = true
164				}
165			}
166			if _, ok := arg.Type().(*CsumType); ok {
167				csums = true
168			}
169		})
170	}
171	return
172}
173
174type CallFlags int
175
176const (
177	CallExecuted CallFlags = 1 << iota // was started at all
178	CallFinished                       // finished executing (rather than blocked forever)
179	CallBlocked                        // finished but blocked during execution
180)
181
182type CallInfo struct {
183	Flags  CallFlags
184	Errno  int
185	Signal []uint32
186}
187
188const (
189	fallbackSignalErrno = iota
190	fallbackSignalErrnoBlocked
191	fallbackSignalCtor
192	fallbackSignalFlags
193	fallbackCallMask = 0x1fff
194)
195
196func (p *Prog) FallbackSignal(info []CallInfo) {
197	resources := make(map[*ResultArg]*Call)
198	for i, c := range p.Calls {
199		inf := &info[i]
200		if inf.Flags&CallExecuted == 0 {
201			continue
202		}
203		id := c.Meta.ID
204		typ := fallbackSignalErrno
205		if inf.Flags&CallFinished != 0 && inf.Flags&CallBlocked != 0 {
206			typ = fallbackSignalErrnoBlocked
207		}
208		inf.Signal = append(inf.Signal, encodeFallbackSignal(typ, id, inf.Errno))
209		if inf.Errno != 0 {
210			continue
211		}
212		ForeachArg(c, func(arg Arg, _ *ArgCtx) {
213			if a, ok := arg.(*ResultArg); ok {
214				resources[a] = c
215			}
216		})
217		// Specifically look only at top-level arguments,
218		// deeper arguments can produce too much false signal.
219		flags := 0
220		for _, arg := range c.Args {
221			switch a := arg.(type) {
222			case *ResultArg:
223				flags <<= 1
224				if a.Res != nil {
225					ctor := resources[a.Res]
226					if ctor != nil {
227						inf.Signal = append(inf.Signal,
228							encodeFallbackSignal(fallbackSignalCtor, id, ctor.Meta.ID))
229					}
230				} else {
231					if a.Val != a.Type().(*ResourceType).SpecialValues()[0] {
232						flags |= 1
233					}
234				}
235			case *ConstArg:
236				const width = 3
237				flags <<= width
238				switch typ := a.Type().(type) {
239				case *FlagsType:
240					if typ.BitMask {
241						for i, v := range typ.Vals {
242							if a.Val&v != 0 {
243								flags ^= 1 << (uint(i) % width)
244							}
245						}
246					} else {
247						for i, v := range typ.Vals {
248							if a.Val == v {
249								flags |= i % (1 << width)
250								break
251							}
252						}
253					}
254				case *LenType:
255					flags <<= 1
256					if a.Val == 0 {
257						flags |= 1
258					}
259				}
260			case *PointerArg:
261				flags <<= 1
262				if a.IsNull() {
263					flags |= 1
264				}
265			}
266		}
267		if flags != 0 {
268			inf.Signal = append(inf.Signal,
269				encodeFallbackSignal(fallbackSignalFlags, id, flags))
270		}
271	}
272}
273
274func DecodeFallbackSignal(s uint32) (callID, errno int) {
275	typ, id, aux := decodeFallbackSignal(s)
276	switch typ {
277	case fallbackSignalErrno, fallbackSignalErrnoBlocked:
278		return id, aux
279	case fallbackSignalCtor, fallbackSignalFlags:
280		return id, 0
281	default:
282		panic(fmt.Sprintf("bad fallback signal type %v", typ))
283	}
284}
285
286func encodeFallbackSignal(typ, id, aux int) uint32 {
287	if typ & ^7 != 0 {
288		panic(fmt.Sprintf("bad fallback signal type %v", typ))
289	}
290	if id & ^fallbackCallMask != 0 {
291		panic(fmt.Sprintf("bad call id in fallback signal %v", id))
292	}
293	return uint32(typ) | uint32(id&fallbackCallMask)<<3 | uint32(aux)<<16
294}
295
296func decodeFallbackSignal(s uint32) (typ, id, aux int) {
297	return int(s & 7), int((s >> 3) & fallbackCallMask), int(s >> 16)
298}
299