1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/stack.h"
18
19 #include <memory>
20
21 #include "abstract/abstract_value.h"
22 #include "abstract/dshape.h"
23 #include "abstract/ops/primitive_infer_map.h"
24 #include "abstract/utils.h"
25 #include "base/base.h"
26 #include "ir/anf.h"
27 #include "ir/dtype/type.h"
28 #include "ir/primitive.h"
29 #include "mindapi/base/shape_vector.h"
30 #include "mindapi/base/shared_ptr.h"
31 #include "mindapi/ir/value.h"
32 #include "mindapi/src/helper.h"
33 #include "mindspore/core/ops/array_ops.h"
34 #include "ops/op_name.h"
35 #include "ops/primitive_c.h"
36 #include "ops/stack_comm.h"
37 #include "utils/check_convert_utils.h"
38 #include "utils/convert_utils_base.h"
39 #include "utils/log_adapter.h"
40 #include "utils/shape_utils.h"
41
42 namespace mindspore {
43 namespace ops {
44 namespace {
45 constexpr int64_t kUnknownDim = -1;
46 constexpr int64_t kUnknownRank = -2;
47 } // namespace
set_axis(const int64_t axis)48 void Stack::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); }
49
get_axis() const50 int64_t Stack::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
51
Init(const int64_t axis)52 void Stack::Init(const int64_t axis) { this->set_axis(axis); }
53 namespace {
StackInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)54 abstract::ShapePtr StackInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
55 MS_EXCEPTION_IF_NULL(primitive);
56 if (input_args.size() < 1) {
57 MS_LOG(ERROR) << "Invalid input size " << input_args.size();
58 }
59 AbstractBasePtrList elements = input_args;
60 if (input_args.size() == 1 && input_args[0]->isa<abstract::AbstractSequence>()) {
61 elements = input_args[0]->cast<abstract::AbstractSequencePtr>()->elements();
62 }
63 (void)CheckAndConvertUtils::CheckInteger("stack element num", SizeToLong(elements.size()), kGreaterEqual, 1,
64 primitive->name());
65
66 bool has_rank_valid_shape = false;
67 ShapeVector input_shape;
68 size_t element_rank = 0;
69 for (size_t i = 0; i < elements.size(); ++i) {
70 MS_EXCEPTION_IF_NULL(elements[i]);
71 auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->GetShape())[kShape];
72 if (IsDynamicRank(input_shape_tmp)) {
73 continue;
74 }
75
76 if (!has_rank_valid_shape) {
77 has_rank_valid_shape = true;
78 input_shape = input_shape_tmp;
79 element_rank = input_shape_tmp.size();
80 continue;
81 }
82 if (input_shape_tmp.size() != input_shape.size()) {
83 MS_EXCEPTION(ValueError) << "All input shape size must be the same!";
84 }
85 for (size_t j = 0; j < input_shape.size(); ++j) {
86 if (input_shape.at(j) == kUnknownDim && input_shape_tmp.at(j) != kUnknownDim) {
87 input_shape[j] = input_shape_tmp.at(j);
88 continue;
89 }
90 if (input_shape_tmp.at(j) != input_shape.at(j)) {
91 MS_EXCEPTION(ValueError) << "All input shape must be the same! " << input_shape_tmp << " And " << input_shape;
92 }
93 }
94 }
95
96 if (!has_rank_valid_shape) {
97 return std::make_shared<abstract::Shape>(ShapeVector{kUnknownRank});
98 }
99 std::vector<int64_t> infer_shape = input_shape;
100 auto axis_temp = GetValue<int64_t>(primitive->GetAttr(kAxis));
101 CheckAndConvertUtils::CheckInRange<int64_t>("Stack axis", axis_temp, kIncludeBoth,
102 {-SizeToLong(element_rank) - 1, SizeToLong(element_rank)},
103 primitive->name());
104 auto axis = axis_temp < 0 ? static_cast<size_t>(axis_temp) + element_rank + 1 : LongToSize(axis_temp);
105 (void)infer_shape.insert(infer_shape.begin() + axis, elements.size());
106 return std::make_shared<abstract::Shape>(infer_shape);
107 }
108
StackInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)109 TypePtr StackInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
110 const auto &prim_name = primitive->name();
111 AbstractBasePtrList elements = input_args;
112 if (input_args.size() == 1) {
113 if (!input_args[0]->isa<abstract::AbstractSequence>()) {
114 MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the input data type must be list or tuple of tensors.";
115 }
116 elements = input_args[0]->cast<abstract::AbstractSequencePtr>()->elements();
117 }
118 (void)CheckAndConvertUtils::CheckInteger("stack element num", SizeToLong(elements.size()), kGreaterEqual, 1,
119 primitive->name());
120 primitive->AddAttr("num", MakeValue(SizeToLong(elements.size())));
121 auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
122 if (element0 == nullptr) {
123 MS_EXCEPTION(TypeError) << "Infer type failed.";
124 }
125 auto infer_type0 = element0->GetType();
126 for (size_t i = 1; i < elements.size(); i++) {
127 auto elementi = elements[i]->cast<abstract::AbstractTensorPtr>();
128 MS_EXCEPTION_IF_NULL(elementi);
129 auto infer_typei = elementi->GetType();
130 MS_EXCEPTION_IF_NULL(infer_typei);
131 if (infer_typei->ToString() != infer_type0->ToString()) {
132 MS_EXCEPTION(TypeError) << "All input must have the same data type!input[" << i << "] data type = " << infer_typei
133 << "infer_type0= " << infer_type0;
134 }
135 }
136 return infer_type0;
137 }
138 } // namespace
139
140 MIND_API_OPERATOR_IMPL(Stack, BaseOperator);
StackInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)141 AbstractBasePtr StackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
142 const std::vector<AbstractBasePtr> &input_args) {
143 auto infer_shape = StackInferShape(primitive, input_args);
144 auto infer_type = StackInferType(primitive, input_args);
145 return abstract::MakeAbstract(infer_shape, infer_type);
146 }
147
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const148 BaseShapePtr AGStackInfer::InferShape(const PrimitivePtr &primitive,
149 const std::vector<AbstractBasePtr> &input_args) const {
150 return StackInferShape(primitive, input_args);
151 }
152
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const153 TypePtr AGStackInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
154 return StackInferType(primitive, input_args);
155 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const156 AbstractBasePtr AGStackInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
157 const PrimitivePtr &primitive,
158 const std::vector<AbstractBasePtr> &input_args) const {
159 return StackInfer(engine, primitive, input_args);
160 }
161
162 REGISTER_PRIMITIVE_OP_INFER_IMPL(Stack, prim::kPrimStack, AGStackInfer, false);
163 } // namespace ops
164 } // namespace mindspore
165