1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef SRC_AST_TRAVERSE_EXPRESSIONS_H_
16 #define SRC_AST_TRAVERSE_EXPRESSIONS_H_
17
18 #include <vector>
19
20 #include "src/ast/binary_expression.h"
21 #include "src/ast/bitcast_expression.h"
22 #include "src/ast/call_expression.h"
23 #include "src/ast/index_accessor_expression.h"
24 #include "src/ast/literal_expression.h"
25 #include "src/ast/member_accessor_expression.h"
26 #include "src/ast/phony_expression.h"
27 #include "src/ast/unary_op_expression.h"
28 #include "src/utils/reverse.h"
29
30 namespace tint {
31 namespace ast {
32
33 /// The action to perform after calling the TraverseExpressions() callback
34 /// function.
35 enum class TraverseAction {
36 /// Stop traversal immediately.
37 Stop,
38 /// Descend into this expression.
39 Descend,
40 /// Do not descend into this expression.
41 Skip,
42 };
43
44 /// The order TraverseExpressions() will traverse expressions
45 enum class TraverseOrder {
46 /// Expressions will be traversed from left to right
47 LeftToRight,
48 /// Expressions will be traversed from right to left
49 RightToLeft,
50 };
51
52 /// TraverseExpressions performs a depth-first traversal of the expression nodes
53 /// from `root`, calling `callback` for each of the visited expressions that
54 /// match the predicate parameter type, in pre-ordering (root first).
55 /// @param root the root expression node
56 /// @param diags the diagnostics used for error messages
57 /// @param callback the callback function. Must be of the signature:
58 /// `TraverseAction(const T*)` where T is an ast::Expression type.
59 /// @return true on success, false on error
60 template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
TraverseExpressions(const ast::Expression * root,diag::List & diags,CALLBACK && callback)61 bool TraverseExpressions(const ast::Expression* root,
62 diag::List& diags,
63 CALLBACK&& callback) {
64 using EXPR_TYPE = std::remove_pointer_t<traits::ParameterType<CALLBACK, 0>>;
65 std::vector<const ast::Expression*> to_visit{root};
66
67 auto push_pair = [&](const ast::Expression* left,
68 const ast::Expression* right) {
69 if (ORDER == TraverseOrder::LeftToRight) {
70 to_visit.push_back(right);
71 to_visit.push_back(left);
72 } else {
73 to_visit.push_back(left);
74 to_visit.push_back(right);
75 }
76 };
77 auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
78 if (ORDER == TraverseOrder::LeftToRight) {
79 for (auto* expr : utils::Reverse(exprs)) {
80 to_visit.push_back(expr);
81 }
82 } else {
83 for (auto* expr : exprs) {
84 to_visit.push_back(expr);
85 }
86 }
87 };
88
89 while (!to_visit.empty()) {
90 auto* expr = to_visit.back();
91 to_visit.pop_back();
92
93 if (auto* filtered = expr->As<EXPR_TYPE>()) {
94 switch (callback(filtered)) {
95 case TraverseAction::Stop:
96 return true;
97 case TraverseAction::Skip:
98 continue;
99 case TraverseAction::Descend:
100 break;
101 }
102 }
103
104 if (auto* idx = expr->As<IndexAccessorExpression>()) {
105 push_pair(idx->object, idx->index);
106 } else if (auto* bin_op = expr->As<BinaryExpression>()) {
107 push_pair(bin_op->lhs, bin_op->rhs);
108 } else if (auto* bitcast = expr->As<BitcastExpression>()) {
109 to_visit.push_back(bitcast->expr);
110 } else if (auto* call = expr->As<CallExpression>()) {
111 // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
112 // function name in the traversal.
113 // to_visit.push_back(call->func);
114 push_list(call->args);
115 } else if (auto* member = expr->As<MemberAccessorExpression>()) {
116 // TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
117 // member name in the traversal.
118 // push_pair(member->structure, member->member);
119 to_visit.push_back(member->structure);
120 } else if (auto* unary = expr->As<UnaryOpExpression>()) {
121 to_visit.push_back(unary->expr);
122 } else if (expr->IsAnyOf<LiteralExpression, IdentifierExpression,
123 PhonyExpression>()) {
124 // Leaf expression
125 } else {
126 TINT_ICE(AST, diags) << "unhandled expression type: "
127 << expr->TypeInfo().name;
128 return false;
129 }
130 }
131 return true;
132 }
133
134 } // namespace ast
135 } // namespace tint
136
137 #endif // SRC_AST_TRAVERSE_EXPRESSIONS_H_
138