1 //
2 // Copyright (C) 2015-2016 Google, Inc.
3 //
4 // All rights reserved.
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions
8 // are met:
9 //
10 // Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //
13 // Redistributions in binary form must reproduce the above
14 // copyright notice, this list of conditions and the following
15 // disclaimer in the documentation and/or other materials provided
16 // with the distribution.
17 //
18 // Neither the name of Google Inc. nor the names of its
19 // contributors may be used to endorse or promote products derived
20 // from this software without specific prior written permission.
21 //
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33 // POSSIBILITY OF SUCH DAMAGE.
34
35 //
36 // Visit the nodes in the glslang intermediate tree representation to
37 // propagate the 'noContraction' qualifier.
38 //
39
40 #include "propagateNoContraction.h"
41
42 #include <cstdlib>
43 #include <string>
44 #include <tuple>
45 #include <unordered_map>
46 #include <unordered_set>
47
48 #include "localintermediate.h"
49 namespace {
50
51 // Use a string to hold the access chain information, as in most cases the
52 // access chain is short and may contain only one element, which is the symbol
53 // ID.
54 // Example: struct {float a; float b;} s;
55 // Object s.a will be represented with: <symbol ID of s>/0
56 // Object s.b will be represented with: <symbol ID of s>/1
57 // Object s will be represented with: <symbol ID of s>
58 // For members of vector, matrix and arrays, they will be represented with the
59 // same symbol ID of their container symbol objects. This is because their
60 // preciseness is always the same as their container symbol objects.
61 typedef std::string ObjectAccessChain;
62
63 // The delimiter used in the ObjectAccessChain string to separate symbol ID and
64 // different level of struct indices.
65 const char ObjectAccesschainDelimiter = '/';
66
67 // Mapping from Symbol IDs of symbol nodes, to their defining operation
68 // nodes.
69 typedef std::unordered_multimap<ObjectAccessChain, glslang::TIntermOperator*> NodeMapping;
70 // Mapping from object nodes to their access chain info string.
71 typedef std::unordered_map<glslang::TIntermTyped*, ObjectAccessChain> AccessChainMapping;
72
73 // Set of object IDs.
74 typedef std::unordered_set<ObjectAccessChain> ObjectAccesschainSet;
75 // Set of return branch nodes.
76 typedef std::unordered_set<glslang::TIntermBranch*> ReturnBranchNodeSet;
77
78 // A helper function to tell whether a node is 'noContraction'. Returns true if
79 // the node has 'noContraction' qualifier, otherwise false.
isPreciseObjectNode(glslang::TIntermTyped * node)80 bool isPreciseObjectNode(glslang::TIntermTyped* node)
81 {
82 return node->getType().getQualifier().noContraction;
83 }
84
85 // Returns true if the opcode is a dereferencing one.
isDereferenceOperation(glslang::TOperator op)86 bool isDereferenceOperation(glslang::TOperator op)
87 {
88 switch (op) {
89 case glslang::EOpIndexDirect:
90 case glslang::EOpIndexDirectStruct:
91 case glslang::EOpIndexIndirect:
92 case glslang::EOpVectorSwizzle:
93 case glslang::EOpMatrixSwizzle:
94 return true;
95 default:
96 return false;
97 }
98 }
99
100 // Returns true if the opcode leads to an assignment operation.
isAssignOperation(glslang::TOperator op)101 bool isAssignOperation(glslang::TOperator op)
102 {
103 switch (op) {
104 case glslang::EOpAssign:
105 case glslang::EOpAddAssign:
106 case glslang::EOpSubAssign:
107 case glslang::EOpMulAssign:
108 case glslang::EOpVectorTimesMatrixAssign:
109 case glslang::EOpVectorTimesScalarAssign:
110 case glslang::EOpMatrixTimesScalarAssign:
111 case glslang::EOpMatrixTimesMatrixAssign:
112 case glslang::EOpDivAssign:
113 case glslang::EOpModAssign:
114 case glslang::EOpAndAssign:
115 case glslang::EOpLeftShiftAssign:
116 case glslang::EOpRightShiftAssign:
117 case glslang::EOpInclusiveOrAssign:
118 case glslang::EOpExclusiveOrAssign:
119
120 case glslang::EOpPostIncrement:
121 case glslang::EOpPostDecrement:
122 case glslang::EOpPreIncrement:
123 case glslang::EOpPreDecrement:
124 return true;
125 default:
126 return false;
127 }
128 }
129
130 // A helper function to get the unsigned int from a given constant union node.
131 // Note the node should only hold a uint scalar.
getStructIndexFromConstantUnion(glslang::TIntermTyped * node)132 unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped* node)
133 {
134 assert(node->getAsConstantUnion() && node->getAsConstantUnion()->isScalar());
135 unsigned struct_dereference_index = node->getAsConstantUnion()->getConstArray()[0].getUConst();
136 return struct_dereference_index;
137 }
138
139 // A helper function to generate symbol_label.
generateSymbolLabel(glslang::TIntermSymbol * node)140 ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol* node)
141 {
142 ObjectAccessChain symbol_id =
143 std::to_string(node->getId()) + "(" + node->getName().c_str() + ")";
144 return symbol_id;
145 }
146
147 // Returns true if the operation is an arithmetic operation and valid for
148 // the 'NoContraction' decoration.
isArithmeticOperation(glslang::TOperator op)149 bool isArithmeticOperation(glslang::TOperator op)
150 {
151 switch (op) {
152 case glslang::EOpAddAssign:
153 case glslang::EOpSubAssign:
154 case glslang::EOpMulAssign:
155 case glslang::EOpVectorTimesMatrixAssign:
156 case glslang::EOpVectorTimesScalarAssign:
157 case glslang::EOpMatrixTimesScalarAssign:
158 case glslang::EOpMatrixTimesMatrixAssign:
159 case glslang::EOpDivAssign:
160 case glslang::EOpModAssign:
161
162 case glslang::EOpNegative:
163
164 case glslang::EOpAdd:
165 case glslang::EOpSub:
166 case glslang::EOpMul:
167 case glslang::EOpDiv:
168 case glslang::EOpMod:
169
170 case glslang::EOpVectorTimesScalar:
171 case glslang::EOpVectorTimesMatrix:
172 case glslang::EOpMatrixTimesVector:
173 case glslang::EOpMatrixTimesScalar:
174 case glslang::EOpMatrixTimesMatrix:
175
176 case glslang::EOpDot:
177
178 case glslang::EOpPostIncrement:
179 case glslang::EOpPostDecrement:
180 case glslang::EOpPreIncrement:
181 case glslang::EOpPreDecrement:
182 return true;
183 default:
184 return false;
185 }
186 }
187
188 // A helper class to help manage the populating_initial_no_contraction_ flag.
189 template <typename T> class StateSettingGuard {
190 public:
StateSettingGuard(T * state_ptr,T new_state_value)191 StateSettingGuard(T* state_ptr, T new_state_value)
192 : state_ptr_(state_ptr), previous_state_(*state_ptr)
193 {
194 *state_ptr = new_state_value;
195 }
StateSettingGuard(T * state_ptr)196 StateSettingGuard(T* state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {}
setState(T new_state_value)197 void setState(T new_state_value) { *state_ptr_ = new_state_value; }
~StateSettingGuard()198 ~StateSettingGuard() { *state_ptr_ = previous_state_; }
199
200 private:
201 T* state_ptr_;
202 T previous_state_;
203 };
204
205 // A helper function to get the front element from a given ObjectAccessChain
getFrontElement(const ObjectAccessChain & chain)206 ObjectAccessChain getFrontElement(const ObjectAccessChain& chain)
207 {
208 size_t pos_delimiter = chain.find(ObjectAccesschainDelimiter);
209 return pos_delimiter == std::string::npos ? chain : chain.substr(0, pos_delimiter);
210 }
211
212 // A helper function to get the access chain starting from the second element.
subAccessChainFromSecondElement(const ObjectAccessChain & chain)213 ObjectAccessChain subAccessChainFromSecondElement(const ObjectAccessChain& chain)
214 {
215 size_t pos_delimiter = chain.find(ObjectAccesschainDelimiter);
216 return pos_delimiter == std::string::npos ? "" : chain.substr(pos_delimiter + 1);
217 }
218
219 // A helper function to get the access chain after removing a given prefix.
getSubAccessChainAfterPrefix(const ObjectAccessChain & chain,const ObjectAccessChain & prefix)220 ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain& chain,
221 const ObjectAccessChain& prefix)
222 {
223 size_t pos = chain.find(prefix);
224 if (pos != 0)
225 return chain;
226 return chain.substr(prefix.length() + sizeof(ObjectAccesschainDelimiter));
227 }
228
229 //
230 // A traverser which traverses the whole AST and populates:
231 // 1) A mapping from symbol nodes' IDs to their defining operation nodes.
232 // 2) A set of access chains of the initial precise object nodes.
233 //
234 class TSymbolDefinitionCollectingTraverser : public glslang::TIntermTraverser {
235 public:
236 TSymbolDefinitionCollectingTraverser(NodeMapping* symbol_definition_mapping,
237 AccessChainMapping* accesschain_mapping,
238 ObjectAccesschainSet* precise_objects,
239 ReturnBranchNodeSet* precise_return_nodes);
240
241 bool visitUnary(glslang::TVisit, glslang::TIntermUnary*) override;
242 bool visitBinary(glslang::TVisit, glslang::TIntermBinary*) override;
243 void visitSymbol(glslang::TIntermSymbol*) override;
244 bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate*) override;
245 bool visitBranch(glslang::TVisit, glslang::TIntermBranch*) override;
246
247 protected:
248 TSymbolDefinitionCollectingTraverser& operator=(const TSymbolDefinitionCollectingTraverser&);
249
250 // The mapping from symbol node IDs to their defining nodes. This should be
251 // populated along traversing the AST.
252 NodeMapping& symbol_definition_mapping_;
253 // The set of symbol node IDs for precise symbol nodes, the ones marked as
254 // 'noContraction'.
255 ObjectAccesschainSet& precise_objects_;
256 // The set of precise return nodes.
257 ReturnBranchNodeSet& precise_return_nodes_;
258 // A temporary cache of the symbol node whose defining node is to be found
259 // currently along traversing the AST.
260 ObjectAccessChain current_object_;
261 // A map from object node to its access chain. This traverser stores
262 // the built access chains into this map for each object node it has
263 // visited.
264 AccessChainMapping& accesschain_mapping_;
265 // The pointer to the Function Definition node, so we can get the
266 // preciseness of the return expression from it when we traverse the
267 // return branch node.
268 glslang::TIntermAggregate* current_function_definition_node_;
269 };
270
TSymbolDefinitionCollectingTraverser(NodeMapping * symbol_definition_mapping,AccessChainMapping * accesschain_mapping,ObjectAccesschainSet * precise_objects,std::unordered_set<glslang::TIntermBranch * > * precise_return_nodes)271 TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser(
272 NodeMapping* symbol_definition_mapping, AccessChainMapping* accesschain_mapping,
273 ObjectAccesschainSet* precise_objects,
274 std::unordered_set<glslang::TIntermBranch*>* precise_return_nodes)
275 : TIntermTraverser(true, false, false), symbol_definition_mapping_(*symbol_definition_mapping),
276 precise_objects_(*precise_objects), precise_return_nodes_(*precise_return_nodes),
277 current_object_(), accesschain_mapping_(*accesschain_mapping),
278 current_function_definition_node_(nullptr) {}
279
280 // Visits a symbol node, set the current_object_ to the
281 // current node symbol ID, and record a mapping from this node to the current
282 // current_object_, which is the just obtained symbol
283 // ID.
visitSymbol(glslang::TIntermSymbol * node)284 void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol* node)
285 {
286 current_object_ = generateSymbolLabel(node);
287 accesschain_mapping_[node] = current_object_;
288 }
289
290 // Visits an aggregate node, traverses all of its children.
visitAggregate(glslang::TVisit,glslang::TIntermAggregate * node)291 bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit,
292 glslang::TIntermAggregate* node)
293 {
294 // This aggregate node might be a function definition node, in which case we need to
295 // cache this node, so we can get the preciseness information of the return value
296 // of this function later.
297 StateSettingGuard<glslang::TIntermAggregate*> current_function_definition_node_setting_guard(
298 ¤t_function_definition_node_);
299 if (node->getOp() == glslang::EOpFunction) {
300 // This is function definition node, we need to cache this node so that we can
301 // get the preciseness of the return value later.
302 current_function_definition_node_setting_guard.setState(node);
303 }
304 // Traverse the items in the sequence.
305 glslang::TIntermSequence& seq = node->getSequence();
306 for (int i = 0; i < (int)seq.size(); ++i) {
307 current_object_.clear();
308 seq[i]->traverse(this);
309 }
310 return false;
311 }
312
visitBranch(glslang::TVisit,glslang::TIntermBranch * node)313 bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit,
314 glslang::TIntermBranch* node)
315 {
316 if (node->getFlowOp() == glslang::EOpReturn && node->getExpression() &&
317 current_function_definition_node_ &&
318 current_function_definition_node_->getType().getQualifier().noContraction) {
319 // This node is a return node with an expression, and its function has a
320 // precise return value. We need to find the involved objects in its
321 // expression and add them to the set of initial precise objects.
322 precise_return_nodes_.insert(node);
323 node->getExpression()->traverse(this);
324 }
325 return false;
326 }
327
328 // Visits a unary node. This might be an implicit assignment like i++, i--. etc.
visitUnary(glslang::TVisit,glslang::TIntermUnary * node)329 bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit */,
330 glslang::TIntermUnary* node)
331 {
332 current_object_.clear();
333 node->getOperand()->traverse(this);
334 if (isAssignOperation(node->getOp())) {
335 // We should always be able to get an access chain of the operand node.
336 assert(!current_object_.empty());
337
338 // If the operand node object is 'precise', we collect its access chain
339 // for the initial set of 'precise' objects.
340 if (isPreciseObjectNode(node->getOperand())) {
341 // The operand node is an 'precise' object node, add its
342 // access chain to the set of 'precise' objects. This is to collect
343 // the initial set of 'precise' objects.
344 precise_objects_.insert(current_object_);
345 }
346 // Gets the symbol ID from the object's access chain.
347 ObjectAccessChain id_symbol = getFrontElement(current_object_);
348 // Add a mapping from the symbol ID to this assignment operation node.
349 symbol_definition_mapping_.insert(std::make_pair(id_symbol, node));
350 }
351 // A unary node is not a dereference node, so we clear the access chain which
352 // is under construction.
353 current_object_.clear();
354 return false;
355 }
356
357 // Visits a binary node and updates the mapping from symbol IDs to the definition
358 // nodes. Also collects the access chains for the initial precise objects.
visitBinary(glslang::TVisit,glslang::TIntermBinary * node)359 bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit */,
360 glslang::TIntermBinary* node)
361 {
362 // Traverses the left node to build the access chain info for the object.
363 current_object_.clear();
364 node->getLeft()->traverse(this);
365
366 if (isAssignOperation(node->getOp())) {
367 // We should always be able to get an access chain for the left node.
368 assert(!current_object_.empty());
369
370 // If the left node object is 'precise', it is an initial precise object
371 // specified in the shader source. Adds it to the initial work list to
372 // process later.
373 if (isPreciseObjectNode(node->getLeft())) {
374 // The left node is an 'precise' object node, add its access chain to
375 // the set of 'precise' objects. This is to collect the initial set
376 // of 'precise' objects.
377 precise_objects_.insert(current_object_);
378 }
379 // Gets the symbol ID from the object access chain, which should be the
380 // first element recorded in the access chain.
381 ObjectAccessChain id_symbol = getFrontElement(current_object_);
382 // Adds a mapping from the symbol ID to this assignment operation node.
383 symbol_definition_mapping_.insert(std::make_pair(id_symbol, node));
384
385 // Traverses the right node, there may be other 'assignment'
386 // operations in the right.
387 current_object_.clear();
388 node->getRight()->traverse(this);
389
390 } else if (isDereferenceOperation(node->getOp())) {
391 // The left node (parent node) is a struct type object. We need to
392 // record the access chain information of the current node into its
393 // object id.
394 if (node->getOp() == glslang::EOpIndexDirectStruct) {
395 unsigned struct_dereference_index = getStructIndexFromConstantUnion(node->getRight());
396 current_object_.push_back(ObjectAccesschainDelimiter);
397 current_object_.append(std::to_string(struct_dereference_index));
398 }
399 accesschain_mapping_[node] = current_object_;
400
401 // For a dereference node, there is no need to traverse the right child
402 // node as the right node should always be an integer type object.
403
404 } else {
405 // For other binary nodes, still traverse the right node.
406 current_object_.clear();
407 node->getRight()->traverse(this);
408 }
409 return false;
410 }
411
412 // Traverses the AST and returns a tuple of four members:
413 // 1) a mapping from symbol IDs to the definition nodes (aka. assignment nodes) of these symbols.
414 // 2) a mapping from object nodes in the AST to the access chains of these objects.
415 // 3) a set of access chains of precise objects.
416 // 4) a set of return nodes with precise expressions.
417 std::tuple<NodeMapping, AccessChainMapping, ObjectAccesschainSet, ReturnBranchNodeSet>
getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate & intermediate)418 getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate& intermediate)
419 {
420 auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(),
421 ReturnBranchNodeSet());
422
423 TIntermNode* root = intermediate.getTreeRoot();
424 if (root == 0)
425 return result_tuple;
426
427 NodeMapping& symbol_definition_mapping = std::get<0>(result_tuple);
428 AccessChainMapping& accesschain_mapping = std::get<1>(result_tuple);
429 ObjectAccesschainSet& precise_objects = std::get<2>(result_tuple);
430 ReturnBranchNodeSet& precise_return_nodes = std::get<3>(result_tuple);
431
432 // Traverses the AST and populate the results.
433 TSymbolDefinitionCollectingTraverser collector(&symbol_definition_mapping, &accesschain_mapping,
434 &precise_objects, &precise_return_nodes);
435 root->traverse(&collector);
436
437 return result_tuple;
438 }
439
440 //
441 // A traverser that determine whether the left node (or operand node for unary
442 // node) of an assignment node is 'precise', containing 'precise' or not,
443 // according to the access chain a given precise object which share the same
444 // symbol as the left node.
445 //
446 // Post-orderly traverses the left node subtree of an binary assignment node and:
447 //
448 // 1) Propagates the 'precise' from the left object nodes to this object node.
449 //
450 // 2) Builds object access chain along the traversal, and also compares with
451 // the access chain of the given 'precise' object along with the traversal to
452 // tell if the node to be defined is 'precise' or not.
453 //
454 class TNoContractionAssigneeCheckingTraverser : public glslang::TIntermTraverser {
455
456 enum DecisionStatus {
457 // The object node to be assigned to may contain 'precise' objects and also not 'precise' objects.
458 Mixed = 0,
459 // The object node to be assigned to is either a 'precise' object or a struct objects whose members are all 'precise'.
460 Precise = 1,
461 // The object node to be assigned to is not a 'precise' object.
462 NotPreicse = 2,
463 };
464
465 public:
TNoContractionAssigneeCheckingTraverser(const AccessChainMapping & accesschain_mapping)466 TNoContractionAssigneeCheckingTraverser(const AccessChainMapping& accesschain_mapping)
467 : TIntermTraverser(true, false, false), accesschain_mapping_(accesschain_mapping),
468 precise_object_(nullptr) {}
469
470 // Checks the preciseness of a given assignment node with a precise object
471 // represented as access chain. The precise object shares the same symbol
472 // with the assignee of the given assignment node. Return a tuple of two:
473 //
474 // 1) The preciseness of the assignee node of this assignment node. True
475 // if the assignee contains 'precise' objects or is 'precise', false if
476 // the assignee is not 'precise' according to the access chain of the given
477 // precise object.
478 //
479 // 2) The incremental access chain from the assignee node to its nested
480 // 'precise' object, according to the access chain of the given precise
481 // object. This incremental access chain can be empty, which means the
482 // assignee is 'precise'. Otherwise it shows the path to the nested
483 // precise object.
484 std::tuple<bool, ObjectAccessChain>
getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator * node,const ObjectAccessChain & precise_object)485 getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator* node,
486 const ObjectAccessChain& precise_object)
487 {
488 assert(isAssignOperation(node->getOp()));
489 precise_object_ = &precise_object;
490 ObjectAccessChain assignee_object;
491 if (glslang::TIntermBinary* BN = node->getAsBinaryNode()) {
492 // This is a binary assignment node, we need to check the
493 // preciseness of the left node.
494 assert(accesschain_mapping_.count(BN->getLeft()));
495 // The left node (assignee node) is an object node, traverse the
496 // node to let the 'precise' of nesting objects being transfered to
497 // nested objects.
498 BN->getLeft()->traverse(this);
499 // After traversing the left node, if the left node is 'precise',
500 // we can conclude this assignment should propagate 'precise'.
501 if (isPreciseObjectNode(BN->getLeft())) {
502 return make_tuple(true, ObjectAccessChain());
503 }
504 // If the preciseness of the left node (assignee node) can not
505 // be determined by now, we need to compare the access chain string
506 // of the assignee object with the given precise object.
507 assignee_object = accesschain_mapping_.at(BN->getLeft());
508
509 } else if (glslang::TIntermUnary* UN = node->getAsUnaryNode()) {
510 // This is a unary assignment node, we need to check the
511 // preciseness of the operand node. For unary assignment node, the
512 // operand node should always be an object node.
513 assert(accesschain_mapping_.count(UN->getOperand()));
514 // Traverse the operand node to let the 'precise' being propagated
515 // from lower nodes to upper nodes.
516 UN->getOperand()->traverse(this);
517 // After traversing the operand node, if the operand node is
518 // 'precise', this assignment should propagate 'precise'.
519 if (isPreciseObjectNode(UN->getOperand())) {
520 return make_tuple(true, ObjectAccessChain());
521 }
522 // If the preciseness of the operand node (assignee node) can not
523 // be determined by now, we need to compare the access chain string
524 // of the assignee object with the given precise object.
525 assignee_object = accesschain_mapping_.at(UN->getOperand());
526 } else {
527 // Not a binary or unary node, should not happen.
528 assert(false);
529 }
530
531 // Compare the access chain string of the assignee node with the given
532 // precise object to determine if this assignment should propagate
533 // 'precise'.
534 if (assignee_object.find(precise_object) == 0) {
535 // The access chain string of the given precise object is a prefix
536 // of assignee's access chain string. The assignee should be
537 // 'precise'.
538 return make_tuple(true, ObjectAccessChain());
539 } else if (precise_object.find(assignee_object) == 0) {
540 // The assignee's access chain string is a prefix of the given
541 // precise object, the assignee object contains 'precise' object,
542 // and we need to pass the remained access chain to the object nodes
543 // in the right.
544 return make_tuple(true, getSubAccessChainAfterPrefix(precise_object, assignee_object));
545 } else {
546 // The access chain strings do not match, the assignee object can
547 // not be labeled as 'precise' according to the given precise
548 // object.
549 return make_tuple(false, ObjectAccessChain());
550 }
551 }
552
553 protected:
554 TNoContractionAssigneeCheckingTraverser& operator=(const TNoContractionAssigneeCheckingTraverser&);
555
556 bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override;
557 void visitSymbol(glslang::TIntermSymbol* node) override;
558
559 // A map from object nodes to their access chain string (used as object ID).
560 const AccessChainMapping& accesschain_mapping_;
561 // A given precise object, represented in it access chain string. This
562 // precise object is used to be compared with the assignee node to tell if
563 // the assignee node is 'precise', contains 'precise' object or not
564 // 'precise'.
565 const ObjectAccessChain* precise_object_;
566 };
567
568 // Visits a binary node. If the node is an object node, it must be a dereference
569 // node. In such cases, if the left node is 'precise', this node should also be
570 // 'precise'.
visitBinary(glslang::TVisit,glslang::TIntermBinary * node)571 bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit,
572 glslang::TIntermBinary* node)
573 {
574 // Traverses the left so that we transfer the 'precise' from nesting object
575 // to its nested object.
576 node->getLeft()->traverse(this);
577 // If this binary node is an object node, we should have it in the
578 // accesschain_mapping_.
579 if (accesschain_mapping_.count(node)) {
580 // A binary object node must be a dereference node.
581 assert(isDereferenceOperation(node->getOp()));
582 // If the left node is 'precise', this node should also be precise,
583 // otherwise, compare with the given precise_object_. If the
584 // access chain of this node matches with the given precise_object_,
585 // this node should be marked as 'precise'.
586 if (isPreciseObjectNode(node->getLeft())) {
587 node->getWritableType().getQualifier().noContraction = true;
588 } else if (accesschain_mapping_.at(node) == *precise_object_) {
589 node->getWritableType().getQualifier().noContraction = true;
590 }
591 }
592 return false;
593 }
594
595 // Visits a symbol node, if the symbol node ID (its access chain string) matches
596 // with the given precise object, this node should be 'precise'.
visitSymbol(glslang::TIntermSymbol * node)597 void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol* node)
598 {
599 // A symbol node should always be an object node, and should have been added
600 // to the map from object nodes to their access chain strings.
601 assert(accesschain_mapping_.count(node));
602 if (accesschain_mapping_.at(node) == *precise_object_) {
603 node->getWritableType().getQualifier().noContraction = true;
604 }
605 }
606
607 //
608 // A traverser that only traverses the right side of binary assignment nodes
609 // and the operand node of unary assignment nodes.
610 //
611 // 1) Marks arithmetic operations as 'NoContraction'.
612 //
613 // 2) Find the object which should be marked as 'precise' in the right and
614 // update the 'precise' object work list.
615 //
616 class TNoContractionPropagator : public glslang::TIntermTraverser {
617 public:
TNoContractionPropagator(ObjectAccesschainSet * precise_objects,const AccessChainMapping & accesschain_mapping)618 TNoContractionPropagator(ObjectAccesschainSet* precise_objects,
619 const AccessChainMapping& accesschain_mapping)
620 : TIntermTraverser(true, false, false),
621 precise_objects_(*precise_objects), added_precise_object_ids_(),
622 remained_accesschain_(), accesschain_mapping_(accesschain_mapping) {}
623
624 // Propagates 'precise' in the right nodes of a given assignment node with
625 // access chain record from the assignee node to a 'precise' object it
626 // contains.
627 void
propagateNoContractionInOneExpression(glslang::TIntermTyped * defining_node,const ObjectAccessChain & assignee_remained_accesschain)628 propagateNoContractionInOneExpression(glslang::TIntermTyped* defining_node,
629 const ObjectAccessChain& assignee_remained_accesschain)
630 {
631 remained_accesschain_ = assignee_remained_accesschain;
632 if (glslang::TIntermBinary* BN = defining_node->getAsBinaryNode()) {
633 assert(isAssignOperation(BN->getOp()));
634 BN->getRight()->traverse(this);
635 if (isArithmeticOperation(BN->getOp())) {
636 BN->getWritableType().getQualifier().noContraction = true;
637 }
638 } else if (glslang::TIntermUnary* UN = defining_node->getAsUnaryNode()) {
639 assert(isAssignOperation(UN->getOp()));
640 UN->getOperand()->traverse(this);
641 if (isArithmeticOperation(UN->getOp())) {
642 UN->getWritableType().getQualifier().noContraction = true;
643 }
644 }
645 }
646
647 // Propagates 'precise' in a given precise return node.
propagateNoContractionInReturnNode(glslang::TIntermBranch * return_node)648 void propagateNoContractionInReturnNode(glslang::TIntermBranch* return_node)
649 {
650 remained_accesschain_ = "";
651 assert(return_node->getFlowOp() == glslang::EOpReturn && return_node->getExpression());
652 return_node->getExpression()->traverse(this);
653 }
654
655 protected:
656 TNoContractionPropagator& operator=(const TNoContractionPropagator&);
657
658 // Visits an aggregate node. The node can be a initializer list, in which
659 // case we need to find the 'precise' or 'precise' containing object node
660 // with the access chain record. In other cases, just need to traverse all
661 // the children nodes.
visitAggregate(glslang::TVisit,glslang::TIntermAggregate * node)662 bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) override
663 {
664 if (!remained_accesschain_.empty() && node->getOp() == glslang::EOpConstructStruct) {
665 // This is a struct initializer node, and the remained
666 // access chain is not empty, we need to refer to the
667 // assignee_remained_access_chain_ to find the nested
668 // 'precise' object. And we don't need to visit other nodes in this
669 // aggregate node.
670
671 // Gets the struct dereference index that leads to 'precise' object.
672 ObjectAccessChain precise_accesschain_index_str =
673 getFrontElement(remained_accesschain_);
674 unsigned precise_accesschain_index = (unsigned)strtoul(precise_accesschain_index_str.c_str(), nullptr, 10);
675 // Gets the node pointed by the access chain index extracted before.
676 glslang::TIntermTyped* potential_precise_node =
677 node->getSequence()[precise_accesschain_index]->getAsTyped();
678 assert(potential_precise_node);
679 // Pop the front access chain index from the path, and visit the nested node.
680 {
681 ObjectAccessChain next_level_accesschain =
682 subAccessChainFromSecondElement(remained_accesschain_);
683 StateSettingGuard<ObjectAccessChain> setup_remained_accesschain_for_next_level(
684 &remained_accesschain_, next_level_accesschain);
685 potential_precise_node->traverse(this);
686 }
687 return false;
688 }
689 return true;
690 }
691
692 // Visits a binary node. A binary node can be an object node, e.g. a dereference node.
693 // As only the top object nodes in the right side of an assignment needs to be visited
694 // and added to 'precise' work list, this traverser won't visit the children nodes of
695 // an object node. If the binary node does not represent an object node, it should
696 // go on to traverse its children nodes and if it is an arithmetic operation node, this
697 // operation should be marked as 'noContraction'.
visitBinary(glslang::TVisit,glslang::TIntermBinary * node)698 bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override
699 {
700 if (isDereferenceOperation(node->getOp())) {
701 // This binary node is an object node. Need to update the precise
702 // object set with the access chain of this node + remained
703 // access chain .
704 ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node);
705 if (remained_accesschain_.empty()) {
706 node->getWritableType().getQualifier().noContraction = true;
707 } else {
708 new_precise_accesschain += ObjectAccesschainDelimiter + remained_accesschain_;
709 }
710 // Cache the access chain as added precise object, so we won't add the
711 // same object to the work list again.
712 if (!added_precise_object_ids_.count(new_precise_accesschain)) {
713 precise_objects_.insert(new_precise_accesschain);
714 added_precise_object_ids_.insert(new_precise_accesschain);
715 }
716 // Only the upper-most object nodes should be visited, so do not
717 // visit children of this object node.
718 return false;
719 }
720 // If this is an arithmetic operation, marks this node as 'noContraction'.
721 if (isArithmeticOperation(node->getOp()) && node->getBasicType() != glslang::EbtInt) {
722 node->getWritableType().getQualifier().noContraction = true;
723 }
724 // As this node is not an object node, need to traverse the children nodes.
725 return true;
726 }
727
728 // Visits a unary node. A unary node can not be an object node. If the operation
729 // is an arithmetic operation, need to mark this node as 'noContraction'.
visitUnary(glslang::TVisit,glslang::TIntermUnary * node)730 bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary* node) override
731 {
732 // If this is an arithmetic operation, marks this with 'noContraction'
733 if (isArithmeticOperation(node->getOp())) {
734 node->getWritableType().getQualifier().noContraction = true;
735 }
736 return true;
737 }
738
739 // Visits a symbol node. A symbol node is always an object node. So we
740 // should always be able to find its in our collected mapping from object
741 // nodes to access chains. As an object node, a symbol node can be either
742 // 'precise' or containing 'precise' objects according to unused
743 // access chain information we have when we visit this node.
visitSymbol(glslang::TIntermSymbol * node)744 void visitSymbol(glslang::TIntermSymbol* node) override
745 {
746 // Symbol nodes are object nodes and should always have an
747 // access chain collected before matches with it.
748 assert(accesschain_mapping_.count(node));
749 ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node);
750 // If the unused access chain is empty, this symbol node should be
751 // marked as 'precise'. Otherwise, the unused access chain should be
752 // appended to the symbol ID to build a new access chain which points to
753 // the nested 'precise' object in this symbol object.
754 if (remained_accesschain_.empty()) {
755 node->getWritableType().getQualifier().noContraction = true;
756 } else {
757 new_precise_accesschain += ObjectAccesschainDelimiter + remained_accesschain_;
758 }
759 // Add the new 'precise' access chain to the work list and make sure we
760 // don't visit it again.
761 if (!added_precise_object_ids_.count(new_precise_accesschain)) {
762 precise_objects_.insert(new_precise_accesschain);
763 added_precise_object_ids_.insert(new_precise_accesschain);
764 }
765 }
766
767 // A set of precise objects, represented as access chains.
768 ObjectAccesschainSet& precise_objects_;
769 // Visited symbol nodes, should not revisit these nodes.
770 ObjectAccesschainSet added_precise_object_ids_;
771 // The left node of an assignment operation might be an parent of 'precise' objects.
772 // This means the left node might not be an 'precise' object node, but it may contains
773 // 'precise' qualifier which should be propagated to the corresponding child node in
774 // the right. So we need the path from the left node to its nested 'precise' node to
775 // tell us how to find the corresponding 'precise' node in the right.
776 ObjectAccessChain remained_accesschain_;
777 // A map from node pointers to their access chains.
778 const AccessChainMapping& accesschain_mapping_;
779 };
780 }
781
782 namespace glslang {
783
PropagateNoContraction(const glslang::TIntermediate & intermediate)784 void PropagateNoContraction(const glslang::TIntermediate& intermediate)
785 {
786 // First, traverses the AST, records symbols with their defining operations
787 // and collects the initial set of precise symbols (symbol nodes that marked
788 // as 'noContraction') and precise return nodes.
789 auto mappings_and_precise_objects =
790 getSymbolToDefinitionMappingAndPreciseSymbolIDs(intermediate);
791
792 // The mapping of symbol node IDs to their defining nodes. This enables us
793 // to get the defining node directly from a given symbol ID without
794 // traversing the tree again.
795 NodeMapping& symbol_definition_mapping = std::get<0>(mappings_and_precise_objects);
796
797 // The mapping of object nodes to their access chains recorded.
798 AccessChainMapping& accesschain_mapping = std::get<1>(mappings_and_precise_objects);
799
800 // The initial set of 'precise' objects which are represented as the
801 // access chain toward them.
802 ObjectAccesschainSet& precise_object_accesschains = std::get<2>(mappings_and_precise_objects);
803
804 // The set of 'precise' return nodes.
805 ReturnBranchNodeSet& precise_return_nodes = std::get<3>(mappings_and_precise_objects);
806
807 // Second, uses the initial set of precise objects as a work list, pops an
808 // access chain, extract the symbol ID from it. Then:
809 // 1) Check the assignee object, see if it is 'precise' object node or
810 // contains 'precise' object. Obtain the incremental access chain from the
811 // assignee node to its nested 'precise' node (if any).
812 // 2) If the assignee object node is 'precise' or it contains 'precise'
813 // objects, traverses the right side of the assignment operation
814 // expression to mark arithmetic operations as 'noContration' and update
815 // 'precise' access chain work list with new found object nodes.
816 // Repeat above steps until the work list is empty.
817 TNoContractionAssigneeCheckingTraverser checker(accesschain_mapping);
818 TNoContractionPropagator propagator(&precise_object_accesschains, accesschain_mapping);
819
820 // We have two initial precise work lists to handle:
821 // 1) precise return nodes
822 // 2) precise object access chains
823 // We should process the precise return nodes first and the involved
824 // objects in the return expression should be added to the precise object
825 // access chain set.
826 while (!precise_return_nodes.empty()) {
827 glslang::TIntermBranch* precise_return_node = *precise_return_nodes.begin();
828 propagator.propagateNoContractionInReturnNode(precise_return_node);
829 precise_return_nodes.erase(precise_return_node);
830 }
831
832 while (!precise_object_accesschains.empty()) {
833 // Get the access chain of a precise object from the work list.
834 ObjectAccessChain precise_object_accesschain = *precise_object_accesschains.begin();
835 // Get the symbol id from the access chain.
836 ObjectAccessChain symbol_id = getFrontElement(precise_object_accesschain);
837 // Get all the defining nodes of that symbol ID.
838 std::pair<NodeMapping::iterator, NodeMapping::iterator> range =
839 symbol_definition_mapping.equal_range(symbol_id);
840 // Visits all the assignment nodes of that symbol ID and
841 // 1) Check if the assignee node is 'precise' or contains 'precise'
842 // objects.
843 // 2) Propagate the 'precise' to the top layer object nodes
844 // in the right side of the assignment operation, update the 'precise'
845 // work list with new access chains representing the new 'precise'
846 // objects, and mark arithmetic operations as 'noContraction'.
847 for (NodeMapping::iterator defining_node_iter = range.first;
848 defining_node_iter != range.second; defining_node_iter++) {
849 TIntermOperator* defining_node = defining_node_iter->second;
850 // Check the assignee node.
851 auto checker_result = checker.getPrecisenessAndRemainedAccessChain(
852 defining_node, precise_object_accesschain);
853 bool& contain_precise = std::get<0>(checker_result);
854 ObjectAccessChain& remained_accesschain = std::get<1>(checker_result);
855 // If the assignee node is 'precise' or contains 'precise', propagate the
856 // 'precise' to the right. Otherwise just skip this assignment node.
857 if (contain_precise) {
858 propagator.propagateNoContractionInOneExpression(defining_node,
859 remained_accesschain);
860 }
861 }
862 // Remove the last processed 'precise' object from the work list.
863 precise_object_accesschains.erase(precise_object_accesschain);
864 }
865 }
866 };
867