• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_EVALUATE_TRAVERSE_H_
10 #define FORTRAN_EVALUATE_TRAVERSE_H_
11 
12 // A utility for scanning all of the constituent objects in an Expr<>
13 // expression representation using a collection of mutually recursive
14 // functions to compose a function object.
15 //
16 // The class template Traverse<> below implements a function object that
17 // can handle every type that can appear in or around an Expr<>.
18 // Each of its overloads for operator() should be viewed as a *default*
19 // handler; some of these must be overridden by the client to accomplish
20 // its particular task.
21 //
22 // The client (Visitor) of Traverse<Visitor,Result> must define:
23 // - a member function "Result Default();"
24 // - a member function "Result Combine(Result &&, Result &&)"
25 // - overrides for "Result operator()"
26 //
27 // Boilerplate classes also appear below to ease construction of visitors.
28 // See CheckSpecificationExpr() in check-expression.cpp for an example client.
29 //
30 // How this works:
31 // - The operator() overloads in Traverse<> invoke the visitor's Default() for
32 //   expression leaf nodes.  They invoke the visitor's operator() for the
33 //   subtrees of interior nodes, and the visitor's Combine() to merge their
34 //   results together.
35 // - Overloads of operator() in each visitor handle the cases of interest.
36 
37 #include "expression.h"
38 #include "flang/Semantics/symbol.h"
39 #include "flang/Semantics/type.h"
40 #include <set>
41 #include <type_traits>
42 
43 namespace Fortran::evaluate {
44 template <typename Visitor, typename Result> class Traverse {
45 public:
Traverse(Visitor & v)46   explicit Traverse(Visitor &v) : visitor_{v} {}
47 
48   // Packaging
49   template <typename A, bool C>
operator()50   Result operator()(const common::Indirection<A, C> &x) const {
51     return visitor_(x.value());
52   }
operator()53   template <typename A> Result operator()(SymbolRef x) const {
54     return visitor_(*x);
55   }
operator()56   template <typename A> Result operator()(const std::unique_ptr<A> &x) const {
57     return visitor_(x.get());
58   }
operator()59   template <typename A> Result operator()(const std::shared_ptr<A> &x) const {
60     return visitor_(x.get());
61   }
operator()62   template <typename A> Result operator()(const A *x) const {
63     if (x) {
64       return visitor_(*x);
65     } else {
66       return visitor_.Default();
67     }
68   }
operator()69   template <typename A> Result operator()(const std::optional<A> &x) const {
70     if (x) {
71       return visitor_(*x);
72     } else {
73       return visitor_.Default();
74     }
75   }
76   template <typename... A>
operator()77   Result operator()(const std::variant<A...> &u) const {
78     return std::visit(visitor_, u);
79   }
operator()80   template <typename A> Result operator()(const std::vector<A> &x) const {
81     return CombineContents(x);
82   }
83 
84   // Leaves
operator()85   Result operator()(const BOZLiteralConstant &) const {
86     return visitor_.Default();
87   }
operator()88   Result operator()(const NullPointer &) const { return visitor_.Default(); }
operator()89   template <typename T> Result operator()(const Constant<T> &x) const {
90     if constexpr (T::category == TypeCategory::Derived) {
91       std::optional<Result> result;
92       for (const StructureConstructorValues &map : x.values()) {
93         for (const auto &pair : map) {
94           auto value{visitor_(pair.second.value())};
95           result = result
96               ? visitor_.Combine(std::move(*result), std::move(value))
97               : std::move(value);
98         }
99       }
100       return result ? *result : visitor_.Default();
101     } else {
102       return visitor_.Default();
103     }
104   }
operator()105   Result operator()(const Symbol &) const { return visitor_.Default(); }
operator()106   Result operator()(const StaticDataObject &) const {
107     return visitor_.Default();
108   }
operator()109   Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); }
110 
111   // Variables
operator()112   Result operator()(const BaseObject &x) const { return visitor_(x.u); }
operator()113   Result operator()(const Component &x) const {
114     return Combine(x.base(), x.GetLastSymbol());
115   }
operator()116   Result operator()(const NamedEntity &x) const {
117     if (const Component * component{x.UnwrapComponent()}) {
118       return visitor_(*component);
119     } else {
120       return visitor_(x.GetFirstSymbol());
121     }
122   }
operator()123   Result operator()(const TypeParamInquiry &x) const {
124     return visitor_(x.base());
125   }
operator()126   Result operator()(const Triplet &x) const {
127     return Combine(x.lower(), x.upper(), x.stride());
128   }
operator()129   Result operator()(const Subscript &x) const { return visitor_(x.u); }
operator()130   Result operator()(const ArrayRef &x) const {
131     return Combine(x.base(), x.subscript());
132   }
operator()133   Result operator()(const CoarrayRef &x) const {
134     return Combine(
135         x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team());
136   }
operator()137   Result operator()(const DataRef &x) const { return visitor_(x.u); }
operator()138   Result operator()(const Substring &x) const {
139     return Combine(x.parent(), x.lower(), x.upper());
140   }
operator()141   Result operator()(const ComplexPart &x) const {
142     return visitor_(x.complex());
143   }
operator()144   template <typename T> Result operator()(const Designator<T> &x) const {
145     return visitor_(x.u);
146   }
operator()147   template <typename T> Result operator()(const Variable<T> &x) const {
148     return visitor_(x.u);
149   }
operator()150   Result operator()(const DescriptorInquiry &x) const {
151     return visitor_(x.base());
152   }
153 
154   // Calls
operator()155   Result operator()(const SpecificIntrinsic &) const {
156     return visitor_.Default();
157   }
operator()158   Result operator()(const ProcedureDesignator &x) const {
159     if (const Component * component{x.GetComponent()}) {
160       return visitor_(*component);
161     } else if (const Symbol * symbol{x.GetSymbol()}) {
162       return visitor_(*symbol);
163     } else {
164       return visitor_(DEREF(x.GetSpecificIntrinsic()));
165     }
166   }
operator()167   Result operator()(const ActualArgument &x) const {
168     if (const auto *symbol{x.GetAssumedTypeDummy()}) {
169       return visitor_(*symbol);
170     } else {
171       return visitor_(x.UnwrapExpr());
172     }
173   }
operator()174   Result operator()(const ProcedureRef &x) const {
175     return Combine(x.proc(), x.arguments());
176   }
operator()177   template <typename T> Result operator()(const FunctionRef<T> &x) const {
178     return visitor_(static_cast<const ProcedureRef &>(x));
179   }
180 
181   // Other primaries
182   template <typename T>
operator()183   Result operator()(const ArrayConstructorValue<T> &x) const {
184     return visitor_(x.u);
185   }
186   template <typename T>
operator()187   Result operator()(const ArrayConstructorValues<T> &x) const {
188     return CombineContents(x);
189   }
operator()190   template <typename T> Result operator()(const ImpliedDo<T> &x) const {
191     return Combine(x.lower(), x.upper(), x.stride(), x.values());
192   }
operator()193   Result operator()(const semantics::ParamValue &x) const {
194     return visitor_(x.GetExplicit());
195   }
operator()196   Result operator()(
197       const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const {
198     return visitor_(x.second);
199   }
operator()200   Result operator()(const semantics::DerivedTypeSpec &x) const {
201     return CombineContents(x.parameters());
202   }
operator()203   Result operator()(const StructureConstructorValues::value_type &x) const {
204     return visitor_(x.second);
205   }
operator()206   Result operator()(const StructureConstructor &x) const {
207     return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x));
208   }
209 
210   // Operations and wrappers
211   template <typename D, typename R, typename O>
operator()212   Result operator()(const Operation<D, R, O> &op) const {
213     return visitor_(op.left());
214   }
215   template <typename D, typename R, typename LO, typename RO>
operator()216   Result operator()(const Operation<D, R, LO, RO> &op) const {
217     return Combine(op.left(), op.right());
218   }
operator()219   Result operator()(const Relational<SomeType> &x) const {
220     return visitor_(x.u);
221   }
operator()222   template <typename T> Result operator()(const Expr<T> &x) const {
223     return visitor_(x.u);
224   }
225 
226 private:
CombineRange(ITER iter,ITER end)227   template <typename ITER> Result CombineRange(ITER iter, ITER end) const {
228     if (iter == end) {
229       return visitor_.Default();
230     } else {
231       Result result{visitor_(*iter++)};
232       for (; iter != end; ++iter) {
233         result = visitor_.Combine(std::move(result), visitor_(*iter));
234       }
235       return result;
236     }
237   }
238 
CombineContents(const A & x)239   template <typename A> Result CombineContents(const A &x) const {
240     return CombineRange(x.begin(), x.end());
241   }
242 
243   template <typename A, typename... Bs>
Combine(const A & x,const Bs &...ys)244   Result Combine(const A &x, const Bs &...ys) const {
245     if constexpr (sizeof...(Bs) == 0) {
246       return visitor_(x);
247     } else {
248       return visitor_.Combine(visitor_(x), Combine(ys...));
249     }
250   }
251 
252   Visitor &visitor_;
253 };
254 
255 // For validity checks across an expression: if any operator() result is
256 // false, so is the overall result.
257 template <typename Visitor, bool DefaultValue,
258     typename Base = Traverse<Visitor, bool>>
259 struct AllTraverse : public Base {
AllTraverseAllTraverse260   explicit AllTraverse(Visitor &v) : Base{v} {}
261   using Base::operator();
DefaultAllTraverse262   static bool Default() { return DefaultValue; }
CombineAllTraverse263   static bool Combine(bool x, bool y) { return x && y; }
264 };
265 
266 // For searches over an expression: the first operator() result that
267 // is truthful is the final result.  Works for Booleans, pointers,
268 // and std::optional<>.
269 template <typename Visitor, typename Result = bool,
270     typename Base = Traverse<Visitor, Result>>
271 class AnyTraverse : public Base {
272 public:
AnyTraverse(Visitor & v)273   explicit AnyTraverse(Visitor &v) : Base{v} {}
274   using Base::operator();
Default()275   Result Default() const { return default_; }
Combine(Result && x,Result && y)276   static Result Combine(Result &&x, Result &&y) {
277     if (x) {
278       return std::move(x);
279     } else {
280       return std::move(y);
281     }
282   }
283 
284 private:
285   Result default_{};
286 };
287 
288 template <typename Visitor, typename Set,
289     typename Base = Traverse<Visitor, Set>>
290 struct SetTraverse : public Base {
SetTraverseSetTraverse291   explicit SetTraverse(Visitor &v) : Base{v} {}
292   using Base::operator();
DefaultSetTraverse293   static Set Default() { return {}; }
CombineSetTraverse294   static Set Combine(Set &&x, Set &&y) {
295 #if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES)
296     x.merge(y);
297 #else
298     // std::set::merge() not available (yet)
299     for (auto &value : y) {
300       x.insert(std::move(value));
301     }
302 #endif
303     return std::move(x);
304   }
305 };
306 
307 } // namespace Fortran::evaluate
308 #endif
309