• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2017 The Wuffs Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package check
16
17import (
18	"errors"
19	"fmt"
20
21	a "github.com/google/wuffs/lang/ast"
22	t "github.com/google/wuffs/lang/token"
23)
24
25// otherHandSide returns the operator and other hand side when n is an
26// binary-op expression like "thisHS == thatHS" or "thatHS < thisHS" (which is
27// equivalent to "thisHS > thatHS"). If not, it returns (0, nil).
28func otherHandSide(n *a.Expr, thisHS *a.Expr) (op t.ID, thatHS *a.Expr) {
29	op = n.Operator()
30
31	reverseOp := t.ID(0)
32	switch op {
33	case t.IDXBinaryNotEq:
34		reverseOp = t.IDXBinaryNotEq
35	case t.IDXBinaryLessThan:
36		reverseOp = t.IDXBinaryGreaterThan
37	case t.IDXBinaryLessEq:
38		reverseOp = t.IDXBinaryGreaterEq
39	case t.IDXBinaryEqEq:
40		reverseOp = t.IDXBinaryEqEq
41	case t.IDXBinaryGreaterEq:
42		reverseOp = t.IDXBinaryLessEq
43	case t.IDXBinaryGreaterThan:
44		reverseOp = t.IDXBinaryLessThan
45	}
46
47	if reverseOp != 0 {
48		if thisHS.Eq(n.LHS().AsExpr()) {
49			return op, n.RHS().AsExpr()
50		}
51		if thisHS.Eq(n.RHS().AsExpr()) {
52			return reverseOp, n.LHS().AsExpr()
53		}
54	}
55	return 0, nil
56}
57
58type facts []*a.Expr
59
60func (z *facts) appendBinaryOpFact(op t.ID, lhs *a.Expr, rhs *a.Expr) {
61	o := a.NewExpr(0, op, 0, 0, lhs.AsNode(), nil, rhs.AsNode(), nil)
62	o.SetMBounds(bounds{zero, one})
63	o.SetMType(typeExprBool)
64	z.appendFact(o)
65}
66
67func (z *facts) appendFact(fact *a.Expr) {
68	// TODO: make this faster than O(N) by keeping facts sorted somehow?
69	for _, x := range *z {
70		if x.Eq(fact) {
71			return
72		}
73	}
74
75	switch fact.Operator() {
76	case t.IDXBinaryAnd:
77		z.appendFact(fact.LHS().AsExpr())
78		z.appendFact(fact.RHS().AsExpr())
79		return
80	case t.IDXAssociativeAnd:
81		for _, a := range fact.Args() {
82			z.appendFact(a.AsExpr())
83		}
84		return
85	}
86
87	*z = append(*z, fact)
88}
89
90// update applies f to each fact, replacing the slice element with the result
91// of the function call. The slice is then compacted to remove all nils.
92func (z *facts) update(f func(*a.Expr) (*a.Expr, error)) error {
93	i := 0
94	for _, x := range *z {
95		x, err := f(x)
96		if err != nil {
97			return err
98		}
99		if x != nil {
100			(*z)[i] = x
101			i++
102		}
103	}
104	for j := i; j < len(*z); j++ {
105		(*z)[j] = nil
106	}
107	*z = (*z)[:i]
108	return nil
109}
110
111func (z facts) refine(n *a.Expr, nb bounds, tm *t.Map) (bounds, error) {
112	if nb[0] == nil || nb[1] == nil {
113		return nb, nil
114	}
115
116	for _, x := range z {
117		op, other := otherHandSide(x, n)
118		if op == 0 {
119			continue
120		}
121		cv := other.ConstValue()
122		if cv == nil {
123			continue
124		}
125
126		originalNB, changed := nb, false
127		switch op {
128		case t.IDXBinaryNotEq:
129			if nb[0].Cmp(cv) == 0 {
130				nb[0] = add1(nb[0])
131				changed = true
132			} else if nb[1].Cmp(cv) == 0 {
133				nb[1] = sub1(nb[1])
134				changed = true
135			}
136		case t.IDXBinaryLessThan:
137			if nb[1].Cmp(cv) >= 0 {
138				nb[1] = sub1(cv)
139				changed = true
140			}
141		case t.IDXBinaryLessEq:
142			if nb[1].Cmp(cv) > 0 {
143				nb[1] = cv
144				changed = true
145			}
146		case t.IDXBinaryEqEq:
147			nb[0], nb[1] = cv, cv
148			changed = true
149		case t.IDXBinaryGreaterEq:
150			if nb[0].Cmp(cv) < 0 {
151				nb[0] = cv
152				changed = true
153			}
154		case t.IDXBinaryGreaterThan:
155			if nb[0].Cmp(cv) <= 0 {
156				nb[0] = add1(cv)
157				changed = true
158			}
159		}
160
161		if changed && nb[0].Cmp(nb[1]) > 0 {
162			return bounds{}, fmt.Errorf("check: expression %q bounds %v inconsistent with fact %q",
163				n.Str(tm), originalNB, x.Str(tm))
164		}
165	}
166
167	return nb, nil
168}
169
170// simplify returns a simplified form of n. For example, (x - x) becomes 0.
171func simplify(tm *t.Map, n *a.Expr) (*a.Expr, error) {
172	// TODO: be rigorous about this, not ad hoc.
173	op, lhs, rhs := parseBinaryOp(n)
174	if lhs != nil && rhs != nil {
175		if lcv, rcv := lhs.ConstValue(), rhs.ConstValue(); lcv != nil && rcv != nil {
176			ncv, err := evalConstValueBinaryOp(tm, n, lcv, rcv)
177			if err != nil {
178				return nil, err
179			}
180			return makeConstValueExpr(tm, ncv)
181		}
182	}
183
184	switch op {
185	case t.IDXBinaryPlus:
186		// TODO: more constant folding, so ((x + 1) + 1) becomes (x + 2).
187
188	case t.IDXBinaryMinus:
189		if lhs.Eq(rhs) {
190			return zeroExpr, nil
191		}
192		if lOp, lLHS, lRHS := parseBinaryOp(lhs); lOp == t.IDXBinaryPlus {
193			if lLHS.Eq(rhs) {
194				return lRHS, nil
195			}
196			if lRHS.Eq(rhs) {
197				return lLHS, nil
198			}
199		}
200
201	case t.IDXBinaryNotEq, t.IDXBinaryLessThan, t.IDXBinaryLessEq,
202		t.IDXBinaryEqEq, t.IDXBinaryGreaterEq, t.IDXBinaryGreaterThan:
203
204		l, err := simplify(tm, lhs)
205		if err != nil {
206			return nil, err
207		}
208		r, err := simplify(tm, rhs)
209		if err != nil {
210			return nil, err
211		}
212		if l != lhs || r != rhs {
213			o := a.NewExpr(0, op, 0, 0, l.AsNode(), nil, r.AsNode(), nil)
214			o.SetConstValue(n.ConstValue())
215			o.SetMType(n.MType())
216			return o, nil
217		}
218	}
219	return n, nil
220}
221
222func argValue(tm *t.Map, args []*a.Node, name string) *a.Expr {
223	if x := tm.ByName(name); x != 0 {
224		for _, a := range args {
225			if a.AsArg().Name() == x {
226				return a.AsArg().Value()
227			}
228		}
229	}
230	return nil
231}
232
233// parseBinaryOp parses n as "lhs op rhs".
234func parseBinaryOp(n *a.Expr) (op t.ID, lhs *a.Expr, rhs *a.Expr) {
235	if !n.Operator().IsBinaryOp() {
236		return 0, nil, nil
237	}
238	op = n.Operator()
239	if op == t.IDAs {
240		return 0, nil, nil
241	}
242	return op, n.LHS().AsExpr(), n.RHS().AsExpr()
243}
244
245func proveBinaryOpConstValues(op t.ID, lb bounds, rb bounds) (ok bool) {
246	switch op {
247	case t.IDXBinaryNotEq:
248		return lb[1].Cmp(rb[0]) < 0 || lb[0].Cmp(rb[1]) > 0
249	case t.IDXBinaryLessThan:
250		return lb[1].Cmp(rb[0]) < 0
251	case t.IDXBinaryLessEq:
252		return lb[1].Cmp(rb[0]) <= 0
253	case t.IDXBinaryEqEq:
254		return lb[0].Cmp(rb[1]) == 0 && lb[1].Cmp(rb[0]) == 0
255	case t.IDXBinaryGreaterEq:
256		return lb[0].Cmp(rb[1]) >= 0
257	case t.IDXBinaryGreaterThan:
258		return lb[0].Cmp(rb[1]) > 0
259	}
260	return false
261}
262
263func (q *checker) proveBinaryOp(op t.ID, lhs *a.Expr, rhs *a.Expr) error {
264	lcv := lhs.ConstValue()
265	if lcv != nil {
266		rb, err := q.bcheckExpr(rhs, 0)
267		if err != nil {
268			return err
269		}
270		if proveBinaryOpConstValues(op, bounds{lcv, lcv}, rb) {
271			return nil
272		}
273	}
274	rcv := rhs.ConstValue()
275	if rcv != nil {
276		lb, err := q.bcheckExpr(lhs, 0)
277		if err != nil {
278			return err
279		}
280		if proveBinaryOpConstValues(op, lb, bounds{rcv, rcv}) {
281			return nil
282		}
283	}
284
285	for _, x := range q.facts {
286		if !x.LHS().AsExpr().Eq(lhs) {
287			continue
288		}
289		factOp := x.Operator()
290		if opImpliesOp(factOp, op) && x.RHS().AsExpr().Eq(rhs) {
291			return nil
292		}
293
294		if factOp == t.IDXBinaryEqEq && rcv != nil {
295			if factCV := x.RHS().AsExpr().ConstValue(); factCV != nil {
296				switch op {
297				case t.IDXBinaryNotEq:
298					return errFailedOrNil(factCV.Cmp(rcv) != 0)
299				case t.IDXBinaryLessThan:
300					return errFailedOrNil(factCV.Cmp(rcv) < 0)
301				case t.IDXBinaryLessEq:
302					return errFailedOrNil(factCV.Cmp(rcv) <= 0)
303				case t.IDXBinaryEqEq:
304					return errFailedOrNil(factCV.Cmp(rcv) == 0)
305				case t.IDXBinaryGreaterEq:
306					return errFailedOrNil(factCV.Cmp(rcv) >= 0)
307				case t.IDXBinaryGreaterThan:
308					return errFailedOrNil(factCV.Cmp(rcv) > 0)
309				}
310			}
311		}
312	}
313	return errFailed
314}
315
316// opImpliesOp returns whether the first op implies the second. For example,
317// knowing "x < y" implies that "x != y" and "x <= y".
318func opImpliesOp(op0 t.ID, op1 t.ID) bool {
319	if op0 == op1 {
320		return true
321	}
322	switch op0 {
323	case t.IDXBinaryLessThan:
324		return op1 == t.IDXBinaryNotEq || op1 == t.IDXBinaryLessEq
325	case t.IDXBinaryGreaterThan:
326		return op1 == t.IDXBinaryNotEq || op1 == t.IDXBinaryGreaterEq
327	}
328	return false
329}
330
331func errFailedOrNil(ok bool) error {
332	if ok {
333		return nil
334	}
335	return errFailed
336}
337
338var errFailed = errors.New("failed")
339
340func proveReasonRequirement(q *checker, op t.ID, lhs *a.Expr, rhs *a.Expr) error {
341	if !op.IsXBinaryOp() {
342		return fmt.Errorf(
343			"check: internal error: proveReasonRequirement token (0x%02X) is not an XBinaryOp", op)
344	}
345	if err := q.proveBinaryOp(op, lhs, rhs); err != nil {
346		n := a.NewExpr(0, op, 0, 0, lhs.AsNode(), nil, rhs.AsNode(), nil)
347		return fmt.Errorf("cannot prove %q: %v", n.Str(q.tm), err)
348	}
349	return nil
350}
351
352func proveReasonRequirementForRHSLength(q *checker, op t.ID, lhs *a.Expr, rhs *a.Expr) error {
353	if err := proveReasonRequirement(q, op, lhs, rhs); err != nil {
354		if (op == t.IDXBinaryLessThan) || (op == t.IDXBinaryLessEq) {
355			for _, x := range q.facts {
356				// Try to prove "lhs op rhs" by proving "lhs op const", given a
357				// fact x of the form "rhs >= const".
358				if (x.Operator() == t.IDXBinaryGreaterEq) && x.LHS().AsExpr().Eq(rhs) &&
359					(x.RHS().AsExpr().ConstValue() != nil) &&
360					(proveReasonRequirement(q, op, lhs, x.RHS().AsExpr()) == nil) {
361
362					return nil
363				}
364			}
365		}
366		return err
367	}
368	return nil
369}
370