• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_AST_TRAVERSE_EXPRESSIONS_H_
16 #define SRC_AST_TRAVERSE_EXPRESSIONS_H_
17 
18 #include <vector>
19 
20 #include "src/ast/binary_expression.h"
21 #include "src/ast/bitcast_expression.h"
22 #include "src/ast/call_expression.h"
23 #include "src/ast/index_accessor_expression.h"
24 #include "src/ast/literal_expression.h"
25 #include "src/ast/member_accessor_expression.h"
26 #include "src/ast/phony_expression.h"
27 #include "src/ast/unary_op_expression.h"
28 #include "src/utils/reverse.h"
29 
30 namespace tint {
31 namespace ast {
32 
33 /// The action to perform after calling the TraverseExpressions() callback
34 /// function.
35 enum class TraverseAction {
36   /// Stop traversal immediately.
37   Stop,
38   /// Descend into this expression.
39   Descend,
40   /// Do not descend into this expression.
41   Skip,
42 };
43 
44 /// The order TraverseExpressions() will traverse expressions
45 enum class TraverseOrder {
46   /// Expressions will be traversed from left to right
47   LeftToRight,
48   /// Expressions will be traversed from right to left
49   RightToLeft,
50 };
51 
52 /// TraverseExpressions performs a depth-first traversal of the expression nodes
53 /// from `root`, calling `callback` for each of the visited expressions that
54 /// match the predicate parameter type, in pre-ordering (root first).
55 /// @param root the root expression node
56 /// @param diags the diagnostics used for error messages
57 /// @param callback the callback function. Must be of the signature:
58 ///        `TraverseAction(const T*)` where T is an ast::Expression type.
59 /// @return true on success, false on error
60 template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
TraverseExpressions(const ast::Expression * root,diag::List & diags,CALLBACK && callback)61 bool TraverseExpressions(const ast::Expression* root,
62                          diag::List& diags,
63                          CALLBACK&& callback) {
64   using EXPR_TYPE = std::remove_pointer_t<traits::ParameterType<CALLBACK, 0>>;
65   std::vector<const ast::Expression*> to_visit{root};
66 
67   auto push_pair = [&](const ast::Expression* left,
68                        const ast::Expression* right) {
69     if (ORDER == TraverseOrder::LeftToRight) {
70       to_visit.push_back(right);
71       to_visit.push_back(left);
72     } else {
73       to_visit.push_back(left);
74       to_visit.push_back(right);
75     }
76   };
77   auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
78     if (ORDER == TraverseOrder::LeftToRight) {
79       for (auto* expr : utils::Reverse(exprs)) {
80         to_visit.push_back(expr);
81       }
82     } else {
83       for (auto* expr : exprs) {
84         to_visit.push_back(expr);
85       }
86     }
87   };
88 
89   while (!to_visit.empty()) {
90     auto* expr = to_visit.back();
91     to_visit.pop_back();
92 
93     if (auto* filtered = expr->As<EXPR_TYPE>()) {
94       switch (callback(filtered)) {
95         case TraverseAction::Stop:
96           return true;
97         case TraverseAction::Skip:
98           continue;
99         case TraverseAction::Descend:
100           break;
101       }
102     }
103 
104     if (auto* idx = expr->As<IndexAccessorExpression>()) {
105       push_pair(idx->object, idx->index);
106     } else if (auto* bin_op = expr->As<BinaryExpression>()) {
107       push_pair(bin_op->lhs, bin_op->rhs);
108     } else if (auto* bitcast = expr->As<BitcastExpression>()) {
109       to_visit.push_back(bitcast->expr);
110     } else if (auto* call = expr->As<CallExpression>()) {
111       // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
112       // function name in the traversal.
113       // to_visit.push_back(call->func);
114       push_list(call->args);
115     } else if (auto* member = expr->As<MemberAccessorExpression>()) {
116       // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
117       // member name in the traversal.
118       // push_pair(member->structure, member->member);
119       to_visit.push_back(member->structure);
120     } else if (auto* unary = expr->As<UnaryOpExpression>()) {
121       to_visit.push_back(unary->expr);
122     } else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
123                              PhonyExpression>()) {
124       // Leaf expression
125     } else {
126       TINT_ICE(AST, diags) << "unhandled expression type: "
127                            << expr->TypeInfo().name;
128       return false;
129     }
130   }
131   return true;
132 }
133 
134 }  // namespace ast
135 }  // namespace tint
136 
137 #endif  // SRC_AST_TRAVERSE_EXPRESSIONS_H_
138