1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/ValidateLimitations.h"
8
9 #include "angle_gl.h"
10 #include "compiler/translator/Diagnostics.h"
11 #include "compiler/translator/ParseContext.h"
12 #include "compiler/translator/tree_util/IntermTraverse.h"
13
14 namespace sh
15 {
16
17 namespace
18 {
19
GetLoopSymbolId(TIntermLoop * loop)20 int GetLoopSymbolId(TIntermLoop *loop)
21 {
22 // Here we assume all the operations are valid, because the loop node is
23 // already validated before this call.
24 TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
25 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
26 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
27
28 return symbol->uniqueId().get();
29 }
30
31 // Traverses a node to check if it represents a constant index expression.
32 // Definition:
33 // constant-index-expressions are a superset of constant-expressions.
34 // Constant-index-expressions can include loop indices as defined in
35 // GLSL ES 1.0 spec, Appendix A, section 4.
36 // The following are constant-index-expressions:
37 // - Constant expressions
38 // - Loop indices as defined in section 4
39 // - Expressions composed of both of the above
40 class ValidateConstIndexExpr : public TIntermTraverser
41 {
42 public:
ValidateConstIndexExpr(const std::vector<int> & loopSymbols)43 ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
44 : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
45 {}
46
47 // Returns true if the parsed node represents a constant index expression.
isValid() const48 bool isValid() const { return mValid; }
49
visitSymbol(TIntermSymbol * symbol)50 void visitSymbol(TIntermSymbol *symbol) override
51 {
52 // Only constants and loop indices are allowed in a
53 // constant index expression.
54 if (mValid)
55 {
56 bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
57 symbol->uniqueId().get()) != mLoopSymbolIds.end();
58 mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
59 }
60 }
61
62 private:
63 bool mValid;
64 const std::vector<int> mLoopSymbolIds;
65 };
66
67 // Traverses intermediate tree to ensure that the shader does not exceed the
68 // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
69 class ValidateLimitationsTraverser : public TLValueTrackingTraverser
70 {
71 public:
72 ValidateLimitationsTraverser(sh::GLenum shaderType,
73 TSymbolTable *symbolTable,
74 TDiagnostics *diagnostics);
75
76 void visitSymbol(TIntermSymbol *node) override;
77 bool visitBinary(Visit, TIntermBinary *) override;
78 bool visitLoop(Visit, TIntermLoop *) override;
79
80 private:
81 void error(TSourceLoc loc, const char *reason, const char *token);
82 void error(TSourceLoc loc, const char *reason, const ImmutableString &token);
83
84 bool isLoopIndex(TIntermSymbol *symbol);
85 bool validateLoopType(TIntermLoop *node);
86
87 bool validateForLoopHeader(TIntermLoop *node);
88 // If valid, return the index symbol id; Otherwise, return -1.
89 int validateForLoopInit(TIntermLoop *node);
90 bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
91 bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
92
93 // Returns true if indexing does not exceed the minimum functionality
94 // mandated in GLSL 1.0 spec, Appendix A, Section 5.
95 bool isConstExpr(TIntermNode *node);
96 bool isConstIndexExpr(TIntermNode *node);
97 bool validateIndexing(TIntermBinary *node);
98
99 sh::GLenum mShaderType;
100 TDiagnostics *mDiagnostics;
101 std::vector<int> mLoopSymbolIds;
102 };
103
ValidateLimitationsTraverser(sh::GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)104 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
105 TSymbolTable *symbolTable,
106 TDiagnostics *diagnostics)
107 : TLValueTrackingTraverser(true, false, false, symbolTable),
108 mShaderType(shaderType),
109 mDiagnostics(diagnostics)
110 {
111 ASSERT(diagnostics);
112 }
113
visitSymbol(TIntermSymbol * node)114 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
115 {
116 if (isLoopIndex(node) && isLValueRequiredHere())
117 {
118 error(node->getLine(),
119 "Loop index cannot be statically assigned to within the body of the loop",
120 node->getName());
121 }
122 }
123
visitBinary(Visit,TIntermBinary * node)124 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
125 {
126 // Check indexing.
127 switch (node->getOp())
128 {
129 case EOpIndexDirect:
130 case EOpIndexIndirect:
131 validateIndexing(node);
132 break;
133 default:
134 break;
135 }
136 return true;
137 }
138
visitLoop(Visit,TIntermLoop * node)139 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
140 {
141 if (!validateLoopType(node))
142 return false;
143
144 if (!validateForLoopHeader(node))
145 return false;
146
147 TIntermNode *body = node->getBody();
148 if (body != nullptr)
149 {
150 mLoopSymbolIds.push_back(GetLoopSymbolId(node));
151 body->traverse(this);
152 mLoopSymbolIds.pop_back();
153 }
154
155 // The loop is fully processed - no need to visit children.
156 return false;
157 }
158
error(TSourceLoc loc,const char * reason,const char * token)159 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
160 {
161 mDiagnostics->error(loc, reason, token);
162 }
163
error(TSourceLoc loc,const char * reason,const ImmutableString & token)164 void ValidateLimitationsTraverser::error(TSourceLoc loc,
165 const char *reason,
166 const ImmutableString &token)
167 {
168 error(loc, reason, token.data());
169 }
170
isLoopIndex(TIntermSymbol * symbol)171 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
172 {
173 return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) !=
174 mLoopSymbolIds.end();
175 }
176
validateLoopType(TIntermLoop * node)177 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
178 {
179 TLoopType type = node->getType();
180 if (type == ELoopFor)
181 return true;
182
183 // Reject while and do-while loops.
184 error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
185 return false;
186 }
187
validateForLoopHeader(TIntermLoop * node)188 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
189 {
190 ASSERT(node->getType() == ELoopFor);
191
192 //
193 // The for statement has the form:
194 // for ( init-declaration ; condition ; expression ) statement
195 //
196 int indexSymbolId = validateForLoopInit(node);
197 if (indexSymbolId < 0)
198 return false;
199 if (!validateForLoopCond(node, indexSymbolId))
200 return false;
201 if (!validateForLoopExpr(node, indexSymbolId))
202 return false;
203
204 return true;
205 }
206
validateForLoopInit(TIntermLoop * node)207 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
208 {
209 TIntermNode *init = node->getInit();
210 if (init == nullptr)
211 {
212 error(node->getLine(), "Missing init declaration", "for");
213 return -1;
214 }
215
216 //
217 // init-declaration has the form:
218 // type-specifier identifier = constant-expression
219 //
220 TIntermDeclaration *decl = init->getAsDeclarationNode();
221 if (decl == nullptr)
222 {
223 error(init->getLine(), "Invalid init declaration", "for");
224 return -1;
225 }
226 // To keep things simple do not allow declaration list.
227 TIntermSequence *declSeq = decl->getSequence();
228 if (declSeq->size() != 1)
229 {
230 error(decl->getLine(), "Invalid init declaration", "for");
231 return -1;
232 }
233 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
234 if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
235 {
236 error(decl->getLine(), "Invalid init declaration", "for");
237 return -1;
238 }
239 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
240 if (symbol == nullptr)
241 {
242 error(declInit->getLine(), "Invalid init declaration", "for");
243 return -1;
244 }
245 // The loop index has type int or float.
246 TBasicType type = symbol->getBasicType();
247 if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
248 {
249 error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
250 return -1;
251 }
252 // The loop index is initialized with constant expression.
253 if (!isConstExpr(declInit->getRight()))
254 {
255 error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
256 symbol->getName());
257 return -1;
258 }
259
260 return symbol->uniqueId().get();
261 }
262
validateForLoopCond(TIntermLoop * node,int indexSymbolId)263 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
264 {
265 TIntermNode *cond = node->getCondition();
266 if (cond == nullptr)
267 {
268 error(node->getLine(), "Missing condition", "for");
269 return false;
270 }
271 //
272 // condition has the form:
273 // loop_index relational_operator constant_expression
274 //
275 TIntermBinary *binOp = cond->getAsBinaryNode();
276 if (binOp == nullptr)
277 {
278 error(node->getLine(), "Invalid condition", "for");
279 return false;
280 }
281 // Loop index should be to the left of relational operator.
282 TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
283 if (symbol == nullptr)
284 {
285 error(binOp->getLine(), "Invalid condition", "for");
286 return false;
287 }
288 if (symbol->uniqueId().get() != indexSymbolId)
289 {
290 error(symbol->getLine(), "Expected loop index", symbol->getName());
291 return false;
292 }
293 // Relational operator is one of: > >= < <= == or !=.
294 switch (binOp->getOp())
295 {
296 case EOpEqual:
297 case EOpNotEqual:
298 case EOpLessThan:
299 case EOpGreaterThan:
300 case EOpLessThanEqual:
301 case EOpGreaterThanEqual:
302 break;
303 default:
304 error(binOp->getLine(), "Invalid relational operator",
305 GetOperatorString(binOp->getOp()));
306 break;
307 }
308 // Loop index must be compared with a constant.
309 if (!isConstExpr(binOp->getRight()))
310 {
311 error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
312 symbol->getName());
313 return false;
314 }
315
316 return true;
317 }
318
validateForLoopExpr(TIntermLoop * node,int indexSymbolId)319 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
320 {
321 TIntermNode *expr = node->getExpression();
322 if (expr == nullptr)
323 {
324 error(node->getLine(), "Missing expression", "for");
325 return false;
326 }
327
328 // for expression has one of the following forms:
329 // loop_index++
330 // loop_index--
331 // loop_index += constant_expression
332 // loop_index -= constant_expression
333 // ++loop_index
334 // --loop_index
335 // The last two forms are not specified in the spec, but I am assuming
336 // its an oversight.
337 TIntermUnary *unOp = expr->getAsUnaryNode();
338 TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
339
340 TOperator op = EOpNull;
341 TIntermSymbol *symbol = nullptr;
342 if (unOp != nullptr)
343 {
344 op = unOp->getOp();
345 symbol = unOp->getOperand()->getAsSymbolNode();
346 }
347 else if (binOp != nullptr)
348 {
349 op = binOp->getOp();
350 symbol = binOp->getLeft()->getAsSymbolNode();
351 }
352
353 // The operand must be loop index.
354 if (symbol == nullptr)
355 {
356 error(expr->getLine(), "Invalid expression", "for");
357 return false;
358 }
359 if (symbol->uniqueId().get() != indexSymbolId)
360 {
361 error(symbol->getLine(), "Expected loop index", symbol->getName());
362 return false;
363 }
364
365 // The operator is one of: ++ -- += -=.
366 switch (op)
367 {
368 case EOpPostIncrement:
369 case EOpPostDecrement:
370 case EOpPreIncrement:
371 case EOpPreDecrement:
372 ASSERT((unOp != nullptr) && (binOp == nullptr));
373 break;
374 case EOpAddAssign:
375 case EOpSubAssign:
376 ASSERT((unOp == nullptr) && (binOp != nullptr));
377 break;
378 default:
379 error(expr->getLine(), "Invalid operator", GetOperatorString(op));
380 return false;
381 }
382
383 // Loop index must be incremented/decremented with a constant.
384 if (binOp != nullptr)
385 {
386 if (!isConstExpr(binOp->getRight()))
387 {
388 error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
389 symbol->getName());
390 return false;
391 }
392 }
393
394 return true;
395 }
396
isConstExpr(TIntermNode * node)397 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
398 {
399 ASSERT(node != nullptr);
400 return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
401 }
402
isConstIndexExpr(TIntermNode * node)403 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
404 {
405 ASSERT(node != nullptr);
406
407 ValidateConstIndexExpr validate(mLoopSymbolIds);
408 node->traverse(&validate);
409 return validate.isValid();
410 }
411
validateIndexing(TIntermBinary * node)412 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
413 {
414 ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
415
416 bool valid = true;
417 TIntermTyped *index = node->getRight();
418 // The index expession must be a constant-index-expression unless
419 // the operand is a uniform in a vertex shader.
420 TIntermTyped *operand = node->getLeft();
421 bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
422 if (!skip && !isConstIndexExpr(index))
423 {
424 error(index->getLine(), "Index expression must be constant", "[]");
425 valid = false;
426 }
427 return valid;
428 }
429
430 } // namespace
431
ValidateLimitations(TIntermNode * root,GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)432 bool ValidateLimitations(TIntermNode *root,
433 GLenum shaderType,
434 TSymbolTable *symbolTable,
435 TDiagnostics *diagnostics)
436 {
437 ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics);
438 root->traverse(&validate);
439 return diagnostics->numErrors() == 0;
440 }
441
442 } // namespace sh
443