1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "ValidateLimitations.h"
16 #include "InfoSink.h"
17 #include "InitializeParseContext.h"
18 #include "ParseHelper.h"
19
20 namespace {
IsLoopIndex(const TIntermSymbol * symbol,const TLoopStack & stack)21 bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
22 for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
23 if (i->index.id == symbol->getId())
24 return true;
25 }
26 return false;
27 }
28
MarkLoopForUnroll(const TIntermSymbol * symbol,TLoopStack & stack)29 void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
30 for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
31 if (i->index.id == symbol->getId()) {
32 ASSERT(i->loop);
33 i->loop->setUnrollFlag(true);
34 return;
35 }
36 }
37 UNREACHABLE(0);
38 }
39
40 // Traverses a node to check if it represents a constant index expression.
41 // Definition:
42 // constant-index-expressions are a superset of constant-expressions.
43 // Constant-index-expressions can include loop indices as defined in
44 // GLSL ES 1.0 spec, Appendix A, section 4.
45 // The following are constant-index-expressions:
46 // - Constant expressions
47 // - Loop indices as defined in section 4
48 // - Expressions composed of both of the above
49 class ValidateConstIndexExpr : public TIntermTraverser {
50 public:
ValidateConstIndexExpr(const TLoopStack & stack)51 ValidateConstIndexExpr(const TLoopStack& stack)
52 : mValid(true), mLoopStack(stack) {}
53
54 // Returns true if the parsed node represents a constant index expression.
isValid() const55 bool isValid() const { return mValid; }
56
visitSymbol(TIntermSymbol * symbol)57 virtual void visitSymbol(TIntermSymbol* symbol) {
58 // Only constants and loop indices are allowed in a
59 // constant index expression.
60 if (mValid) {
61 mValid = (symbol->getQualifier() == EvqConstExpr) ||
62 IsLoopIndex(symbol, mLoopStack);
63 }
64 }
65
66 private:
67 bool mValid;
68 const TLoopStack& mLoopStack;
69 };
70
71 // Traverses a node to check if it uses a loop index.
72 // If an int loop index is used in its body as a sampler array index,
73 // mark the loop for unroll.
74 class ValidateLoopIndexExpr : public TIntermTraverser {
75 public:
ValidateLoopIndexExpr(TLoopStack & stack)76 ValidateLoopIndexExpr(TLoopStack& stack)
77 : mUsesFloatLoopIndex(false),
78 mUsesIntLoopIndex(false),
79 mLoopStack(stack) {}
80
usesFloatLoopIndex() const81 bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
usesIntLoopIndex() const82 bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
83
visitSymbol(TIntermSymbol * symbol)84 virtual void visitSymbol(TIntermSymbol* symbol) {
85 if (IsLoopIndex(symbol, mLoopStack)) {
86 switch (symbol->getBasicType()) {
87 case EbtFloat:
88 mUsesFloatLoopIndex = true;
89 break;
90 case EbtUInt:
91 mUsesIntLoopIndex = true;
92 MarkLoopForUnroll(symbol, mLoopStack);
93 break;
94 case EbtInt:
95 mUsesIntLoopIndex = true;
96 MarkLoopForUnroll(symbol, mLoopStack);
97 break;
98 default:
99 UNREACHABLE(symbol->getBasicType());
100 }
101 }
102 }
103
104 private:
105 bool mUsesFloatLoopIndex;
106 bool mUsesIntLoopIndex;
107 TLoopStack& mLoopStack;
108 };
109 } // namespace
110
ValidateLimitations(GLenum shaderType,TInfoSinkBase & sink)111 ValidateLimitations::ValidateLimitations(GLenum shaderType,
112 TInfoSinkBase& sink)
113 : mShaderType(shaderType),
114 mSink(sink),
115 mNumErrors(0)
116 {
117 }
118
visitBinary(Visit,TIntermBinary * node)119 bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
120 {
121 // Check if loop index is modified in the loop body.
122 validateOperation(node, node->getLeft());
123
124 // Check indexing.
125 switch (node->getOp()) {
126 case EOpIndexDirect:
127 validateIndexing(node);
128 break;
129 case EOpIndexIndirect:
130 validateIndexing(node);
131 break;
132 default: break;
133 }
134 return true;
135 }
136
visitUnary(Visit,TIntermUnary * node)137 bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
138 {
139 // Check if loop index is modified in the loop body.
140 validateOperation(node, node->getOperand());
141
142 return true;
143 }
144
visitAggregate(Visit,TIntermAggregate * node)145 bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
146 {
147 switch (node->getOp()) {
148 case EOpFunctionCall:
149 validateFunctionCall(node);
150 break;
151 default:
152 break;
153 }
154 return true;
155 }
156
visitLoop(Visit,TIntermLoop * node)157 bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
158 {
159 if (!validateLoopType(node))
160 return false;
161
162 TLoopInfo info;
163 memset(&info, 0, sizeof(TLoopInfo));
164 info.loop = node;
165 if (!validateForLoopHeader(node, &info))
166 return false;
167
168 TIntermNode* body = node->getBody();
169 if (body) {
170 mLoopStack.push_back(info);
171 body->traverse(this);
172 mLoopStack.pop_back();
173 }
174
175 // The loop is fully processed - no need to visit children.
176 return false;
177 }
178
error(TSourceLoc loc,const char * reason,const char * token)179 void ValidateLimitations::error(TSourceLoc loc,
180 const char *reason, const char* token)
181 {
182 mSink.prefix(EPrefixError);
183 mSink.location(loc);
184 mSink << "'" << token << "' : " << reason << "\n";
185 ++mNumErrors;
186 }
187
withinLoopBody() const188 bool ValidateLimitations::withinLoopBody() const
189 {
190 return !mLoopStack.empty();
191 }
192
isLoopIndex(const TIntermSymbol * symbol) const193 bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
194 {
195 return IsLoopIndex(symbol, mLoopStack);
196 }
197
validateLoopType(TIntermLoop * node)198 bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
199 TLoopType type = node->getType();
200 if (type == ELoopFor)
201 return true;
202
203 // Reject while and do-while loops.
204 error(node->getLine(),
205 "This type of loop is not allowed",
206 type == ELoopWhile ? "while" : "do");
207 return false;
208 }
209
validateForLoopHeader(TIntermLoop * node,TLoopInfo * info)210 bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
211 TLoopInfo* info)
212 {
213 ASSERT(node->getType() == ELoopFor);
214
215 //
216 // The for statement has the form:
217 // for ( init-declaration ; condition ; expression ) statement
218 //
219 if (!validateForLoopInit(node, info))
220 return false;
221 if (!validateForLoopCond(node, info))
222 return false;
223 if (!validateForLoopExpr(node, info))
224 return false;
225
226 return true;
227 }
228
validateForLoopInit(TIntermLoop * node,TLoopInfo * info)229 bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
230 TLoopInfo* info)
231 {
232 TIntermNode* init = node->getInit();
233 if (!init) {
234 error(node->getLine(), "Missing init declaration", "for");
235 return false;
236 }
237
238 //
239 // init-declaration has the form:
240 // type-specifier identifier = constant-expression
241 //
242 TIntermAggregate* decl = init->getAsAggregate();
243 if (!decl || (decl->getOp() != EOpDeclaration)) {
244 error(init->getLine(), "Invalid init declaration", "for");
245 return false;
246 }
247 // To keep things simple do not allow declaration list.
248 TIntermSequence& declSeq = decl->getSequence();
249 if (declSeq.size() != 1) {
250 error(decl->getLine(), "Invalid init declaration", "for");
251 return false;
252 }
253 TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
254 if (!declInit || (declInit->getOp() != EOpInitialize)) {
255 error(decl->getLine(), "Invalid init declaration", "for");
256 return false;
257 }
258 TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
259 if (!symbol) {
260 error(declInit->getLine(), "Invalid init declaration", "for");
261 return false;
262 }
263 // The loop index has type int or float.
264 TBasicType type = symbol->getBasicType();
265 if (!IsInteger(type) && (type != EbtFloat)) {
266 error(symbol->getLine(),
267 "Invalid type for loop index", getBasicString(type));
268 return false;
269 }
270 // The loop index is initialized with constant expression.
271 if (!isConstExpr(declInit->getRight())) {
272 error(declInit->getLine(),
273 "Loop index cannot be initialized with non-constant expression",
274 symbol->getSymbol().c_str());
275 return false;
276 }
277
278 info->index.id = symbol->getId();
279 return true;
280 }
281
validateForLoopCond(TIntermLoop * node,TLoopInfo * info)282 bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
283 TLoopInfo* info)
284 {
285 TIntermNode* cond = node->getCondition();
286 if (!cond) {
287 error(node->getLine(), "Missing condition", "for");
288 return false;
289 }
290 //
291 // condition has the form:
292 // loop_index relational_operator constant_expression
293 //
294 TIntermBinary* binOp = cond->getAsBinaryNode();
295 if (!binOp) {
296 error(node->getLine(), "Invalid condition", "for");
297 return false;
298 }
299 // Loop index should be to the left of relational operator.
300 TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
301 if (!symbol) {
302 error(binOp->getLine(), "Invalid condition", "for");
303 return false;
304 }
305 if (symbol->getId() != info->index.id) {
306 error(symbol->getLine(),
307 "Expected loop index", symbol->getSymbol().c_str());
308 return false;
309 }
310 // Relational operator is one of: > >= < <= == or !=.
311 switch (binOp->getOp()) {
312 case EOpEqual:
313 case EOpNotEqual:
314 case EOpLessThan:
315 case EOpGreaterThan:
316 case EOpLessThanEqual:
317 case EOpGreaterThanEqual:
318 break;
319 default:
320 error(binOp->getLine(),
321 "Invalid relational operator",
322 getOperatorString(binOp->getOp()));
323 break;
324 }
325 // Loop index must be compared with a constant.
326 if (!isConstExpr(binOp->getRight())) {
327 error(binOp->getLine(),
328 "Loop index cannot be compared with non-constant expression",
329 symbol->getSymbol().c_str());
330 return false;
331 }
332
333 return true;
334 }
335
validateForLoopExpr(TIntermLoop * node,TLoopInfo * info)336 bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
337 TLoopInfo* info)
338 {
339 TIntermNode* expr = node->getExpression();
340 if (!expr) {
341 error(node->getLine(), "Missing expression", "for");
342 return false;
343 }
344
345 // for expression has one of the following forms:
346 // loop_index++
347 // loop_index--
348 // loop_index += constant_expression
349 // loop_index -= constant_expression
350 // ++loop_index
351 // --loop_index
352 // The last two forms are not specified in the spec, but I am assuming
353 // its an oversight.
354 TIntermUnary* unOp = expr->getAsUnaryNode();
355 TIntermBinary* binOp = unOp ? nullptr : expr->getAsBinaryNode();
356
357 TOperator op = EOpNull;
358 TIntermSymbol* symbol = nullptr;
359 if (unOp) {
360 op = unOp->getOp();
361 symbol = unOp->getOperand()->getAsSymbolNode();
362 } else if (binOp) {
363 op = binOp->getOp();
364 symbol = binOp->getLeft()->getAsSymbolNode();
365 }
366
367 // The operand must be loop index.
368 if (!symbol) {
369 error(expr->getLine(), "Invalid expression", "for");
370 return false;
371 }
372 if (symbol->getId() != info->index.id) {
373 error(symbol->getLine(),
374 "Expected loop index", symbol->getSymbol().c_str());
375 return false;
376 }
377
378 // The operator is one of: ++ -- += -=.
379 switch (op) {
380 case EOpPostIncrement:
381 case EOpPostDecrement:
382 case EOpPreIncrement:
383 case EOpPreDecrement:
384 ASSERT((unOp != NULL) && (binOp == NULL));
385 break;
386 case EOpAddAssign:
387 case EOpSubAssign:
388 ASSERT((unOp == NULL) && (binOp != NULL));
389 break;
390 default:
391 error(expr->getLine(), "Invalid operator", getOperatorString(op));
392 return false;
393 }
394
395 // Loop index must be incremented/decremented with a constant.
396 if (binOp != NULL) {
397 if (!isConstExpr(binOp->getRight())) {
398 error(binOp->getLine(),
399 "Loop index cannot be modified by non-constant expression",
400 symbol->getSymbol().c_str());
401 return false;
402 }
403 }
404
405 return true;
406 }
407
validateFunctionCall(TIntermAggregate * node)408 bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
409 {
410 ASSERT(node->getOp() == EOpFunctionCall);
411
412 // If not within loop body, there is nothing to check.
413 if (!withinLoopBody())
414 return true;
415
416 // List of param indices for which loop indices are used as argument.
417 typedef std::vector<int> ParamIndex;
418 ParamIndex pIndex;
419 TIntermSequence& params = node->getSequence();
420 for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
421 TIntermSymbol* symbol = params[i]->getAsSymbolNode();
422 if (symbol && isLoopIndex(symbol))
423 pIndex.push_back(i);
424 }
425 // If none of the loop indices are used as arguments,
426 // there is nothing to check.
427 if (pIndex.empty())
428 return true;
429
430 bool valid = true;
431 TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
432 TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->getShaderVersion());
433 ASSERT(symbol && symbol->isFunction());
434 TFunction* function = static_cast<TFunction*>(symbol);
435 for (ParamIndex::const_iterator i = pIndex.begin();
436 i != pIndex.end(); ++i) {
437 const TParameter& param = function->getParam(*i);
438 TQualifier qual = param.type->getQualifier();
439 if ((qual == EvqOut) || (qual == EvqInOut)) {
440 error(params[*i]->getLine(),
441 "Loop index cannot be used as argument to a function out or inout parameter",
442 params[*i]->getAsSymbolNode()->getSymbol().c_str());
443 valid = false;
444 }
445 }
446
447 return valid;
448 }
449
validateOperation(TIntermOperator * node,TIntermNode * operand)450 bool ValidateLimitations::validateOperation(TIntermOperator* node,
451 TIntermNode* operand) {
452 // Check if loop index is modified in the loop body.
453 if (!withinLoopBody() || !node->modifiesState())
454 return true;
455
456 const TIntermSymbol* symbol = operand->getAsSymbolNode();
457 if (symbol && isLoopIndex(symbol)) {
458 error(node->getLine(),
459 "Loop index cannot be statically assigned to within the body of the loop",
460 symbol->getSymbol().c_str());
461 }
462 return true;
463 }
464
isConstExpr(TIntermNode * node)465 bool ValidateLimitations::isConstExpr(TIntermNode* node)
466 {
467 ASSERT(node);
468 return node->getAsConstantUnion() != nullptr;
469 }
470
isConstIndexExpr(TIntermNode * node)471 bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
472 {
473 ASSERT(node);
474
475 ValidateConstIndexExpr validate(mLoopStack);
476 node->traverse(&validate);
477 return validate.isValid();
478 }
479
validateIndexing(TIntermBinary * node)480 bool ValidateLimitations::validateIndexing(TIntermBinary* node)
481 {
482 ASSERT((node->getOp() == EOpIndexDirect) ||
483 (node->getOp() == EOpIndexIndirect));
484
485 bool valid = true;
486 TIntermTyped* index = node->getRight();
487 // The index expression must have integral type.
488 if (!index->isScalarInt()) {
489 error(index->getLine(),
490 "Index expression must have integral type",
491 index->getCompleteString().c_str());
492 valid = false;
493 }
494 // The index expession must be a constant-index-expression unless
495 // the operand is a uniform in a vertex shader.
496 TIntermTyped* operand = node->getLeft();
497 bool skip = (mShaderType == GL_VERTEX_SHADER) &&
498 (operand->getQualifier() == EvqUniform);
499 if (!skip && !isConstIndexExpr(index)) {
500 error(index->getLine(), "Index expression must be constant", "[]");
501 valid = false;
502 }
503 return valid;
504 }
505
506