1 //===-- lib/Evaluate/fold-implementation.h --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
10 #define FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
11
12 #include "character.h"
13 #include "host.h"
14 #include "int-power.h"
15 #include "flang/Common/indirection.h"
16 #include "flang/Common/template.h"
17 #include "flang/Common/unwrap.h"
18 #include "flang/Evaluate/characteristics.h"
19 #include "flang/Evaluate/common.h"
20 #include "flang/Evaluate/constant.h"
21 #include "flang/Evaluate/expression.h"
22 #include "flang/Evaluate/fold.h"
23 #include "flang/Evaluate/formatting.h"
24 #include "flang/Evaluate/intrinsics-library.h"
25 #include "flang/Evaluate/intrinsics.h"
26 #include "flang/Evaluate/shape.h"
27 #include "flang/Evaluate/tools.h"
28 #include "flang/Evaluate/traverse.h"
29 #include "flang/Evaluate/type.h"
30 #include "flang/Parser/message.h"
31 #include "flang/Semantics/scope.h"
32 #include "flang/Semantics/symbol.h"
33 #include "flang/Semantics/tools.h"
34 #include <algorithm>
35 #include <cmath>
36 #include <complex>
37 #include <cstdio>
38 #include <optional>
39 #include <type_traits>
40 #include <variant>
41
42 // Some environments, viz. clang on Darwin, allow the macro HUGE
43 // to leak out of <math.h> even when it is never directly included.
44 #undef HUGE
45
46 namespace Fortran::evaluate {
47
48 // Utilities
49 template <typename T> class Folder {
50 public:
Folder(FoldingContext & c)51 explicit Folder(FoldingContext &c) : context_{c} {}
52 std::optional<Constant<T>> GetNamedConstant(const Symbol &);
53 std::optional<Constant<T>> ApplySubscripts(const Constant<T> &array,
54 const std::vector<Constant<SubscriptInteger>> &subscripts);
55 std::optional<Constant<T>> ApplyComponent(Constant<SomeDerived> &&,
56 const Symbol &component,
57 const std::vector<Constant<SubscriptInteger>> * = nullptr);
58 std::optional<Constant<T>> GetConstantComponent(
59 Component &, const std::vector<Constant<SubscriptInteger>> * = nullptr);
60 std::optional<Constant<T>> Folding(ArrayRef &);
61 Expr<T> Folding(Designator<T> &&);
62 Constant<T> *Folding(std::optional<ActualArgument> &);
63 Expr<T> Reshape(FunctionRef<T> &&);
64
65 private:
66 FoldingContext &context_;
67 };
68
69 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
70 FoldingContext &, Subscript &, const NamedEntity &, int dim);
71
72 // Helper to use host runtime on scalars for folding.
73 template <typename TR, typename... TA>
74 std::optional<std::function<Scalar<TR>(FoldingContext &, Scalar<TA>...)>>
GetHostRuntimeWrapper(const std::string & name)75 GetHostRuntimeWrapper(const std::string &name) {
76 std::vector<DynamicType> argTypes{TA{}.GetType()...};
77 if (auto hostWrapper{GetHostRuntimeWrapper(name, TR{}.GetType(), argTypes)}) {
78 return [hostWrapper](
79 FoldingContext &context, Scalar<TA>... args) -> Scalar<TR> {
80 std::vector<Expr<SomeType>> genericArgs{
81 AsGenericExpr(Constant<TA>{args})...};
82 return GetScalarConstantValue<TR>(
83 (*hostWrapper)(context, std::move(genericArgs)))
84 .value();
85 };
86 }
87 return std::nullopt;
88 }
89
90 // FoldOperation() rewrites expression tree nodes.
91 // If there is any possibility that the rewritten node will
92 // not have the same representation type, the result of
93 // FoldOperation() will be packaged in an Expr<> of the same
94 // specific type.
95
96 // no-op base case
97 template <typename A>
FoldOperation(FoldingContext &,A && x)98 common::IfNoLvalue<Expr<ResultType<A>>, A> FoldOperation(
99 FoldingContext &, A &&x) {
100 static_assert(!std::is_same_v<A, Expr<ResultType<A>>>,
101 "call Fold() instead for Expr<>");
102 return Expr<ResultType<A>>{std::move(x)};
103 }
104
105 Component FoldOperation(FoldingContext &, Component &&);
106 NamedEntity FoldOperation(FoldingContext &, NamedEntity &&);
107 Triplet FoldOperation(FoldingContext &, Triplet &&);
108 Subscript FoldOperation(FoldingContext &, Subscript &&);
109 ArrayRef FoldOperation(FoldingContext &, ArrayRef &&);
110 CoarrayRef FoldOperation(FoldingContext &, CoarrayRef &&);
111 DataRef FoldOperation(FoldingContext &, DataRef &&);
112 Substring FoldOperation(FoldingContext &, Substring &&);
113 ComplexPart FoldOperation(FoldingContext &, ComplexPart &&);
114
115 template <typename T>
116 Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&);
117 template <int KIND>
118 Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
119 FoldingContext &context, FunctionRef<Type<TypeCategory::Integer, KIND>> &&);
120 template <int KIND>
121 Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
122 FoldingContext &context, FunctionRef<Type<TypeCategory::Real, KIND>> &&);
123 template <int KIND>
124 Expr<Type<TypeCategory::Complex, KIND>> FoldIntrinsicFunction(
125 FoldingContext &context, FunctionRef<Type<TypeCategory::Complex, KIND>> &&);
126 template <int KIND>
127 Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
128 FoldingContext &context, FunctionRef<Type<TypeCategory::Logical, KIND>> &&);
129
130 template <typename T>
FoldOperation(FoldingContext & context,Designator<T> && designator)131 Expr<T> FoldOperation(FoldingContext &context, Designator<T> &&designator) {
132 return Folder<T>{context}.Folding(std::move(designator));
133 }
134
135 Expr<TypeParamInquiry::Result> FoldOperation(
136 FoldingContext &, TypeParamInquiry &&);
137 Expr<ImpliedDoIndex::Result> FoldOperation(
138 FoldingContext &context, ImpliedDoIndex &&);
139 template <typename T>
140 Expr<T> FoldOperation(FoldingContext &, ArrayConstructor<T> &&);
141 Expr<SomeDerived> FoldOperation(FoldingContext &, StructureConstructor &&);
142
143 template <typename T>
GetNamedConstant(const Symbol & symbol0)144 std::optional<Constant<T>> Folder<T>::GetNamedConstant(const Symbol &symbol0) {
145 const Symbol &symbol{ResolveAssociations(symbol0)};
146 if (IsNamedConstant(symbol)) {
147 if (const auto *object{
148 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
149 if (const auto *constant{UnwrapConstantValue<T>(object->init())}) {
150 return *constant;
151 }
152 }
153 }
154 return std::nullopt;
155 }
156
157 template <typename T>
Folding(ArrayRef & aRef)158 std::optional<Constant<T>> Folder<T>::Folding(ArrayRef &aRef) {
159 std::vector<Constant<SubscriptInteger>> subscripts;
160 int dim{0};
161 for (Subscript &ss : aRef.subscript()) {
162 if (auto constant{GetConstantSubscript(context_, ss, aRef.base(), dim++)}) {
163 subscripts.emplace_back(std::move(*constant));
164 } else {
165 return std::nullopt;
166 }
167 }
168 if (Component * component{aRef.base().UnwrapComponent()}) {
169 return GetConstantComponent(*component, &subscripts);
170 } else if (std::optional<Constant<T>> array{
171 GetNamedConstant(aRef.base().GetLastSymbol())}) {
172 return ApplySubscripts(*array, subscripts);
173 } else {
174 return std::nullopt;
175 }
176 }
177
178 template <typename T>
ApplySubscripts(const Constant<T> & array,const std::vector<Constant<SubscriptInteger>> & subscripts)179 std::optional<Constant<T>> Folder<T>::ApplySubscripts(const Constant<T> &array,
180 const std::vector<Constant<SubscriptInteger>> &subscripts) {
181 const auto &shape{array.shape()};
182 const auto &lbounds{array.lbounds()};
183 int rank{GetRank(shape)};
184 CHECK(rank == static_cast<int>(subscripts.size()));
185 std::size_t elements{1};
186 ConstantSubscripts resultShape;
187 ConstantSubscripts ssLB;
188 for (const auto &ss : subscripts) {
189 CHECK(ss.Rank() <= 1);
190 if (ss.Rank() == 1) {
191 resultShape.push_back(static_cast<ConstantSubscript>(ss.size()));
192 elements *= ss.size();
193 ssLB.push_back(ss.lbounds().front());
194 }
195 }
196 ConstantSubscripts ssAt(rank, 0), at(rank, 0), tmp(1, 0);
197 std::vector<Scalar<T>> values;
198 while (elements-- > 0) {
199 bool increment{true};
200 int k{0};
201 for (int j{0}; j < rank; ++j) {
202 if (subscripts[j].Rank() == 0) {
203 at[j] = subscripts[j].GetScalarValue().value().ToInt64();
204 } else {
205 CHECK(k < GetRank(resultShape));
206 tmp[0] = ssLB.at(k) + ssAt.at(k);
207 at[j] = subscripts[j].At(tmp).ToInt64();
208 if (increment) {
209 if (++ssAt[k] == resultShape[k]) {
210 ssAt[k] = 0;
211 } else {
212 increment = false;
213 }
214 }
215 ++k;
216 }
217 if (at[j] < lbounds[j] || at[j] >= lbounds[j] + shape[j]) {
218 context_.messages().Say(
219 "Subscript value (%jd) is out of range on dimension %d in reference to a constant array value"_err_en_US,
220 at[j], j + 1);
221 return std::nullopt;
222 }
223 }
224 values.emplace_back(array.At(at));
225 CHECK(!increment || elements == 0);
226 CHECK(k == GetRank(resultShape));
227 }
228 if constexpr (T::category == TypeCategory::Character) {
229 return Constant<T>{array.LEN(), std::move(values), std::move(resultShape)};
230 } else if constexpr (std::is_same_v<T, SomeDerived>) {
231 return Constant<T>{array.result().derivedTypeSpec(), std::move(values),
232 std::move(resultShape)};
233 } else {
234 return Constant<T>{std::move(values), std::move(resultShape)};
235 }
236 }
237
238 template <typename T>
ApplyComponent(Constant<SomeDerived> && structures,const Symbol & component,const std::vector<Constant<SubscriptInteger>> * subscripts)239 std::optional<Constant<T>> Folder<T>::ApplyComponent(
240 Constant<SomeDerived> &&structures, const Symbol &component,
241 const std::vector<Constant<SubscriptInteger>> *subscripts) {
242 if (auto scalar{structures.GetScalarValue()}) {
243 if (std::optional<Expr<SomeType>> expr{scalar->Find(component)}) {
244 if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
245 if (!subscripts) {
246 return std::move(*value);
247 } else {
248 return ApplySubscripts(*value, *subscripts);
249 }
250 }
251 }
252 } else {
253 // A(:)%scalar_component & A(:)%array_component(subscripts)
254 std::unique_ptr<ArrayConstructor<T>> array;
255 if (structures.empty()) {
256 return std::nullopt;
257 }
258 ConstantSubscripts at{structures.lbounds()};
259 do {
260 StructureConstructor scalar{structures.At(at)};
261 if (std::optional<Expr<SomeType>> expr{scalar.Find(component)}) {
262 if (const Constant<T> *value{UnwrapConstantValue<T>(expr.value())}) {
263 if (!array.get()) {
264 // This technique ensures that character length or derived type
265 // information is propagated to the array constructor.
266 auto *typedExpr{UnwrapExpr<Expr<T>>(expr.value())};
267 CHECK(typedExpr);
268 array = std::make_unique<ArrayConstructor<T>>(*typedExpr);
269 }
270 if (subscripts) {
271 if (auto element{ApplySubscripts(*value, *subscripts)}) {
272 CHECK(element->Rank() == 0);
273 array->Push(Expr<T>{std::move(*element)});
274 } else {
275 return std::nullopt;
276 }
277 } else {
278 CHECK(value->Rank() == 0);
279 array->Push(Expr<T>{*value});
280 }
281 } else {
282 return std::nullopt;
283 }
284 }
285 } while (structures.IncrementSubscripts(at));
286 // Fold the ArrayConstructor<> into a Constant<>.
287 CHECK(array);
288 Expr<T> result{Fold(context_, Expr<T>{std::move(*array)})};
289 if (auto *constant{UnwrapConstantValue<T>(result)}) {
290 return constant->Reshape(common::Clone(structures.shape()));
291 }
292 }
293 return std::nullopt;
294 }
295
296 template <typename T>
GetConstantComponent(Component & component,const std::vector<Constant<SubscriptInteger>> * subscripts)297 std::optional<Constant<T>> Folder<T>::GetConstantComponent(Component &component,
298 const std::vector<Constant<SubscriptInteger>> *subscripts) {
299 if (std::optional<Constant<SomeDerived>> structures{std::visit(
300 common::visitors{
301 [&](const Symbol &symbol) {
302 return Folder<SomeDerived>{context_}.GetNamedConstant(symbol);
303 },
304 [&](ArrayRef &aRef) {
305 return Folder<SomeDerived>{context_}.Folding(aRef);
306 },
307 [&](Component &base) {
308 return Folder<SomeDerived>{context_}.GetConstantComponent(base);
309 },
310 [&](CoarrayRef &) {
311 return std::optional<Constant<SomeDerived>>{};
312 },
313 },
314 component.base().u)}) {
315 return ApplyComponent(
316 std::move(*structures), component.GetLastSymbol(), subscripts);
317 } else {
318 return std::nullopt;
319 }
320 }
321
Folding(Designator<T> && designator)322 template <typename T> Expr<T> Folder<T>::Folding(Designator<T> &&designator) {
323 if constexpr (T::category == TypeCategory::Character) {
324 if (auto *substring{common::Unwrap<Substring>(designator.u)}) {
325 if (std::optional<Expr<SomeCharacter>> folded{
326 substring->Fold(context_)}) {
327 if (auto value{GetScalarConstantValue<T>(*folded)}) {
328 return Expr<T>{*value};
329 }
330 }
331 if (auto length{ToInt64(Fold(context_, substring->LEN()))}) {
332 if (*length == 0) {
333 return Expr<T>{Constant<T>{Scalar<T>{}}};
334 }
335 }
336 }
337 }
338 return std::visit(
339 common::visitors{
340 [&](SymbolRef &&symbol) {
341 if (auto constant{GetNamedConstant(*symbol)}) {
342 return Expr<T>{std::move(*constant)};
343 }
344 return Expr<T>{std::move(designator)};
345 },
346 [&](ArrayRef &&aRef) {
347 aRef = FoldOperation(context_, std::move(aRef));
348 if (auto c{Folding(aRef)}) {
349 return Expr<T>{std::move(*c)};
350 } else {
351 return Expr<T>{Designator<T>{std::move(aRef)}};
352 }
353 },
354 [&](Component &&component) {
355 component = FoldOperation(context_, std::move(component));
356 if (auto c{GetConstantComponent(component)}) {
357 return Expr<T>{std::move(*c)};
358 } else {
359 return Expr<T>{Designator<T>{std::move(component)}};
360 }
361 },
362 [&](auto &&x) {
363 return Expr<T>{
364 Designator<T>{FoldOperation(context_, std::move(x))}};
365 },
366 },
367 std::move(designator.u));
368 }
369
370 // Apply type conversion and re-folding if necessary.
371 // This is where BOZ arguments are converted.
372 template <typename T>
Folding(std::optional<ActualArgument> & arg)373 Constant<T> *Folder<T>::Folding(std::optional<ActualArgument> &arg) {
374 if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
375 if (!UnwrapExpr<Expr<T>>(*expr)) {
376 if (auto converted{ConvertToType(T::GetType(), std::move(*expr))}) {
377 *expr = Fold(context_, std::move(*converted));
378 }
379 }
380 return UnwrapConstantValue<T>(*expr);
381 }
382 return nullptr;
383 }
384
385 template <typename... A, std::size_t... I>
GetConstantArgumentsHelper(FoldingContext & context,ActualArguments & arguments,std::index_sequence<I...>)386 std::optional<std::tuple<const Constant<A> *...>> GetConstantArgumentsHelper(
387 FoldingContext &context, ActualArguments &arguments,
388 std::index_sequence<I...>) {
389 static_assert(
390 (... && IsSpecificIntrinsicType<A>)); // TODO derived types for MERGE?
391 static_assert(sizeof...(A) > 0);
392 std::tuple<const Constant<A> *...> args{
393 Folder<A>{context}.Folding(arguments.at(I))...};
394 if ((... && (std::get<I>(args)))) {
395 return args;
396 } else {
397 return std::nullopt;
398 }
399 }
400
401 template <typename... A>
GetConstantArguments(FoldingContext & context,ActualArguments & args)402 std::optional<std::tuple<const Constant<A> *...>> GetConstantArguments(
403 FoldingContext &context, ActualArguments &args) {
404 return GetConstantArgumentsHelper<A...>(
405 context, args, std::index_sequence_for<A...>{});
406 }
407
408 template <typename... A, std::size_t... I>
GetScalarConstantArgumentsHelper(FoldingContext & context,ActualArguments & args,std::index_sequence<I...>)409 std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArgumentsHelper(
410 FoldingContext &context, ActualArguments &args, std::index_sequence<I...>) {
411 if (auto constArgs{GetConstantArguments<A...>(context, args)}) {
412 return std::tuple<Scalar<A>...>{
413 std::get<I>(*constArgs)->GetScalarValue().value()...};
414 } else {
415 return std::nullopt;
416 }
417 }
418
419 template <typename... A>
GetScalarConstantArguments(FoldingContext & context,ActualArguments & args)420 std::optional<std::tuple<Scalar<A>...>> GetScalarConstantArguments(
421 FoldingContext &context, ActualArguments &args) {
422 return GetScalarConstantArgumentsHelper<A...>(
423 context, args, std::index_sequence_for<A...>{});
424 }
425
426 // helpers to fold intrinsic function references
427 // Define callable types used in a common utility that
428 // takes care of array and cast/conversion aspects for elemental intrinsics
429
430 template <typename TR, typename... TArgs>
431 using ScalarFunc = std::function<Scalar<TR>(const Scalar<TArgs> &...)>;
432 template <typename TR, typename... TArgs>
433 using ScalarFuncWithContext =
434 std::function<Scalar<TR>(FoldingContext &, const Scalar<TArgs> &...)>;
435
436 template <template <typename, typename...> typename WrapperType, typename TR,
437 typename... TA, std::size_t... I>
FoldElementalIntrinsicHelper(FoldingContext & context,FunctionRef<TR> && funcRef,WrapperType<TR,TA...> func,std::index_sequence<I...>)438 Expr<TR> FoldElementalIntrinsicHelper(FoldingContext &context,
439 FunctionRef<TR> &&funcRef, WrapperType<TR, TA...> func,
440 std::index_sequence<I...>) {
441 if (std::optional<std::tuple<const Constant<TA> *...>> args{
442 GetConstantArguments<TA...>(context, funcRef.arguments())}) {
443 // Compute the shape of the result based on shapes of arguments
444 ConstantSubscripts shape;
445 int rank{0};
446 const ConstantSubscripts *shapes[sizeof...(TA)]{
447 &std::get<I>(*args)->shape()...};
448 const int ranks[sizeof...(TA)]{std::get<I>(*args)->Rank()...};
449 for (unsigned int i{0}; i < sizeof...(TA); ++i) {
450 if (ranks[i] > 0) {
451 if (rank == 0) {
452 rank = ranks[i];
453 shape = *shapes[i];
454 } else {
455 if (shape != *shapes[i]) {
456 // TODO: Rank compatibility was already checked but it seems to be
457 // the first place where the actual shapes are checked to be the
458 // same. Shouldn't this be checked elsewhere so that this is also
459 // checked for non constexpr call to elemental intrinsics function?
460 context.messages().Say(
461 "Arguments in elemental intrinsic function are not conformable"_err_en_US);
462 return Expr<TR>{std::move(funcRef)};
463 }
464 }
465 }
466 }
467 CHECK(rank == GetRank(shape));
468
469 // Compute all the scalar values of the results
470 std::vector<Scalar<TR>> results;
471 if (TotalElementCount(shape) > 0) {
472 ConstantBounds bounds{shape};
473 ConstantSubscripts index(rank, 1);
474 do {
475 if constexpr (std::is_same_v<WrapperType<TR, TA...>,
476 ScalarFuncWithContext<TR, TA...>>) {
477 results.emplace_back(func(context,
478 (ranks[I] ? std::get<I>(*args)->At(index)
479 : std::get<I>(*args)->GetScalarValue().value())...));
480 } else if constexpr (std::is_same_v<WrapperType<TR, TA...>,
481 ScalarFunc<TR, TA...>>) {
482 results.emplace_back(func(
483 (ranks[I] ? std::get<I>(*args)->At(index)
484 : std::get<I>(*args)->GetScalarValue().value())...));
485 }
486 } while (bounds.IncrementSubscripts(index));
487 }
488 // Build and return constant result
489 if constexpr (TR::category == TypeCategory::Character) {
490 auto len{static_cast<ConstantSubscript>(
491 results.size() ? results[0].length() : 0)};
492 return Expr<TR>{Constant<TR>{len, std::move(results), std::move(shape)}};
493 } else {
494 return Expr<TR>{Constant<TR>{std::move(results), std::move(shape)}};
495 }
496 }
497 return Expr<TR>{std::move(funcRef)};
498 }
499
500 template <typename TR, typename... TA>
FoldElementalIntrinsic(FoldingContext & context,FunctionRef<TR> && funcRef,ScalarFunc<TR,TA...> func)501 Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
502 FunctionRef<TR> &&funcRef, ScalarFunc<TR, TA...> func) {
503 return FoldElementalIntrinsicHelper<ScalarFunc, TR, TA...>(
504 context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
505 }
506 template <typename TR, typename... TA>
FoldElementalIntrinsic(FoldingContext & context,FunctionRef<TR> && funcRef,ScalarFuncWithContext<TR,TA...> func)507 Expr<TR> FoldElementalIntrinsic(FoldingContext &context,
508 FunctionRef<TR> &&funcRef, ScalarFuncWithContext<TR, TA...> func) {
509 return FoldElementalIntrinsicHelper<ScalarFuncWithContext, TR, TA...>(
510 context, std::move(funcRef), func, std::index_sequence_for<TA...>{});
511 }
512
513 std::optional<std::int64_t> GetInt64Arg(const std::optional<ActualArgument> &);
514 std::optional<std::int64_t> GetInt64ArgOr(
515 const std::optional<ActualArgument> &, std::int64_t defaultValue);
516
517 template <typename A, typename B>
GetIntegerVector(const B & x)518 std::optional<std::vector<A>> GetIntegerVector(const B &x) {
519 static_assert(std::is_integral_v<A>);
520 if (const auto *someInteger{UnwrapExpr<Expr<SomeInteger>>(x)}) {
521 return std::visit(
522 [](const auto &typedExpr) -> std::optional<std::vector<A>> {
523 using T = ResultType<decltype(typedExpr)>;
524 if (const auto *constant{UnwrapConstantValue<T>(typedExpr)}) {
525 if (constant->Rank() == 1) {
526 std::vector<A> result;
527 for (const auto &value : constant->values()) {
528 result.push_back(static_cast<A>(value.ToInt64()));
529 }
530 return result;
531 }
532 }
533 return std::nullopt;
534 },
535 someInteger->u);
536 }
537 return std::nullopt;
538 }
539
540 // Transform an intrinsic function reference that contains user errors
541 // into an intrinsic with the same characteristic but the "invalid" name.
542 // This to prevent generating warnings over and over if the expression
543 // gets re-folded.
MakeInvalidIntrinsic(FunctionRef<T> && funcRef)544 template <typename T> Expr<T> MakeInvalidIntrinsic(FunctionRef<T> &&funcRef) {
545 SpecificIntrinsic invalid{std::get<SpecificIntrinsic>(funcRef.proc().u)};
546 invalid.name = IntrinsicProcTable::InvalidName;
547 return Expr<T>{FunctionRef<T>{ProcedureDesignator{std::move(invalid)},
548 ActualArguments{std::move(funcRef.arguments())}}};
549 }
550
Reshape(FunctionRef<T> && funcRef)551 template <typename T> Expr<T> Folder<T>::Reshape(FunctionRef<T> &&funcRef) {
552 auto args{funcRef.arguments()};
553 CHECK(args.size() == 4);
554 const auto *source{UnwrapConstantValue<T>(args[0])};
555 const auto *pad{UnwrapConstantValue<T>(args[2])};
556 std::optional<std::vector<ConstantSubscript>> shape{
557 GetIntegerVector<ConstantSubscript>(args[1])};
558 std::optional<std::vector<int>> order{GetIntegerVector<int>(args[3])};
559 if (!source || !shape || (args[2] && !pad) || (args[3] && !order)) {
560 return Expr<T>{std::move(funcRef)}; // Non-constant arguments
561 } else if (shape.value().size() > common::maxRank) {
562 context_.messages().Say(
563 "Size of 'shape=' argument must not be greater than %d"_err_en_US,
564 common::maxRank);
565 } else if (HasNegativeExtent(shape.value())) {
566 context_.messages().Say(
567 "'shape=' argument must not have a negative extent"_err_en_US);
568 } else {
569 int rank{GetRank(shape.value())};
570 std::size_t resultElements{TotalElementCount(shape.value())};
571 std::optional<std::vector<int>> dimOrder;
572 if (order) {
573 dimOrder = ValidateDimensionOrder(rank, *order);
574 }
575 std::vector<int> *dimOrderPtr{dimOrder ? &dimOrder.value() : nullptr};
576 if (order && !dimOrder) {
577 context_.messages().Say("Invalid 'order=' argument in RESHAPE"_err_en_US);
578 } else if (resultElements > source->size() && (!pad || pad->empty())) {
579 context_.messages().Say(
580 "Too few elements in 'source=' argument and 'pad=' "
581 "argument is not present or has null size"_err_en_US);
582 } else {
583 Constant<T> result{!source->empty() || !pad
584 ? source->Reshape(std::move(shape.value()))
585 : pad->Reshape(std::move(shape.value()))};
586 ConstantSubscripts subscripts{result.lbounds()};
587 auto copied{result.CopyFrom(*source,
588 std::min(source->size(), resultElements), subscripts, dimOrderPtr)};
589 if (copied < resultElements) {
590 CHECK(pad);
591 copied += result.CopyFrom(
592 *pad, resultElements - copied, subscripts, dimOrderPtr);
593 }
594 CHECK(copied == resultElements);
595 return Expr<T>{std::move(result)};
596 }
597 }
598 // Invalid, prevent re-folding
599 return MakeInvalidIntrinsic(std::move(funcRef));
600 }
601
602 template <typename T>
FoldMINorMAX(FoldingContext & context,FunctionRef<T> && funcRef,Ordering order)603 Expr<T> FoldMINorMAX(
604 FoldingContext &context, FunctionRef<T> &&funcRef, Ordering order) {
605 std::vector<Constant<T> *> constantArgs;
606 // Call Folding on all arguments, even if some are not constant,
607 // to make operand promotion explicit.
608 for (auto &arg : funcRef.arguments()) {
609 if (auto *cst{Folder<T>{context}.Folding(arg)}) {
610 constantArgs.push_back(cst);
611 }
612 }
613 if (constantArgs.size() != funcRef.arguments().size())
614 return Expr<T>(std::move(funcRef));
615 CHECK(constantArgs.size() > 0);
616 Expr<T> result{std::move(*constantArgs[0])};
617 for (std::size_t i{1}; i < constantArgs.size(); ++i) {
618 Extremum<T> extremum{order, result, Expr<T>{std::move(*constantArgs[i])}};
619 result = FoldOperation(context, std::move(extremum));
620 }
621 return result;
622 }
623
624 // For AMAX0, AMIN0, AMAX1, AMIN1, DMAX1, DMIN1, MAX0, MIN0, MAX1, and MIN1
625 // a special care has to be taken to insert the conversion on the result
626 // of the MIN/MAX. This is made slightly more complex by the extension
627 // supported by f18 that arguments may have different kinds. This implies
628 // that the created MIN/MAX result type cannot be deduced from the standard but
629 // has to be deduced from the arguments.
630 // e.g. AMAX0(int8, int4) is rewritten to REAL(MAX(int8, INT(int4, 8)))).
631 template <typename T>
RewriteSpecificMINorMAX(FoldingContext & context,FunctionRef<T> && funcRef)632 Expr<T> RewriteSpecificMINorMAX(
633 FoldingContext &context, FunctionRef<T> &&funcRef) {
634 ActualArguments &args{funcRef.arguments()};
635 auto &intrinsic{DEREF(std::get_if<SpecificIntrinsic>(&funcRef.proc().u))};
636 // Rewrite MAX1(args) to INT(MAX(args)) and fold. Same logic for MIN1.
637 // Find result type for max/min based on the arguments.
638 DynamicType resultType{args[0].value().GetType().value()};
639 auto *resultTypeArg{&args[0]};
640 for (auto j{args.size() - 1}; j > 0; --j) {
641 DynamicType type{args[j].value().GetType().value()};
642 if (type.category() == resultType.category()) {
643 if (type.kind() > resultType.kind()) {
644 resultTypeArg = &args[j];
645 resultType = type;
646 }
647 } else if (resultType.category() == TypeCategory::Integer) {
648 // Handle mixed real/integer arguments: all the previous arguments were
649 // integers and this one is real. The type of the MAX/MIN result will
650 // be the one of the real argument.
651 resultTypeArg = &args[j];
652 resultType = type;
653 }
654 }
655 intrinsic.name =
656 intrinsic.name.find("max") != std::string::npos ? "max"s : "min"s;
657 intrinsic.characteristics.value().functionResult.value().SetType(resultType);
658 auto insertConversion{[&](const auto &x) -> Expr<T> {
659 using TR = ResultType<decltype(x)>;
660 FunctionRef<TR> maxRef{std::move(funcRef.proc()), std::move(args)};
661 return Fold(context, ConvertToType<T>(AsCategoryExpr(std::move(maxRef))));
662 }};
663 if (auto *sx{UnwrapExpr<Expr<SomeReal>>(*resultTypeArg)}) {
664 return std::visit(insertConversion, sx->u);
665 }
666 auto &sx{DEREF(UnwrapExpr<Expr<SomeInteger>>(*resultTypeArg))};
667 return std::visit(insertConversion, sx.u);
668 }
669
670 template <typename T>
FoldOperation(FoldingContext & context,FunctionRef<T> && funcRef)671 Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
672 ActualArguments &args{funcRef.arguments()};
673 for (std::optional<ActualArgument> &arg : args) {
674 if (auto *expr{UnwrapExpr<Expr<SomeType>>(arg)}) {
675 *expr = Fold(context, std::move(*expr));
676 }
677 }
678 if (auto *intrinsic{std::get_if<SpecificIntrinsic>(&funcRef.proc().u)}) {
679 const std::string name{intrinsic->name};
680 if (name == "reshape") {
681 return Folder<T>{context}.Reshape(std::move(funcRef));
682 }
683 // TODO: other type independent transformationals
684 if constexpr (!std::is_same_v<T, SomeDerived>) {
685 return FoldIntrinsicFunction(context, std::move(funcRef));
686 }
687 }
688 return Expr<T>{std::move(funcRef)};
689 }
690
691 template <typename T>
FoldMerge(FoldingContext & context,FunctionRef<T> && funcRef)692 Expr<T> FoldMerge(FoldingContext &context, FunctionRef<T> &&funcRef) {
693 return FoldElementalIntrinsic<T, T, T, LogicalResult>(context,
694 std::move(funcRef),
695 ScalarFunc<T, T, T, LogicalResult>(
696 [](const Scalar<T> &ifTrue, const Scalar<T> &ifFalse,
697 const Scalar<LogicalResult> &predicate) -> Scalar<T> {
698 return predicate.IsTrue() ? ifTrue : ifFalse;
699 }));
700 }
701
702 Expr<ImpliedDoIndex::Result> FoldOperation(FoldingContext &, ImpliedDoIndex &&);
703
704 // Array constructor folding
705 template <typename T> class ArrayConstructorFolder {
706 public:
ArrayConstructorFolder(const FoldingContext & c)707 explicit ArrayConstructorFolder(const FoldingContext &c) : context_{c} {}
708
FoldArray(ArrayConstructor<T> && array)709 Expr<T> FoldArray(ArrayConstructor<T> &&array) {
710 // Calls FoldArray(const ArrayConstructorValues<T> &) below
711 if (FoldArray(array)) {
712 auto n{static_cast<ConstantSubscript>(elements_.size())};
713 if constexpr (std::is_same_v<T, SomeDerived>) {
714 return Expr<T>{Constant<T>{array.GetType().GetDerivedTypeSpec(),
715 std::move(elements_), ConstantSubscripts{n}}};
716 } else if constexpr (T::category == TypeCategory::Character) {
717 auto length{Fold(context_, common::Clone(array.LEN()))};
718 if (std::optional<ConstantSubscript> lengthValue{ToInt64(length)}) {
719 return Expr<T>{Constant<T>{
720 *lengthValue, std::move(elements_), ConstantSubscripts{n}}};
721 }
722 } else {
723 return Expr<T>{
724 Constant<T>{std::move(elements_), ConstantSubscripts{n}}};
725 }
726 }
727 return Expr<T>{std::move(array)};
728 }
729
730 private:
FoldArray(const common::CopyableIndirection<Expr<T>> & expr)731 bool FoldArray(const common::CopyableIndirection<Expr<T>> &expr) {
732 Expr<T> folded{Fold(context_, common::Clone(expr.value()))};
733 if (const auto *c{UnwrapConstantValue<T>(folded)}) {
734 // Copy elements in Fortran array element order
735 ConstantSubscripts shape{c->shape()};
736 int rank{c->Rank()};
737 ConstantSubscripts index(GetRank(shape), 1);
738 for (std::size_t n{c->size()}; n-- > 0;) {
739 elements_.emplace_back(c->At(index));
740 for (int d{0}; d < rank; ++d) {
741 if (++index[d] <= shape[d]) {
742 break;
743 }
744 index[d] = 1;
745 }
746 }
747 return true;
748 } else {
749 return false;
750 }
751 }
FoldArray(const ImpliedDo<T> & iDo)752 bool FoldArray(const ImpliedDo<T> &iDo) {
753 Expr<SubscriptInteger> lower{
754 Fold(context_, Expr<SubscriptInteger>{iDo.lower()})};
755 Expr<SubscriptInteger> upper{
756 Fold(context_, Expr<SubscriptInteger>{iDo.upper()})};
757 Expr<SubscriptInteger> stride{
758 Fold(context_, Expr<SubscriptInteger>{iDo.stride()})};
759 std::optional<ConstantSubscript> start{ToInt64(lower)}, end{ToInt64(upper)},
760 step{ToInt64(stride)};
761 if (start && end && step && *step != 0) {
762 bool result{true};
763 ConstantSubscript &j{context_.StartImpliedDo(iDo.name(), *start)};
764 if (*step > 0) {
765 for (; j <= *end; j += *step) {
766 result &= FoldArray(iDo.values());
767 }
768 } else {
769 for (; j >= *end; j += *step) {
770 result &= FoldArray(iDo.values());
771 }
772 }
773 context_.EndImpliedDo(iDo.name());
774 return result;
775 } else {
776 return false;
777 }
778 }
FoldArray(const ArrayConstructorValue<T> & x)779 bool FoldArray(const ArrayConstructorValue<T> &x) {
780 return std::visit([&](const auto &y) { return FoldArray(y); }, x.u);
781 }
FoldArray(const ArrayConstructorValues<T> & xs)782 bool FoldArray(const ArrayConstructorValues<T> &xs) {
783 for (const auto &x : xs) {
784 if (!FoldArray(x)) {
785 return false;
786 }
787 }
788 return true;
789 }
790
791 FoldingContext context_;
792 std::vector<Scalar<T>> elements_;
793 };
794
795 template <typename T>
FoldOperation(FoldingContext & context,ArrayConstructor<T> && array)796 Expr<T> FoldOperation(FoldingContext &context, ArrayConstructor<T> &&array) {
797 return ArrayConstructorFolder<T>{context}.FoldArray(std::move(array));
798 }
799
800 // Array operation elemental application: When all operands to an operation
801 // are constant arrays, array constructors without any implied DO loops,
802 // &/or expanded scalars, pull the operation "into" the array result by
803 // applying it in an elementwise fashion. For example, [A,1]+[B,2]
804 // is rewritten into [A+B,1+2] and then partially folded to [A+B,3].
805
806 // If possible, restructures an array expression into an array constructor
807 // that comprises a "flat" ArrayConstructorValues with no implied DO loops.
808 template <typename T>
ArrayConstructorIsFlat(const ArrayConstructorValues<T> & values)809 bool ArrayConstructorIsFlat(const ArrayConstructorValues<T> &values) {
810 for (const ArrayConstructorValue<T> &x : values) {
811 if (!std::holds_alternative<Expr<T>>(x.u)) {
812 return false;
813 }
814 }
815 return true;
816 }
817
818 template <typename T>
AsFlatArrayConstructor(const Expr<T> & expr)819 std::optional<Expr<T>> AsFlatArrayConstructor(const Expr<T> &expr) {
820 if (const auto *c{UnwrapConstantValue<T>(expr)}) {
821 ArrayConstructor<T> result{expr};
822 if (c->size() > 0) {
823 ConstantSubscripts at{c->lbounds()};
824 do {
825 result.Push(Expr<T>{Constant<T>{c->At(at)}});
826 } while (c->IncrementSubscripts(at));
827 }
828 return std::make_optional<Expr<T>>(std::move(result));
829 } else if (const auto *a{UnwrapExpr<ArrayConstructor<T>>(expr)}) {
830 if (ArrayConstructorIsFlat(*a)) {
831 return std::make_optional<Expr<T>>(expr);
832 }
833 } else if (const auto *p{UnwrapExpr<Parentheses<T>>(expr)}) {
834 return AsFlatArrayConstructor(Expr<T>{p->left()});
835 }
836 return std::nullopt;
837 }
838
839 template <TypeCategory CAT>
840 std::enable_if_t<CAT != TypeCategory::Derived,
841 std::optional<Expr<SomeKind<CAT>>>>
AsFlatArrayConstructor(const Expr<SomeKind<CAT>> & expr)842 AsFlatArrayConstructor(const Expr<SomeKind<CAT>> &expr) {
843 return std::visit(
844 [&](const auto &kindExpr) -> std::optional<Expr<SomeKind<CAT>>> {
845 if (auto flattened{AsFlatArrayConstructor(kindExpr)}) {
846 return Expr<SomeKind<CAT>>{std::move(*flattened)};
847 } else {
848 return std::nullopt;
849 }
850 },
851 expr.u);
852 }
853
854 // FromArrayConstructor is a subroutine for MapOperation() below.
855 // Given a flat ArrayConstructor<T> and a shape, it wraps the array
856 // into an Expr<T>, folds it, and returns the resulting wrapped
857 // array constructor or constant array value.
858 template <typename T>
FromArrayConstructor(FoldingContext & context,ArrayConstructor<T> && values,std::optional<ConstantSubscripts> && shape)859 Expr<T> FromArrayConstructor(FoldingContext &context,
860 ArrayConstructor<T> &&values, std::optional<ConstantSubscripts> &&shape) {
861 Expr<T> result{Fold(context, Expr<T>{std::move(values)})};
862 if (shape) {
863 if (auto *constant{UnwrapConstantValue<T>(result)}) {
864 return Expr<T>{constant->Reshape(std::move(*shape))};
865 }
866 }
867 return result;
868 }
869
870 // MapOperation is a utility for various specializations of ApplyElementwise()
871 // that follow. Given one or two flat ArrayConstructor<OPERAND> (wrapped in an
872 // Expr<OPERAND>) for some specific operand type(s), apply a given function f
873 // to each of their corresponding elements to produce a flat
874 // ArrayConstructor<RESULT> (wrapped in an Expr<RESULT>).
875 // Preserves shape.
876
877 // Unary case
878 template <typename RESULT, typename OPERAND>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<OPERAND> &&)> && f,const Shape & shape,Expr<OPERAND> && values)879 Expr<RESULT> MapOperation(FoldingContext &context,
880 std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f, const Shape &shape,
881 Expr<OPERAND> &&values) {
882 ArrayConstructor<RESULT> result{values};
883 if constexpr (common::HasMember<OPERAND, AllIntrinsicCategoryTypes>) {
884 std::visit(
885 [&](auto &&kindExpr) {
886 using kindType = ResultType<decltype(kindExpr)>;
887 auto &aConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
888 for (auto &acValue : aConst) {
889 auto &scalar{std::get<Expr<kindType>>(acValue.u)};
890 result.Push(Fold(context, f(Expr<OPERAND>{std::move(scalar)})));
891 }
892 },
893 std::move(values.u));
894 } else {
895 auto &aConst{std::get<ArrayConstructor<OPERAND>>(values.u)};
896 for (auto &acValue : aConst) {
897 auto &scalar{std::get<Expr<OPERAND>>(acValue.u)};
898 result.Push(Fold(context, f(std::move(scalar))));
899 }
900 }
901 return FromArrayConstructor(
902 context, std::move(result), AsConstantExtents(context, shape));
903 }
904
905 // array * array case
906 template <typename RESULT, typename LEFT, typename RIGHT>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<LEFT> &&,Expr<RIGHT> &&)> && f,const Shape & shape,Expr<LEFT> && leftValues,Expr<RIGHT> && rightValues)907 Expr<RESULT> MapOperation(FoldingContext &context,
908 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
909 const Shape &shape, Expr<LEFT> &&leftValues, Expr<RIGHT> &&rightValues) {
910 ArrayConstructor<RESULT> result{leftValues};
911 auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
912 if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
913 std::visit(
914 [&](auto &&kindExpr) {
915 using kindType = ResultType<decltype(kindExpr)>;
916
917 auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
918 auto rightIter{rightArrConst.begin()};
919 for (auto &leftValue : leftArrConst) {
920 CHECK(rightIter != rightArrConst.end());
921 auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
922 auto &rightScalar{std::get<Expr<kindType>>(rightIter->u)};
923 result.Push(Fold(context,
924 f(std::move(leftScalar), Expr<RIGHT>{std::move(rightScalar)})));
925 ++rightIter;
926 }
927 },
928 std::move(rightValues.u));
929 } else {
930 auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
931 auto rightIter{rightArrConst.begin()};
932 for (auto &leftValue : leftArrConst) {
933 CHECK(rightIter != rightArrConst.end());
934 auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
935 auto &rightScalar{std::get<Expr<RIGHT>>(rightIter->u)};
936 result.Push(
937 Fold(context, f(std::move(leftScalar), std::move(rightScalar))));
938 ++rightIter;
939 }
940 }
941 return FromArrayConstructor(
942 context, std::move(result), AsConstantExtents(context, shape));
943 }
944
945 // array * scalar case
946 template <typename RESULT, typename LEFT, typename RIGHT>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<LEFT> &&,Expr<RIGHT> &&)> && f,const Shape & shape,Expr<LEFT> && leftValues,const Expr<RIGHT> & rightScalar)947 Expr<RESULT> MapOperation(FoldingContext &context,
948 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
949 const Shape &shape, Expr<LEFT> &&leftValues,
950 const Expr<RIGHT> &rightScalar) {
951 ArrayConstructor<RESULT> result{leftValues};
952 auto &leftArrConst{std::get<ArrayConstructor<LEFT>>(leftValues.u)};
953 for (auto &leftValue : leftArrConst) {
954 auto &leftScalar{std::get<Expr<LEFT>>(leftValue.u)};
955 result.Push(
956 Fold(context, f(std::move(leftScalar), Expr<RIGHT>{rightScalar})));
957 }
958 return FromArrayConstructor(
959 context, std::move(result), AsConstantExtents(context, shape));
960 }
961
962 // scalar * array case
963 template <typename RESULT, typename LEFT, typename RIGHT>
MapOperation(FoldingContext & context,std::function<Expr<RESULT> (Expr<LEFT> &&,Expr<RIGHT> &&)> && f,const Shape & shape,const Expr<LEFT> & leftScalar,Expr<RIGHT> && rightValues)964 Expr<RESULT> MapOperation(FoldingContext &context,
965 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f,
966 const Shape &shape, const Expr<LEFT> &leftScalar,
967 Expr<RIGHT> &&rightValues) {
968 ArrayConstructor<RESULT> result{leftScalar};
969 if constexpr (common::HasMember<RIGHT, AllIntrinsicCategoryTypes>) {
970 std::visit(
971 [&](auto &&kindExpr) {
972 using kindType = ResultType<decltype(kindExpr)>;
973 auto &rightArrConst{std::get<ArrayConstructor<kindType>>(kindExpr.u)};
974 for (auto &rightValue : rightArrConst) {
975 auto &rightScalar{std::get<Expr<kindType>>(rightValue.u)};
976 result.Push(Fold(context,
977 f(Expr<LEFT>{leftScalar},
978 Expr<RIGHT>{std::move(rightScalar)})));
979 }
980 },
981 std::move(rightValues.u));
982 } else {
983 auto &rightArrConst{std::get<ArrayConstructor<RIGHT>>(rightValues.u)};
984 for (auto &rightValue : rightArrConst) {
985 auto &rightScalar{std::get<Expr<RIGHT>>(rightValue.u)};
986 result.Push(
987 Fold(context, f(Expr<LEFT>{leftScalar}, std::move(rightScalar))));
988 }
989 }
990 return FromArrayConstructor(
991 context, std::move(result), AsConstantExtents(context, shape));
992 }
993
994 // ApplyElementwise() recursively folds the operand expression(s) of an
995 // operation, then attempts to apply the operation to the (corresponding)
996 // scalar element(s) of those operands. Returns std::nullopt for scalars
997 // or unlinearizable operands.
998 template <typename DERIVED, typename RESULT, typename OPERAND>
999 auto ApplyElementwise(FoldingContext &context,
1000 Operation<DERIVED, RESULT, OPERAND> &operation,
1001 std::function<Expr<RESULT>(Expr<OPERAND> &&)> &&f)
1002 -> std::optional<Expr<RESULT>> {
1003 auto &expr{operation.left()};
1004 expr = Fold(context, std::move(expr));
1005 if (expr.Rank() > 0) {
1006 if (std::optional<Shape> shape{GetShape(context, expr)}) {
1007 if (auto values{AsFlatArrayConstructor(expr)}) {
1008 return MapOperation(context, std::move(f), *shape, std::move(*values));
1009 }
1010 }
1011 }
1012 return std::nullopt;
1013 }
1014
1015 template <typename DERIVED, typename RESULT, typename OPERAND>
1016 auto ApplyElementwise(
1017 FoldingContext &context, Operation<DERIVED, RESULT, OPERAND> &operation)
1018 -> std::optional<Expr<RESULT>> {
1019 return ApplyElementwise(context, operation,
1020 std::function<Expr<RESULT>(Expr<OPERAND> &&)>{
1021 [](Expr<OPERAND> &&operand) {
1022 return Expr<RESULT>{DERIVED{std::move(operand)}};
1023 }});
1024 }
1025
1026 template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1027 auto ApplyElementwise(FoldingContext &context,
1028 Operation<DERIVED, RESULT, LEFT, RIGHT> &operation,
1029 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)> &&f)
1030 -> std::optional<Expr<RESULT>> {
1031 auto &leftExpr{operation.left()};
1032 leftExpr = Fold(context, std::move(leftExpr));
1033 auto &rightExpr{operation.right()};
1034 rightExpr = Fold(context, std::move(rightExpr));
1035 if (leftExpr.Rank() > 0) {
1036 if (std::optional<Shape> leftShape{GetShape(context, leftExpr)}) {
1037 if (auto left{AsFlatArrayConstructor(leftExpr)}) {
1038 if (rightExpr.Rank() > 0) {
1039 if (std::optional<Shape> rightShape{GetShape(context, rightExpr)}) {
1040 if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1041 if (CheckConformance(
1042 context.messages(), *leftShape, *rightShape)) {
1043 return MapOperation(context, std::move(f), *leftShape,
1044 std::move(*left), std::move(*right));
1045 } else {
1046 return std::nullopt;
1047 }
1048 return MapOperation(context, std::move(f), *leftShape,
1049 std::move(*left), std::move(*right));
1050 }
1051 }
1052 } else if (IsExpandableScalar(rightExpr)) {
1053 return MapOperation(
1054 context, std::move(f), *leftShape, std::move(*left), rightExpr);
1055 }
1056 }
1057 }
1058 } else if (rightExpr.Rank() > 0 && IsExpandableScalar(leftExpr)) {
1059 if (std::optional<Shape> shape{GetShape(context, rightExpr)}) {
1060 if (auto right{AsFlatArrayConstructor(rightExpr)}) {
1061 return MapOperation(
1062 context, std::move(f), *shape, leftExpr, std::move(*right));
1063 }
1064 }
1065 }
1066 return std::nullopt;
1067 }
1068
1069 template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
1070 auto ApplyElementwise(
1071 FoldingContext &context, Operation<DERIVED, RESULT, LEFT, RIGHT> &operation)
1072 -> std::optional<Expr<RESULT>> {
1073 return ApplyElementwise(context, operation,
1074 std::function<Expr<RESULT>(Expr<LEFT> &&, Expr<RIGHT> &&)>{
1075 [](Expr<LEFT> &&left, Expr<RIGHT> &&right) {
1076 return Expr<RESULT>{DERIVED{std::move(left), std::move(right)}};
1077 }});
1078 }
1079
1080 // Unary operations
1081
1082 template <typename TO, typename FROM>
ConvertString(FROM && s)1083 common::IfNoLvalue<std::optional<TO>, FROM> ConvertString(FROM &&s) {
1084 if constexpr (std::is_same_v<TO, FROM>) {
1085 return std::make_optional<TO>(std::move(s));
1086 } else {
1087 // Fortran character conversion is well defined between distinct kinds
1088 // only when the actual characters are valid 7-bit ASCII.
1089 TO str;
1090 for (auto iter{s.cbegin()}; iter != s.cend(); ++iter) {
1091 if (static_cast<std::uint64_t>(*iter) > 127) {
1092 return std::nullopt;
1093 }
1094 str.push_back(*iter);
1095 }
1096 return std::make_optional<TO>(std::move(str));
1097 }
1098 }
1099
1100 template <typename TO, TypeCategory FROMCAT>
FoldOperation(FoldingContext & context,Convert<TO,FROMCAT> && convert)1101 Expr<TO> FoldOperation(
1102 FoldingContext &context, Convert<TO, FROMCAT> &&convert) {
1103 if (auto array{ApplyElementwise(context, convert)}) {
1104 return *array;
1105 }
1106 struct {
1107 FoldingContext &context;
1108 Convert<TO, FROMCAT> &convert;
1109 } msvcWorkaround{context, convert};
1110 return std::visit(
1111 [&msvcWorkaround](auto &kindExpr) -> Expr<TO> {
1112 using Operand = ResultType<decltype(kindExpr)>;
1113 // This variable is a workaround for msvc which emits an error when
1114 // using the FROMCAT template parameter below.
1115 TypeCategory constexpr FromCat{FROMCAT};
1116 auto &convert{msvcWorkaround.convert};
1117 char buffer[64];
1118 if (auto value{GetScalarConstantValue<Operand>(kindExpr)}) {
1119 FoldingContext &ctx{msvcWorkaround.context};
1120 if constexpr (TO::category == TypeCategory::Integer) {
1121 if constexpr (Operand::category == TypeCategory::Integer) {
1122 auto converted{Scalar<TO>::ConvertSigned(*value)};
1123 if (converted.overflow) {
1124 ctx.messages().Say(
1125 "INTEGER(%d) to INTEGER(%d) conversion overflowed"_en_US,
1126 Operand::kind, TO::kind);
1127 }
1128 return ScalarConstantToExpr(std::move(converted.value));
1129 } else if constexpr (Operand::category == TypeCategory::Real) {
1130 auto converted{value->template ToInteger<Scalar<TO>>()};
1131 if (converted.flags.test(RealFlag::InvalidArgument)) {
1132 ctx.messages().Say(
1133 "REAL(%d) to INTEGER(%d) conversion: invalid argument"_en_US,
1134 Operand::kind, TO::kind);
1135 } else if (converted.flags.test(RealFlag::Overflow)) {
1136 ctx.messages().Say(
1137 "REAL(%d) to INTEGER(%d) conversion overflowed"_en_US,
1138 Operand::kind, TO::kind);
1139 }
1140 return ScalarConstantToExpr(std::move(converted.value));
1141 }
1142 } else if constexpr (TO::category == TypeCategory::Real) {
1143 if constexpr (Operand::category == TypeCategory::Integer) {
1144 auto converted{Scalar<TO>::FromInteger(*value)};
1145 if (!converted.flags.empty()) {
1146 std::snprintf(buffer, sizeof buffer,
1147 "INTEGER(%d) to REAL(%d) conversion", Operand::kind,
1148 TO::kind);
1149 RealFlagWarnings(ctx, converted.flags, buffer);
1150 }
1151 return ScalarConstantToExpr(std::move(converted.value));
1152 } else if constexpr (Operand::category == TypeCategory::Real) {
1153 auto converted{Scalar<TO>::Convert(*value)};
1154 if (!converted.flags.empty()) {
1155 std::snprintf(buffer, sizeof buffer,
1156 "REAL(%d) to REAL(%d) conversion", Operand::kind, TO::kind);
1157 RealFlagWarnings(ctx, converted.flags, buffer);
1158 }
1159 if (ctx.flushSubnormalsToZero()) {
1160 converted.value = converted.value.FlushSubnormalToZero();
1161 }
1162 return ScalarConstantToExpr(std::move(converted.value));
1163 }
1164 } else if constexpr (TO::category == TypeCategory::Complex) {
1165 if constexpr (Operand::category == TypeCategory::Complex) {
1166 return FoldOperation(ctx,
1167 ComplexConstructor<TO::kind>{
1168 AsExpr(Convert<typename TO::Part>{AsCategoryExpr(
1169 Constant<typename Operand::Part>{value->REAL()})}),
1170 AsExpr(Convert<typename TO::Part>{AsCategoryExpr(
1171 Constant<typename Operand::Part>{value->AIMAG()})})});
1172 }
1173 } else if constexpr (TO::category == TypeCategory::Character &&
1174 Operand::category == TypeCategory::Character) {
1175 if (auto converted{ConvertString<Scalar<TO>>(std::move(*value))}) {
1176 return ScalarConstantToExpr(std::move(*converted));
1177 }
1178 } else if constexpr (TO::category == TypeCategory::Logical &&
1179 Operand::category == TypeCategory::Logical) {
1180 return Expr<TO>{value->IsTrue()};
1181 }
1182 } else if constexpr (std::is_same_v<Operand, TO> &&
1183 FromCat != TypeCategory::Character) {
1184 return std::move(kindExpr); // remove needless conversion
1185 }
1186 return Expr<TO>{std::move(convert)};
1187 },
1188 convert.left().u);
1189 }
1190
1191 template <typename T>
FoldOperation(FoldingContext & context,Parentheses<T> && x)1192 Expr<T> FoldOperation(FoldingContext &context, Parentheses<T> &&x) {
1193 auto &operand{x.left()};
1194 operand = Fold(context, std::move(operand));
1195 if (auto value{GetScalarConstantValue<T>(operand)}) {
1196 // Preserve parentheses, even around constants.
1197 return Expr<T>{Parentheses<T>{Expr<T>{Constant<T>{*value}}}};
1198 } else if (std::holds_alternative<Parentheses<T>>(operand.u)) {
1199 // ((x)) -> (x)
1200 return std::move(operand);
1201 } else {
1202 return Expr<T>{Parentheses<T>{std::move(operand)}};
1203 }
1204 }
1205
1206 template <typename T>
FoldOperation(FoldingContext & context,Negate<T> && x)1207 Expr<T> FoldOperation(FoldingContext &context, Negate<T> &&x) {
1208 if (auto array{ApplyElementwise(context, x)}) {
1209 return *array;
1210 }
1211 auto &operand{x.left()};
1212 if (auto value{GetScalarConstantValue<T>(operand)}) {
1213 if constexpr (T::category == TypeCategory::Integer) {
1214 auto negated{value->Negate()};
1215 if (negated.overflow) {
1216 context.messages().Say(
1217 "INTEGER(%d) negation overflowed"_en_US, T::kind);
1218 }
1219 return Expr<T>{Constant<T>{std::move(negated.value)}};
1220 } else {
1221 // REAL & COMPLEX negation: no exceptions possible
1222 return Expr<T>{Constant<T>{value->Negate()}};
1223 }
1224 }
1225 return Expr<T>{std::move(x)};
1226 }
1227
1228 // Binary (dyadic) operations
1229
1230 template <typename LEFT, typename RIGHT>
OperandsAreConstants(const Expr<LEFT> & x,const Expr<RIGHT> & y)1231 std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
1232 const Expr<LEFT> &x, const Expr<RIGHT> &y) {
1233 if (auto xvalue{GetScalarConstantValue<LEFT>(x)}) {
1234 if (auto yvalue{GetScalarConstantValue<RIGHT>(y)}) {
1235 return {std::make_pair(*xvalue, *yvalue)};
1236 }
1237 }
1238 return std::nullopt;
1239 }
1240
1241 template <typename DERIVED, typename RESULT, typename LEFT, typename RIGHT>
OperandsAreConstants(const Operation<DERIVED,RESULT,LEFT,RIGHT> & operation)1242 std::optional<std::pair<Scalar<LEFT>, Scalar<RIGHT>>> OperandsAreConstants(
1243 const Operation<DERIVED, RESULT, LEFT, RIGHT> &operation) {
1244 return OperandsAreConstants(operation.left(), operation.right());
1245 }
1246
1247 template <typename T>
FoldOperation(FoldingContext & context,Add<T> && x)1248 Expr<T> FoldOperation(FoldingContext &context, Add<T> &&x) {
1249 if (auto array{ApplyElementwise(context, x)}) {
1250 return *array;
1251 }
1252 if (auto folded{OperandsAreConstants(x)}) {
1253 if constexpr (T::category == TypeCategory::Integer) {
1254 auto sum{folded->first.AddSigned(folded->second)};
1255 if (sum.overflow) {
1256 context.messages().Say(
1257 "INTEGER(%d) addition overflowed"_en_US, T::kind);
1258 }
1259 return Expr<T>{Constant<T>{sum.value}};
1260 } else {
1261 auto sum{folded->first.Add(folded->second, context.rounding())};
1262 RealFlagWarnings(context, sum.flags, "addition");
1263 if (context.flushSubnormalsToZero()) {
1264 sum.value = sum.value.FlushSubnormalToZero();
1265 }
1266 return Expr<T>{Constant<T>{sum.value}};
1267 }
1268 }
1269 return Expr<T>{std::move(x)};
1270 }
1271
1272 template <typename T>
FoldOperation(FoldingContext & context,Subtract<T> && x)1273 Expr<T> FoldOperation(FoldingContext &context, Subtract<T> &&x) {
1274 if (auto array{ApplyElementwise(context, x)}) {
1275 return *array;
1276 }
1277 if (auto folded{OperandsAreConstants(x)}) {
1278 if constexpr (T::category == TypeCategory::Integer) {
1279 auto difference{folded->first.SubtractSigned(folded->second)};
1280 if (difference.overflow) {
1281 context.messages().Say(
1282 "INTEGER(%d) subtraction overflowed"_en_US, T::kind);
1283 }
1284 return Expr<T>{Constant<T>{difference.value}};
1285 } else {
1286 auto difference{
1287 folded->first.Subtract(folded->second, context.rounding())};
1288 RealFlagWarnings(context, difference.flags, "subtraction");
1289 if (context.flushSubnormalsToZero()) {
1290 difference.value = difference.value.FlushSubnormalToZero();
1291 }
1292 return Expr<T>{Constant<T>{difference.value}};
1293 }
1294 }
1295 return Expr<T>{std::move(x)};
1296 }
1297
1298 template <typename T>
FoldOperation(FoldingContext & context,Multiply<T> && x)1299 Expr<T> FoldOperation(FoldingContext &context, Multiply<T> &&x) {
1300 if (auto array{ApplyElementwise(context, x)}) {
1301 return *array;
1302 }
1303 if (auto folded{OperandsAreConstants(x)}) {
1304 if constexpr (T::category == TypeCategory::Integer) {
1305 auto product{folded->first.MultiplySigned(folded->second)};
1306 if (product.SignedMultiplicationOverflowed()) {
1307 context.messages().Say(
1308 "INTEGER(%d) multiplication overflowed"_en_US, T::kind);
1309 }
1310 return Expr<T>{Constant<T>{product.lower}};
1311 } else {
1312 auto product{folded->first.Multiply(folded->second, context.rounding())};
1313 RealFlagWarnings(context, product.flags, "multiplication");
1314 if (context.flushSubnormalsToZero()) {
1315 product.value = product.value.FlushSubnormalToZero();
1316 }
1317 return Expr<T>{Constant<T>{product.value}};
1318 }
1319 }
1320 return Expr<T>{std::move(x)};
1321 }
1322
1323 template <typename T>
FoldOperation(FoldingContext & context,Divide<T> && x)1324 Expr<T> FoldOperation(FoldingContext &context, Divide<T> &&x) {
1325 if (auto array{ApplyElementwise(context, x)}) {
1326 return *array;
1327 }
1328 if (auto folded{OperandsAreConstants(x)}) {
1329 if constexpr (T::category == TypeCategory::Integer) {
1330 auto quotAndRem{folded->first.DivideSigned(folded->second)};
1331 if (quotAndRem.divisionByZero) {
1332 context.messages().Say("INTEGER(%d) division by zero"_en_US, T::kind);
1333 return Expr<T>{std::move(x)};
1334 }
1335 if (quotAndRem.overflow) {
1336 context.messages().Say(
1337 "INTEGER(%d) division overflowed"_en_US, T::kind);
1338 }
1339 return Expr<T>{Constant<T>{quotAndRem.quotient}};
1340 } else {
1341 auto quotient{folded->first.Divide(folded->second, context.rounding())};
1342 RealFlagWarnings(context, quotient.flags, "division");
1343 if (context.flushSubnormalsToZero()) {
1344 quotient.value = quotient.value.FlushSubnormalToZero();
1345 }
1346 return Expr<T>{Constant<T>{quotient.value}};
1347 }
1348 }
1349 return Expr<T>{std::move(x)};
1350 }
1351
1352 template <typename T>
FoldOperation(FoldingContext & context,Power<T> && x)1353 Expr<T> FoldOperation(FoldingContext &context, Power<T> &&x) {
1354 if (auto array{ApplyElementwise(context, x)}) {
1355 return *array;
1356 }
1357 if (auto folded{OperandsAreConstants(x)}) {
1358 if constexpr (T::category == TypeCategory::Integer) {
1359 auto power{folded->first.Power(folded->second)};
1360 if (power.divisionByZero) {
1361 context.messages().Say(
1362 "INTEGER(%d) zero to negative power"_en_US, T::kind);
1363 } else if (power.overflow) {
1364 context.messages().Say("INTEGER(%d) power overflowed"_en_US, T::kind);
1365 } else if (power.zeroToZero) {
1366 context.messages().Say(
1367 "INTEGER(%d) 0**0 is not defined"_en_US, T::kind);
1368 }
1369 return Expr<T>{Constant<T>{power.power}};
1370 } else {
1371 if (auto callable{GetHostRuntimeWrapper<T, T, T>("pow")}) {
1372 return Expr<T>{
1373 Constant<T>{(*callable)(context, folded->first, folded->second)}};
1374 } else {
1375 context.messages().Say(
1376 "Power for %s cannot be folded on host"_en_US, T{}.AsFortran());
1377 }
1378 }
1379 }
1380 return Expr<T>{std::move(x)};
1381 }
1382
1383 template <typename T>
FoldOperation(FoldingContext & context,RealToIntPower<T> && x)1384 Expr<T> FoldOperation(FoldingContext &context, RealToIntPower<T> &&x) {
1385 if (auto array{ApplyElementwise(context, x)}) {
1386 return *array;
1387 }
1388 return std::visit(
1389 [&](auto &y) -> Expr<T> {
1390 if (auto folded{OperandsAreConstants(x.left(), y)}) {
1391 auto power{evaluate::IntPower(folded->first, folded->second)};
1392 RealFlagWarnings(context, power.flags, "power with INTEGER exponent");
1393 if (context.flushSubnormalsToZero()) {
1394 power.value = power.value.FlushSubnormalToZero();
1395 }
1396 return Expr<T>{Constant<T>{power.value}};
1397 } else {
1398 return Expr<T>{std::move(x)};
1399 }
1400 },
1401 x.right().u);
1402 }
1403
1404 template <typename T>
FoldOperation(FoldingContext & context,Extremum<T> && x)1405 Expr<T> FoldOperation(FoldingContext &context, Extremum<T> &&x) {
1406 if (auto array{ApplyElementwise(context, x,
1407 std::function<Expr<T>(Expr<T> &&, Expr<T> &&)>{[=](Expr<T> &&l,
1408 Expr<T> &&r) {
1409 return Expr<T>{Extremum<T>{x.ordering, std::move(l), std::move(r)}};
1410 }})}) {
1411 return *array;
1412 }
1413 if (auto folded{OperandsAreConstants(x)}) {
1414 if constexpr (T::category == TypeCategory::Integer) {
1415 if (folded->first.CompareSigned(folded->second) == x.ordering) {
1416 return Expr<T>{Constant<T>{folded->first}};
1417 }
1418 } else if constexpr (T::category == TypeCategory::Real) {
1419 if (folded->first.IsNotANumber() ||
1420 (folded->first.Compare(folded->second) == Relation::Less) ==
1421 (x.ordering == Ordering::Less)) {
1422 return Expr<T>{Constant<T>{folded->first}};
1423 }
1424 } else {
1425 static_assert(T::category == TypeCategory::Character);
1426 // Result of MIN and MAX on character has the length of
1427 // the longest argument.
1428 auto maxLen{std::max(folded->first.length(), folded->second.length())};
1429 bool isFirst{x.ordering == Compare(folded->first, folded->second)};
1430 auto res{isFirst ? std::move(folded->first) : std::move(folded->second)};
1431 res = res.length() == maxLen
1432 ? std::move(res)
1433 : CharacterUtils<T::kind>::Resize(res, maxLen);
1434 return Expr<T>{Constant<T>{std::move(res)}};
1435 }
1436 return Expr<T>{Constant<T>{folded->second}};
1437 }
1438 return Expr<T>{std::move(x)};
1439 }
1440
1441 template <int KIND>
ToReal(FoldingContext & context,Expr<SomeType> && expr)1442 Expr<Type<TypeCategory::Real, KIND>> ToReal(
1443 FoldingContext &context, Expr<SomeType> &&expr) {
1444 using Result = Type<TypeCategory::Real, KIND>;
1445 std::optional<Expr<Result>> result;
1446 std::visit(
1447 [&](auto &&x) {
1448 using From = std::decay_t<decltype(x)>;
1449 if constexpr (std::is_same_v<From, BOZLiteralConstant>) {
1450 // Move the bits without any integer->real conversion
1451 From original{x};
1452 result = ConvertToType<Result>(std::move(x));
1453 const auto *constant{UnwrapExpr<Constant<Result>>(*result)};
1454 CHECK(constant);
1455 Scalar<Result> real{constant->GetScalarValue().value()};
1456 From converted{From::ConvertUnsigned(real.RawBits()).value};
1457 if (original != converted) { // C1601
1458 context.messages().Say(
1459 "Nonzero bits truncated from BOZ literal constant in REAL intrinsic"_en_US);
1460 }
1461 } else if constexpr (IsNumericCategoryExpr<From>()) {
1462 result = Fold(context, ConvertToType<Result>(std::move(x)));
1463 } else {
1464 common::die("ToReal: bad argument expression");
1465 }
1466 },
1467 std::move(expr.u));
1468 return result.value();
1469 }
1470
1471 template <typename T>
Rewrite(FoldingContext & context,Expr<T> && expr)1472 Expr<T> ExpressionBase<T>::Rewrite(FoldingContext &context, Expr<T> &&expr) {
1473 return std::visit(
1474 [&](auto &&x) -> Expr<T> {
1475 if constexpr (IsSpecificIntrinsicType<T>) {
1476 return FoldOperation(context, std::move(x));
1477 } else if constexpr (std::is_same_v<T, SomeDerived>) {
1478 return FoldOperation(context, std::move(x));
1479 } else if constexpr (common::HasMember<decltype(x),
1480 TypelessExpression>) {
1481 return std::move(expr);
1482 } else {
1483 return Expr<T>{Fold(context, std::move(x))};
1484 }
1485 },
1486 std::move(expr.u));
1487 }
1488
1489 FOR_EACH_TYPE_AND_KIND(extern template class ExpressionBase, )
1490
1491 } // namespace Fortran::evaluate
1492 #endif // FORTRAN_EVALUATE_FOLD_IMPLEMENTATION_H_
1493