1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/ops/infer_functions.h"
18 #include "abstract/utils.h"
19 #include "abstract/param_validator.h"
20 #include "utils/check_convert_utils.h"
21
22 namespace mindspore {
23 namespace abstract {
InferImplReduceFuncCheckAxis(const int64_t & axis,const size_t dim)24 int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) {
25 int64_t dim_ = static_cast<int64_t>(dim);
26 if (axis < -dim_ || axis >= dim_) {
27 MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
28 }
29 int64_t ret_axis = axis;
30 if (axis >= -dim_ && axis < 0) {
31 ret_axis += dim_;
32 }
33 return ret_axis;
34 }
35
InferImplReduceFuncCalShape(ShapeVector * shape,const ShapeVector & x_shape,const ValuePtr & axis,bool keep_dims_value)36 void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis,
37 bool keep_dims_value) {
38 MS_EXCEPTION_IF_NULL(axis);
39 if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
40 auto axis_ptr_list =
41 axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
42 if (axis_ptr_list.empty()) {
43 if (keep_dims_value) {
44 (void)shape->insert(shape->end(), x_shape.size(), 1);
45 }
46 } else {
47 if (keep_dims_value) {
48 *shape = x_shape;
49 for (auto it = axis_ptr_list.begin(); it != axis_ptr_list.end(); ++it) {
50 int64_t axis_value = GetValue<int64_t>(*it);
51 axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
52 shape->at(LongToSize(axis_value)) = 1;
53 }
54 } else {
55 std::set<size_t> axis_items;
56 for (auto &axis_ptr : axis_ptr_list) {
57 auto positive_axis = InferImplReduceFuncCheckAxis(GetValue<int64_t>(axis_ptr), x_shape.size());
58 (void)axis_items.insert(LongToSize(positive_axis));
59 }
60 for (size_t i = 0; i < x_shape.size(); ++i) {
61 if (axis_items.count(i) == 0) {
62 (void)shape->emplace_back(x_shape[i]);
63 }
64 }
65 }
66 }
67 } else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
68 (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
69 auto axis_value = GetValue<int64_t>(axis);
70 axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
71 if (keep_dims_value) {
72 shape->at(LongToSize(axis_value)) = 1;
73 } else {
74 (void)shape->erase(shape->begin() + axis_value);
75 }
76 } else {
77 MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
78 }
79 return;
80 }
81
InferImplBinaryBase(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)82 AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
83 const AbstractBasePtrList &args_abs_list) {
84 constexpr auto kBinaryBaseInputNum = 2;
85 const std::string op_name = primitive->name();
86 CheckArgsSize(op_name, args_abs_list, kBinaryBaseInputNum);
87 auto input_x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
88 MS_EXCEPTION_IF_NULL(input_x);
89 MS_EXCEPTION_IF_NULL(input_x->shape());
90
91 auto input_y = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
92 MS_EXCEPTION_IF_NULL(input_y);
93 MS_EXCEPTION_IF_NULL(input_y->shape());
94
95 auto x_shape = input_x->shape()->shape();
96 auto y_shape = input_y->shape()->shape();
97 auto output_shape = BroadcastShape(x_shape, y_shape);
98
99 auto x_type = input_x->BuildType();
100 MS_EXCEPTION_IF_NULL(x_type);
101 MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>());
102 auto y_type = input_y->BuildType();
103 MS_EXCEPTION_IF_NULL(y_type);
104 MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>());
105
106 auto x_element = x_type->cast<TensorTypePtr>()->element();
107 MS_EXCEPTION_IF_NULL(x_element);
108 auto y_element = y_type->cast<TensorTypePtr>()->element();
109 MS_EXCEPTION_IF_NULL(y_element);
110
111 auto x_element_type = x_element->number_type();
112 auto y_element_type = y_element->number_type();
113
114 auto x_priority = type_priority_map().find(x_element_type);
115 if (x_priority == type_priority_map().cend()) {
116 MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type.";
117 }
118 auto y_priority = type_priority_map().find(y_element_type);
119 if (y_priority == type_priority_map().cend()) {
120 MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type.";
121 }
122
123 if (x_priority->second >= y_priority->second) {
124 return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape));
125 } else {
126 return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
127 }
128 }
129
InferImplMinimum(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)130 AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
131 const AbstractBasePtrList &args_abs_list) {
132 return InferImplBinaryBase(engine_ptr, primitive, args_abs_list);
133 }
134
InferImplDivNoNan(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)135 AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
136 const AbstractBasePtrList &args_abs_list) {
137 return InferImplBinaryBase(engine_ptr, primitive, args_abs_list);
138 }
139
InferImplLinSpace(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)140 AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
141 const AbstractBasePtrList &args_abs_list) {
142 constexpr auto kLinSpaceInputNum = 3;
143 const std::string op_name = primitive->name();
144 CheckArgsSize(op_name, args_abs_list, kLinSpaceInputNum);
145 auto start = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
146 MS_EXCEPTION_IF_NULL(start);
147 MS_EXCEPTION_IF_NULL(start->shape());
148 auto stop = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
149 MS_EXCEPTION_IF_NULL(stop);
150 MS_EXCEPTION_IF_NULL(stop->shape());
151 (void)CheckTensorDType(start, {kFloat32}, "Input 0 (start) for LinSpace should be %s");
152 (void)CheckTensorDType(stop, {kFloat32}, "Input 1 (stop) for LinSpace should be %s");
153 ShapeVector shape;
154 int64_t num_val = 0;
155 // 3rd input is a Tensor when LinSpace is a dynamic shape operator
156 const size_t tensor_index = 2;
157 auto abs_num = args_abs_list[tensor_index];
158 if (abs_num->isa<AbstractTensor>()) {
159 auto num = abs_num->cast<AbstractTensorPtr>();
160 MS_EXCEPTION_IF_NULL(num);
161 auto num_value_ptr = num->BuildValue();
162 MS_EXCEPTION_IF_NULL(num_value_ptr);
163 auto num_tensor = num_value_ptr->cast<tensor::TensorPtr>();
164 MS_EXCEPTION_IF_NULL(num_tensor);
165 num_val = *static_cast<int64_t *>(num_tensor->data_c());
166 } else if (abs_num->isa<AbstractScalar>()) {
167 auto num = abs_num->cast<AbstractScalarPtr>();
168 num_val = GetValue<int64_t>(num->BuildValue());
169 } else {
170 MS_LOG(EXCEPTION) << "Invalid abstract type:" << abs_num->type_name();
171 }
172 shape.emplace_back(num_val);
173 if (shape[0] < 0) {
174 MS_LOG(EXCEPTION) << "num must be >= 0 in LinSpace";
175 }
176 AbstractTensorPtr ret = std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape));
177 return ret;
178 }
179
InferImplRealInner(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)180 AbstractBasePtr InferImplRealInner(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
181 const AbstractBasePtrList &args_abs_list) {
182 // Inputs: one tensors.
183 constexpr auto kRealInputNum = 1;
184 const std::string op_name = primitive->name();
185 CheckArgsSize(op_name, args_abs_list, kRealInputNum);
186 AbstractBasePtr input_abs = args_abs_list[0];
187 auto input = dyn_cast<AbstractTensor>(input_abs);
188 if (input == nullptr) {
189 return input_abs->Clone();
190 }
191 TypePtr input_type = input->element()->GetTypeTrack();
192 TypePtr output_type = nullptr;
193 if (input_type->type_id() == TypeId::kNumberTypeComplex64) {
194 output_type = kFloat32;
195 } else if (input_type->type_id() == TypeId::kNumberTypeComplex128) {
196 output_type = kFloat64;
197 } else {
198 return input_abs->Clone();
199 }
200
201 return std::make_shared<AbstractTensor>(output_type, input->shape());
202 }
203 } // namespace abstract
204 } // namespace mindspore
205