1 // 2 // Copyright 2020 The ANGLE Project Authors. All rights reserved. 3 // Use of this source code is governed by a BSD-style license that can be 4 // found in the LICENSE file. 5 // 6 7 #ifndef COMPILER_TRANSLATOR_MSL_INTERMREBUILD_H_ 8 #define COMPILER_TRANSLATOR_MSL_INTERMREBUILD_H_ 9 10 #include "compiler/translator/msl/NodeType.h" 11 #include "compiler/translator/tree_util/IntermTraverse.h" 12 13 namespace sh 14 { 15 16 // Walks the tree to rebuild nodes. 17 // This class is intended to be derived with overridden visitXXX functions. 18 // 19 // Each visitXXX function that does not have a Visit parameter simply has the visitor called 20 // exactly once, regardless of (preVisit) or (postVisit) values. 21 22 // Each visitXXX function that has a Visit parameter behaves as follows: 23 // * If (preVisit): 24 // - The node is visited before children are traversed. 25 // - The returned value is used to replace the visited node. The returned value may be the same 26 // as the original node. 27 // - If multiple nodes are returned, children and post visits of the returned nodes are not 28 // preformed, even if it is a singleton collection. 29 // * If (childVisit) 30 // - If any new children are returned, the node is automatically rebuilt with the new children 31 // before post visit. 32 // - Depending on the type of the node, null children may be discarded. 33 // - Ill-typed children cause rebuild errors. Ill-typed means the node to automatically rebuild 34 // cannot accept a child of a certain type as input to its constructor. 35 // - Only instances of TIntermAggregateBase can accept Multi results for any of its children. 36 // If supplied, the nodes are spliced children at the spot of the original child. 37 // * If (postVisit) 38 // - The node is visited after any children are traversed. 39 // - Only after such a rebuild (or lack thereof), the post-visit is performed. 40 // 41 // Nodes in visit functions are allowed to be modified in place, including TIntermAggregateBase 42 // child sequences. 43 // 44 // The default implementations of all the visitXXX functions support full pre and post traversal 45 // without modifying the visited nodes. 46 // 47 class TIntermRebuild : angle::NonCopyable 48 { 49 50 enum class Action 51 { 52 ReplaceSingle, 53 ReplaceMulti, 54 Drop, 55 Fail, 56 }; 57 58 public: 59 struct Fail 60 {}; 61 62 enum VisitBits : size_t 63 { 64 // No bits are set. 65 Empty = 0u, 66 67 // Allow visit of returned node's children. 68 Children = 1u << 0u, 69 70 // Allow post visit of returned node. 71 Post = 1u << 1u, 72 73 // If (Children) bit, only visit if the returned node is the same as the original node. 74 ChildrenRequiresSame = 1u << 2u, 75 76 // If (Post) bit, only visit if the returned node is the same as the original node. 77 PostRequiresSame = 1u << 3u, 78 79 RequireSame = ChildrenRequiresSame | PostRequiresSame, 80 Neither = Empty, 81 Both = Children | Post, 82 BothWhenSame = Both | RequireSame, 83 }; 84 85 private: 86 struct NodeStackGuard; 87 88 template <typename T> 89 struct ConsList 90 { 91 T value; 92 ConsList<T> *tail; 93 }; 94 95 class BaseResult 96 { 97 BaseResult(const BaseResult &) = delete; 98 BaseResult &operator=(const BaseResult &) = delete; 99 100 public: 101 BaseResult(BaseResult &&other) = default; 102 BaseResult(BaseResult &other); // For subclass move constructor impls 103 BaseResult(TIntermNode &node, VisitBits visit); 104 BaseResult(TIntermNode *node, VisitBits visit); 105 BaseResult(std::nullptr_t); 106 BaseResult(Fail); 107 BaseResult(std::vector<TIntermNode *> &&nodes); 108 109 void moveAssignImpl(BaseResult &other); // For subclass move assign impls 110 111 static BaseResult Multi(std::vector<TIntermNode *> &&nodes); 112 113 template <typename Iter> Multi(Iter nodesBegin,Iter nodesEnd)114 static BaseResult Multi(Iter nodesBegin, Iter nodesEnd) 115 { 116 std::vector<TIntermNode *> nodes; 117 for (Iter nodesCurr = nodesBegin; nodesCurr != nodesEnd; ++nodesCurr) 118 { 119 nodes.push_back(*nodesCurr); 120 } 121 return std::move(nodes); 122 } 123 124 bool isFail() const; 125 bool isDrop() const; 126 TIntermNode *single() const; 127 const std::vector<TIntermNode *> *multi() const; 128 129 public: 130 Action mAction; 131 VisitBits mVisit; 132 TIntermNode *mSingle; 133 std::vector<TIntermNode *> mMulti; 134 }; 135 136 public: 137 class PreResult : private BaseResult 138 { 139 friend class TIntermRebuild; 140 141 public: 142 PreResult(PreResult &&other); 143 PreResult(TIntermNode &node, VisitBits visit = VisitBits::BothWhenSame); 144 PreResult(TIntermNode *node, VisitBits visit = VisitBits::BothWhenSame); 145 PreResult(std::nullptr_t); // Used to drop a node. 146 PreResult(Fail); // Used to signal failure. 147 148 void operator=(PreResult &&other); 149 Multi(std::vector<TIntermNode * > && nodes)150 static PreResult Multi(std::vector<TIntermNode *> &&nodes) 151 { 152 return BaseResult::Multi(std::move(nodes)); 153 } 154 155 template <typename Iter> Multi(Iter nodesBegin,Iter nodesEnd)156 static PreResult Multi(Iter nodesBegin, Iter nodesEnd) 157 { 158 return BaseResult::Multi(nodesBegin, nodesEnd); 159 } 160 161 using BaseResult::isDrop; 162 using BaseResult::isFail; 163 using BaseResult::multi; 164 using BaseResult::single; 165 166 private: 167 PreResult(BaseResult &&other); 168 }; 169 170 class PostResult : private BaseResult 171 { 172 friend class TIntermRebuild; 173 174 public: 175 PostResult(PostResult &&other); 176 PostResult(TIntermNode &node); 177 PostResult(TIntermNode *node); 178 PostResult(std::nullptr_t); // Used to drop a node 179 PostResult(Fail); // Used to signal failure. 180 181 void operator=(PostResult &&other); 182 Multi(std::vector<TIntermNode * > && nodes)183 static PostResult Multi(std::vector<TIntermNode *> &&nodes) 184 { 185 return BaseResult::Multi(std::move(nodes)); 186 } 187 188 template <typename Iter> Multi(Iter nodesBegin,Iter nodesEnd)189 static PostResult Multi(Iter nodesBegin, Iter nodesEnd) 190 { 191 return BaseResult::Multi(nodesBegin, nodesEnd); 192 } 193 194 using BaseResult::isDrop; 195 using BaseResult::isFail; 196 using BaseResult::multi; 197 using BaseResult::single; 198 199 private: 200 PostResult(BaseResult &&other); 201 }; 202 203 public: 204 TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit); 205 206 virtual ~TIntermRebuild(); 207 208 // Rebuilds the tree starting at the provided root. If a new node would be returned for the 209 // root, the root node's children become that of the new node instead. Returns false if failure 210 // occurred. 211 [[nodiscard]] bool rebuildRoot(TIntermBlock &root); 212 213 protected: 214 virtual PreResult visitSymbolPre(TIntermSymbol &node); 215 virtual PreResult visitConstantUnionPre(TIntermConstantUnion &node); 216 virtual PreResult visitFunctionPrototypePre(TIntermFunctionPrototype &node); 217 virtual PreResult visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node); 218 virtual PreResult visitUnaryPre(TIntermUnary &node); 219 virtual PreResult visitBinaryPre(TIntermBinary &node); 220 virtual PreResult visitTernaryPre(TIntermTernary &node); 221 virtual PreResult visitSwizzlePre(TIntermSwizzle &node); 222 virtual PreResult visitIfElsePre(TIntermIfElse &node); 223 virtual PreResult visitSwitchPre(TIntermSwitch &node); 224 virtual PreResult visitCasePre(TIntermCase &node); 225 virtual PreResult visitLoopPre(TIntermLoop &node); 226 virtual PreResult visitBranchPre(TIntermBranch &node); 227 virtual PreResult visitDeclarationPre(TIntermDeclaration &node); 228 virtual PreResult visitBlockPre(TIntermBlock &node); 229 virtual PreResult visitAggregatePre(TIntermAggregate &node); 230 virtual PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node); 231 virtual PreResult visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration &node); 232 233 virtual PostResult visitSymbolPost(TIntermSymbol &node); 234 virtual PostResult visitConstantUnionPost(TIntermConstantUnion &node); 235 virtual PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node); 236 virtual PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node); 237 virtual PostResult visitUnaryPost(TIntermUnary &node); 238 virtual PostResult visitBinaryPost(TIntermBinary &node); 239 virtual PostResult visitTernaryPost(TIntermTernary &node); 240 virtual PostResult visitSwizzlePost(TIntermSwizzle &node); 241 virtual PostResult visitIfElsePost(TIntermIfElse &node); 242 virtual PostResult visitSwitchPost(TIntermSwitch &node); 243 virtual PostResult visitCasePost(TIntermCase &node); 244 virtual PostResult visitLoopPost(TIntermLoop &node); 245 virtual PostResult visitBranchPost(TIntermBranch &node); 246 virtual PostResult visitDeclarationPost(TIntermDeclaration &node); 247 virtual PostResult visitBlockPost(TIntermBlock &node); 248 virtual PostResult visitAggregatePost(TIntermAggregate &node); 249 virtual PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node); 250 virtual PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node); 251 252 // Can be used to rebuild a specific node during a traversal. Useful for fine control of 253 // rebuilding a node's children. 254 [[nodiscard]] PostResult rebuild(TIntermNode &node); 255 256 // Rebuilds the provided node in place. If a new node would be returned, the old node's children 257 // become that of the new node instead. Returns false if failure occurred. 258 [[nodiscard]] bool rebuildInPlace(TIntermAggregate &node); 259 260 // Rebuilds the provided node in place. If a new node would be returned, the old node's children 261 // become that of the new node instead. Returns false if failure occurred. 262 [[nodiscard]] bool rebuildInPlace(TIntermBlock &node); 263 264 // Rebuilds the provided node in place. If a new node would be returned, the old node's children 265 // become that of the new node instead. Returns false if failure occurred. 266 [[nodiscard]] bool rebuildInPlace(TIntermDeclaration &node); 267 268 // If currently at or below a function declaration body, this returns the function that encloses 269 // the currently visited node. (This returns null if at a function declaration node.) 270 const TFunction *getParentFunction() const; 271 272 TIntermNode *getParentNode(size_t offset = 0) const; 273 274 private: 275 template <typename Node> 276 [[nodiscard]] bool rebuildInPlaceImpl(Node &node); 277 278 PostResult traverseAny(TIntermNode &node); 279 280 template <typename Node> 281 Node *traverseAnyAs(TIntermNode &node); 282 283 template <typename Node> 284 bool traverseAnyAs(TIntermNode &node, Node *&out); 285 286 PreResult traversePre(TIntermNode &originalNode); 287 TIntermNode *traverseChildren(NodeType currNodeType, 288 const TIntermNode &originalNode, 289 TIntermNode &currNode, 290 VisitBits visit); 291 PostResult traversePost(NodeType nodeType, 292 const TIntermNode &originalNode, 293 TIntermNode &currNode, 294 VisitBits visit); 295 296 bool traverseAggregateBaseChildren(TIntermAggregateBase &node); 297 298 TIntermNode *traverseUnaryChildren(TIntermUnary &node); 299 TIntermNode *traverseBinaryChildren(TIntermBinary &node); 300 TIntermNode *traverseTernaryChildren(TIntermTernary &node); 301 TIntermNode *traverseSwizzleChildren(TIntermSwizzle &node); 302 TIntermNode *traverseIfElseChildren(TIntermIfElse &node); 303 TIntermNode *traverseSwitchChildren(TIntermSwitch &node); 304 TIntermNode *traverseCaseChildren(TIntermCase &node); 305 TIntermNode *traverseLoopChildren(TIntermLoop &node); 306 TIntermNode *traverseBranchChildren(TIntermBranch &node); 307 TIntermNode *traverseDeclarationChildren(TIntermDeclaration &node); 308 TIntermNode *traverseBlockChildren(TIntermBlock &node); 309 TIntermNode *traverseAggregateChildren(TIntermAggregate &node); 310 TIntermNode *traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node); 311 TIntermNode *traverseGlobalQualifierDeclarationChildren( 312 TIntermGlobalQualifierDeclaration &node); 313 314 protected: 315 TCompiler &mCompiler; 316 TSymbolTable &mSymbolTable; 317 const TFunction *mParentFunc = nullptr; 318 GetNodeType getNodeType; 319 320 private: 321 ConsList<TIntermNode *> mNodeStack{nullptr, nullptr}; 322 bool mPreVisit; 323 bool mPostVisit; 324 }; 325 326 } // namespace sh 327 328 #endif // COMPILER_TRANSLATOR_MSL_INTERMREBUILD_H_ 329