1 //===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/fold.h"
10 #include "fold-implementation.h"
11
12 namespace Fortran::evaluate {
13
GetConstantSubscript(FoldingContext & context,Subscript & ss,const NamedEntity & base,int dim)14 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
15 FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
16 ss = FoldOperation(context, std::move(ss));
17 return std::visit(
18 common::visitors{
19 [](IndirectSubscriptIntegerExpr &expr)
20 -> std::optional<Constant<SubscriptInteger>> {
21 if (const auto *constant{
22 UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
23 return *constant;
24 } else {
25 return std::nullopt;
26 }
27 },
28 [&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
29 auto lower{triplet.lower()}, upper{triplet.upper()};
30 std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
31 if (!lower) {
32 lower = GetLowerBound(context, base, dim);
33 }
34 if (!upper) {
35 upper =
36 ComputeUpperBound(context, GetLowerBound(context, base, dim),
37 GetExtent(context, base, dim));
38 }
39 auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
40 if (lbi && ubi && stride && *stride != 0) {
41 std::vector<SubscriptInteger::Scalar> values;
42 while ((*stride > 0 && *lbi <= *ubi) ||
43 (*stride < 0 && *lbi >= *ubi)) {
44 values.emplace_back(*lbi);
45 *lbi += *stride;
46 }
47 return Constant<SubscriptInteger>{std::move(values),
48 ConstantSubscripts{
49 static_cast<ConstantSubscript>(values.size())}};
50 } else {
51 return std::nullopt;
52 }
53 },
54 },
55 ss.u);
56 }
57
FoldOperation(FoldingContext & context,StructureConstructor && structure)58 Expr<SomeDerived> FoldOperation(
59 FoldingContext &context, StructureConstructor &&structure) {
60 StructureConstructor ctor{structure.derivedTypeSpec()};
61 bool constantExtents{true};
62 for (auto &&[symbol, value] : std::move(structure)) {
63 auto expr{Fold(context, std::move(value.value()))};
64 if (!IsPointer(symbol)) {
65 bool ok{false};
66 if (auto valueShape{GetConstantExtents(context, expr)}) {
67 if (auto componentShape{GetConstantExtents(context, symbol)}) {
68 if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
69 expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
70 std::move(expr));
71 ok = expr.Rank() > 0;
72 } else {
73 ok = *valueShape == *componentShape;
74 }
75 }
76 }
77 if (!ok) {
78 constantExtents = false;
79 }
80 }
81 ctor.Add(symbol, Fold(context, std::move(expr)));
82 }
83 if (constantExtents && IsConstantExpr(ctor)) {
84 return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
85 } else {
86 return Expr<SomeDerived>{std::move(ctor)};
87 }
88 }
89
FoldOperation(FoldingContext & context,Component && component)90 Component FoldOperation(FoldingContext &context, Component &&component) {
91 return {FoldOperation(context, std::move(component.base())),
92 component.GetLastSymbol()};
93 }
94
FoldOperation(FoldingContext & context,NamedEntity && x)95 NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
96 if (Component * c{x.UnwrapComponent()}) {
97 return NamedEntity{FoldOperation(context, std::move(*c))};
98 } else {
99 return std::move(x);
100 }
101 }
102
FoldOperation(FoldingContext & context,Triplet && triplet)103 Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
104 MaybeExtentExpr lower{triplet.lower()};
105 MaybeExtentExpr upper{triplet.upper()};
106 return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
107 Fold(context, triplet.stride())};
108 }
109
FoldOperation(FoldingContext & context,Subscript && subscript)110 Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
111 return std::visit(common::visitors{
112 [&](IndirectSubscriptIntegerExpr &&expr) {
113 expr.value() = Fold(context, std::move(expr.value()));
114 return Subscript(std::move(expr));
115 },
116 [&](Triplet &&triplet) {
117 return Subscript(
118 FoldOperation(context, std::move(triplet)));
119 },
120 },
121 std::move(subscript.u));
122 }
123
FoldOperation(FoldingContext & context,ArrayRef && arrayRef)124 ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
125 NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
126 for (Subscript &subscript : arrayRef.subscript()) {
127 subscript = FoldOperation(context, std::move(subscript));
128 }
129 return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
130 }
131
FoldOperation(FoldingContext & context,CoarrayRef && coarrayRef)132 CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
133 std::vector<Subscript> subscript;
134 for (Subscript x : coarrayRef.subscript()) {
135 subscript.emplace_back(FoldOperation(context, std::move(x)));
136 }
137 std::vector<Expr<SubscriptInteger>> cosubscript;
138 for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
139 cosubscript.emplace_back(Fold(context, std::move(x)));
140 }
141 CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
142 std::move(cosubscript)};
143 if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
144 folded.set_stat(Fold(context, std::move(*stat)));
145 }
146 if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
147 folded.set_team(
148 Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
149 }
150 return folded;
151 }
152
FoldOperation(FoldingContext & context,DataRef && dataRef)153 DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
154 return std::visit(common::visitors{
155 [&](SymbolRef symbol) { return DataRef{*symbol}; },
156 [&](auto &&x) {
157 return DataRef{FoldOperation(context, std::move(x))};
158 },
159 },
160 std::move(dataRef.u));
161 }
162
FoldOperation(FoldingContext & context,Substring && substring)163 Substring FoldOperation(FoldingContext &context, Substring &&substring) {
164 auto lower{Fold(context, substring.lower())};
165 auto upper{Fold(context, substring.upper())};
166 if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
167 return Substring{FoldOperation(context, DataRef{*dataRef}),
168 std::move(lower), std::move(upper)};
169 } else {
170 auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
171 return Substring{std::move(p), std::move(lower), std::move(upper)};
172 }
173 }
174
FoldOperation(FoldingContext & context,ComplexPart && complexPart)175 ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
176 DataRef complex{complexPart.complex()};
177 return ComplexPart{
178 FoldOperation(context, std::move(complex)), complexPart.part()};
179 }
180
GetInt64Arg(const std::optional<ActualArgument> & arg)181 std::optional<std::int64_t> GetInt64Arg(
182 const std::optional<ActualArgument> &arg) {
183 if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
184 return ToInt64(*intExpr);
185 } else {
186 return std::nullopt;
187 }
188 }
189
GetInt64ArgOr(const std::optional<ActualArgument> & arg,std::int64_t defaultValue)190 std::optional<std::int64_t> GetInt64ArgOr(
191 const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
192 if (!arg) {
193 return defaultValue;
194 } else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
195 return ToInt64(*intExpr);
196 } else {
197 return std::nullopt;
198 }
199 }
200
FoldOperation(FoldingContext & context,ImpliedDoIndex && iDo)201 Expr<ImpliedDoIndex::Result> FoldOperation(
202 FoldingContext &context, ImpliedDoIndex &&iDo) {
203 if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
204 return Expr<ImpliedDoIndex::Result>{*value};
205 } else {
206 return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
207 }
208 }
209
210 template class ExpressionBase<SomeDerived>;
211 template class ExpressionBase<SomeType>;
212
213 } // namespace Fortran::evaluate
214