1// Copyright 2017 The Wuffs Authors. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package check 16 17import ( 18 "errors" 19 "fmt" 20 21 a "github.com/google/wuffs/lang/ast" 22 t "github.com/google/wuffs/lang/token" 23) 24 25// otherHandSide returns the operator and other hand side when n is an 26// binary-op expression like "thisHS == thatHS" or "thatHS < thisHS" (which is 27// equivalent to "thisHS > thatHS"). If not, it returns (0, nil). 28func otherHandSide(n *a.Expr, thisHS *a.Expr) (op t.ID, thatHS *a.Expr) { 29 op = n.Operator() 30 31 reverseOp := t.ID(0) 32 switch op { 33 case t.IDXBinaryNotEq: 34 reverseOp = t.IDXBinaryNotEq 35 case t.IDXBinaryLessThan: 36 reverseOp = t.IDXBinaryGreaterThan 37 case t.IDXBinaryLessEq: 38 reverseOp = t.IDXBinaryGreaterEq 39 case t.IDXBinaryEqEq: 40 reverseOp = t.IDXBinaryEqEq 41 case t.IDXBinaryGreaterEq: 42 reverseOp = t.IDXBinaryLessEq 43 case t.IDXBinaryGreaterThan: 44 reverseOp = t.IDXBinaryLessThan 45 } 46 47 if reverseOp != 0 { 48 if thisHS.Eq(n.LHS().AsExpr()) { 49 return op, n.RHS().AsExpr() 50 } 51 if thisHS.Eq(n.RHS().AsExpr()) { 52 return reverseOp, n.LHS().AsExpr() 53 } 54 } 55 return 0, nil 56} 57 58type facts []*a.Expr 59 60func (z *facts) appendBinaryOpFact(op t.ID, lhs *a.Expr, rhs *a.Expr) { 61 o := a.NewExpr(0, op, 0, 0, lhs.AsNode(), nil, rhs.AsNode(), nil) 62 o.SetMBounds(bounds{zero, one}) 63 o.SetMType(typeExprBool) 64 z.appendFact(o) 65} 66 67func (z *facts) appendFact(fact *a.Expr) { 68 // TODO: make this faster than O(N) by keeping facts sorted somehow? 69 for _, x := range *z { 70 if x.Eq(fact) { 71 return 72 } 73 } 74 75 switch fact.Operator() { 76 case t.IDXBinaryAnd: 77 z.appendFact(fact.LHS().AsExpr()) 78 z.appendFact(fact.RHS().AsExpr()) 79 return 80 case t.IDXAssociativeAnd: 81 for _, a := range fact.Args() { 82 z.appendFact(a.AsExpr()) 83 } 84 return 85 } 86 87 *z = append(*z, fact) 88} 89 90// update applies f to each fact, replacing the slice element with the result 91// of the function call. The slice is then compacted to remove all nils. 92func (z *facts) update(f func(*a.Expr) (*a.Expr, error)) error { 93 i := 0 94 for _, x := range *z { 95 x, err := f(x) 96 if err != nil { 97 return err 98 } 99 if x != nil { 100 (*z)[i] = x 101 i++ 102 } 103 } 104 for j := i; j < len(*z); j++ { 105 (*z)[j] = nil 106 } 107 *z = (*z)[:i] 108 return nil 109} 110 111func (z facts) refine(n *a.Expr, nb bounds, tm *t.Map) (bounds, error) { 112 if nb[0] == nil || nb[1] == nil { 113 return nb, nil 114 } 115 116 for _, x := range z { 117 op, other := otherHandSide(x, n) 118 if op == 0 { 119 continue 120 } 121 cv := other.ConstValue() 122 if cv == nil { 123 continue 124 } 125 126 originalNB, changed := nb, false 127 switch op { 128 case t.IDXBinaryNotEq: 129 if nb[0].Cmp(cv) == 0 { 130 nb[0] = add1(nb[0]) 131 changed = true 132 } else if nb[1].Cmp(cv) == 0 { 133 nb[1] = sub1(nb[1]) 134 changed = true 135 } 136 case t.IDXBinaryLessThan: 137 if nb[1].Cmp(cv) >= 0 { 138 nb[1] = sub1(cv) 139 changed = true 140 } 141 case t.IDXBinaryLessEq: 142 if nb[1].Cmp(cv) > 0 { 143 nb[1] = cv 144 changed = true 145 } 146 case t.IDXBinaryEqEq: 147 nb[0], nb[1] = cv, cv 148 changed = true 149 case t.IDXBinaryGreaterEq: 150 if nb[0].Cmp(cv) < 0 { 151 nb[0] = cv 152 changed = true 153 } 154 case t.IDXBinaryGreaterThan: 155 if nb[0].Cmp(cv) <= 0 { 156 nb[0] = add1(cv) 157 changed = true 158 } 159 } 160 161 if changed && nb[0].Cmp(nb[1]) > 0 { 162 return bounds{}, fmt.Errorf("check: expression %q bounds %v inconsistent with fact %q", 163 n.Str(tm), originalNB, x.Str(tm)) 164 } 165 } 166 167 return nb, nil 168} 169 170// simplify returns a simplified form of n. For example, (x - x) becomes 0. 171func simplify(tm *t.Map, n *a.Expr) (*a.Expr, error) { 172 // TODO: be rigorous about this, not ad hoc. 173 op, lhs, rhs := parseBinaryOp(n) 174 if lhs != nil && rhs != nil { 175 if lcv, rcv := lhs.ConstValue(), rhs.ConstValue(); lcv != nil && rcv != nil { 176 ncv, err := evalConstValueBinaryOp(tm, n, lcv, rcv) 177 if err != nil { 178 return nil, err 179 } 180 return makeConstValueExpr(tm, ncv) 181 } 182 } 183 184 switch op { 185 case t.IDXBinaryPlus: 186 // TODO: more constant folding, so ((x + 1) + 1) becomes (x + 2). 187 188 case t.IDXBinaryMinus: 189 if lhs.Eq(rhs) { 190 return zeroExpr, nil 191 } 192 if lOp, lLHS, lRHS := parseBinaryOp(lhs); lOp == t.IDXBinaryPlus { 193 if lLHS.Eq(rhs) { 194 return lRHS, nil 195 } 196 if lRHS.Eq(rhs) { 197 return lLHS, nil 198 } 199 } 200 201 case t.IDXBinaryNotEq, t.IDXBinaryLessThan, t.IDXBinaryLessEq, 202 t.IDXBinaryEqEq, t.IDXBinaryGreaterEq, t.IDXBinaryGreaterThan: 203 204 l, err := simplify(tm, lhs) 205 if err != nil { 206 return nil, err 207 } 208 r, err := simplify(tm, rhs) 209 if err != nil { 210 return nil, err 211 } 212 if l != lhs || r != rhs { 213 o := a.NewExpr(0, op, 0, 0, l.AsNode(), nil, r.AsNode(), nil) 214 o.SetConstValue(n.ConstValue()) 215 o.SetMType(n.MType()) 216 return o, nil 217 } 218 } 219 return n, nil 220} 221 222func argValue(tm *t.Map, args []*a.Node, name string) *a.Expr { 223 if x := tm.ByName(name); x != 0 { 224 for _, a := range args { 225 if a.AsArg().Name() == x { 226 return a.AsArg().Value() 227 } 228 } 229 } 230 return nil 231} 232 233// parseBinaryOp parses n as "lhs op rhs". 234func parseBinaryOp(n *a.Expr) (op t.ID, lhs *a.Expr, rhs *a.Expr) { 235 if !n.Operator().IsBinaryOp() { 236 return 0, nil, nil 237 } 238 op = n.Operator() 239 if op == t.IDAs { 240 return 0, nil, nil 241 } 242 return op, n.LHS().AsExpr(), n.RHS().AsExpr() 243} 244 245func proveBinaryOpConstValues(op t.ID, lb bounds, rb bounds) (ok bool) { 246 switch op { 247 case t.IDXBinaryNotEq: 248 return lb[1].Cmp(rb[0]) < 0 || lb[0].Cmp(rb[1]) > 0 249 case t.IDXBinaryLessThan: 250 return lb[1].Cmp(rb[0]) < 0 251 case t.IDXBinaryLessEq: 252 return lb[1].Cmp(rb[0]) <= 0 253 case t.IDXBinaryEqEq: 254 return lb[0].Cmp(rb[1]) == 0 && lb[1].Cmp(rb[0]) == 0 255 case t.IDXBinaryGreaterEq: 256 return lb[0].Cmp(rb[1]) >= 0 257 case t.IDXBinaryGreaterThan: 258 return lb[0].Cmp(rb[1]) > 0 259 } 260 return false 261} 262 263func (q *checker) proveBinaryOp(op t.ID, lhs *a.Expr, rhs *a.Expr) error { 264 lcv := lhs.ConstValue() 265 if lcv != nil { 266 rb, err := q.bcheckExpr(rhs, 0) 267 if err != nil { 268 return err 269 } 270 if proveBinaryOpConstValues(op, bounds{lcv, lcv}, rb) { 271 return nil 272 } 273 } 274 rcv := rhs.ConstValue() 275 if rcv != nil { 276 lb, err := q.bcheckExpr(lhs, 0) 277 if err != nil { 278 return err 279 } 280 if proveBinaryOpConstValues(op, lb, bounds{rcv, rcv}) { 281 return nil 282 } 283 } 284 285 for _, x := range q.facts { 286 if !x.LHS().AsExpr().Eq(lhs) { 287 continue 288 } 289 factOp := x.Operator() 290 if opImpliesOp(factOp, op) && x.RHS().AsExpr().Eq(rhs) { 291 return nil 292 } 293 294 if factOp == t.IDXBinaryEqEq && rcv != nil { 295 if factCV := x.RHS().AsExpr().ConstValue(); factCV != nil { 296 switch op { 297 case t.IDXBinaryNotEq: 298 return errFailedOrNil(factCV.Cmp(rcv) != 0) 299 case t.IDXBinaryLessThan: 300 return errFailedOrNil(factCV.Cmp(rcv) < 0) 301 case t.IDXBinaryLessEq: 302 return errFailedOrNil(factCV.Cmp(rcv) <= 0) 303 case t.IDXBinaryEqEq: 304 return errFailedOrNil(factCV.Cmp(rcv) == 0) 305 case t.IDXBinaryGreaterEq: 306 return errFailedOrNil(factCV.Cmp(rcv) >= 0) 307 case t.IDXBinaryGreaterThan: 308 return errFailedOrNil(factCV.Cmp(rcv) > 0) 309 } 310 } 311 } 312 } 313 return errFailed 314} 315 316// opImpliesOp returns whether the first op implies the second. For example, 317// knowing "x < y" implies that "x != y" and "x <= y". 318func opImpliesOp(op0 t.ID, op1 t.ID) bool { 319 if op0 == op1 { 320 return true 321 } 322 switch op0 { 323 case t.IDXBinaryLessThan: 324 return op1 == t.IDXBinaryNotEq || op1 == t.IDXBinaryLessEq 325 case t.IDXBinaryGreaterThan: 326 return op1 == t.IDXBinaryNotEq || op1 == t.IDXBinaryGreaterEq 327 } 328 return false 329} 330 331func errFailedOrNil(ok bool) error { 332 if ok { 333 return nil 334 } 335 return errFailed 336} 337 338var errFailed = errors.New("failed") 339 340func proveReasonRequirement(q *checker, op t.ID, lhs *a.Expr, rhs *a.Expr) error { 341 if !op.IsXBinaryOp() { 342 return fmt.Errorf( 343 "check: internal error: proveReasonRequirement token (0x%02X) is not an XBinaryOp", op) 344 } 345 if err := q.proveBinaryOp(op, lhs, rhs); err != nil { 346 n := a.NewExpr(0, op, 0, 0, lhs.AsNode(), nil, rhs.AsNode(), nil) 347 return fmt.Errorf("cannot prove %q: %v", n.Str(q.tm), err) 348 } 349 return nil 350} 351 352func proveReasonRequirementForRHSLength(q *checker, op t.ID, lhs *a.Expr, rhs *a.Expr) error { 353 if err := proveReasonRequirement(q, op, lhs, rhs); err != nil { 354 if (op == t.IDXBinaryLessThan) || (op == t.IDXBinaryLessEq) { 355 for _, x := range q.facts { 356 // Try to prove "lhs op rhs" by proving "lhs op const", given a 357 // fact x of the form "rhs >= const". 358 if (x.Operator() == t.IDXBinaryGreaterEq) && x.LHS().AsExpr().Eq(rhs) && 359 (x.RHS().AsExpr().ConstValue() != nil) && 360 (proveReasonRequirement(q, op, lhs, x.RHS().AsExpr()) == nil) { 361 362 return nil 363 } 364 } 365 } 366 return err 367 } 368 return nil 369} 370