• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/ValidateLimitations.h"
8 
9 #include "angle_gl.h"
10 #include "compiler/translator/Diagnostics.h"
11 #include "compiler/translator/ParseContext.h"
12 #include "compiler/translator/tree_util/IntermTraverse.h"
13 
14 namespace sh
15 {
16 
17 namespace
18 {
19 
GetLoopSymbolId(TIntermLoop * loop)20 int GetLoopSymbolId(TIntermLoop *loop)
21 {
22     // Here we assume all the operations are valid, because the loop node is
23     // already validated before this call.
24     TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
25     TIntermBinary *declInit  = (*declSeq)[0]->getAsBinaryNode();
26     TIntermSymbol *symbol    = declInit->getLeft()->getAsSymbolNode();
27 
28     return symbol->uniqueId().get();
29 }
30 
31 // Traverses a node to check if it represents a constant index expression.
32 // Definition:
33 // constant-index-expressions are a superset of constant-expressions.
34 // Constant-index-expressions can include loop indices as defined in
35 // GLSL ES 1.0 spec, Appendix A, section 4.
36 // The following are constant-index-expressions:
37 // - Constant expressions
38 // - Loop indices as defined in section 4
39 // - Expressions composed of both of the above
40 class ValidateConstIndexExpr : public TIntermTraverser
41 {
42   public:
ValidateConstIndexExpr(const std::vector<int> & loopSymbols)43     ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
44         : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
45     {}
46 
47     // Returns true if the parsed node represents a constant index expression.
isValid() const48     bool isValid() const { return mValid; }
49 
visitSymbol(TIntermSymbol * symbol)50     void visitSymbol(TIntermSymbol *symbol) override
51     {
52         // Only constants and loop indices are allowed in a
53         // constant index expression.
54         if (mValid)
55         {
56             bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
57                                           symbol->uniqueId().get()) != mLoopSymbolIds.end();
58             mValid            = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
59         }
60     }
61 
62   private:
63     bool mValid;
64     const std::vector<int> mLoopSymbolIds;
65 };
66 
67 // Traverses intermediate tree to ensure that the shader does not exceed the
68 // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
69 class ValidateLimitationsTraverser : public TLValueTrackingTraverser
70 {
71   public:
72     ValidateLimitationsTraverser(sh::GLenum shaderType,
73                                  TSymbolTable *symbolTable,
74                                  TDiagnostics *diagnostics);
75 
76     void visitSymbol(TIntermSymbol *node) override;
77     bool visitBinary(Visit, TIntermBinary *) override;
78     bool visitLoop(Visit, TIntermLoop *) override;
79 
80   private:
81     void error(TSourceLoc loc, const char *reason, const char *token);
82     void error(TSourceLoc loc, const char *reason, const ImmutableString &token);
83 
84     bool isLoopIndex(TIntermSymbol *symbol);
85     bool validateLoopType(TIntermLoop *node);
86 
87     bool validateForLoopHeader(TIntermLoop *node);
88     // If valid, return the index symbol id; Otherwise, return -1.
89     int validateForLoopInit(TIntermLoop *node);
90     bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
91     bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
92 
93     // Returns true if indexing does not exceed the minimum functionality
94     // mandated in GLSL 1.0 spec, Appendix A, Section 5.
95     bool isConstExpr(TIntermNode *node);
96     bool isConstIndexExpr(TIntermNode *node);
97     bool validateIndexing(TIntermBinary *node);
98 
99     sh::GLenum mShaderType;
100     TDiagnostics *mDiagnostics;
101     std::vector<int> mLoopSymbolIds;
102 };
103 
ValidateLimitationsTraverser(sh::GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)104 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
105                                                            TSymbolTable *symbolTable,
106                                                            TDiagnostics *diagnostics)
107     : TLValueTrackingTraverser(true, false, false, symbolTable),
108       mShaderType(shaderType),
109       mDiagnostics(diagnostics)
110 {
111     ASSERT(diagnostics);
112 }
113 
visitSymbol(TIntermSymbol * node)114 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
115 {
116     if (isLoopIndex(node) && isLValueRequiredHere())
117     {
118         error(node->getLine(),
119               "Loop index cannot be statically assigned to within the body of the loop",
120               node->getName());
121     }
122 }
123 
visitBinary(Visit,TIntermBinary * node)124 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
125 {
126     // Check indexing.
127     switch (node->getOp())
128     {
129         case EOpIndexDirect:
130         case EOpIndexIndirect:
131             validateIndexing(node);
132             break;
133         default:
134             break;
135     }
136     return true;
137 }
138 
visitLoop(Visit,TIntermLoop * node)139 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
140 {
141     if (!validateLoopType(node))
142         return false;
143 
144     if (!validateForLoopHeader(node))
145         return false;
146 
147     TIntermNode *body = node->getBody();
148     if (body != nullptr)
149     {
150         mLoopSymbolIds.push_back(GetLoopSymbolId(node));
151         body->traverse(this);
152         mLoopSymbolIds.pop_back();
153     }
154 
155     // The loop is fully processed - no need to visit children.
156     return false;
157 }
158 
error(TSourceLoc loc,const char * reason,const char * token)159 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
160 {
161     mDiagnostics->error(loc, reason, token);
162 }
163 
error(TSourceLoc loc,const char * reason,const ImmutableString & token)164 void ValidateLimitationsTraverser::error(TSourceLoc loc,
165                                          const char *reason,
166                                          const ImmutableString &token)
167 {
168     error(loc, reason, token.data());
169 }
170 
isLoopIndex(TIntermSymbol * symbol)171 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
172 {
173     return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) !=
174            mLoopSymbolIds.end();
175 }
176 
validateLoopType(TIntermLoop * node)177 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
178 {
179     TLoopType type = node->getType();
180     if (type == ELoopFor)
181         return true;
182 
183     // Reject while and do-while loops.
184     error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
185     return false;
186 }
187 
validateForLoopHeader(TIntermLoop * node)188 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
189 {
190     ASSERT(node->getType() == ELoopFor);
191 
192     //
193     // The for statement has the form:
194     //    for ( init-declaration ; condition ; expression ) statement
195     //
196     int indexSymbolId = validateForLoopInit(node);
197     if (indexSymbolId < 0)
198         return false;
199     if (!validateForLoopCond(node, indexSymbolId))
200         return false;
201     if (!validateForLoopExpr(node, indexSymbolId))
202         return false;
203 
204     return true;
205 }
206 
validateForLoopInit(TIntermLoop * node)207 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
208 {
209     TIntermNode *init = node->getInit();
210     if (init == nullptr)
211     {
212         error(node->getLine(), "Missing init declaration", "for");
213         return -1;
214     }
215 
216     //
217     // init-declaration has the form:
218     //     type-specifier identifier = constant-expression
219     //
220     TIntermDeclaration *decl = init->getAsDeclarationNode();
221     if (decl == nullptr)
222     {
223         error(init->getLine(), "Invalid init declaration", "for");
224         return -1;
225     }
226     // To keep things simple do not allow declaration list.
227     TIntermSequence *declSeq = decl->getSequence();
228     if (declSeq->size() != 1)
229     {
230         error(decl->getLine(), "Invalid init declaration", "for");
231         return -1;
232     }
233     TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
234     if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
235     {
236         error(decl->getLine(), "Invalid init declaration", "for");
237         return -1;
238     }
239     TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
240     if (symbol == nullptr)
241     {
242         error(declInit->getLine(), "Invalid init declaration", "for");
243         return -1;
244     }
245     // The loop index has type int or float.
246     TBasicType type = symbol->getBasicType();
247     if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
248     {
249         error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
250         return -1;
251     }
252     // The loop index is initialized with constant expression.
253     if (!isConstExpr(declInit->getRight()))
254     {
255         error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
256               symbol->getName());
257         return -1;
258     }
259 
260     return symbol->uniqueId().get();
261 }
262 
validateForLoopCond(TIntermLoop * node,int indexSymbolId)263 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
264 {
265     TIntermNode *cond = node->getCondition();
266     if (cond == nullptr)
267     {
268         error(node->getLine(), "Missing condition", "for");
269         return false;
270     }
271     //
272     // condition has the form:
273     //     loop_index relational_operator constant_expression
274     //
275     TIntermBinary *binOp = cond->getAsBinaryNode();
276     if (binOp == nullptr)
277     {
278         error(node->getLine(), "Invalid condition", "for");
279         return false;
280     }
281     // Loop index should be to the left of relational operator.
282     TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
283     if (symbol == nullptr)
284     {
285         error(binOp->getLine(), "Invalid condition", "for");
286         return false;
287     }
288     if (symbol->uniqueId().get() != indexSymbolId)
289     {
290         error(symbol->getLine(), "Expected loop index", symbol->getName());
291         return false;
292     }
293     // Relational operator is one of: > >= < <= == or !=.
294     switch (binOp->getOp())
295     {
296         case EOpEqual:
297         case EOpNotEqual:
298         case EOpLessThan:
299         case EOpGreaterThan:
300         case EOpLessThanEqual:
301         case EOpGreaterThanEqual:
302             break;
303         default:
304             error(binOp->getLine(), "Invalid relational operator",
305                   GetOperatorString(binOp->getOp()));
306             break;
307     }
308     // Loop index must be compared with a constant.
309     if (!isConstExpr(binOp->getRight()))
310     {
311         error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
312               symbol->getName());
313         return false;
314     }
315 
316     return true;
317 }
318 
validateForLoopExpr(TIntermLoop * node,int indexSymbolId)319 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
320 {
321     TIntermNode *expr = node->getExpression();
322     if (expr == nullptr)
323     {
324         error(node->getLine(), "Missing expression", "for");
325         return false;
326     }
327 
328     // for expression has one of the following forms:
329     //     loop_index++
330     //     loop_index--
331     //     loop_index += constant_expression
332     //     loop_index -= constant_expression
333     //     ++loop_index
334     //     --loop_index
335     // The last two forms are not specified in the spec, but I am assuming
336     // its an oversight.
337     TIntermUnary *unOp   = expr->getAsUnaryNode();
338     TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
339 
340     TOperator op          = EOpNull;
341     TIntermSymbol *symbol = nullptr;
342     if (unOp != nullptr)
343     {
344         op     = unOp->getOp();
345         symbol = unOp->getOperand()->getAsSymbolNode();
346     }
347     else if (binOp != nullptr)
348     {
349         op     = binOp->getOp();
350         symbol = binOp->getLeft()->getAsSymbolNode();
351     }
352 
353     // The operand must be loop index.
354     if (symbol == nullptr)
355     {
356         error(expr->getLine(), "Invalid expression", "for");
357         return false;
358     }
359     if (symbol->uniqueId().get() != indexSymbolId)
360     {
361         error(symbol->getLine(), "Expected loop index", symbol->getName());
362         return false;
363     }
364 
365     // The operator is one of: ++ -- += -=.
366     switch (op)
367     {
368         case EOpPostIncrement:
369         case EOpPostDecrement:
370         case EOpPreIncrement:
371         case EOpPreDecrement:
372             ASSERT((unOp != nullptr) && (binOp == nullptr));
373             break;
374         case EOpAddAssign:
375         case EOpSubAssign:
376             ASSERT((unOp == nullptr) && (binOp != nullptr));
377             break;
378         default:
379             error(expr->getLine(), "Invalid operator", GetOperatorString(op));
380             return false;
381     }
382 
383     // Loop index must be incremented/decremented with a constant.
384     if (binOp != nullptr)
385     {
386         if (!isConstExpr(binOp->getRight()))
387         {
388             error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
389                   symbol->getName());
390             return false;
391         }
392     }
393 
394     return true;
395 }
396 
isConstExpr(TIntermNode * node)397 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
398 {
399     ASSERT(node != nullptr);
400     return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
401 }
402 
isConstIndexExpr(TIntermNode * node)403 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
404 {
405     ASSERT(node != nullptr);
406 
407     ValidateConstIndexExpr validate(mLoopSymbolIds);
408     node->traverse(&validate);
409     return validate.isValid();
410 }
411 
validateIndexing(TIntermBinary * node)412 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
413 {
414     ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
415 
416     bool valid          = true;
417     TIntermTyped *index = node->getRight();
418     // The index expession must be a constant-index-expression unless
419     // the operand is a uniform in a vertex shader.
420     TIntermTyped *operand = node->getLeft();
421     bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
422     if (!skip && !isConstIndexExpr(index))
423     {
424         error(index->getLine(), "Index expression must be constant", "[]");
425         valid = false;
426     }
427     return valid;
428 }
429 
430 }  // namespace
431 
ValidateLimitations(TIntermNode * root,GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)432 bool ValidateLimitations(TIntermNode *root,
433                          GLenum shaderType,
434                          TSymbolTable *symbolTable,
435                          TDiagnostics *diagnostics)
436 {
437     ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics);
438     root->traverse(&validate);
439     return diagnostics->numErrors() == 0;
440 }
441 
442 }  // namespace sh
443