• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/fold.h"
10 #include "fold-implementation.h"
11 
12 namespace Fortran::evaluate {
13 
GetConstantSubscript(FoldingContext & context,Subscript & ss,const NamedEntity & base,int dim)14 std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
15     FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
16   ss = FoldOperation(context, std::move(ss));
17   return std::visit(
18       common::visitors{
19           [](IndirectSubscriptIntegerExpr &expr)
20               -> std::optional<Constant<SubscriptInteger>> {
21             if (const auto *constant{
22                     UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
23               return *constant;
24             } else {
25               return std::nullopt;
26             }
27           },
28           [&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
29             auto lower{triplet.lower()}, upper{triplet.upper()};
30             std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
31             if (!lower) {
32               lower = GetLowerBound(context, base, dim);
33             }
34             if (!upper) {
35               upper =
36                   ComputeUpperBound(context, GetLowerBound(context, base, dim),
37                       GetExtent(context, base, dim));
38             }
39             auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
40             if (lbi && ubi && stride && *stride != 0) {
41               std::vector<SubscriptInteger::Scalar> values;
42               while ((*stride > 0 && *lbi <= *ubi) ||
43                   (*stride < 0 && *lbi >= *ubi)) {
44                 values.emplace_back(*lbi);
45                 *lbi += *stride;
46               }
47               return Constant<SubscriptInteger>{std::move(values),
48                   ConstantSubscripts{
49                       static_cast<ConstantSubscript>(values.size())}};
50             } else {
51               return std::nullopt;
52             }
53           },
54       },
55       ss.u);
56 }
57 
FoldOperation(FoldingContext & context,StructureConstructor && structure)58 Expr<SomeDerived> FoldOperation(
59     FoldingContext &context, StructureConstructor &&structure) {
60   StructureConstructor ctor{structure.derivedTypeSpec()};
61   bool constantExtents{true};
62   for (auto &&[symbol, value] : std::move(structure)) {
63     auto expr{Fold(context, std::move(value.value()))};
64     if (!IsPointer(symbol)) {
65       bool ok{false};
66       if (auto valueShape{GetConstantExtents(context, expr)}) {
67         if (auto componentShape{GetConstantExtents(context, symbol)}) {
68           if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
69             expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
70                 std::move(expr));
71             ok = expr.Rank() > 0;
72           } else {
73             ok = *valueShape == *componentShape;
74           }
75         }
76       }
77       if (!ok) {
78         constantExtents = false;
79       }
80     }
81     ctor.Add(symbol, Fold(context, std::move(expr)));
82   }
83   if (constantExtents && IsConstantExpr(ctor)) {
84     return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
85   } else {
86     return Expr<SomeDerived>{std::move(ctor)};
87   }
88 }
89 
FoldOperation(FoldingContext & context,Component && component)90 Component FoldOperation(FoldingContext &context, Component &&component) {
91   return {FoldOperation(context, std::move(component.base())),
92       component.GetLastSymbol()};
93 }
94 
FoldOperation(FoldingContext & context,NamedEntity && x)95 NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
96   if (Component * c{x.UnwrapComponent()}) {
97     return NamedEntity{FoldOperation(context, std::move(*c))};
98   } else {
99     return std::move(x);
100   }
101 }
102 
FoldOperation(FoldingContext & context,Triplet && triplet)103 Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
104   MaybeExtentExpr lower{triplet.lower()};
105   MaybeExtentExpr upper{triplet.upper()};
106   return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
107       Fold(context, triplet.stride())};
108 }
109 
FoldOperation(FoldingContext & context,Subscript && subscript)110 Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
111   return std::visit(common::visitors{
112                         [&](IndirectSubscriptIntegerExpr &&expr) {
113                           expr.value() = Fold(context, std::move(expr.value()));
114                           return Subscript(std::move(expr));
115                         },
116                         [&](Triplet &&triplet) {
117                           return Subscript(
118                               FoldOperation(context, std::move(triplet)));
119                         },
120                     },
121       std::move(subscript.u));
122 }
123 
FoldOperation(FoldingContext & context,ArrayRef && arrayRef)124 ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
125   NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
126   for (Subscript &subscript : arrayRef.subscript()) {
127     subscript = FoldOperation(context, std::move(subscript));
128   }
129   return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
130 }
131 
FoldOperation(FoldingContext & context,CoarrayRef && coarrayRef)132 CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
133   std::vector<Subscript> subscript;
134   for (Subscript x : coarrayRef.subscript()) {
135     subscript.emplace_back(FoldOperation(context, std::move(x)));
136   }
137   std::vector<Expr<SubscriptInteger>> cosubscript;
138   for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
139     cosubscript.emplace_back(Fold(context, std::move(x)));
140   }
141   CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
142       std::move(cosubscript)};
143   if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
144     folded.set_stat(Fold(context, std::move(*stat)));
145   }
146   if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
147     folded.set_team(
148         Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
149   }
150   return folded;
151 }
152 
FoldOperation(FoldingContext & context,DataRef && dataRef)153 DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
154   return std::visit(common::visitors{
155                         [&](SymbolRef symbol) { return DataRef{*symbol}; },
156                         [&](auto &&x) {
157                           return DataRef{FoldOperation(context, std::move(x))};
158                         },
159                     },
160       std::move(dataRef.u));
161 }
162 
FoldOperation(FoldingContext & context,Substring && substring)163 Substring FoldOperation(FoldingContext &context, Substring &&substring) {
164   auto lower{Fold(context, substring.lower())};
165   auto upper{Fold(context, substring.upper())};
166   if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
167     return Substring{FoldOperation(context, DataRef{*dataRef}),
168         std::move(lower), std::move(upper)};
169   } else {
170     auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
171     return Substring{std::move(p), std::move(lower), std::move(upper)};
172   }
173 }
174 
FoldOperation(FoldingContext & context,ComplexPart && complexPart)175 ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
176   DataRef complex{complexPart.complex()};
177   return ComplexPart{
178       FoldOperation(context, std::move(complex)), complexPart.part()};
179 }
180 
GetInt64Arg(const std::optional<ActualArgument> & arg)181 std::optional<std::int64_t> GetInt64Arg(
182     const std::optional<ActualArgument> &arg) {
183   if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
184     return ToInt64(*intExpr);
185   } else {
186     return std::nullopt;
187   }
188 }
189 
GetInt64ArgOr(const std::optional<ActualArgument> & arg,std::int64_t defaultValue)190 std::optional<std::int64_t> GetInt64ArgOr(
191     const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
192   if (!arg) {
193     return defaultValue;
194   } else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
195     return ToInt64(*intExpr);
196   } else {
197     return std::nullopt;
198   }
199 }
200 
FoldOperation(FoldingContext & context,ImpliedDoIndex && iDo)201 Expr<ImpliedDoIndex::Result> FoldOperation(
202     FoldingContext &context, ImpliedDoIndex &&iDo) {
203   if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
204     return Expr<ImpliedDoIndex::Result>{*value};
205   } else {
206     return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
207   }
208 }
209 
210 template class ExpressionBase<SomeDerived>;
211 template class ExpressionBase<SomeType>;
212 
213 } // namespace Fortran::evaluate
214