• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- lib/Evaluate/fold-implementation.h --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
10 #define FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
11 
12 #include "character.h"
13 #include "host.h"
14 #include "int-power.h"
15 #include "flang/Common/indirection.h"
16 #include "flang/Common/template.h"
17 #include "flang/Common/unwrap.h"
18 #include "flang/Evaluate/characteristics.h"
19 #include "flang/Evaluate/common.h"
20 #include "flang/Evaluate/constant.h"
21 #include "flang/Evaluate/expression.h"
22 #include "flang/Evaluate/fold.h"
23 #include "flang/Evaluate/formatting.h"
24 #include "flang/Evaluate/intrinsics-library.h"
25 #include "flang/Evaluate/intrinsics.h"
26 #include "flang/Evaluate/shape.h"
27 #include "flang/Evaluate/tools.h"
28 #include "flang/Evaluate/traverse.h"
29 #include "flang/Evaluate/type.h"
30 #include "flang/Parser/message.h"
31 #include "flang/Semantics/scope.h"
32 #include "flang/Semantics/symbol.h"
33 #include "flang/Semantics/tools.h"
34 #include <algorithm>
35 #include <cmath>
36 #include <complex>
37 #include <cstdio>
38 #include <optional>
39 #include <type_traits>
40 #include <variant>
41 
42 // Some environments, viz. clang on Darwin, allow the macro HUGE
43 // to leak out of <math.h> even when it is never directly included.
44 #undef HUGE
45 
46 namespace Fortran::evaluate {
47 
48 // Utilities
49 template <typename T> class Folder {
50 public:
Folder(FoldingContext & c)51   explicit Folder(FoldingContext &c) : context_{c} {}
52   std::optional<Constant<T>> GetNamedConstant(const Symbol &);
53   std::optional<Constant<T>> ApplySubscripts(const Constant<T> &array,
54       const std::vector<Constant<SubscriptInteger>> &subscripts);
55   std::optional<Constant<T>> ApplyComponent(Constant<SomeDerived> &&,
56       const Symbol &component,
57       const std::vector<Constant<SubscriptInteger>> * = nullptr);
58   std::optional<Constant<T>> GetConstantComponent(
59       Component &, const std::vector<Constant<SubscriptInteger>> * = nullptr);
60   std::optional<Constant<T>> Folding(ArrayRef &);
61   Expr<T> Folding(Designator<T> &&);
62   Constant<T> *Folding(std::optional<ActualArgument> &);
63   Expr<T> Reshape(FunctionRef<T> &&);
64 
65 private:
66   FoldingContext &context_;
67 };
68 
69 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
70     FoldingContext &, Subscript &, const NamedEntity &, int dim);
71 
72 // Helper to use host runtime on scalars for folding.
73 template <typename TR, typename... TA>
74 std::optional<std::function<Scalar<TR>(FoldingContext &, Scalar<TA>...)>>
GetHostRuntimeWrapper(const std::string & name)75 GetHostRuntimeWrapper(const std::string &name) {
76   std::vector<DynamicType> argTypes{TA{}.GetType()...};
77   if (auto hostWrapper{GetHostRuntimeWrapper(name, TR{}.GetType(), argTypes)}) {
78     return [hostWrapper](
79                FoldingContext &context, Scalar<TA>... args) -> Scalar<TR> {
80       std::vector<Expr<SomeType>> genericArgs{
81           AsGenericExpr(Constant<TA>{args})...};
82       return GetScalarConstantValue<TR>(
83           (*hostWrapper)(context, std::move(genericArgs)))
84           .value();
85     };
86   }
87   return std::nullopt;
88 }
89 
90 // FoldOperation() rewrites expression tree nodes.
91 // If there is any possibility that the rewritten node will
92 // not have the same representation type, the result of
93 // FoldOperation() will be packaged in an Expr<> of the same
94 // specific type.
95 
96 // no-op base case
97 template <typename A>
FoldOperation(FoldingContext &,A && x)98 common::IfNoLvalue<Expr<ResultType<A>>, A> FoldOperation(
99     FoldingContext &, A &&x) {
100   static_assert(!std::is_same_v<A, Expr<ResultType<A>>>,
101       "call Fold() instead for Expr<>");
102   return Expr<ResultType<A>>{std::move(x)};
103 }
104 
105 Component FoldOperation(FoldingContext &, Component &&);
106 NamedEntity FoldOperation(FoldingContext &, NamedEntity &&);
107 Triplet FoldOperation(FoldingContext &, Triplet &&);
108 Subscript FoldOperation(FoldingContext &, Subscript &&);
109 ArrayRef FoldOperation(FoldingContext &, ArrayRef &&);
110 CoarrayRef FoldOperation(FoldingContext &, CoarrayRef &&);
111 DataRef FoldOperation(FoldingContext &, DataRef &&);
112 Substring FoldOperation(FoldingContext &, Substring &&);
113 ComplexPart FoldOperation(FoldingContext &, ComplexPart &&);
114 
115 template <typename T>
116 Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&);
117 template <int KIND>
118 Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
119     FoldingContext &context, FunctionRef<Type<TypeCategory::Integer, KIND>> &&);
120 template <int KIND>
121 Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
122     FoldingContext &context, FunctionRef<Type<TypeCategory::Real, KIND>> &&);
123 template <int KIND>
124 Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
125     FoldingContext &context, FunctionRef<Type<TypeCategory::Complex, KIND>> &&);
126 template <int KIND>
127 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
128     FoldingContext &context, FunctionRef<Type<TypeCategory::Logical, KIND>> &&);
129 
130 template <typename T>
FoldOperation(FoldingContext & context,Designator<T> && designator)131 Expr<T> FoldOperation(FoldingContext &context, Designator<T> &&designator) {
132   return Folder<T>{context}.Folding(std::move(designator));
133 }
134 
135 Expr<TypeParamInquiry::Result> FoldOperation(
136     FoldingContext &, TypeParamInquiry &&);
137 Expr<ImpliedDoIndex::Result> FoldOperation(
138     FoldingContext &context, ImpliedDoIndex &&);
139 template <typename T>
140 Expr<T> FoldOperation(FoldingContext &, ArrayConstructor<T> &&);
141 Expr<SomeDerived> FoldOperation(FoldingContext &, StructureConstructor &&);
142 
143 template <typename T>
GetNamedConstant(const Symbol & symbol0)144 std::optional<Constant<T>> Folder<T>::GetNamedConstant(const Symbol &symbol0) {
145   const Symbol &symbol{ResolveAssociations(symbol0)};
146   if (IsNamedConstant(symbol)) {
147     if (const auto *object{
148             symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
149       if (const auto *constant{UnwrapConstantValue<T>(object->init())}) {
150         return *constant;
151       }
152     }
153   }
154   return std::nullopt;
155 }
156 
157 template <typename T>
Folding(ArrayRef & aRef)158 std::optional<Constant<T>> Folder<T>::Folding(ArrayRef &aRef) {
159   std::vector<Constant<SubscriptInteger>> subscripts;
160   int dim{0};
161   for (Subscript &ss : aRef.subscript()) {
162     if (auto constant{GetConstantSubscript(context_, ss, aRef.base(), dim++)}) {
163       subscripts.emplace_back(std::move(*constant));
164     } else {
165       return std::nullopt;
166     }
167   }
168   if (Component * component{aRef.base().UnwrapComponent()}) {
169     return GetConstantComponent(*component, &subscripts);
170   } else if (std::optional<Constant<T>> array{
171                  GetNamedConstant(aRef.base().GetLastSymbol())}) {
172     return ApplySubscripts(*array, subscripts);
173   } else {
174     return std::nullopt;
175   }
176 }
177 
178 template <typename T>
ApplySubscripts(const Constant<T> & array,const std::vector<Constant<SubscriptInteger>> & subscripts)179 std::optional<Constant<T>> Folder<T>::ApplySubscripts(const Constant<T> &array,
180     const std::vector<Constant<SubscriptInteger>> &subscripts) {
181   const auto &shape{array.shape()};
182   const auto &lbounds{array.lbounds()};
183   int rank{GetRank(shape)};
184   CHECK(rank == static_cast<int>(subscripts.size()));
185   std::size_t elements{1};
186   ConstantSubscripts resultShape;
187   ConstantSubscripts ssLB;
188   for (const auto &ss : subscripts) {
189     CHECK(ss.Rank() <= 1);
190     if (ss.Rank() == 1) {
191       resultShape.push_back(static_cast<ConstantSubscript>(ss.size()));
192       elements *= ss.size();
193       ssLB.push_back(ss.lbounds().front());
194     }
195   }
196   ConstantSubscripts ssAt(rank, 0), at(rank, 0), tmp(1, 0);
197   std::vector<Scalar<T>> values;
198   while (elements-- > 0) {
199     bool increment{true};
200     int k{0};
201     for (int j{0}; j < rank; ++j) {
202       if (subscripts[j].Rank() == 0) {
203         at[j] = subscripts[j].GetScalarValue().value().ToInt64();
204       } else {
205         CHECK(k < GetRank(resultShape));
206         tmp[0] = ssLB.at(k) + ssAt.at(k);
207         at[j] = subscripts[j].At(tmp).ToInt64();
208         if (increment) {
209           if (++ssAt[k] == resultShape[k]) {
210             ssAt[k] = 0;
211           } else {
212             increment = false;
213           }
214         }
215         ++k;
216       }
217       if (at[j] < lbounds[j] || at[j] >= lbounds[j] + shape[j]) {
218         context_.messages().Say(
219             "Subscript value (%jd) is out of range on dimension %d in reference to a constant array value"_err_en_US,
220             at[j], j + 1);
221         return std::nullopt;
222       }
223     }
224     values.emplace_back(array.At(at));
225     CHECK(!increment || elements == 0);
226     CHECK(k == GetRank(resultShape));
227   }
228   if constexpr (T::category == TypeCategory::Character) {
229     return Constant<T>{array.LEN(), std::move(values), std::move(resultShape)};
230   } else if constexpr (std::is_same_v<T, SomeDerived>) {
231     return Constant<T>{array.result().derivedTypeSpec(), std::move(values),
232         std::move(resultShape)};
233   } else {
234     return Constant<T>{std::move(values), std::move(resultShape)};
235   }
236 }
237 
238 template <typename T>
ApplyComponent(Constant<SomeDerived> && structures,const Symbol & component,const std::vector<Constant<SubscriptInteger>> * subscripts)239 std::optional<Constant<T>> Folder<T>::ApplyComponent(
240     Constant<SomeDerived> &&structures, const Symbol &component,
241     const std::vector<Constant<SubscriptInteger>> *subscripts) {
242   if (auto scalar{structures.GetScalarValue()}) {
243     if (std::optional<Expr<SomeType>> expr{scalar->Find(component)}) {
244       if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
245         if (!subscripts) {
246           return std::move(*value);
247         } else {
248           return ApplySubscripts(*value, *subscripts);
249         }
250       }
251     }
252   } else {
253     // A(:)%scalar_component & A(:)%array_component(subscripts)
254     std::unique_ptr<ArrayConstructor<T>> array;
255     if (structures.empty()) {
256       return std::nullopt;
257     }
258     ConstantSubscripts at{structures.lbounds()};
259     do {
260       StructureConstructor scalar{structures.At(at)};
261       if (std::optional<Expr<SomeType>> expr{scalar.Find(component)}) {
262         if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
263           if (!array.get()) {
264             // This technique ensures that character length or derived type
265             // information is propagated to the array constructor.
266             auto *typedExpr{UnwrapExpr<Expr<T>>(expr.value())};
267             CHECK(typedExpr);
268             array = std::make_unique<ArrayConstructor<T>>(*typedExpr);
269           }
270           if (subscripts) {
271             if (auto element{ApplySubscripts(*value, *subscripts)}) {
272               CHECK(element->Rank() == 0);
273               array->Push(Expr<T>{std::move(*element)});
274             } else {
275               return std::nullopt;
276             }
277           } else {
278             CHECK(value->Rank() == 0);
279             array->Push(Expr<T>{*value});
280           }
281         } else {
282           return std::nullopt;
283         }
284       }
285     } while (structures.IncrementSubscripts(at));
286     // Fold the ArrayConstructor<> into a Constant<>.
287     CHECK(array);
288     Expr<T> result{Fold(context_, Expr<T>{std::move(*array)})};
289     if (auto *constant{UnwrapConstantValue<T>(result)}) {
290       return constant->Reshape(common::Clone(structures.shape()));
291     }
292   }
293   return std::nullopt;
294 }
295 
296 template <typename T>
GetConstantComponent(Component & component,const std::vector<Constant<SubscriptInteger>> * subscripts)297 std::optional<Constant<T>> Folder<T>::GetConstantComponent(Component &component,
298     const std::vector<Constant<SubscriptInteger>> *subscripts) {
299   if (std::optional<Constant<SomeDerived>> structures{std::visit(
300           common::visitors{
301               [&](const Symbol &symbol) {
302                 return Folder<SomeDerived>{context_}.GetNamedConstant(symbol);
303               },
304               [&](ArrayRef &aRef) {
305                 return Folder<SomeDerived>{context_}.Folding(aRef);
306               },
307               [&](Component &base) {
308                 return Folder<SomeDerived>{context_}.GetConstantComponent(base);
309               },
310               [&](CoarrayRef &) {
311                 return std::optional<Constant<SomeDerived>>{};
312               },
313           },
314           component.base().u)}) {
315     return ApplyComponent(
316         std::move(*structures), component.GetLastSymbol(), subscripts);
317   } else {
318     return std::nullopt;
319   }
320 }
321 
Folding(Designator<T> && designator)322 template <typename T> Expr<T> Folder<T>::Folding(Designator<T> &&designator) {
323   if constexpr (T::category == TypeCategory::Character) {
324     if (auto *substring{common::Unwrap<Substring>(designator.u)}) {
325       if (std::optional<Expr<SomeCharacter>> folded{
326               substring->Fold(context_)}) {
327         if (auto value{GetScalarConstantValue<T>(*folded)}) {
328           return Expr<T>{*value};
329         }
330       }
331       if (auto length{ToInt64(Fold(context_, substring->LEN()))}) {
332         if (*length == 0) {
333           return Expr<T>{Constant<T>{Scalar<T>{}}};
334         }
335       }
336     }
337   }
338   return std::visit(
339       common::visitors{
340           [&](SymbolRef &&symbol) {
341             if (auto constant{GetNamedConstant(*symbol)}) {
342               return Expr<T>{std::move(*constant)};
343             }
344             return Expr<T>{std::move(designator)};
345           },
346           [&](ArrayRef &&aRef) {
347             aRef = FoldOperation(context_, std::move(aRef));
348             if (auto c{Folding(aRef)}) {
349               return Expr<T>{std::move(*c)};
350             } else {
351               return Expr<T>{Designator<T>{std::move(aRef)}};
352             }
353           },
354           [&](Component &&component) {
355             component = FoldOperation(context_, std::move(component));
356             if (auto c{GetConstantComponent(component)}) {
357               return Expr<T>{std::move(*c)};
358             } else {
359               return Expr<T>{Designator<T>{std::move(component)}};
360             }
361           },
362           [&](auto &&x) {
363             return Expr<T>{
364                 Designator<T>{FoldOperation(context_, std::move(x))}};
365           },
366       },
367       std::move(designator.u));
368 }
369 
370 // Apply type conversion and re-folding if necessary.
371 // This is where BOZ arguments are converted.
372 template <typename T>
Folding(std::optional<ActualArgument> & arg)373 Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
374   if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
375     if (!UnwrapExpr<Expr<T>>(*expr)) {
376       if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
377         *expr = Fold(context_, std::move(*converted));
378       }
379     }
380     return UnwrapConstantValue<T>(*expr);
381   }
382   return nullptr;
383 }
384 
385 template <typename... A, std::size_t... I>
GetConstantArgumentsHelper(FoldingContext & context,ActualArguments & arguments,std::index_sequence<I...>)386 std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
387     FoldingContext &context, ActualArguments &arguments,
388     std::index_sequence<I...>) {
389   static_assert(
390       (... && IsSpecificIntrinsicType<A>)); // TODO derived types for MERGE?
391   static_assert(sizeof...(A) > 0);
392   std::tuple<const Constant<A> *...> args{
393       Folder<A>{context}.Folding(arguments.at(I))...};
394   if ((... && (std::get<I>(args)))) {
395     return args;
396   } else {
397     return std::nullopt;
398   }
399 }
400 
401 template <typename... A>
GetConstantArguments(FoldingContext & context,ActualArguments & args)402 std::optional<std::tuple<const Constant<A> *...>> GetConstantArguments(
403     FoldingContext &context, ActualArguments &args) {
404   return GetConstantArgumentsHelper<A...>(
405       context, args, std::index_sequence_for<A...>{});
406 }
407 
408 template <typename... A, std::size_t... I>
GetScalarConstantArgumentsHelper(FoldingContext & context,ActualArguments & args,std::index_sequence<I...>)409 std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArgumentsHelper(
410     FoldingContext &context, ActualArguments &args, std::index_sequence<I...>) {
411   if (auto constArgs{GetConstantArguments<A...>(context, args)}) {
412     return std::tuple<Scalar<A>...>{
413         std::get<I>(*constArgs)->GetScalarValue().value()...};
414   } else {
415     return std::nullopt;
416   }
417 }
418 
419 template <typename... A>
GetScalarConstantArguments(FoldingContext & context,ActualArguments & args)420 std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArguments(
421     FoldingContext &context, ActualArguments &args) {
422   return GetScalarConstantArgumentsHelper<A...>(
423       context, args, std::index_sequence_for<A...>{});
424 }
425 
426 // helpers to fold intrinsic function references
427 // Define callable types used in a common utility that
428 // takes care of array and cast/conversion aspects for elemental intrinsics
429 
430 template <typename TR, typename... TArgs>
431 using ScalarFunc = std::function<Scalar<TR>(const Scalar<TArgs> &...)>;
432 template <typename TR, typename... TArgs>
433 using ScalarFuncWithContext =
434     std::function<Scalar<TR>(FoldingContext &, const Scalar<TArgs> &...)>;
435 
436 template <template <typename, typename...> typename WrapperType, typename TR,
437     typename... TA, std::size_t... I>
FoldElementalIntrinsicHelper(FoldingContext & context,FunctionRef<TR> && funcRef,WrapperType<TR,TA...> func,std::index_sequence<I...>)438 Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
439     FunctionRef<TR> &&funcRef, WrapperType<TR, TA...> func,
440     std::index_sequence<I...>) {
441   if (std::optional<std::tuple<const Constant<TA> *...>> args{
442           GetConstantArguments<TA...>(context, funcRef.arguments())}) {
443     // Compute the shape of the result based on shapes of arguments
444     ConstantSubscripts shape;
445     int rank{0};
446     const ConstantSubscripts *shapes[sizeof...(TA)]{
447         &std::get<I>(*args)->shape()...};
448     const int ranks[sizeof...(TA)]{std::get<I>(*args)->Rank()...};
449     for (unsigned int i{0}; i < sizeof...(TA); ++i) {
450       if (ranks[i] > 0) {
451         if (rank == 0) {
452           rank = ranks[i];
453           shape = *shapes[i];
454         } else {
455           if (shape != *shapes[i]) {
456             // TODO: Rank compatibility was already checked but it seems to be
457             // the first place where the actual shapes are checked to be the
458             // same. Shouldn't this be checked elsewhere so that this is also
459             // checked for non constexpr call to elemental intrinsics function?
460             context.messages().Say(
461                 "Arguments in elemental intrinsic function are not conformable"_err_en_US);
462             return Expr<TR>{std::move(funcRef)};
463           }
464         }
465       }
466     }
467     CHECK(rank == GetRank(shape));
468 
469     // Compute all the scalar values of the results
470     std::vector<Scalar<TR>> results;
471     if (TotalElementCount(shape) > 0) {
472       ConstantBounds bounds{shape};
473       ConstantSubscripts index(rank, 1);
474       do {
475         if constexpr (std::is_same_v<WrapperType<TR, TA...>,
476                           ScalarFuncWithContext<TR, TA...>>) {
477           results.emplace_back(func(context,
478               (ranks[I] ? std::get<I>(*args)->At(index)
479                         : std::get<I>(*args)->GetScalarValue().value())...));
480         } else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
481                                  ScalarFunc<TR, TA...>>) {
482           results.emplace_back(func(
483               (ranks[I] ? std::get<I>(*args)->At(index)
484                         : std::get<I>(*args)->GetScalarValue().value())...));
485         }
486       } while (bounds.IncrementSubscripts(index));
487     }
488     // Build and return constant result
489     if constexpr (TR::category == TypeCategory::Character) {
490       auto len{static_cast<ConstantSubscript>(
491           results.size() ? results[0].length() : 0)};
492       return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
493     } else {
494       return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
495     }
496   }
497   return Expr<TR>{std::move(funcRef)};
498 }
499 
500 template <typename TR, typename... TA>
FoldElementalIntrinsic(FoldingContext & context,FunctionRef<TR> && funcRef,ScalarFunc<TR,TA...> func)501 Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
502     FunctionRef<TR> &&funcRef, ScalarFunc<TR, TA...> func) {
503   return FoldElementalIntrinsicHelper<ScalarFunc, TR, TA...>(
504       context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
505 }
506 template <typename TR, typename... TA>
FoldElementalIntrinsic(FoldingContext & context,FunctionRef<TR> && funcRef,ScalarFuncWithContext<TR,TA...> func)507 Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
508     FunctionRef<TR> &&funcRef, ScalarFuncWithContext<TR, TA...> func) {
509   return FoldElementalIntrinsicHelper<ScalarFuncWithContext, TR, TA...>(
510       context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
511 }
512 
513 std::optional<std::int64_t> GetInt64Arg(const std::optional<ActualArgument> &);
514 std::optional<std::int64_t> GetInt64ArgOr(
515     const std::optional<ActualArgument> &, std::int64_t defaultValue);
516 
517 template <typename A, typename B>
GetIntegerVector(const B & x)518 std::optional<std::vector<A>> GetIntegerVector(const B &x) {
519   static_assert(std::is_integral_v<A>);
520   if (const auto *someInteger{UnwrapExpr<Expr<SomeInteger>>(x)}) {
521     return std::visit(
522         [](const auto &typedExpr) -> std::optional<std::vector<A>> {
523           using T = ResultType<decltype(typedExpr)>;
524           if (const auto *constant{UnwrapConstantValue<T>(typedExpr)}) {
525             if (constant->Rank() == 1) {
526               std::vector<A> result;
527               for (const auto &value : constant->values()) {
528                 result.push_back(static_cast<A>(value.ToInt64()));
529               }
530               return result;
531             }
532           }
533           return std::nullopt;
534         },
535         someInteger->u);
536   }
537   return std::nullopt;
538 }
539 
540 // Transform an intrinsic function reference that contains user errors
541 // into an intrinsic with the same characteristic but the "invalid" name.
542 // This to prevent generating warnings over and over if the expression
543 // gets re-folded.
MakeInvalidIntrinsic(FunctionRef<T> && funcRef)544 template <typename T> Expr<T> MakeInvalidIntrinsic(FunctionRef<T> &&funcRef) {
545   SpecificIntrinsic invalid{std::get<SpecificIntrinsic>(funcRef.proc().u)};
546   invalid.name = IntrinsicProcTable::InvalidName;
547   return Expr<T>{FunctionRef<T>{ProcedureDesignator{std::move(invalid)},
548       ActualArguments{std::move(funcRef.arguments())}}};
549 }
550 
Reshape(FunctionRef<T> && funcRef)551 template <typename T> Expr<T> Folder<T>::Reshape(FunctionRef<T> &&funcRef) {
552   auto args{funcRef.arguments()};
553   CHECK(args.size() == 4);
554   const auto *source{UnwrapConstantValue<T>(args[0])};
555   const auto *pad{UnwrapConstantValue<T>(args[2])};
556   std::optional<std::vector<ConstantSubscript>> shape{
557       GetIntegerVector<ConstantSubscript>(args[1])};
558   std::optional<std::vector<int>> order{GetIntegerVector<int>(args[3])};
559   if (!source || !shape || (args[2] && !pad) || (args[3] && !order)) {
560     return Expr<T>{std::move(funcRef)}; // Non-constant arguments
561   } else if (shape.value().size() > common::maxRank) {
562     context_.messages().Say(
563         "Size of 'shape=' argument must not be greater than %d"_err_en_US,
564         common::maxRank);
565   } else if (HasNegativeExtent(shape.value())) {
566     context_.messages().Say(
567         "'shape=' argument must not have a negative extent"_err_en_US);
568   } else {
569     int rank{GetRank(shape.value())};
570     std::size_t resultElements{TotalElementCount(shape.value())};
571     std::optional<std::vector<int>> dimOrder;
572     if (order) {
573       dimOrder = ValidateDimensionOrder(rank, *order);
574     }
575     std::vector<int> *dimOrderPtr{dimOrder ? &dimOrder.value() : nullptr};
576     if (order && !dimOrder) {
577       context_.messages().Say("Invalid 'order=' argument in RESHAPE"_err_en_US);
578     } else if (resultElements > source->size() && (!pad || pad->empty())) {
579       context_.messages().Say(
580           "Too few elements in 'source=' argument and 'pad=' "
581           "argument is not present or has null size"_err_en_US);
582     } else {
583       Constant<T> result{!source->empty() || !pad
584               ? source->Reshape(std::move(shape.value()))
585               : pad->Reshape(std::move(shape.value()))};
586       ConstantSubscripts subscripts{result.lbounds()};
587       auto copied{result.CopyFrom(*source,
588           std::min(source->size(), resultElements), subscripts, dimOrderPtr)};
589       if (copied < resultElements) {
590         CHECK(pad);
591         copied += result.CopyFrom(
592             *pad, resultElements - copied, subscripts, dimOrderPtr);
593       }
594       CHECK(copied == resultElements);
595       return Expr<T>{std::move(result)};
596     }
597   }
598   // Invalid, prevent re-folding
599   return MakeInvalidIntrinsic(std::move(funcRef));
600 }
601 
602 template <typename T>
FoldMINorMAX(FoldingContext & context,FunctionRef<T> && funcRef,Ordering order)603 Expr<T> FoldMINorMAX(
604     FoldingContext &context, FunctionRef<T> &&funcRef, Ordering order) {
605   std::vector<Constant<T> *> constantArgs;
606   // Call Folding on all arguments, even if some are not constant,
607   // to make operand promotion explicit.
608   for (auto &arg : funcRef.arguments()) {
609     if (auto *cst{Folder<T>{context}.Folding(arg)}) {
610       constantArgs.push_back(cst);
611     }
612   }
613   if (constantArgs.size() != funcRef.arguments().size())
614     return Expr<T>(std::move(funcRef));
615   CHECK(constantArgs.size() > 0);
616   Expr<T> result{std::move(*constantArgs[0])};
617   for (std::size_t i{1}; i < constantArgs.size(); ++i) {
618     Extremum<T> extremum{order, result, Expr<T>{std::move(*constantArgs[i])}};
619     result = FoldOperation(context, std::move(extremum));
620   }
621   return result;
622 }
623 
624 // For AMAX0, AMIN0, AMAX1, AMIN1, DMAX1, DMIN1, MAX0, MIN0, MAX1, and MIN1
625 // a special care has to be taken to insert the conversion on the result
626 // of the MIN/MAX. This is made slightly more complex by the extension
627 // supported by f18 that arguments may have different kinds. This implies
628 // that the created MIN/MAX result type cannot be deduced from the standard but
629 // has to be deduced from the arguments.
630 // e.g. AMAX0(int8, int4) is rewritten to REAL(MAX(int8, INT(int4, 8)))).
631 template <typename T>
RewriteSpecificMINorMAX(FoldingContext & context,FunctionRef<T> && funcRef)632 Expr<T> RewriteSpecificMINorMAX(
633     FoldingContext &context, FunctionRef<T> &&funcRef) {
634   ActualArguments &args{funcRef.arguments()};
635   auto &intrinsic{DEREF(std::get_if<SpecificIntrinsic>(&funcRef.proc().u))};
636   // Rewrite MAX1(args) to INT(MAX(args)) and fold. Same logic for MIN1.
637   // Find result type for max/min based on the arguments.
638   DynamicType resultType{args[0].value().GetType().value()};
639   auto *resultTypeArg{&args[0]};
640   for (auto j{args.size() - 1}; j > 0; --j) {
641     DynamicType type{args[j].value().GetType().value()};
642     if (type.category() == resultType.category()) {
643       if (type.kind() > resultType.kind()) {
644         resultTypeArg = &args[j];
645         resultType = type;
646       }
647     } else if (resultType.category() == TypeCategory::Integer) {
648       // Handle mixed real/integer arguments: all the previous arguments were
649       // integers and this one is real. The type of the MAX/MIN result will
650       // be the one of the real argument.
651       resultTypeArg = &args[j];
652       resultType = type;
653     }
654   }
655   intrinsic.name =
656       intrinsic.name.find("max") != std::string::npos ? "max"s : "min"s;
657   intrinsic.characteristics.value().functionResult.value().SetType(resultType);
658   auto insertConversion{[&](const auto &x) -> Expr<T> {
659     using TR = ResultType<decltype(x)>;
660     FunctionRef<TR> maxRef{std::move(funcRef.proc()), std::move(args)};
661     return Fold(context, ConvertToType<T>(AsCategoryExpr(std::move(maxRef))));
662   }};
663   if (auto *sx{UnwrapExpr<Expr<SomeReal>>(*resultTypeArg)}) {
664     return std::visit(insertConversion, sx->u);
665   }
666   auto &sx{DEREF(UnwrapExpr<Expr<SomeInteger>>(*resultTypeArg))};
667   return std::visit(insertConversion, sx.u);
668 }
669 
670 template <typename T>
FoldOperation(FoldingContext & context,FunctionRef<T> && funcRef)671 Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
672   ActualArguments &args{funcRef.arguments()};
673   for (std::optional<ActualArgument> &arg : args) {
674     if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
675       *expr = Fold(context, std::move(*expr));
676     }
677   }
678   if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
679     const std::string name{intrinsic->name};
680     if (name == "reshape") {
681       return Folder<T>{context}.Reshape(std::move(funcRef));
682     }
683     // TODO: other type independent transformationals
684     if constexpr (!std::is_same_v<T, SomeDerived>) {
685       return FoldIntrinsicFunction(context, std::move(funcRef));
686     }
687   }
688   return Expr<T>{std::move(funcRef)};
689 }
690 
691 template <typename T>
FoldMerge(FoldingContext & context,FunctionRef<T> && funcRef)692 Expr<T> FoldMerge(FoldingContext &context, FunctionRef<T> &&funcRef) {
693   return FoldElementalIntrinsic<T, T, T, LogicalResult>(context,
694       std::move(funcRef),
695       ScalarFunc<T, T, T, LogicalResult>(
696           [](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
697               const Scalar<LogicalResult> &predicate) -> Scalar<T> {
698             return predicate.IsTrue() ? ifTrue : ifFalse;
699           }));
700 }
701 
702 Expr<ImpliedDoIndex::Result> FoldOperation(FoldingContext &, ImpliedDoIndex &&);
703 
704 // Array constructor folding
705 template <typename T> class ArrayConstructorFolder {
706 public:
ArrayConstructorFolder(const FoldingContext & c)707   explicit ArrayConstructorFolder(const FoldingContext &c) : context_{c} {}
708 
FoldArray(ArrayConstructor<T> && array)709   Expr<T> FoldArray(ArrayConstructor<T> &&array) {
710     // Calls FoldArray(const ArrayConstructorValues<T> &) below
711     if (FoldArray(array)) {
712       auto n{static_cast<ConstantSubscript>(elements_.size())};
713       if constexpr (std::is_same_v<T, SomeDerived>) {
714         return Expr<T>{Constant<T>{array.GetType().GetDerivedTypeSpec(),
715             std::move(elements_), ConstantSubscripts{n}}};
716       } else if constexpr (T::category == TypeCategory::Character) {
717         auto length{Fold(context_, common::Clone(array.LEN()))};
718         if (std::optional<ConstantSubscript> lengthValue{ToInt64(length)}) {
719           return Expr<T>{Constant<T>{
720               *lengthValue, std::move(elements_), ConstantSubscripts{n}}};
721         }
722       } else {
723         return Expr<T>{
724             Constant<T>{std::move(elements_), ConstantSubscripts{n}}};
725       }
726     }
727     return Expr<T>{std::move(array)};
728   }
729 
730 private:
FoldArray(const common::CopyableIndirection<Expr<T>> & expr)731   bool FoldArray(const common::CopyableIndirection<Expr<T>> &expr) {
732     Expr<T> folded{Fold(context_, common::Clone(expr.value()))};
733     if (const auto *c{UnwrapConstantValue<T>(folded)}) {
734       // Copy elements in Fortran array element order
735       ConstantSubscripts shape{c->shape()};
736       int rank{c->Rank()};
737       ConstantSubscripts index(GetRank(shape), 1);
738       for (std::size_t n{c->size()}; n-- > 0;) {
739         elements_.emplace_back(c->At(index));
740         for (int d{0}; d < rank; ++d) {
741           if (++index[d] <= shape[d]) {
742             break;
743           }
744           index[d] = 1;
745         }
746       }
747       return true;
748     } else {
749       return false;
750     }
751   }
FoldArray(const ImpliedDo<T> & iDo)752   bool FoldArray(const ImpliedDo<T> &iDo) {
753     Expr<SubscriptInteger> lower{
754         Fold(context_, Expr<SubscriptInteger>{iDo.lower()})};
755     Expr<SubscriptInteger> upper{
756         Fold(context_, Expr<SubscriptInteger>{iDo.upper()})};
757     Expr<SubscriptInteger> stride{
758         Fold(context_, Expr<SubscriptInteger>{iDo.stride()})};
759     std::optional<ConstantSubscript> start{ToInt64(lower)}, end{ToInt64(upper)},
760         step{ToInt64(stride)};
761     if (start && end && step && *step != 0) {
762       bool result{true};
763       ConstantSubscript &j{context_.StartImpliedDo(iDo.name(), *start)};
764       if (*step > 0) {
765         for (; j <= *end; j += *step) {
766           result &= FoldArray(iDo.values());
767         }
768       } else {
769         for (; j >= *end; j += *step) {
770           result &= FoldArray(iDo.values());
771         }
772       }
773       context_.EndImpliedDo(iDo.name());
774       return result;
775     } else {
776       return false;
777     }
778   }
FoldArray(const ArrayConstructorValue<T> & x)779   bool FoldArray(const ArrayConstructorValue<T> &x) {
780     return std::visit([&](const auto &y) { return FoldArray(y); }, x.u);
781   }
FoldArray(const ArrayConstructorValues<T> & xs)782   bool FoldArray(const ArrayConstructorValues<T> &xs) {
783     for (const auto &x : xs) {
784       if (!FoldArray(x)) {
785         return false;
786       }
787     }
788     return true;
789   }
790 
791   FoldingContext context_;
792   std::vector<Scalar<T>> elements_;
793 };
794 
795 template <typename T>
FoldOperation(FoldingContext & context,ArrayConstructor<T> && array)796 Expr<T> FoldOperation(FoldingContext &context, ArrayConstructor<T> &&array) {
797   return ArrayConstructorFolder<T>{context}.FoldArray(std::move(array));
798 }
799 
800 // Array operation elemental application: When all operands to an operation
801 // are constant arrays, array constructors without any implied DO loops,
802 // &/or expanded scalars, pull the operation "into" the array result by
803 // applying it in an elementwise fashion.  For example, [A,1]+[B,2]
804 // is rewritten into [A+B,1+2] and then partially folded to [A+B,3].
805 
806 // If possible, restructures an array expression into an array constructor
807 // that comprises a "flat" ArrayConstructorValues with no implied DO loops.
808 template <typename T>
ArrayConstructorIsFlat(const ArrayConstructorValues<T> & values)809 bool ArrayConstructorIsFlat(const ArrayConstructorValues<T> &values) {
810   for (const ArrayConstructorValue<T> &x : values) {
811     if (!std::holds_alternative<Expr<T>>(x.u)) {
812       return false;
813     }
814   }
815   return true;
816 }
817 
818 template <typename T>
AsFlatArrayConstructor(const Expr<T> & expr)819 std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
820   if (const auto *c{UnwrapConstantValue<T>(expr)}) {
821     ArrayConstructor<T> result{expr};
822     if (c->size() > 0) {
823       ConstantSubscripts at{c->lbounds()};
824       do {
825         result.Push(Expr<T>{Constant<T>{c->At(at)}});
826       } while (c->IncrementSubscripts(at));
827     }
828     return std::make_optional<Expr<T>>(std::move(result));
829   } else if (const auto *a{UnwrapExpr<ArrayConstructor<T>>(expr)}) {
830     if (ArrayConstructorIsFlat(*a)) {
831       return std::make_optional<Expr<T>>(expr);
832     }
833   } else if (const auto *p{UnwrapExpr<Parentheses<T>>(expr)}) {
834     return AsFlatArrayConstructor(Expr<T>{p->left()});
835   }
836   return std::nullopt;
837 }
838 
839 template <TypeCategory CAT>
840 std::enable_if_t<CAT != TypeCategory::Derived,
841     std::optional<Expr<SomeKind<CAT>>>>
AsFlatArrayConstructor(const Expr<SomeKind<CAT>> & expr)842 AsFlatArrayConstructor(const Expr<SomeKind<CAT>> &expr) {
843   return std::visit(
844       [&](const auto &kindExpr) -> std::optional<Expr<SomeKind<CAT>>> {
845         if (auto flattened{AsFlatArrayConstructor(kindExpr)}) {
846           return Expr<SomeKind<CAT>>{std::move(*flattened)};
847         } else {
848           return std::nullopt;
849         }
850       },
851       expr.u);
852 }
853 
854 // FromArrayConstructor is a subroutine for MapOperation() below.
855 // Given a flat ArrayConstructor<T> and a shape, it wraps the array
856 // into an Expr<T>, folds it, and returns the resulting wrapped
857 // array constructor or constant array value.
858 template <typename T>
FromArrayConstructor(FoldingContext & context,ArrayConstructor<T> && values,std::optional<ConstantSubscripts> && shape)859 Expr<T> FromArrayConstructor(FoldingContext &context,
860     ArrayConstructor<T> &&values, std::optional<ConstantSubscripts> &&shape) {
861   Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
862   if (shape) {
863     if (auto *constant{UnwrapConstantValue<T>(result)}) {
864       return Expr<T>{constant->Reshape(std::move(*shape))};
865     }
866   }
867   return result;
868 }
869 
870 // MapOperation is a utility for various specializations of ApplyElementwise()
871 // that follow.  Given one or two flat ArrayConstructor<OPERAND> (wrapped in an
872 // Expr<OPERAND>) for some specific operand type(s), apply a given function f
873 // to each of their corresponding elements to produce a flat
874 // ArrayConstructor<RESULT> (wrapped in an Expr<RESULT>).
875 // Preserves shape.
876 
877 // Unary case
878 template <typename RESULT, typename OPERAND>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<OPERAND> &&)> && f,const Shape & shape,Expr<OPERAND> && values)879 Expr<RESULT> MapOperation(FoldingContext &context,
880     std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f, const Shape &shape,
881     Expr<OPERAND> &&values) {
882   ArrayConstructor<RESULT> result{values};
883   if constexpr (common::HasMember<OPERAND, AllIntrinsicCategoryTypes>) {
884     std::visit(
885         [&](auto &&kindExpr) {
886           using kindType = ResultType<decltype(kindExpr)>;
887           auto &aConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
888           for (auto &acValue : aConst) {
889             auto &scalar{std::get<Expr<kindType>>(acValue.u)};
890             result.Push(Fold(context, f(Expr<OPERAND>{std::move(scalar)})));
891           }
892         },
893         std::move(values.u));
894   } else {
895     auto &aConst{std::get<ArrayConstructor<OPERAND>>(values.u)};
896     for (auto &acValue : aConst) {
897       auto &scalar{std::get<Expr<OPERAND>>(acValue.u)};
898       result.Push(Fold(context, f(std::move(scalar))));
899     }
900   }
901   return FromArrayConstructor(
902       context, std::move(result), AsConstantExtents(context, shape));
903 }
904 
905 // array * array case
906 template <typename RESULT, typename LEFT, typename RIGHT>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<LEFT> &&,Expr<RIGHT> &&)> && f,const Shape & shape,Expr<LEFT> && leftValues,Expr<RIGHT> && rightValues)907 Expr<RESULT> MapOperation(FoldingContext &context,
908     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
909     const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
910   ArrayConstructor<RESULT> result{leftValues};
911   auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
912   if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
913     std::visit(
914         [&](auto &&kindExpr) {
915           using kindType = ResultType<decltype(kindExpr)>;
916 
917           auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
918           auto rightIter{rightArrConst.begin()};
919           for (auto &leftValue : leftArrConst) {
920             CHECK(rightIter != rightArrConst.end());
921             auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
922             auto &rightScalar{std::get<Expr<kindType>>(rightIter->u)};
923             result.Push(Fold(context,
924                 f(std::move(leftScalar), Expr<RIGHT>{std::move(rightScalar)})));
925             ++rightIter;
926           }
927         },
928         std::move(rightValues.u));
929   } else {
930     auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
931     auto rightIter{rightArrConst.begin()};
932     for (auto &leftValue : leftArrConst) {
933       CHECK(rightIter != rightArrConst.end());
934       auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
935       auto &rightScalar{std::get<Expr<RIGHT>>(rightIter->u)};
936       result.Push(
937           Fold(context, f(std::move(leftScalar), std::move(rightScalar))));
938       ++rightIter;
939     }
940   }
941   return FromArrayConstructor(
942       context, std::move(result), AsConstantExtents(context, shape));
943 }
944 
945 // array * scalar case
946 template <typename RESULT, typename LEFT, typename RIGHT>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<LEFT> &&,Expr<RIGHT> &&)> && f,const Shape & shape,Expr<LEFT> && leftValues,const Expr<RIGHT> & rightScalar)947 Expr<RESULT> MapOperation(FoldingContext &context,
948     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
949     const Shape &shape, Expr<LEFT> &&leftValues,
950     const Expr<RIGHT> &rightScalar) {
951   ArrayConstructor<RESULT> result{leftValues};
952   auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
953   for (auto &leftValue : leftArrConst) {
954     auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
955     result.Push(
956         Fold(context, f(std::move(leftScalar), Expr<RIGHT>{rightScalar})));
957   }
958   return FromArrayConstructor(
959       context, std::move(result), AsConstantExtents(context, shape));
960 }
961 
962 // scalar * array case
963 template <typename RESULT, typename LEFT, typename RIGHT>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<LEFT> &&,Expr<RIGHT> &&)> && f,const Shape & shape,const Expr<LEFT> & leftScalar,Expr<RIGHT> && rightValues)964 Expr<RESULT> MapOperation(FoldingContext &context,
965     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
966     const Shape &shape, const Expr<LEFT> &leftScalar,
967     Expr<RIGHT> &&rightValues) {
968   ArrayConstructor<RESULT> result{leftScalar};
969   if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
970     std::visit(
971         [&](auto &&kindExpr) {
972           using kindType = ResultType<decltype(kindExpr)>;
973           auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
974           for (auto &rightValue : rightArrConst) {
975             auto &rightScalar{std::get<Expr<kindType>>(rightValue.u)};
976             result.Push(Fold(context,
977                 f(Expr<LEFT>{leftScalar},
978                     Expr<RIGHT>{std::move(rightScalar)})));
979           }
980         },
981         std::move(rightValues.u));
982   } else {
983     auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
984     for (auto &rightValue : rightArrConst) {
985       auto &rightScalar{std::get<Expr<RIGHT>>(rightValue.u)};
986       result.Push(
987           Fold(context, f(Expr<LEFT>{leftScalar}, std::move(rightScalar))));
988     }
989   }
990   return FromArrayConstructor(
991       context, std::move(result), AsConstantExtents(context, shape));
992 }
993 
994 // ApplyElementwise() recursively folds the operand expression(s) of an
995 // operation, then attempts to apply the operation to the (corresponding)
996 // scalar element(s) of those operands.  Returns std::nullopt for scalars
997 // or unlinearizable operands.
998 template <typename DERIVED, typename RESULT, typename OPERAND>
999 auto ApplyElementwise(FoldingContext &context,
1000     Operation<DERIVED, RESULT, OPERAND> &operation,
1001     std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f)
1002     -> std::optional<Expr<RESULT>> {
1003   auto &expr{operation.left()};
1004   expr = Fold(context, std::move(expr));
1005   if (expr.Rank() > 0) {
1006     if (std::optional<Shape> shape{GetShape(context, expr)}) {
1007       if (auto values{AsFlatArrayConstructor(expr)}) {
1008         return MapOperation(context, std::move(f), *shape, std::move(*values));
1009       }
1010     }
1011   }
1012   return std::nullopt;
1013 }
1014 
1015 template <typename DERIVED, typename RESULT, typename OPERAND>
1016 auto ApplyElementwise(
1017     FoldingContext &context, Operation<DERIVED, RESULT, OPERAND> &operation)
1018     -> std::optional<Expr<RESULT>> {
1019   return ApplyElementwise(context, operation,
1020       std::function<Expr<RESULT>(Expr<OPERAND> &&)>{
1021           [](Expr<OPERAND> &&operand) {
1022             return Expr<RESULT>{DERIVED{std::move(operand)}};
1023           }});
1024 }
1025 
1026 template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1027 auto ApplyElementwise(FoldingContext &context,
1028     Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
1029     std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
1030     -> std::optional<Expr<RESULT>> {
1031   auto &leftExpr{operation.left()};
1032   leftExpr = Fold(context, std::move(leftExpr));
1033   auto &rightExpr{operation.right()};
1034   rightExpr = Fold(context, std::move(rightExpr));
1035   if (leftExpr.Rank() > 0) {
1036     if (std::optional<Shape> leftShape{GetShape(context, leftExpr)}) {
1037       if (auto left{AsFlatArrayConstructor(leftExpr)}) {
1038         if (rightExpr.Rank() > 0) {
1039           if (std::optional<Shape> rightShape{GetShape(context, rightExpr)}) {
1040             if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1041               if (CheckConformance(
1042                       context.messages(), *leftShape, *rightShape)) {
1043                 return MapOperation(context, std::move(f), *leftShape,
1044                     std::move(*left), std::move(*right));
1045               } else {
1046                 return std::nullopt;
1047               }
1048               return MapOperation(context, std::move(f), *leftShape,
1049                   std::move(*left), std::move(*right));
1050             }
1051           }
1052         } else if (IsExpandableScalar(rightExpr)) {
1053           return MapOperation(
1054               context, std::move(f), *leftShape, std::move(*left), rightExpr);
1055         }
1056       }
1057     }
1058   } else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
1059     if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
1060       if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1061         return MapOperation(
1062             context, std::move(f), *shape, leftExpr, std::move(*right));
1063       }
1064     }
1065   }
1066   return std::nullopt;
1067 }
1068 
1069 template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1070 auto ApplyElementwise(
1071     FoldingContext &context, Operation<DERIVED, RESULT, LEFT, RIGHT> &operation)
1072     -> std::optional<Expr<RESULT>> {
1073   return ApplyElementwise(context, operation,
1074       std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)>{
1075           [](Expr<LEFT> &&left, Expr<RIGHT> &&right) {
1076             return Expr<RESULT>{DERIVED{std::move(left), std::move(right)}};
1077           }});
1078 }
1079 
1080 // Unary operations
1081 
1082 template <typename TO, typename FROM>
ConvertString(FROM && s)1083 common::IfNoLvalue<std::optional<TO>, FROM> ConvertString(FROM &&s) {
1084   if constexpr (std::is_same_v<TO, FROM>) {
1085     return std::make_optional<TO>(std::move(s));
1086   } else {
1087     // Fortran character conversion is well defined between distinct kinds
1088     // only when the actual characters are valid 7-bit ASCII.
1089     TO str;
1090     for (auto iter{s.cbegin()}; iter != s.cend(); ++iter) {
1091       if (static_cast<std::uint64_t>(*iter) > 127) {
1092         return std::nullopt;
1093       }
1094       str.push_back(*iter);
1095     }
1096     return std::make_optional<TO>(std::move(str));
1097   }
1098 }
1099 
1100 template <typename TO, TypeCategory FROMCAT>
FoldOperation(FoldingContext & context,Convert<TO,FROMCAT> && convert)1101 Expr<TO> FoldOperation(
1102     FoldingContext &context, Convert<TO, FROMCAT> &&convert) {
1103   if (auto array{ApplyElementwise(context, convert)}) {
1104     return *array;
1105   }
1106   struct {
1107     FoldingContext &context;
1108     Convert<TO, FROMCAT> &convert;
1109   } msvcWorkaround{context, convert};
1110   return std::visit(
1111       [&msvcWorkaround](auto &kindExpr) -> Expr<TO> {
1112         using Operand = ResultType<decltype(kindExpr)>;
1113         // This variable is a workaround for msvc which emits an error when
1114         // using the FROMCAT template parameter below.
1115         TypeCategory constexpr FromCat{FROMCAT};
1116         auto &convert{msvcWorkaround.convert};
1117         char buffer[64];
1118         if (auto value{GetScalarConstantValue<Operand>(kindExpr)}) {
1119           FoldingContext &ctx{msvcWorkaround.context};
1120           if constexpr (TO::category == TypeCategory::Integer) {
1121             if constexpr (Operand::category == TypeCategory::Integer) {
1122               auto converted{Scalar<TO>::ConvertSigned(*value)};
1123               if (converted.overflow) {
1124                 ctx.messages().Say(
1125                     "INTEGER(%d) to INTEGER(%d) conversion overflowed"_en_US,
1126                     Operand::kind, TO::kind);
1127               }
1128               return ScalarConstantToExpr(std::move(converted.value));
1129             } else if constexpr (Operand::category == TypeCategory::Real) {
1130               auto converted{value->template ToInteger<Scalar<TO>>()};
1131               if (converted.flags.test(RealFlag::InvalidArgument)) {
1132                 ctx.messages().Say(
1133                     "REAL(%d) to INTEGER(%d) conversion: invalid argument"_en_US,
1134                     Operand::kind, TO::kind);
1135               } else if (converted.flags.test(RealFlag::Overflow)) {
1136                 ctx.messages().Say(
1137                     "REAL(%d) to INTEGER(%d) conversion overflowed"_en_US,
1138                     Operand::kind, TO::kind);
1139               }
1140               return ScalarConstantToExpr(std::move(converted.value));
1141             }
1142           } else if constexpr (TO::category == TypeCategory::Real) {
1143             if constexpr (Operand::category == TypeCategory::Integer) {
1144               auto converted{Scalar<TO>::FromInteger(*value)};
1145               if (!converted.flags.empty()) {
1146                 std::snprintf(buffer, sizeof buffer,
1147                     "INTEGER(%d) to REAL(%d) conversion", Operand::kind,
1148                     TO::kind);
1149                 RealFlagWarnings(ctx, converted.flags, buffer);
1150               }
1151               return ScalarConstantToExpr(std::move(converted.value));
1152             } else if constexpr (Operand::category == TypeCategory::Real) {
1153               auto converted{Scalar<TO>::Convert(*value)};
1154               if (!converted.flags.empty()) {
1155                 std::snprintf(buffer, sizeof buffer,
1156                     "REAL(%d) to REAL(%d) conversion", Operand::kind, TO::kind);
1157                 RealFlagWarnings(ctx, converted.flags, buffer);
1158               }
1159               if (ctx.flushSubnormalsToZero()) {
1160                 converted.value = converted.value.FlushSubnormalToZero();
1161               }
1162               return ScalarConstantToExpr(std::move(converted.value));
1163             }
1164           } else if constexpr (TO::category == TypeCategory::Complex) {
1165             if constexpr (Operand::category == TypeCategory::Complex) {
1166               return FoldOperation(ctx,
1167                   ComplexConstructor<TO::kind>{
1168                       AsExpr(Convert<typename TO::Part>{AsCategoryExpr(
1169                           Constant<typename Operand::Part>{value->REAL()})}),
1170                       AsExpr(Convert<typename TO::Part>{AsCategoryExpr(
1171                           Constant<typename Operand::Part>{value->AIMAG()})})});
1172             }
1173           } else if constexpr (TO::category == TypeCategory::Character &&
1174               Operand::category == TypeCategory::Character) {
1175             if (auto converted{ConvertString<Scalar<TO>>(std::move(*value))}) {
1176               return ScalarConstantToExpr(std::move(*converted));
1177             }
1178           } else if constexpr (TO::category == TypeCategory::Logical &&
1179               Operand::category == TypeCategory::Logical) {
1180             return Expr<TO>{value->IsTrue()};
1181           }
1182         } else if constexpr (std::is_same_v<Operand, TO> &&
1183             FromCat != TypeCategory::Character) {
1184           return std::move(kindExpr); // remove needless conversion
1185         }
1186         return Expr<TO>{std::move(convert)};
1187       },
1188       convert.left().u);
1189 }
1190 
1191 template <typename T>
FoldOperation(FoldingContext & context,Parentheses<T> && x)1192 Expr<T> FoldOperation(FoldingContext &context, Parentheses<T> &&x) {
1193   auto &operand{x.left()};
1194   operand = Fold(context, std::move(operand));
1195   if (auto value{GetScalarConstantValue<T>(operand)}) {
1196     // Preserve parentheses, even around constants.
1197     return Expr<T>{Parentheses<T>{Expr<T>{Constant<T>{*value}}}};
1198   } else if (std::holds_alternative<Parentheses<T>>(operand.u)) {
1199     // ((x)) -> (x)
1200     return std::move(operand);
1201   } else {
1202     return Expr<T>{Parentheses<T>{std::move(operand)}};
1203   }
1204 }
1205 
1206 template <typename T>
FoldOperation(FoldingContext & context,Negate<T> && x)1207 Expr<T> FoldOperation(FoldingContext &context, Negate<T> &&x) {
1208   if (auto array{ApplyElementwise(context, x)}) {
1209     return *array;
1210   }
1211   auto &operand{x.left()};
1212   if (auto value{GetScalarConstantValue<T>(operand)}) {
1213     if constexpr (T::category == TypeCategory::Integer) {
1214       auto negated{value->Negate()};
1215       if (negated.overflow) {
1216         context.messages().Say(
1217             "INTEGER(%d) negation overflowed"_en_US, T::kind);
1218       }
1219       return Expr<T>{Constant<T>{std::move(negated.value)}};
1220     } else {
1221       // REAL & COMPLEX negation: no exceptions possible
1222       return Expr<T>{Constant<T>{value->Negate()}};
1223     }
1224   }
1225   return Expr<T>{std::move(x)};
1226 }
1227 
1228 // Binary (dyadic) operations
1229 
1230 template <typename LEFT, typename RIGHT>
OperandsAreConstants(const Expr<LEFT> & x,const Expr<RIGHT> & y)1231 std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
1232     const Expr<LEFT> &x, const Expr<RIGHT> &y) {
1233   if (auto xvalue{GetScalarConstantValue<LEFT>(x)}) {
1234     if (auto yvalue{GetScalarConstantValue<RIGHT>(y)}) {
1235       return {std::make_pair(*xvalue, *yvalue)};
1236     }
1237   }
1238   return std::nullopt;
1239 }
1240 
1241 template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
OperandsAreConstants(const Operation<DERIVED,RESULT,LEFT,RIGHT> & operation)1242 std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
1243     const Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
1244   return OperandsAreConstants(operation.left(), operation.right());
1245 }
1246 
1247 template <typename T>
FoldOperation(FoldingContext & context,Add<T> && x)1248 Expr<T> FoldOperation(FoldingContext &context, Add<T> &&x) {
1249   if (auto array{ApplyElementwise(context, x)}) {
1250     return *array;
1251   }
1252   if (auto folded{OperandsAreConstants(x)}) {
1253     if constexpr (T::category == TypeCategory::Integer) {
1254       auto sum{folded->first.AddSigned(folded->second)};
1255       if (sum.overflow) {
1256         context.messages().Say(
1257             "INTEGER(%d) addition overflowed"_en_US, T::kind);
1258       }
1259       return Expr<T>{Constant<T>{sum.value}};
1260     } else {
1261       auto sum{folded->first.Add(folded->second, context.rounding())};
1262       RealFlagWarnings(context, sum.flags, "addition");
1263       if (context.flushSubnormalsToZero()) {
1264         sum.value = sum.value.FlushSubnormalToZero();
1265       }
1266       return Expr<T>{Constant<T>{sum.value}};
1267     }
1268   }
1269   return Expr<T>{std::move(x)};
1270 }
1271 
1272 template <typename T>
FoldOperation(FoldingContext & context,Subtract<T> && x)1273 Expr<T> FoldOperation(FoldingContext &context, Subtract<T> &&x) {
1274   if (auto array{ApplyElementwise(context, x)}) {
1275     return *array;
1276   }
1277   if (auto folded{OperandsAreConstants(x)}) {
1278     if constexpr (T::category == TypeCategory::Integer) {
1279       auto difference{folded->first.SubtractSigned(folded->second)};
1280       if (difference.overflow) {
1281         context.messages().Say(
1282             "INTEGER(%d) subtraction overflowed"_en_US, T::kind);
1283       }
1284       return Expr<T>{Constant<T>{difference.value}};
1285     } else {
1286       auto difference{
1287           folded->first.Subtract(folded->second, context.rounding())};
1288       RealFlagWarnings(context, difference.flags, "subtraction");
1289       if (context.flushSubnormalsToZero()) {
1290         difference.value = difference.value.FlushSubnormalToZero();
1291       }
1292       return Expr<T>{Constant<T>{difference.value}};
1293     }
1294   }
1295   return Expr<T>{std::move(x)};
1296 }
1297 
1298 template <typename T>
FoldOperation(FoldingContext & context,Multiply<T> && x)1299 Expr<T> FoldOperation(FoldingContext &context, Multiply<T> &&x) {
1300   if (auto array{ApplyElementwise(context, x)}) {
1301     return *array;
1302   }
1303   if (auto folded{OperandsAreConstants(x)}) {
1304     if constexpr (T::category == TypeCategory::Integer) {
1305       auto product{folded->first.MultiplySigned(folded->second)};
1306       if (product.SignedMultiplicationOverflowed()) {
1307         context.messages().Say(
1308             "INTEGER(%d) multiplication overflowed"_en_US, T::kind);
1309       }
1310       return Expr<T>{Constant<T>{product.lower}};
1311     } else {
1312       auto product{folded->first.Multiply(folded->second, context.rounding())};
1313       RealFlagWarnings(context, product.flags, "multiplication");
1314       if (context.flushSubnormalsToZero()) {
1315         product.value = product.value.FlushSubnormalToZero();
1316       }
1317       return Expr<T>{Constant<T>{product.value}};
1318     }
1319   }
1320   return Expr<T>{std::move(x)};
1321 }
1322 
1323 template <typename T>
FoldOperation(FoldingContext & context,Divide<T> && x)1324 Expr<T> FoldOperation(FoldingContext &context, Divide<T> &&x) {
1325   if (auto array{ApplyElementwise(context, x)}) {
1326     return *array;
1327   }
1328   if (auto folded{OperandsAreConstants(x)}) {
1329     if constexpr (T::category == TypeCategory::Integer) {
1330       auto quotAndRem{folded->first.DivideSigned(folded->second)};
1331       if (quotAndRem.divisionByZero) {
1332         context.messages().Say("INTEGER(%d) division by zero"_en_US, T::kind);
1333         return Expr<T>{std::move(x)};
1334       }
1335       if (quotAndRem.overflow) {
1336         context.messages().Say(
1337             "INTEGER(%d) division overflowed"_en_US, T::kind);
1338       }
1339       return Expr<T>{Constant<T>{quotAndRem.quotient}};
1340     } else {
1341       auto quotient{folded->first.Divide(folded->second, context.rounding())};
1342       RealFlagWarnings(context, quotient.flags, "division");
1343       if (context.flushSubnormalsToZero()) {
1344         quotient.value = quotient.value.FlushSubnormalToZero();
1345       }
1346       return Expr<T>{Constant<T>{quotient.value}};
1347     }
1348   }
1349   return Expr<T>{std::move(x)};
1350 }
1351 
1352 template <typename T>
FoldOperation(FoldingContext & context,Power<T> && x)1353 Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) {
1354   if (auto array{ApplyElementwise(context, x)}) {
1355     return *array;
1356   }
1357   if (auto folded{OperandsAreConstants(x)}) {
1358     if constexpr (T::category == TypeCategory::Integer) {
1359       auto power{folded->first.Power(folded->second)};
1360       if (power.divisionByZero) {
1361         context.messages().Say(
1362             "INTEGER(%d) zero to negative power"_en_US, T::kind);
1363       } else if (power.overflow) {
1364         context.messages().Say("INTEGER(%d) power overflowed"_en_US, T::kind);
1365       } else if (power.zeroToZero) {
1366         context.messages().Say(
1367             "INTEGER(%d) 0**0 is not defined"_en_US, T::kind);
1368       }
1369       return Expr<T>{Constant<T>{power.power}};
1370     } else {
1371       if (auto callable{GetHostRuntimeWrapper<T, T, T>("pow")}) {
1372         return Expr<T>{
1373             Constant<T>{(*callable)(context, folded->first, folded->second)}};
1374       } else {
1375         context.messages().Say(
1376             "Power for %s cannot be folded on host"_en_US, T{}.AsFortran());
1377       }
1378     }
1379   }
1380   return Expr<T>{std::move(x)};
1381 }
1382 
1383 template <typename T>
FoldOperation(FoldingContext & context,RealToIntPower<T> && x)1384 Expr<T> FoldOperation(FoldingContext &context, RealToIntPower<T> &&x) {
1385   if (auto array{ApplyElementwise(context, x)}) {
1386     return *array;
1387   }
1388   return std::visit(
1389       [&](auto &y) -> Expr<T> {
1390         if (auto folded{OperandsAreConstants(x.left(), y)}) {
1391           auto power{evaluate::IntPower(folded->first, folded->second)};
1392           RealFlagWarnings(context, power.flags, "power with INTEGER exponent");
1393           if (context.flushSubnormalsToZero()) {
1394             power.value = power.value.FlushSubnormalToZero();
1395           }
1396           return Expr<T>{Constant<T>{power.value}};
1397         } else {
1398           return Expr<T>{std::move(x)};
1399         }
1400       },
1401       x.right().u);
1402 }
1403 
1404 template <typename T>
FoldOperation(FoldingContext & context,Extremum<T> && x)1405 Expr<T> FoldOperation(FoldingContext &context, Extremum<T> &&x) {
1406   if (auto array{ApplyElementwise(context, x,
1407           std::function<Expr<T>(Expr<T> &&, Expr<T> &&)>{[=](Expr<T> &&l,
1408                                                              Expr<T> &&r) {
1409             return Expr<T>{Extremum<T>{x.ordering, std::move(l), std::move(r)}};
1410           }})}) {
1411     return *array;
1412   }
1413   if (auto folded{OperandsAreConstants(x)}) {
1414     if constexpr (T::category == TypeCategory::Integer) {
1415       if (folded->first.CompareSigned(folded->second) == x.ordering) {
1416         return Expr<T>{Constant<T>{folded->first}};
1417       }
1418     } else if constexpr (T::category == TypeCategory::Real) {
1419       if (folded->first.IsNotANumber() ||
1420           (folded->first.Compare(folded->second) == Relation::Less) ==
1421               (x.ordering == Ordering::Less)) {
1422         return Expr<T>{Constant<T>{folded->first}};
1423       }
1424     } else {
1425       static_assert(T::category == TypeCategory::Character);
1426       // Result of MIN and MAX on character has the length of
1427       // the longest argument.
1428       auto maxLen{std::max(folded->first.length(), folded->second.length())};
1429       bool isFirst{x.ordering == Compare(folded->first, folded->second)};
1430       auto res{isFirst ? std::move(folded->first) : std::move(folded->second)};
1431       res = res.length() == maxLen
1432           ? std::move(res)
1433           : CharacterUtils<T::kind>::Resize(res, maxLen);
1434       return Expr<T>{Constant<T>{std::move(res)}};
1435     }
1436     return Expr<T>{Constant<T>{folded->second}};
1437   }
1438   return Expr<T>{std::move(x)};
1439 }
1440 
1441 template <int KIND>
ToReal(FoldingContext & context,Expr<SomeType> && expr)1442 Expr<Type<TypeCategory::Real, KIND>> ToReal(
1443     FoldingContext &context, Expr<SomeType> &&expr) {
1444   using Result = Type<TypeCategory::Real, KIND>;
1445   std::optional<Expr<Result>> result;
1446   std::visit(
1447       [&](auto &&x) {
1448         using From = std::decay_t<decltype(x)>;
1449         if constexpr (std::is_same_v<From, BOZLiteralConstant>) {
1450           // Move the bits without any integer->real conversion
1451           From original{x};
1452           result = ConvertToType<Result>(std::move(x));
1453           const auto *constant{UnwrapExpr<Constant<Result>>(*result)};
1454           CHECK(constant);
1455           Scalar<Result> real{constant->GetScalarValue().value()};
1456           From converted{From::ConvertUnsigned(real.RawBits()).value};
1457           if (original != converted) { // C1601
1458             context.messages().Say(
1459                 "Nonzero bits truncated from BOZ literal constant in REAL intrinsic"_en_US);
1460           }
1461         } else if constexpr (IsNumericCategoryExpr<From>()) {
1462           result = Fold(context, ConvertToType<Result>(std::move(x)));
1463         } else {
1464           common::die("ToReal: bad argument expression");
1465         }
1466       },
1467       std::move(expr.u));
1468   return result.value();
1469 }
1470 
1471 template <typename T>
Rewrite(FoldingContext & context,Expr<T> && expr)1472 Expr<T> ExpressionBase<T>::Rewrite(FoldingContext &context, Expr<T> &&expr) {
1473   return std::visit(
1474       [&](auto &&x) -> Expr<T> {
1475         if constexpr (IsSpecificIntrinsicType<T>) {
1476           return FoldOperation(context, std::move(x));
1477         } else if constexpr (std::is_same_v<T, SomeDerived>) {
1478           return FoldOperation(context, std::move(x));
1479         } else if constexpr (common::HasMember<decltype(x),
1480                                  TypelessExpression>) {
1481           return std::move(expr);
1482         } else {
1483           return Expr<T>{Fold(context, std::move(x))};
1484         }
1485       },
1486       std::move(expr.u));
1487 }
1488 
1489 FOR_EACH_TYPE_AND_KIND(extern template class ExpressionBase, )
1490 
1491 } // namespace Fortran::evaluate
1492 #endif // FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
1493