1 //
2 // Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/ValidateLimitations.h"
8 #include "compiler/InfoSink.h"
9 #include "compiler/InitializeParseContext.h"
10 #include "compiler/ParseContext.h"
11
12 namespace {
IsLoopIndex(const TIntermSymbol * symbol,const TLoopStack & stack)13 bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
14 for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
15 if (i->index.id == symbol->getId())
16 return true;
17 }
18 return false;
19 }
20
MarkLoopForUnroll(const TIntermSymbol * symbol,TLoopStack & stack)21 void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
22 for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
23 if (i->index.id == symbol->getId()) {
24 ASSERT(i->loop != NULL);
25 i->loop->setUnrollFlag(true);
26 return;
27 }
28 }
29 UNREACHABLE();
30 }
31
32 // Traverses a node to check if it represents a constant index expression.
33 // Definition:
34 // constant-index-expressions are a superset of constant-expressions.
35 // Constant-index-expressions can include loop indices as defined in
36 // GLSL ES 1.0 spec, Appendix A, section 4.
37 // The following are constant-index-expressions:
38 // - Constant expressions
39 // - Loop indices as defined in section 4
40 // - Expressions composed of both of the above
41 class ValidateConstIndexExpr : public TIntermTraverser {
42 public:
ValidateConstIndexExpr(const TLoopStack & stack)43 ValidateConstIndexExpr(const TLoopStack& stack)
44 : mValid(true), mLoopStack(stack) {}
45
46 // Returns true if the parsed node represents a constant index expression.
isValid() const47 bool isValid() const { return mValid; }
48
visitSymbol(TIntermSymbol * symbol)49 virtual void visitSymbol(TIntermSymbol* symbol) {
50 // Only constants and loop indices are allowed in a
51 // constant index expression.
52 if (mValid) {
53 mValid = (symbol->getQualifier() == EvqConst) ||
54 IsLoopIndex(symbol, mLoopStack);
55 }
56 }
57
58 private:
59 bool mValid;
60 const TLoopStack& mLoopStack;
61 };
62
63 // Traverses a node to check if it uses a loop index.
64 // If an int loop index is used in its body as a sampler array index,
65 // mark the loop for unroll.
66 class ValidateLoopIndexExpr : public TIntermTraverser {
67 public:
ValidateLoopIndexExpr(TLoopStack & stack)68 ValidateLoopIndexExpr(TLoopStack& stack)
69 : mUsesFloatLoopIndex(false),
70 mUsesIntLoopIndex(false),
71 mLoopStack(stack) {}
72
usesFloatLoopIndex() const73 bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
usesIntLoopIndex() const74 bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
75
visitSymbol(TIntermSymbol * symbol)76 virtual void visitSymbol(TIntermSymbol* symbol) {
77 if (IsLoopIndex(symbol, mLoopStack)) {
78 switch (symbol->getBasicType()) {
79 case EbtFloat:
80 mUsesFloatLoopIndex = true;
81 break;
82 case EbtInt:
83 mUsesIntLoopIndex = true;
84 MarkLoopForUnroll(symbol, mLoopStack);
85 break;
86 default:
87 UNREACHABLE();
88 }
89 }
90 }
91
92 private:
93 bool mUsesFloatLoopIndex;
94 bool mUsesIntLoopIndex;
95 TLoopStack& mLoopStack;
96 };
97 } // namespace
98
ValidateLimitations(ShShaderType shaderType,TInfoSinkBase & sink)99 ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
100 TInfoSinkBase& sink)
101 : mShaderType(shaderType),
102 mSink(sink),
103 mNumErrors(0)
104 {
105 }
106
visitBinary(Visit,TIntermBinary * node)107 bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
108 {
109 // Check if loop index is modified in the loop body.
110 validateOperation(node, node->getLeft());
111
112 // Check indexing.
113 switch (node->getOp()) {
114 case EOpIndexDirect:
115 validateIndexing(node);
116 break;
117 case EOpIndexIndirect:
118 #if defined(__APPLE__)
119 // Loop unrolling is a work-around for a Mac Cg compiler bug where it
120 // crashes when a sampler array's index is also the loop index.
121 // Once Apple fixes this bug, we should remove the code in this CL.
122 // See http://codereview.appspot.com/4331048/.
123 if ((node->getLeft() != NULL) && (node->getRight() != NULL) &&
124 (node->getLeft()->getAsSymbolNode())) {
125 TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode();
126 if (IsSampler(symbol->getBasicType()) && symbol->isArray()) {
127 ValidateLoopIndexExpr validate(mLoopStack);
128 node->getRight()->traverse(&validate);
129 if (validate.usesFloatLoopIndex()) {
130 error(node->getLine(),
131 "sampler array index is float loop index",
132 "for");
133 }
134 }
135 }
136 #endif
137 validateIndexing(node);
138 break;
139 default: break;
140 }
141 return true;
142 }
143
visitUnary(Visit,TIntermUnary * node)144 bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
145 {
146 // Check if loop index is modified in the loop body.
147 validateOperation(node, node->getOperand());
148
149 return true;
150 }
151
visitAggregate(Visit,TIntermAggregate * node)152 bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
153 {
154 switch (node->getOp()) {
155 case EOpFunctionCall:
156 validateFunctionCall(node);
157 break;
158 default:
159 break;
160 }
161 return true;
162 }
163
visitLoop(Visit,TIntermLoop * node)164 bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
165 {
166 if (!validateLoopType(node))
167 return false;
168
169 TLoopInfo info;
170 memset(&info, 0, sizeof(TLoopInfo));
171 info.loop = node;
172 if (!validateForLoopHeader(node, &info))
173 return false;
174
175 TIntermNode* body = node->getBody();
176 if (body != NULL) {
177 mLoopStack.push_back(info);
178 body->traverse(this);
179 mLoopStack.pop_back();
180 }
181
182 // The loop is fully processed - no need to visit children.
183 return false;
184 }
185
error(TSourceLoc loc,const char * reason,const char * token)186 void ValidateLimitations::error(TSourceLoc loc,
187 const char *reason, const char* token)
188 {
189 mSink.prefix(EPrefixError);
190 mSink.location(loc);
191 mSink << "'" << token << "' : " << reason << "\n";
192 ++mNumErrors;
193 }
194
withinLoopBody() const195 bool ValidateLimitations::withinLoopBody() const
196 {
197 return !mLoopStack.empty();
198 }
199
isLoopIndex(const TIntermSymbol * symbol) const200 bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
201 {
202 return IsLoopIndex(symbol, mLoopStack);
203 }
204
validateLoopType(TIntermLoop * node)205 bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
206 TLoopType type = node->getType();
207 if (type == ELoopFor)
208 return true;
209
210 // Reject while and do-while loops.
211 error(node->getLine(),
212 "This type of loop is not allowed",
213 type == ELoopWhile ? "while" : "do");
214 return false;
215 }
216
validateForLoopHeader(TIntermLoop * node,TLoopInfo * info)217 bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
218 TLoopInfo* info)
219 {
220 ASSERT(node->getType() == ELoopFor);
221
222 //
223 // The for statement has the form:
224 // for ( init-declaration ; condition ; expression ) statement
225 //
226 if (!validateForLoopInit(node, info))
227 return false;
228 if (!validateForLoopCond(node, info))
229 return false;
230 if (!validateForLoopExpr(node, info))
231 return false;
232
233 return true;
234 }
235
validateForLoopInit(TIntermLoop * node,TLoopInfo * info)236 bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
237 TLoopInfo* info)
238 {
239 TIntermNode* init = node->getInit();
240 if (init == NULL) {
241 error(node->getLine(), "Missing init declaration", "for");
242 return false;
243 }
244
245 //
246 // init-declaration has the form:
247 // type-specifier identifier = constant-expression
248 //
249 TIntermAggregate* decl = init->getAsAggregate();
250 if ((decl == NULL) || (decl->getOp() != EOpDeclaration)) {
251 error(init->getLine(), "Invalid init declaration", "for");
252 return false;
253 }
254 // To keep things simple do not allow declaration list.
255 TIntermSequence& declSeq = decl->getSequence();
256 if (declSeq.size() != 1) {
257 error(decl->getLine(), "Invalid init declaration", "for");
258 return false;
259 }
260 TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
261 if ((declInit == NULL) || (declInit->getOp() != EOpInitialize)) {
262 error(decl->getLine(), "Invalid init declaration", "for");
263 return false;
264 }
265 TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
266 if (symbol == NULL) {
267 error(declInit->getLine(), "Invalid init declaration", "for");
268 return false;
269 }
270 // The loop index has type int or float.
271 TBasicType type = symbol->getBasicType();
272 if ((type != EbtInt) && (type != EbtFloat)) {
273 error(symbol->getLine(),
274 "Invalid type for loop index", getBasicString(type));
275 return false;
276 }
277 // The loop index is initialized with constant expression.
278 if (!isConstExpr(declInit->getRight())) {
279 error(declInit->getLine(),
280 "Loop index cannot be initialized with non-constant expression",
281 symbol->getSymbol().c_str());
282 return false;
283 }
284
285 info->index.id = symbol->getId();
286 return true;
287 }
288
validateForLoopCond(TIntermLoop * node,TLoopInfo * info)289 bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
290 TLoopInfo* info)
291 {
292 TIntermNode* cond = node->getCondition();
293 if (cond == NULL) {
294 error(node->getLine(), "Missing condition", "for");
295 return false;
296 }
297 //
298 // condition has the form:
299 // loop_index relational_operator constant_expression
300 //
301 TIntermBinary* binOp = cond->getAsBinaryNode();
302 if (binOp == NULL) {
303 error(node->getLine(), "Invalid condition", "for");
304 return false;
305 }
306 // Loop index should be to the left of relational operator.
307 TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
308 if (symbol == NULL) {
309 error(binOp->getLine(), "Invalid condition", "for");
310 return false;
311 }
312 if (symbol->getId() != info->index.id) {
313 error(symbol->getLine(),
314 "Expected loop index", symbol->getSymbol().c_str());
315 return false;
316 }
317 // Relational operator is one of: > >= < <= == or !=.
318 switch (binOp->getOp()) {
319 case EOpEqual:
320 case EOpNotEqual:
321 case EOpLessThan:
322 case EOpGreaterThan:
323 case EOpLessThanEqual:
324 case EOpGreaterThanEqual:
325 break;
326 default:
327 error(binOp->getLine(),
328 "Invalid relational operator",
329 getOperatorString(binOp->getOp()));
330 break;
331 }
332 // Loop index must be compared with a constant.
333 if (!isConstExpr(binOp->getRight())) {
334 error(binOp->getLine(),
335 "Loop index cannot be compared with non-constant expression",
336 symbol->getSymbol().c_str());
337 return false;
338 }
339
340 return true;
341 }
342
validateForLoopExpr(TIntermLoop * node,TLoopInfo * info)343 bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
344 TLoopInfo* info)
345 {
346 TIntermNode* expr = node->getExpression();
347 if (expr == NULL) {
348 error(node->getLine(), "Missing expression", "for");
349 return false;
350 }
351
352 // for expression has one of the following forms:
353 // loop_index++
354 // loop_index--
355 // loop_index += constant_expression
356 // loop_index -= constant_expression
357 // ++loop_index
358 // --loop_index
359 // The last two forms are not specified in the spec, but I am assuming
360 // its an oversight.
361 TIntermUnary* unOp = expr->getAsUnaryNode();
362 TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
363
364 TOperator op = EOpNull;
365 TIntermSymbol* symbol = NULL;
366 if (unOp != NULL) {
367 op = unOp->getOp();
368 symbol = unOp->getOperand()->getAsSymbolNode();
369 } else if (binOp != NULL) {
370 op = binOp->getOp();
371 symbol = binOp->getLeft()->getAsSymbolNode();
372 }
373
374 // The operand must be loop index.
375 if (symbol == NULL) {
376 error(expr->getLine(), "Invalid expression", "for");
377 return false;
378 }
379 if (symbol->getId() != info->index.id) {
380 error(symbol->getLine(),
381 "Expected loop index", symbol->getSymbol().c_str());
382 return false;
383 }
384
385 // The operator is one of: ++ -- += -=.
386 switch (op) {
387 case EOpPostIncrement:
388 case EOpPostDecrement:
389 case EOpPreIncrement:
390 case EOpPreDecrement:
391 ASSERT((unOp != NULL) && (binOp == NULL));
392 break;
393 case EOpAddAssign:
394 case EOpSubAssign:
395 ASSERT((unOp == NULL) && (binOp != NULL));
396 break;
397 default:
398 error(expr->getLine(), "Invalid operator", getOperatorString(op));
399 return false;
400 }
401
402 // Loop index must be incremented/decremented with a constant.
403 if (binOp != NULL) {
404 if (!isConstExpr(binOp->getRight())) {
405 error(binOp->getLine(),
406 "Loop index cannot be modified by non-constant expression",
407 symbol->getSymbol().c_str());
408 return false;
409 }
410 }
411
412 return true;
413 }
414
validateFunctionCall(TIntermAggregate * node)415 bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
416 {
417 ASSERT(node->getOp() == EOpFunctionCall);
418
419 // If not within loop body, there is nothing to check.
420 if (!withinLoopBody())
421 return true;
422
423 // List of param indices for which loop indices are used as argument.
424 typedef std::vector<size_t> ParamIndex;
425 ParamIndex pIndex;
426 TIntermSequence& params = node->getSequence();
427 for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
428 TIntermSymbol* symbol = params[i]->getAsSymbolNode();
429 if (symbol && isLoopIndex(symbol))
430 pIndex.push_back(i);
431 }
432 // If none of the loop indices are used as arguments,
433 // there is nothing to check.
434 if (pIndex.empty())
435 return true;
436
437 bool valid = true;
438 TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
439 TSymbol* symbol = symbolTable.find(node->getName());
440 ASSERT(symbol && symbol->isFunction());
441 TFunction* function = static_cast<TFunction*>(symbol);
442 for (ParamIndex::const_iterator i = pIndex.begin();
443 i != pIndex.end(); ++i) {
444 const TParameter& param = function->getParam(*i);
445 TQualifier qual = param.type->getQualifier();
446 if ((qual == EvqOut) || (qual == EvqInOut)) {
447 error(params[*i]->getLine(),
448 "Loop index cannot be used as argument to a function out or inout parameter",
449 params[*i]->getAsSymbolNode()->getSymbol().c_str());
450 valid = false;
451 }
452 }
453
454 return valid;
455 }
456
validateOperation(TIntermOperator * node,TIntermNode * operand)457 bool ValidateLimitations::validateOperation(TIntermOperator* node,
458 TIntermNode* operand) {
459 // Check if loop index is modified in the loop body.
460 if (!withinLoopBody() || !node->isAssignment())
461 return true;
462
463 const TIntermSymbol* symbol = operand->getAsSymbolNode();
464 if (symbol && isLoopIndex(symbol)) {
465 error(node->getLine(),
466 "Loop index cannot be statically assigned to within the body of the loop",
467 symbol->getSymbol().c_str());
468 }
469 return true;
470 }
471
isConstExpr(TIntermNode * node)472 bool ValidateLimitations::isConstExpr(TIntermNode* node)
473 {
474 ASSERT(node != NULL);
475 return node->getAsConstantUnion() != NULL;
476 }
477
isConstIndexExpr(TIntermNode * node)478 bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
479 {
480 ASSERT(node != NULL);
481
482 ValidateConstIndexExpr validate(mLoopStack);
483 node->traverse(&validate);
484 return validate.isValid();
485 }
486
validateIndexing(TIntermBinary * node)487 bool ValidateLimitations::validateIndexing(TIntermBinary* node)
488 {
489 ASSERT((node->getOp() == EOpIndexDirect) ||
490 (node->getOp() == EOpIndexIndirect));
491
492 bool valid = true;
493 TIntermTyped* index = node->getRight();
494 // The index expression must have integral type.
495 if (!index->isScalar() || (index->getBasicType() != EbtInt)) {
496 error(index->getLine(),
497 "Index expression must have integral type",
498 index->getCompleteString().c_str());
499 valid = false;
500 }
501 // The index expession must be a constant-index-expression unless
502 // the operand is a uniform in a vertex shader.
503 TIntermTyped* operand = node->getLeft();
504 bool skip = (mShaderType == SH_VERTEX_SHADER) &&
505 (operand->getQualifier() == EvqUniform);
506 if (!skip && !isConstIndexExpr(index)) {
507 error(index->getLine(), "Index expression must be constant", "[]");
508 valid = false;
509 }
510 return valid;
511 }
512
513