1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/intrinsics.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/type.h"
17 #include "flang/Parser/message.h"
18 #include "flang/Semantics/symbol.h"
19 #include <functional>
20
21 using namespace std::placeholders; // _1, _2, &c. for std::bind()
22
23 namespace Fortran::evaluate {
24
IsImpliedShape(const Symbol & symbol0)25 bool IsImpliedShape(const Symbol &symbol0) {
26 const Symbol &symbol{ResolveAssociations(symbol0)};
27 const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
28 return symbol.attrs().test(semantics::Attr::PARAMETER) && details &&
29 details->shape().IsImpliedShape();
30 }
31
IsExplicitShape(const Symbol & symbol0)32 bool IsExplicitShape(const Symbol &symbol0) {
33 const Symbol &symbol{ResolveAssociations(symbol0)};
34 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
35 const auto &shape{details->shape()};
36 return shape.Rank() == 0 || shape.IsExplicitShape(); // even if scalar
37 } else {
38 return false;
39 }
40 }
41
AsShape(const Constant<ExtentType> & arrayConstant)42 Shape AsShape(const Constant<ExtentType> &arrayConstant) {
43 CHECK(arrayConstant.Rank() == 1);
44 Shape result;
45 std::size_t dimensions{arrayConstant.size()};
46 for (std::size_t j{0}; j < dimensions; ++j) {
47 Scalar<ExtentType> extent{arrayConstant.values().at(j)};
48 result.emplace_back(MaybeExtentExpr{ExtentExpr{extent}});
49 }
50 return result;
51 }
52
AsShape(FoldingContext & context,ExtentExpr && arrayExpr)53 std::optional<Shape> AsShape(FoldingContext &context, ExtentExpr &&arrayExpr) {
54 // Flatten any array expression into an array constructor if possible.
55 arrayExpr = Fold(context, std::move(arrayExpr));
56 if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
57 return AsShape(*constArray);
58 }
59 if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
60 Shape result;
61 for (auto &value : *constructor) {
62 if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
63 if (expr->Rank() == 0) {
64 result.emplace_back(std::move(*expr));
65 continue;
66 }
67 }
68 return std::nullopt;
69 }
70 return result;
71 }
72 return std::nullopt;
73 }
74
AsExtentArrayExpr(const Shape & shape)75 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
76 ArrayConstructorValues<ExtentType> values;
77 for (const auto &dim : shape) {
78 if (dim) {
79 values.Push(common::Clone(*dim));
80 } else {
81 return std::nullopt;
82 }
83 }
84 return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
85 }
86
AsConstantShape(FoldingContext & context,const Shape & shape)87 std::optional<Constant<ExtentType>> AsConstantShape(
88 FoldingContext &context, const Shape &shape) {
89 if (auto shapeArray{AsExtentArrayExpr(shape)}) {
90 auto folded{Fold(context, std::move(*shapeArray))};
91 if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
92 return std::move(*p);
93 }
94 }
95 return std::nullopt;
96 }
97
AsConstantShape(const ConstantSubscripts & shape)98 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
99 using IntType = Scalar<SubscriptInteger>;
100 std::vector<IntType> result;
101 for (auto dim : shape) {
102 result.emplace_back(dim);
103 }
104 return {std::move(result), ConstantSubscripts{GetRank(shape)}};
105 }
106
AsConstantExtents(const Constant<ExtentType> & shape)107 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
108 ConstantSubscripts result;
109 for (const auto &extent : shape.values()) {
110 result.push_back(extent.ToInt64());
111 }
112 return result;
113 }
114
AsConstantExtents(FoldingContext & context,const Shape & shape)115 std::optional<ConstantSubscripts> AsConstantExtents(
116 FoldingContext &context, const Shape &shape) {
117 if (auto shapeConstant{AsConstantShape(context, shape)}) {
118 return AsConstantExtents(*shapeConstant);
119 } else {
120 return std::nullopt;
121 }
122 }
123
ComputeTripCount(FoldingContext & context,ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)124 static ExtentExpr ComputeTripCount(FoldingContext &context, ExtentExpr &&lower,
125 ExtentExpr &&upper, ExtentExpr &&stride) {
126 ExtentExpr strideCopy{common::Clone(stride)};
127 ExtentExpr span{
128 (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
129 std::move(stride)};
130 ExtentExpr extent{
131 Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
132 return Fold(context, std::move(extent));
133 }
134
CountTrips(FoldingContext & context,ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)135 ExtentExpr CountTrips(FoldingContext &context, ExtentExpr &&lower,
136 ExtentExpr &&upper, ExtentExpr &&stride) {
137 return ComputeTripCount(
138 context, std::move(lower), std::move(upper), std::move(stride));
139 }
140
CountTrips(FoldingContext & context,const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)141 ExtentExpr CountTrips(FoldingContext &context, const ExtentExpr &lower,
142 const ExtentExpr &upper, const ExtentExpr &stride) {
143 return ComputeTripCount(context, common::Clone(lower), common::Clone(upper),
144 common::Clone(stride));
145 }
146
CountTrips(FoldingContext & context,MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)147 MaybeExtentExpr CountTrips(FoldingContext &context, MaybeExtentExpr &&lower,
148 MaybeExtentExpr &&upper, MaybeExtentExpr &&stride) {
149 std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
150 std::bind(ComputeTripCount, context, _1, _2, _3)};
151 return common::MapOptional(
152 std::move(bound), std::move(lower), std::move(upper), std::move(stride));
153 }
154
GetSize(Shape && shape)155 MaybeExtentExpr GetSize(Shape &&shape) {
156 ExtentExpr extent{1};
157 for (auto &&dim : std::move(shape)) {
158 if (dim) {
159 extent = std::move(extent) * std::move(*dim);
160 } else {
161 return std::nullopt;
162 }
163 }
164 return extent;
165 }
166
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)167 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
168 struct MyVisitor : public AnyTraverse<MyVisitor> {
169 using Base = AnyTraverse<MyVisitor>;
170 MyVisitor() : Base{*this} {}
171 using Base::operator();
172 bool operator()(const ImpliedDoIndex &) { return true; }
173 };
174 return MyVisitor{}(expr);
175 }
176
177 // Determines lower bound on a dimension. This can be other than 1 only
178 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
179 // ASSOCIATE construct entities may require tranversal of their referents.
180 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> {
181 public:
182 using Result = ExtentExpr;
183 using Base = Traverse<GetLowerBoundHelper, ExtentExpr>;
184 using Base::operator();
GetLowerBoundHelper(FoldingContext & c,int d)185 GetLowerBoundHelper(FoldingContext &c, int d)
186 : Base{*this}, context_{c}, dimension_{d} {}
Default()187 static ExtentExpr Default() { return ExtentExpr{1}; }
Combine(Result &&,Result &&)188 static ExtentExpr Combine(Result &&, Result &&) { return Default(); }
189 ExtentExpr operator()(const Symbol &);
190 ExtentExpr operator()(const Component &);
191
192 private:
193 FoldingContext &context_;
194 int dimension_;
195 };
196
operator ()(const Symbol & symbol0)197 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
198 const Symbol &symbol{symbol0.GetUltimate()};
199 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
200 int j{0};
201 for (const auto &shapeSpec : details->shape()) {
202 if (j++ == dimension_) {
203 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
204 return Fold(context_, common::Clone(*bound));
205 } else if (IsDescriptor(symbol)) {
206 return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
207 DescriptorInquiry::Field::LowerBound, dimension_}};
208 } else {
209 break;
210 }
211 }
212 }
213 } else if (const auto *assoc{
214 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
215 return (*this)(assoc->expr());
216 }
217 return Default();
218 }
219
operator ()(const Component & component)220 auto GetLowerBoundHelper::operator()(const Component &component) -> Result {
221 if (component.base().Rank() == 0) {
222 const Symbol &symbol{component.GetLastSymbol().GetUltimate()};
223 if (const auto *details{
224 symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
225 int j{0};
226 for (const auto &shapeSpec : details->shape()) {
227 if (j++ == dimension_) {
228 if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
229 return Fold(context_, common::Clone(*bound));
230 } else if (IsDescriptor(symbol)) {
231 return ExtentExpr{
232 DescriptorInquiry{NamedEntity{common::Clone(component)},
233 DescriptorInquiry::Field::LowerBound, dimension_}};
234 } else {
235 break;
236 }
237 }
238 }
239 }
240 }
241 return Default();
242 }
243
GetLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)244 ExtentExpr GetLowerBound(
245 FoldingContext &context, const NamedEntity &base, int dimension) {
246 return GetLowerBoundHelper{context, dimension}(base);
247 }
248
GetLowerBounds(FoldingContext & context,const NamedEntity & base)249 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) {
250 Shape result;
251 int rank{base.Rank()};
252 for (int dim{0}; dim < rank; ++dim) {
253 result.emplace_back(GetLowerBound(context, base, dim));
254 }
255 return result;
256 }
257
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)258 MaybeExtentExpr GetExtent(
259 FoldingContext &context, const NamedEntity &base, int dimension) {
260 CHECK(dimension >= 0);
261 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
262 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
263 if (IsImpliedShape(symbol)) {
264 Shape shape{GetShape(context, symbol).value()};
265 return std::move(shape.at(dimension));
266 }
267 int j{0};
268 for (const auto &shapeSpec : details->shape()) {
269 if (j++ == dimension) {
270 if (shapeSpec.ubound().isExplicit()) {
271 if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
272 if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
273 return Fold(context,
274 common::Clone(ubound.value()) -
275 common::Clone(lbound.value()) + ExtentExpr{1});
276 } else {
277 return Fold(context, common::Clone(ubound.value()));
278 }
279 }
280 } else if (details->IsAssumedSize() && j == symbol.Rank()) {
281 return std::nullopt;
282 } else if (semantics::IsDescriptor(symbol)) {
283 return ExtentExpr{DescriptorInquiry{
284 NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
285 }
286 }
287 }
288 } else if (const auto *assoc{
289 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
290 if (auto shape{GetShape(context, assoc->expr())}) {
291 if (dimension < static_cast<int>(shape->size())) {
292 return std::move(shape->at(dimension));
293 }
294 }
295 }
296 return std::nullopt;
297 }
298
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)299 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
300 const NamedEntity &base, int dimension) {
301 return std::visit(
302 common::visitors{
303 [&](const Triplet &triplet) -> MaybeExtentExpr {
304 MaybeExtentExpr upper{triplet.upper()};
305 if (!upper) {
306 upper = GetUpperBound(context, base, dimension);
307 }
308 MaybeExtentExpr lower{triplet.lower()};
309 if (!lower) {
310 lower = GetLowerBound(context, base, dimension);
311 }
312 return CountTrips(context, std::move(lower), std::move(upper),
313 MaybeExtentExpr{triplet.stride()});
314 },
315 [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
316 if (auto shape{GetShape(context, subs.value())}) {
317 if (GetRank(*shape) > 0) {
318 CHECK(GetRank(*shape) == 1); // vector-valued subscript
319 return std::move(shape->at(0));
320 }
321 }
322 return std::nullopt;
323 },
324 },
325 subscript.u);
326 }
327
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)328 MaybeExtentExpr ComputeUpperBound(
329 FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
330 if (extent) {
331 return Fold(context, std::move(*extent) - std::move(lower) + ExtentExpr{1});
332 } else {
333 return std::nullopt;
334 }
335 }
336
GetUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)337 MaybeExtentExpr GetUpperBound(
338 FoldingContext &context, const NamedEntity &base, int dimension) {
339 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
340 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
341 int j{0};
342 for (const auto &shapeSpec : details->shape()) {
343 if (j++ == dimension) {
344 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
345 return Fold(context, common::Clone(*bound));
346 } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
347 break;
348 } else {
349 return ComputeUpperBound(context,
350 GetLowerBound(context, base, dimension),
351 GetExtent(context, base, dimension));
352 }
353 }
354 }
355 } else if (const auto *assoc{
356 symbol.detailsIf<semantics::AssocEntityDetails>()}) {
357 if (auto shape{GetShape(context, assoc->expr())}) {
358 if (dimension < static_cast<int>(shape->size())) {
359 return ComputeUpperBound(context,
360 GetLowerBound(context, base, dimension),
361 std::move(shape->at(dimension)));
362 }
363 }
364 }
365 return std::nullopt;
366 }
367
GetUpperBounds(FoldingContext & context,const NamedEntity & base)368 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
369 const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
370 if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
371 Shape result;
372 int dim{0};
373 for (const auto &shapeSpec : details->shape()) {
374 if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
375 result.emplace_back(Fold(context, common::Clone(*bound)));
376 } else if (details->IsAssumedSize()) {
377 CHECK(dim + 1 == base.Rank());
378 result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
379 } else {
380 result.emplace_back(ComputeUpperBound(context,
381 GetLowerBound(context, base, dim), GetExtent(context, base, dim)));
382 }
383 ++dim;
384 }
385 CHECK(GetRank(result) == symbol.Rank());
386 return result;
387 } else {
388 return std::move(GetShape(context, base).value());
389 }
390 }
391
operator ()(const Symbol & symbol) const392 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
393 return std::visit(
394 common::visitors{
395 [&](const semantics::ObjectEntityDetails &object) {
396 if (IsImpliedShape(symbol)) {
397 return (*this)(object.init());
398 } else {
399 int n{object.shape().Rank()};
400 NamedEntity base{symbol};
401 return Result{CreateShape(n, base)};
402 }
403 },
404 [](const semantics::EntityDetails &) {
405 return Scalar(); // no dimensions seen
406 },
407 [&](const semantics::ProcEntityDetails &proc) {
408 if (const Symbol * interface{proc.interface().symbol()}) {
409 return (*this)(*interface);
410 } else {
411 return Scalar();
412 }
413 },
414 [&](const semantics::AssocEntityDetails &assoc) {
415 if (!assoc.rank()) {
416 return (*this)(assoc.expr());
417 } else {
418 int n{assoc.rank().value()};
419 NamedEntity base{symbol};
420 return Result{CreateShape(n, base)};
421 }
422 },
423 [&](const semantics::SubprogramDetails &subp) {
424 if (subp.isFunction()) {
425 return (*this)(subp.result());
426 } else {
427 return Result{};
428 }
429 },
430 [&](const semantics::ProcBindingDetails &binding) {
431 return (*this)(binding.symbol());
432 },
433 [&](const semantics::UseDetails &use) {
434 return (*this)(use.symbol());
435 },
436 [&](const semantics::HostAssocDetails &assoc) {
437 return (*this)(assoc.symbol());
438 },
439 [](const semantics::TypeParamDetails &) { return Scalar(); },
440 [](const auto &) { return Result{}; },
441 },
442 symbol.details());
443 }
444
operator ()(const Component & component) const445 auto GetShapeHelper::operator()(const Component &component) const -> Result {
446 const Symbol &symbol{component.GetLastSymbol()};
447 int rank{symbol.Rank()};
448 if (rank == 0) {
449 return (*this)(component.base());
450 } else if (symbol.has<semantics::ObjectEntityDetails>()) {
451 NamedEntity base{Component{component}};
452 return CreateShape(rank, base);
453 } else if (symbol.has<semantics::AssocEntityDetails>()) {
454 NamedEntity base{Component{component}};
455 return Result{CreateShape(rank, base)};
456 } else {
457 return (*this)(symbol);
458 }
459 }
460
operator ()(const ArrayRef & arrayRef) const461 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
462 Shape shape;
463 int dimension{0};
464 const NamedEntity &base{arrayRef.base()};
465 for (const Subscript &ss : arrayRef.subscript()) {
466 if (ss.Rank() > 0) {
467 shape.emplace_back(GetExtent(context_, ss, base, dimension));
468 }
469 ++dimension;
470 }
471 if (shape.empty()) {
472 if (const Component * component{base.UnwrapComponent()}) {
473 return (*this)(component->base());
474 }
475 }
476 return shape;
477 }
478
operator ()(const CoarrayRef & coarrayRef) const479 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
480 NamedEntity base{coarrayRef.GetBase()};
481 if (coarrayRef.subscript().empty()) {
482 return (*this)(base);
483 } else {
484 Shape shape;
485 int dimension{0};
486 for (const Subscript &ss : coarrayRef.subscript()) {
487 if (ss.Rank() > 0) {
488 shape.emplace_back(GetExtent(context_, ss, base, dimension));
489 }
490 ++dimension;
491 }
492 return shape;
493 }
494 }
495
operator ()(const Substring & substring) const496 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
497 return (*this)(substring.parent());
498 }
499
operator ()(const ProcedureRef & call) const500 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
501 if (call.Rank() == 0) {
502 return Scalar();
503 } else if (call.IsElemental()) {
504 for (const auto &arg : call.arguments()) {
505 if (arg && arg->Rank() > 0) {
506 return (*this)(*arg);
507 }
508 }
509 return Scalar();
510 } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
511 return (*this)(*symbol);
512 } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
513 if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
514 intrinsic->name == "ubound") {
515 // These are the array-valued cases for LBOUND and UBOUND (no DIM=).
516 const auto *expr{call.arguments().front().value().UnwrapExpr()};
517 CHECK(expr);
518 return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
519 } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
520 intrinsic->name == "count" || intrinsic->name == "iall" ||
521 intrinsic->name == "iany" || intrinsic->name == "iparity" ||
522 intrinsic->name == "maxloc" || intrinsic->name == "maxval" ||
523 intrinsic->name == "minloc" || intrinsic->name == "minval" ||
524 intrinsic->name == "norm2" || intrinsic->name == "parity" ||
525 intrinsic->name == "product" || intrinsic->name == "sum") {
526 // Reduction with DIM=
527 if (call.arguments().size() >= 2) {
528 auto arrayShape{
529 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
530 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
531 if (arrayShape && dimArg) {
532 if (auto dim{ToInt64(*dimArg)}) {
533 if (*dim >= 1 &&
534 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
535 arrayShape->erase(arrayShape->begin() + (*dim - 1));
536 return std::move(*arrayShape);
537 }
538 }
539 }
540 }
541 } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
542 if (!call.arguments().empty()) {
543 return (*this)(call.arguments()[0]);
544 }
545 } else if (intrinsic->name == "matmul") {
546 if (call.arguments().size() == 2) {
547 if (auto ashape{(*this)(call.arguments()[0])}) {
548 if (auto bshape{(*this)(call.arguments()[1])}) {
549 if (ashape->size() == 1 && bshape->size() == 2) {
550 bshape->erase(bshape->begin());
551 return std::move(*bshape); // matmul(vector, matrix)
552 } else if (ashape->size() == 2 && bshape->size() == 1) {
553 ashape->pop_back();
554 return std::move(*ashape); // matmul(matrix, vector)
555 } else if (ashape->size() == 2 && bshape->size() == 2) {
556 (*ashape)[1] = std::move((*bshape)[1]);
557 return std::move(*ashape); // matmul(matrix, matrix)
558 }
559 }
560 }
561 }
562 } else if (intrinsic->name == "reshape") {
563 if (call.arguments().size() >= 2 && call.arguments().at(1)) {
564 // SHAPE(RESHAPE(array,shape)) -> shape
565 if (const auto *shapeExpr{
566 call.arguments().at(1).value().UnwrapExpr()}) {
567 auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
568 return AsShape(context_, ConvertToType<ExtentType>(std::move(shape)));
569 }
570 }
571 } else if (intrinsic->name == "pack") {
572 if (call.arguments().size() >= 3 && call.arguments().at(2)) {
573 // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
574 return (*this)(call.arguments().at(2));
575 } else if (call.arguments().size() >= 2) {
576 if (auto maskShape{(*this)(call.arguments().at(1))}) {
577 if (maskShape->size() == 0) {
578 // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
579 if (auto arrayShape{(*this)(call.arguments().at(0))}) {
580 auto arraySize{GetSize(std::move(*arrayShape))};
581 CHECK(arraySize);
582 ActualArguments toMerge{
583 ActualArgument{AsGenericExpr(std::move(*arraySize))},
584 ActualArgument{AsGenericExpr(ExtentExpr{0})},
585 common::Clone(call.arguments().at(1))};
586 auto specific{context_.intrinsics().Probe(
587 CallCharacteristics{"merge"}, toMerge, context_)};
588 CHECK(specific);
589 return Shape{ExtentExpr{FunctionRef<ExtentType>{
590 ProcedureDesignator{std::move(specific->specificIntrinsic)},
591 std::move(specific->arguments)}}};
592 }
593 } else {
594 // Non-scalar MASK= -> [COUNT(mask)]
595 ActualArguments toCount{ActualArgument{common::Clone(
596 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
597 auto specific{context_.intrinsics().Probe(
598 CallCharacteristics{"count"}, toCount, context_)};
599 CHECK(specific);
600 return Shape{ExtentExpr{FunctionRef<ExtentType>{
601 ProcedureDesignator{std::move(specific->specificIntrinsic)},
602 std::move(specific->arguments)}}};
603 }
604 }
605 }
606 } else if (intrinsic->name == "spread") {
607 // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
608 // at position DIM.
609 if (call.arguments().size() == 3) {
610 auto arrayShape{
611 (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
612 const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
613 const auto *nCopies{
614 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
615 if (arrayShape && dimArg && nCopies) {
616 if (auto dim{ToInt64(*dimArg)}) {
617 if (*dim >= 1 &&
618 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
619 arrayShape->emplace(arrayShape->begin() + *dim - 1,
620 ConvertToType<ExtentType>(common::Clone(*nCopies)));
621 return std::move(*arrayShape);
622 }
623 }
624 }
625 }
626 } else if (intrinsic->name == "transfer") {
627 if (call.arguments().size() == 3 && call.arguments().at(2)) {
628 // SIZE= is present; shape is vector [SIZE=]
629 if (const auto *size{
630 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
631 return Shape{
632 MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
633 }
634 } else if (auto moldTypeAndShape{
635 characteristics::TypeAndShape::Characterize(
636 call.arguments().at(1), context_)}) {
637 if (GetRank(moldTypeAndShape->shape()) == 0) {
638 // SIZE= is absent and MOLD= is scalar: result is scalar
639 return Scalar();
640 } else {
641 // SIZE= is absent and MOLD= is array: result is vector whose
642 // length is determined by sizes of types. See 16.9.193p4 case(ii).
643 if (auto sourceTypeAndShape{
644 characteristics::TypeAndShape::Characterize(
645 call.arguments().at(0), context_)}) {
646 auto sourceElements{
647 GetSize(common::Clone(sourceTypeAndShape->shape()))};
648 auto sourceElementBytes{
649 sourceTypeAndShape->MeasureSizeInBytes(&context_)};
650 auto moldElementBytes{
651 moldTypeAndShape->MeasureSizeInBytes(&context_)};
652 if (sourceElements && sourceElementBytes && moldElementBytes) {
653 ExtentExpr extent{Fold(context_,
654 ((std::move(*sourceElements) *
655 std::move(*sourceElementBytes)) +
656 common::Clone(*moldElementBytes) - ExtentExpr{1}) /
657 common::Clone(*moldElementBytes))};
658 return Shape{MaybeExtentExpr{std::move(extent)}};
659 }
660 }
661 }
662 }
663 } else if (intrinsic->name == "transpose") {
664 if (call.arguments().size() >= 1) {
665 if (auto shape{(*this)(call.arguments().at(0))}) {
666 if (shape->size() == 2) {
667 std::swap((*shape)[0], (*shape)[1]);
668 return shape;
669 }
670 }
671 }
672 } else if (intrinsic->characteristics.value().attrs.test(characteristics::
673 Procedure::Attr::NullPointer)) { // NULL(MOLD=)
674 return (*this)(call.arguments());
675 } else {
676 // TODO: shapes of other non-elemental intrinsic results
677 }
678 }
679 return std::nullopt;
680 }
681
682 // Check conformance of the passed shapes. Only return true if we can verify
683 // that they conform
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,const char * leftIs,const char * rightIs,bool leftScalarExpandable,bool rightScalarExpandable)684 bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
685 const Shape &right, const char *leftIs, const char *rightIs,
686 bool leftScalarExpandable, bool rightScalarExpandable) {
687 int n{GetRank(left)};
688 if (n == 0 && leftScalarExpandable) {
689 return true;
690 }
691 int rn{GetRank(right)};
692 if (rn == 0 && rightScalarExpandable) {
693 return true;
694 }
695 if (n != rn) {
696 messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
697 leftIs, n, rightIs, rn);
698 return false;
699 }
700 for (int j{0}; j < n; ++j) {
701 auto leftDim{ToInt64(left[j])};
702 auto rightDim{ToInt64(right[j])};
703 if (!leftDim || !rightDim) {
704 return false;
705 }
706 if (*leftDim != *rightDim) {
707 messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
708 "but %4$s has extent %5$jd"_err_en_US,
709 j + 1, leftIs, *leftDim, rightIs, *rightDim);
710 return false;
711 }
712 }
713 return true;
714 }
715
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)716 bool IncrementSubscripts(
717 ConstantSubscripts &indices, const ConstantSubscripts &extents) {
718 std::size_t rank(indices.size());
719 CHECK(rank <= extents.size());
720 for (std::size_t j{0}; j < rank; ++j) {
721 if (extents[j] < 1) {
722 return false;
723 }
724 }
725 for (std::size_t j{0}; j < rank; ++j) {
726 if (indices[j]++ < extents[j]) {
727 return true;
728 }
729 indices[j] = 1;
730 }
731 return false;
732 }
733 } // namespace Fortran::evaluate
734