• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/ops/infer_functions.h"
18 #include "abstract/utils.h"
19 #include "abstract/param_validator.h"
20 #include "utils/check_convert_utils.h"
21 
22 namespace mindspore {
23 namespace abstract {
InferImplReduceFuncCheckAxis(const int64_t & axis,const size_t dim)24 int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) {
25   int64_t dim_ = static_cast<int64_t>(dim);
26   if (axis < -dim_ || axis >= dim_) {
27     MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
28   }
29   int64_t ret_axis = axis;
30   if (axis >= -dim_ && axis < 0) {
31     ret_axis += dim_;
32   }
33   return ret_axis;
34 }
35 
InferImplReduceFuncCalShape(ShapeVector * shape,const ShapeVector & x_shape,const ValuePtr & axis,bool keep_dims_value)36 void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis,
37                                  bool keep_dims_value) {
38   MS_EXCEPTION_IF_NULL(axis);
39   if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
40     auto axis_ptr_list =
41       axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
42     if (axis_ptr_list.empty()) {
43       if (keep_dims_value) {
44         (void)shape->insert(shape->end(), x_shape.size(), 1);
45       }
46     } else {
47       if (keep_dims_value) {
48         *shape = x_shape;
49         for (auto it = axis_ptr_list.begin(); it != axis_ptr_list.end(); ++it) {
50           int64_t axis_value = GetValue<int64_t>(*it);
51           axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
52           shape->at(LongToSize(axis_value)) = 1;
53         }
54       } else {
55         std::set<size_t> axis_items;
56         for (auto &axis_ptr : axis_ptr_list) {
57           auto positive_axis = InferImplReduceFuncCheckAxis(GetValue<int64_t>(axis_ptr), x_shape.size());
58           (void)axis_items.insert(LongToSize(positive_axis));
59         }
60         for (size_t i = 0; i < x_shape.size(); ++i) {
61           if (axis_items.count(i) == 0) {
62             (void)shape->emplace_back(x_shape[i]);
63           }
64         }
65       }
66     }
67   } else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
68     (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
69     auto axis_value = GetValue<int64_t>(axis);
70     axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
71     if (keep_dims_value) {
72       shape->at(LongToSize(axis_value)) = 1;
73     } else {
74       (void)shape->erase(shape->begin() + axis_value);
75     }
76   } else {
77     MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
78   }
79   return;
80 }
81 
InferImplBinaryBase(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)82 AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
83                                     const AbstractBasePtrList &args_abs_list) {
84   constexpr auto kBinaryBaseInputNum = 2;
85   const std::string op_name = primitive->name();
86   CheckArgsSize(op_name, args_abs_list, kBinaryBaseInputNum);
87   auto input_x = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
88   MS_EXCEPTION_IF_NULL(input_x);
89   MS_EXCEPTION_IF_NULL(input_x->shape());
90 
91   auto input_y = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
92   MS_EXCEPTION_IF_NULL(input_y);
93   MS_EXCEPTION_IF_NULL(input_y->shape());
94 
95   auto x_shape = input_x->shape()->shape();
96   auto y_shape = input_y->shape()->shape();
97   auto output_shape = BroadcastShape(x_shape, y_shape);
98 
99   auto x_type = input_x->BuildType();
100   MS_EXCEPTION_IF_NULL(x_type);
101   MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>());
102   auto y_type = input_y->BuildType();
103   MS_EXCEPTION_IF_NULL(y_type);
104   MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>());
105 
106   auto x_element = x_type->cast<TensorTypePtr>()->element();
107   MS_EXCEPTION_IF_NULL(x_element);
108   auto y_element = y_type->cast<TensorTypePtr>()->element();
109   MS_EXCEPTION_IF_NULL(y_element);
110 
111   auto x_element_type = x_element->number_type();
112   auto y_element_type = y_element->number_type();
113 
114   auto x_priority = type_priority_map().find(x_element_type);
115   if (x_priority == type_priority_map().cend()) {
116     MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type.";
117   }
118   auto y_priority = type_priority_map().find(y_element_type);
119   if (y_priority == type_priority_map().cend()) {
120     MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type.";
121   }
122 
123   if (x_priority->second >= y_priority->second) {
124     return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape));
125   } else {
126     return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
127   }
128 }
129 
InferImplMinimum(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)130 AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
131                                  const AbstractBasePtrList &args_abs_list) {
132   return InferImplBinaryBase(engine_ptr, primitive, args_abs_list);
133 }
134 
InferImplDivNoNan(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)135 AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
136                                   const AbstractBasePtrList &args_abs_list) {
137   return InferImplBinaryBase(engine_ptr, primitive, args_abs_list);
138 }
139 
InferImplLinSpace(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)140 AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
141                                   const AbstractBasePtrList &args_abs_list) {
142   constexpr auto kLinSpaceInputNum = 3;
143   const std::string op_name = primitive->name();
144   CheckArgsSize(op_name, args_abs_list, kLinSpaceInputNum);
145   auto start = CheckArg<AbstractTensor>(op_name, args_abs_list, 0);
146   MS_EXCEPTION_IF_NULL(start);
147   MS_EXCEPTION_IF_NULL(start->shape());
148   auto stop = CheckArg<AbstractTensor>(op_name, args_abs_list, 1);
149   MS_EXCEPTION_IF_NULL(stop);
150   MS_EXCEPTION_IF_NULL(stop->shape());
151   (void)CheckTensorDType(start, {kFloat32}, "Input 0 (start) for LinSpace should be %s");
152   (void)CheckTensorDType(stop, {kFloat32}, "Input 1 (stop) for LinSpace should be %s");
153   ShapeVector shape;
154   int64_t num_val = 0;
155   // 3rd input is a Tensor when LinSpace is a dynamic shape operator
156   const size_t tensor_index = 2;
157   auto abs_num = args_abs_list[tensor_index];
158   if (abs_num->isa<AbstractTensor>()) {
159     auto num = abs_num->cast<AbstractTensorPtr>();
160     MS_EXCEPTION_IF_NULL(num);
161     auto num_value_ptr = num->BuildValue();
162     MS_EXCEPTION_IF_NULL(num_value_ptr);
163     auto num_tensor = num_value_ptr->cast<tensor::TensorPtr>();
164     MS_EXCEPTION_IF_NULL(num_tensor);
165     num_val = *static_cast<int64_t *>(num_tensor->data_c());
166   } else if (abs_num->isa<AbstractScalar>()) {
167     auto num = abs_num->cast<AbstractScalarPtr>();
168     num_val = GetValue<int64_t>(num->BuildValue());
169   } else {
170     MS_LOG(EXCEPTION) << "Invalid abstract type:" << abs_num->type_name();
171   }
172   shape.emplace_back(num_val);
173   if (shape[0] < 0) {
174     MS_LOG(EXCEPTION) << "num must be >= 0 in LinSpace";
175   }
176   AbstractTensorPtr ret = std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape));
177   return ret;
178 }
179 
InferImplRealInner(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_abs_list)180 AbstractBasePtr InferImplRealInner(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
181                                    const AbstractBasePtrList &args_abs_list) {
182   // Inputs: one tensors.
183   constexpr auto kRealInputNum = 1;
184   const std::string op_name = primitive->name();
185   CheckArgsSize(op_name, args_abs_list, kRealInputNum);
186   AbstractBasePtr input_abs = args_abs_list[0];
187   auto input = dyn_cast<AbstractTensor>(input_abs);
188   if (input == nullptr) {
189     return input_abs->Clone();
190   }
191   TypePtr input_type = input->element()->GetTypeTrack();
192   TypePtr output_type = nullptr;
193   if (input_type->type_id() == TypeId::kNumberTypeComplex64) {
194     output_type = kFloat32;
195   } else if (input_type->type_id() == TypeId::kNumberTypeComplex128) {
196     output_type = kFloat64;
197   } else {
198     return input_abs->Clone();
199   }
200 
201   return std::make_shared<AbstractTensor>(output_type, input->shape());
202 }
203 }  // namespace abstract
204 }  // namespace mindspore
205