• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/stack.h"
18 
19 #include <memory>
20 
21 #include "abstract/abstract_value.h"
22 #include "abstract/dshape.h"
23 #include "abstract/ops/primitive_infer_map.h"
24 #include "abstract/utils.h"
25 #include "base/base.h"
26 #include "ir/anf.h"
27 #include "ir/dtype/type.h"
28 #include "ir/primitive.h"
29 #include "mindapi/base/shape_vector.h"
30 #include "mindapi/base/shared_ptr.h"
31 #include "mindapi/ir/value.h"
32 #include "mindapi/src/helper.h"
33 #include "mindspore/core/ops/array_ops.h"
34 #include "ops/op_name.h"
35 #include "ops/primitive_c.h"
36 #include "ops/stack_comm.h"
37 #include "utils/check_convert_utils.h"
38 #include "utils/convert_utils_base.h"
39 #include "utils/log_adapter.h"
40 #include "utils/shape_utils.h"
41 
42 namespace mindspore {
43 namespace ops {
44 namespace {
45 constexpr int64_t kUnknownDim = -1;
46 constexpr int64_t kUnknownRank = -2;
47 }  // namespace
set_axis(const int64_t axis)48 void Stack::set_axis(const int64_t axis) { (void)AddAttr(kAxis, api::MakeValue(axis)); }
49 
get_axis() const50 int64_t Stack::get_axis() const { return GetValue<int64_t>(GetAttr(kAxis)); }
51 
Init(const int64_t axis)52 void Stack::Init(const int64_t axis) { this->set_axis(axis); }
53 namespace {
StackInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)54 abstract::ShapePtr StackInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
55   MS_EXCEPTION_IF_NULL(primitive);
56   if (input_args.size() < 1) {
57     MS_LOG(ERROR) << "Invalid input size " << input_args.size();
58   }
59   AbstractBasePtrList elements = input_args;
60   if (input_args.size() == 1 && input_args[0]->isa<abstract::AbstractSequence>()) {
61     elements = input_args[0]->cast<abstract::AbstractSequencePtr>()->elements();
62   }
63   (void)CheckAndConvertUtils::CheckInteger("stack element num", SizeToLong(elements.size()), kGreaterEqual, 1,
64                                            primitive->name());
65 
66   bool has_rank_valid_shape = false;
67   ShapeVector input_shape;
68   size_t element_rank = 0;
69   for (size_t i = 0; i < elements.size(); ++i) {
70     MS_EXCEPTION_IF_NULL(elements[i]);
71     auto input_shape_tmp = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->GetShape())[kShape];
72     if (IsDynamicRank(input_shape_tmp)) {
73       continue;
74     }
75 
76     if (!has_rank_valid_shape) {
77       has_rank_valid_shape = true;
78       input_shape = input_shape_tmp;
79       element_rank = input_shape_tmp.size();
80       continue;
81     }
82     if (input_shape_tmp.size() != input_shape.size()) {
83       MS_EXCEPTION(ValueError) << "All input shape size must be the same!";
84     }
85     for (size_t j = 0; j < input_shape.size(); ++j) {
86       if (input_shape.at(j) == kUnknownDim && input_shape_tmp.at(j) != kUnknownDim) {
87         input_shape[j] = input_shape_tmp.at(j);
88         continue;
89       }
90       if (input_shape_tmp.at(j) != input_shape.at(j)) {
91         MS_EXCEPTION(ValueError) << "All input shape must be the same! " << input_shape_tmp << " And " << input_shape;
92       }
93     }
94   }
95 
96   if (!has_rank_valid_shape) {
97     return std::make_shared<abstract::Shape>(ShapeVector{kUnknownRank});
98   }
99   std::vector<int64_t> infer_shape = input_shape;
100   auto axis_temp = GetValue<int64_t>(primitive->GetAttr(kAxis));
101   CheckAndConvertUtils::CheckInRange<int64_t>("Stack axis", axis_temp, kIncludeBoth,
102                                               {-SizeToLong(element_rank) - 1, SizeToLong(element_rank)},
103                                               primitive->name());
104   auto axis = axis_temp < 0 ? static_cast<size_t>(axis_temp) + element_rank + 1 : LongToSize(axis_temp);
105   (void)infer_shape.insert(infer_shape.begin() + axis, elements.size());
106   return std::make_shared<abstract::Shape>(infer_shape);
107 }
108 
StackInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)109 TypePtr StackInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
110   const auto &prim_name = primitive->name();
111   AbstractBasePtrList elements = input_args;
112   if (input_args.size() == 1) {
113     if (!input_args[0]->isa<abstract::AbstractSequence>()) {
114       MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the input data type must be list or tuple of tensors.";
115     }
116     elements = input_args[0]->cast<abstract::AbstractSequencePtr>()->elements();
117   }
118   (void)CheckAndConvertUtils::CheckInteger("stack element num", SizeToLong(elements.size()), kGreaterEqual, 1,
119                                            primitive->name());
120   primitive->AddAttr("num", MakeValue(SizeToLong(elements.size())));
121   auto element0 = elements[0]->cast<abstract::AbstractTensorPtr>();
122   if (element0 == nullptr) {
123     MS_EXCEPTION(TypeError) << "Infer type failed.";
124   }
125   auto infer_type0 = element0->GetType();
126   for (size_t i = 1; i < elements.size(); i++) {
127     auto elementi = elements[i]->cast<abstract::AbstractTensorPtr>();
128     MS_EXCEPTION_IF_NULL(elementi);
129     auto infer_typei = elementi->GetType();
130     MS_EXCEPTION_IF_NULL(infer_typei);
131     if (infer_typei->ToString() != infer_type0->ToString()) {
132       MS_EXCEPTION(TypeError) << "All input must have the same data type!input[" << i << "] data type = " << infer_typei
133                               << "infer_type0= " << infer_type0;
134     }
135   }
136   return infer_type0;
137 }
138 }  // namespace
139 
140 MIND_API_OPERATOR_IMPL(Stack, BaseOperator);
StackInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)141 AbstractBasePtr StackInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
142                            const std::vector<AbstractBasePtr> &input_args) {
143   auto infer_shape = StackInferShape(primitive, input_args);
144   auto infer_type = StackInferType(primitive, input_args);
145   return abstract::MakeAbstract(infer_shape, infer_type);
146 }
147 
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const148 BaseShapePtr AGStackInfer::InferShape(const PrimitivePtr &primitive,
149                                       const std::vector<AbstractBasePtr> &input_args) const {
150   return StackInferShape(primitive, input_args);
151 }
152 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const153 TypePtr AGStackInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
154   return StackInferType(primitive, input_args);
155 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const156 AbstractBasePtr AGStackInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
157                                                 const PrimitivePtr &primitive,
158                                                 const std::vector<AbstractBasePtr> &input_args) const {
159   return StackInfer(engine, primitive, input_args);
160 }
161 
162 REGISTER_PRIMITIVE_OP_INFER_IMPL(Stack, prim::kPrimStack, AGStackInfer, false);
163 }  // namespace ops
164 }  // namespace mindspore
165