1 //===-- include/flang/Evaluate/traverse.h -----------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef FORTRAN_EVALUATE_TRAVERSE_H_ 10 #define FORTRAN_EVALUATE_TRAVERSE_H_ 11 12 // A utility for scanning all of the constituent objects in an Expr<> 13 // expression representation using a collection of mutually recursive 14 // functions to compose a function object. 15 // 16 // The class template Traverse<> below implements a function object that 17 // can handle every type that can appear in or around an Expr<>. 18 // Each of its overloads for operator() should be viewed as a *default* 19 // handler; some of these must be overridden by the client to accomplish 20 // its particular task. 21 // 22 // The client (Visitor) of Traverse<Visitor,Result> must define: 23 // - a member function "Result Default();" 24 // - a member function "Result Combine(Result &&, Result &&)" 25 // - overrides for "Result operator()" 26 // 27 // Boilerplate classes also appear below to ease construction of visitors. 28 // See CheckSpecificationExpr() in check-expression.cpp for an example client. 29 // 30 // How this works: 31 // - The operator() overloads in Traverse<> invoke the visitor's Default() for 32 // expression leaf nodes. They invoke the visitor's operator() for the 33 // subtrees of interior nodes, and the visitor's Combine() to merge their 34 // results together. 35 // - Overloads of operator() in each visitor handle the cases of interest. 36 37 #include "expression.h" 38 #include "flang/Semantics/symbol.h" 39 #include "flang/Semantics/type.h" 40 #include <set> 41 #include <type_traits> 42 43 namespace Fortran::evaluate { 44 template <typename Visitor, typename Result> class Traverse { 45 public: Traverse(Visitor & v)46 explicit Traverse(Visitor &v) : visitor_{v} {} 47 48 // Packaging 49 template <typename A, bool C> operator()50 Result operator()(const common::Indirection<A, C> &x) const { 51 return visitor_(x.value()); 52 } operator()53 template <typename A> Result operator()(SymbolRef x) const { 54 return visitor_(*x); 55 } operator()56 template <typename A> Result operator()(const std::unique_ptr<A> &x) const { 57 return visitor_(x.get()); 58 } operator()59 template <typename A> Result operator()(const std::shared_ptr<A> &x) const { 60 return visitor_(x.get()); 61 } operator()62 template <typename A> Result operator()(const A *x) const { 63 if (x) { 64 return visitor_(*x); 65 } else { 66 return visitor_.Default(); 67 } 68 } operator()69 template <typename A> Result operator()(const std::optional<A> &x) const { 70 if (x) { 71 return visitor_(*x); 72 } else { 73 return visitor_.Default(); 74 } 75 } 76 template <typename... A> operator()77 Result operator()(const std::variant<A...> &u) const { 78 return std::visit(visitor_, u); 79 } operator()80 template <typename A> Result operator()(const std::vector<A> &x) const { 81 return CombineContents(x); 82 } 83 84 // Leaves operator()85 Result operator()(const BOZLiteralConstant &) const { 86 return visitor_.Default(); 87 } operator()88 Result operator()(const NullPointer &) const { return visitor_.Default(); } operator()89 template <typename T> Result operator()(const Constant<T> &x) const { 90 if constexpr (T::category == TypeCategory::Derived) { 91 std::optional<Result> result; 92 for (const StructureConstructorValues &map : x.values()) { 93 for (const auto &pair : map) { 94 auto value{visitor_(pair.second.value())}; 95 result = result 96 ? visitor_.Combine(std::move(*result), std::move(value)) 97 : std::move(value); 98 } 99 } 100 return result ? *result : visitor_.Default(); 101 } else { 102 return visitor_.Default(); 103 } 104 } operator()105 Result operator()(const Symbol &) const { return visitor_.Default(); } operator()106 Result operator()(const StaticDataObject &) const { 107 return visitor_.Default(); 108 } operator()109 Result operator()(const ImpliedDoIndex &) const { return visitor_.Default(); } 110 111 // Variables operator()112 Result operator()(const BaseObject &x) const { return visitor_(x.u); } operator()113 Result operator()(const Component &x) const { 114 return Combine(x.base(), x.GetLastSymbol()); 115 } operator()116 Result operator()(const NamedEntity &x) const { 117 if (const Component * component{x.UnwrapComponent()}) { 118 return visitor_(*component); 119 } else { 120 return visitor_(x.GetFirstSymbol()); 121 } 122 } operator()123 Result operator()(const TypeParamInquiry &x) const { 124 return visitor_(x.base()); 125 } operator()126 Result operator()(const Triplet &x) const { 127 return Combine(x.lower(), x.upper(), x.stride()); 128 } operator()129 Result operator()(const Subscript &x) const { return visitor_(x.u); } operator()130 Result operator()(const ArrayRef &x) const { 131 return Combine(x.base(), x.subscript()); 132 } operator()133 Result operator()(const CoarrayRef &x) const { 134 return Combine( 135 x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team()); 136 } operator()137 Result operator()(const DataRef &x) const { return visitor_(x.u); } operator()138 Result operator()(const Substring &x) const { 139 return Combine(x.parent(), x.lower(), x.upper()); 140 } operator()141 Result operator()(const ComplexPart &x) const { 142 return visitor_(x.complex()); 143 } operator()144 template <typename T> Result operator()(const Designator<T> &x) const { 145 return visitor_(x.u); 146 } operator()147 template <typename T> Result operator()(const Variable<T> &x) const { 148 return visitor_(x.u); 149 } operator()150 Result operator()(const DescriptorInquiry &x) const { 151 return visitor_(x.base()); 152 } 153 154 // Calls operator()155 Result operator()(const SpecificIntrinsic &) const { 156 return visitor_.Default(); 157 } operator()158 Result operator()(const ProcedureDesignator &x) const { 159 if (const Component * component{x.GetComponent()}) { 160 return visitor_(*component); 161 } else if (const Symbol * symbol{x.GetSymbol()}) { 162 return visitor_(*symbol); 163 } else { 164 return visitor_(DEREF(x.GetSpecificIntrinsic())); 165 } 166 } operator()167 Result operator()(const ActualArgument &x) const { 168 if (const auto *symbol{x.GetAssumedTypeDummy()}) { 169 return visitor_(*symbol); 170 } else { 171 return visitor_(x.UnwrapExpr()); 172 } 173 } operator()174 Result operator()(const ProcedureRef &x) const { 175 return Combine(x.proc(), x.arguments()); 176 } operator()177 template <typename T> Result operator()(const FunctionRef<T> &x) const { 178 return visitor_(static_cast<const ProcedureRef &>(x)); 179 } 180 181 // Other primaries 182 template <typename T> operator()183 Result operator()(const ArrayConstructorValue<T> &x) const { 184 return visitor_(x.u); 185 } 186 template <typename T> operator()187 Result operator()(const ArrayConstructorValues<T> &x) const { 188 return CombineContents(x); 189 } operator()190 template <typename T> Result operator()(const ImpliedDo<T> &x) const { 191 return Combine(x.lower(), x.upper(), x.stride(), x.values()); 192 } operator()193 Result operator()(const semantics::ParamValue &x) const { 194 return visitor_(x.GetExplicit()); 195 } operator()196 Result operator()( 197 const semantics::DerivedTypeSpec::ParameterMapType::value_type &x) const { 198 return visitor_(x.second); 199 } operator()200 Result operator()(const semantics::DerivedTypeSpec &x) const { 201 return CombineContents(x.parameters()); 202 } operator()203 Result operator()(const StructureConstructorValues::value_type &x) const { 204 return visitor_(x.second); 205 } operator()206 Result operator()(const StructureConstructor &x) const { 207 return visitor_.Combine(visitor_(x.derivedTypeSpec()), CombineContents(x)); 208 } 209 210 // Operations and wrappers 211 template <typename D, typename R, typename O> operator()212 Result operator()(const Operation<D, R, O> &op) const { 213 return visitor_(op.left()); 214 } 215 template <typename D, typename R, typename LO, typename RO> operator()216 Result operator()(const Operation<D, R, LO, RO> &op) const { 217 return Combine(op.left(), op.right()); 218 } operator()219 Result operator()(const Relational<SomeType> &x) const { 220 return visitor_(x.u); 221 } operator()222 template <typename T> Result operator()(const Expr<T> &x) const { 223 return visitor_(x.u); 224 } 225 226 private: CombineRange(ITER iter,ITER end)227 template <typename ITER> Result CombineRange(ITER iter, ITER end) const { 228 if (iter == end) { 229 return visitor_.Default(); 230 } else { 231 Result result{visitor_(*iter++)}; 232 for (; iter != end; ++iter) { 233 result = visitor_.Combine(std::move(result), visitor_(*iter)); 234 } 235 return result; 236 } 237 } 238 CombineContents(const A & x)239 template <typename A> Result CombineContents(const A &x) const { 240 return CombineRange(x.begin(), x.end()); 241 } 242 243 template <typename A, typename... Bs> Combine(const A & x,const Bs &...ys)244 Result Combine(const A &x, const Bs &...ys) const { 245 if constexpr (sizeof...(Bs) == 0) { 246 return visitor_(x); 247 } else { 248 return visitor_.Combine(visitor_(x), Combine(ys...)); 249 } 250 } 251 252 Visitor &visitor_; 253 }; 254 255 // For validity checks across an expression: if any operator() result is 256 // false, so is the overall result. 257 template <typename Visitor, bool DefaultValue, 258 typename Base = Traverse<Visitor, bool>> 259 struct AllTraverse : public Base { AllTraverseAllTraverse260 explicit AllTraverse(Visitor &v) : Base{v} {} 261 using Base::operator(); DefaultAllTraverse262 static bool Default() { return DefaultValue; } CombineAllTraverse263 static bool Combine(bool x, bool y) { return x && y; } 264 }; 265 266 // For searches over an expression: the first operator() result that 267 // is truthful is the final result. Works for Booleans, pointers, 268 // and std::optional<>. 269 template <typename Visitor, typename Result = bool, 270 typename Base = Traverse<Visitor, Result>> 271 class AnyTraverse : public Base { 272 public: AnyTraverse(Visitor & v)273 explicit AnyTraverse(Visitor &v) : Base{v} {} 274 using Base::operator(); Default()275 Result Default() const { return default_; } Combine(Result && x,Result && y)276 static Result Combine(Result &&x, Result &&y) { 277 if (x) { 278 return std::move(x); 279 } else { 280 return std::move(y); 281 } 282 } 283 284 private: 285 Result default_{}; 286 }; 287 288 template <typename Visitor, typename Set, 289 typename Base = Traverse<Visitor, Set>> 290 struct SetTraverse : public Base { SetTraverseSetTraverse291 explicit SetTraverse(Visitor &v) : Base{v} {} 292 using Base::operator(); DefaultSetTraverse293 static Set Default() { return {}; } CombineSetTraverse294 static Set Combine(Set &&x, Set &&y) { 295 #if defined __GNUC__ && !defined __APPLE__ && !(CLANG_LIBRARIES) 296 x.merge(y); 297 #else 298 // std::set::merge() not available (yet) 299 for (auto &value : y) { 300 x.insert(std::move(value)); 301 } 302 #endif 303 return std::move(x); 304 } 305 }; 306 307 } // namespace Fortran::evaluate 308 #endif 309