1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <cmath>
18 #include "abstract/infer_functions.h"
19 #include "abstract/utils.h"
20 #include "abstract/param_validator.h"
21 #include "utils/check_convert_utils.h"
22 #include "utils/shape_utils.h"
23
24 namespace mindspore {
25 namespace abstract {
26 const size_t stride_num_element = 2;
27 const size_t stride_start_idx = 2;
28 const size_t dilation_num_element = 2;
29 const size_t dilation_start_idx = 2;
30 const size_t padding_num_element = 4;
31 const size_t padding_start_idx = 0;
GetAndCheckFormat(const ValuePtr & value)32 int64_t GetAndCheckFormat(const ValuePtr &value) {
33 int64_t data_format;
34 bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
35 if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
36 MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW";
37 }
38 return data_format;
39 }
40
InferImplPooling(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)41 AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
42 const AbstractBasePtrList &args_spec_list) {
43 // Inputs: a tensor.
44 const std::string op_name = primitive->name();
45 CheckArgsSize(op_name, args_spec_list, 1);
46 AbstractTensorPtr input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
47 (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input 0 of Pooling should be %s");
48
49 ShapePtr input_shape = dyn_cast<Shape>(input_tensor->GetShapeTrack()); // NCHW
50 MS_EXCEPTION_IF_NULL(input_shape);
51 const size_t input_shape_size = 4;
52 if (input_shape->shape().size() != input_shape_size) {
53 MS_LOG(EXCEPTION) << "Pooling input should be a 4-D tensor.";
54 }
55 const size_t H_INDEX = 2;
56 const size_t W_INDEX = 3;
57 int64_t h_input = input_shape->shape()[H_INDEX];
58 int64_t w_input = input_shape->shape()[W_INDEX];
59
60 int64_t window = GetValue<int64_t>(primitive->GetAttr("window"));
61 int64_t stride = GetValue<int64_t>(primitive->GetAttr("stride"));
62 int64_t padding = GetValue<int64_t>(primitive->GetAttr("pad"));
63 int64_t nan_opt = GetValue<int64_t>(primitive->GetAttr("nan_opt"));
64 int64_t data_mode = GetValue<int64_t>(primitive->GetAttr("data_mode"));
65 int64_t ceil_mode = GetValue<int64_t>(primitive->GetAttr("ceil_mode"));
66
67 if (stride <= 0) {
68 MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0";
69 }
70 if (nan_opt != 0) {
71 MS_LOG(EXCEPTION) << "Invalid nan_opt value: " << nan_opt << ", should be 0";
72 }
73 if (data_mode != 1) {
74 MS_LOG(EXCEPTION) << "Invalid data_mode value: " << data_mode << ", should be 1";
75 }
76 if (ceil_mode != 0) {
77 MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0";
78 }
79
80 auto pad_mode_ptr = primitive->GetAttr("pad_mode");
81 if (pad_mode_ptr != nullptr) {
82 int64_t pad_mode;
83 CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
84 if (pad_mode == PadMode::VALID) {
85 padding = 0;
86 } else if (pad_mode == PadMode::SAME) {
87 padding = (window - 1) / 2;
88 }
89 }
90 std::set<std::string> available_mode{"max", "avg"};
91 auto mode_ptr = primitive->GetAttr("mode");
92 if ((mode_ptr != nullptr) && mode_ptr->isa<StringImm>()) {
93 auto mode = mode_ptr->cast<StringImmPtr>()->value();
94 if (available_mode.find(mode) == available_mode.end()) {
95 MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << ".";
96 }
97 }
98
99 int64_t h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1;
100 int64_t w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1;
101 ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out};
102 AbstractBasePtr ret = input_tensor->Broaden();
103 ret->set_shape(std::make_shared<Shape>(shape_out));
104 return ret;
105 }
106
InferImplPoolingGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)107 AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
108 const AbstractBasePtrList &args_spec_list) {
109 // Inputs: three tensors(y, dy, x).
110 constexpr auto kPoolingGradInputNum = 3;
111 const std::string op_name = primitive->name();
112 CheckArgsSize(op_name, args_spec_list, kPoolingGradInputNum);
113 auto out_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
114 auto d_out = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
115 auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
116 (void)CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat},
117 op_name + "evaluator three inputs should be %s");
118
119 AbstractBasePtr ret = d_out->Broaden();
120 auto x_shape = dyn_cast<Shape>(args_spec_list[2]->GetShapeTrack());
121 MS_EXCEPTION_IF_NULL(x_shape);
122
123 ret->set_shape(x_shape);
124 return ret;
125 }
126
InferImplBatchNorm(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)127 AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
128 const AbstractBasePtrList &args_spec_list) {
129 // Inputs: five tensors(x, gamma, beta, mean, variance).
130 constexpr auto kBatchNormInputNum = 5;
131 const std::string op_name = primitive->name();
132 CheckArgsSize(op_name, args_spec_list, kBatchNormInputNum);
133 AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
134 MS_EXCEPTION_IF_NULL(input_x);
135 MS_EXCEPTION_IF_NULL(input_x->shape());
136 ShapeVector x_shape = input_x->shape()->shape();
137 ShapeVector x_min_shape = input_x->shape()->min_shape();
138 ShapeVector x_max_shape = input_x->shape()->max_shape();
139 CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
140
141 auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
142 (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be");
143 AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
144 for (size_t i = 1; i < args_spec_list.size(); ++i) {
145 auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
146 tensorPtrList.push_back(param);
147 }
148 (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32},
149 "param gamma, beta, mean, variance of Batchnorm should be");
150
151 auto data_format_ptr = primitive->GetAttr("format");
152 MS_EXCEPTION_IF_NULL(data_format_ptr);
153 int64_t data_format = GetAndCheckFormat(data_format_ptr);
154
155 size_t c_axis = 1;
156 if (data_format == Format::NHWC) {
157 c_axis = 3;
158 }
159 for (size_t i = 1; i < args_spec_list.size(); ++i) {
160 AbstractTensorPtr arg_spec = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
161 MS_EXCEPTION_IF_NULL(arg_spec);
162 MS_EXCEPTION_IF_NULL(arg_spec->shape());
163 ShapeVector arg_shape = arg_spec->shape()->shape();
164 if (arg_shape.size() != 1) {
165 MS_LOG(EXCEPTION) << "Arg " << i << " rank should be 1, but got " << arg_shape.size();
166 }
167 if ((x_shape[c_axis] != Shape::SHP_ANY) && (arg_shape[0] != x_shape[c_axis])) {
168 MS_EXCEPTION(ValueError) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis
169 << "]=" << x_shape[c_axis] << ", but got " << arg_shape[0];
170 }
171 }
172 AbstractTensorPtr input_gamma = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
173 ShapeVector gamma_shape = input_gamma->shape()->shape();
174 ShapeVector gamma_min_shape = input_gamma->shape()->min_shape();
175 ShapeVector gamma_max_shape = input_gamma->shape()->max_shape();
176 CheckMinMaxShape(gamma_shape, &gamma_min_shape, &gamma_max_shape);
177 ShapePtr output_shape_ptr = std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape);
178 AbstractTensorPtr output = std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr);
179 ShapePtr gamma_shape_ptr = std::make_shared<Shape>(gamma_shape, gamma_min_shape, gamma_max_shape);
180 AbstractTensorPtr output_gamma = std::make_shared<AbstractTensor>(input_gamma->element(), gamma_shape_ptr);
181 AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma};
182 return std::make_shared<AbstractTuple>(rets);
183 }
184
InferImplFusedSparseAdam(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)185 AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &,
186 const AbstractBasePtrList &args_spec_list) {
187 // the output is useless, so we dont have to focus on the output shape
188 constexpr size_t dx_index = 1;
189 constexpr size_t dscale_index = 2;
190 constexpr size_t dbias_index = 3;
191 MS_EXCEPTION_IF_NULL(args_spec_list[dx_index]);
192 MS_EXCEPTION_IF_NULL(args_spec_list[dscale_index]);
193 MS_EXCEPTION_IF_NULL(args_spec_list[dbias_index]);
194
195 auto dx = args_spec_list[dx_index]->Broaden();
196 auto dscale = args_spec_list[dscale_index]->Broaden();
197 auto dbias = args_spec_list[dbias_index]->Broaden();
198
199 AbstractBasePtrList rets = {dx, dscale, dbias};
200 return std::make_shared<AbstractTuple>(rets);
201 }
202
Conv2DPadFunction(std::vector<int64_t> * output_hw,std::vector<int64_t> * pad_list,const int64_t x_h,const int64_t x_w,const std::vector<int64_t> & kernel,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,const int64_t & pad_mode,const std::vector<int64_t> & padding)203 void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
204 const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
205 const std::vector<int64_t> &dilation, const int64_t &pad_mode,
206 const std::vector<int64_t> &padding) {
207 if (pad_mode == PadMode::VALID) {
208 output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])));
209 output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])));
210 const size_t nhwc = 4;
211 (void)pad_list->insert(pad_list->begin(), nhwc, 0);
212 } else if (pad_mode == PadMode::SAME) {
213 output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
214 output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
215 int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
216 pad_needed_h = std::max((int64_t)0, pad_needed_h);
217 pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2)));
218 pad_list->push_back(pad_needed_h - pad_list->at(0));
219 int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
220 pad_needed_w = std::max((int64_t)0, pad_needed_w);
221 pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
222 pad_list->push_back(pad_needed_w - pad_list->at(2));
223 } else if (pad_mode == PadMode::PAD) {
224 (void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
225 output_hw->push_back(static_cast<int64_t>(std::floor(
226 1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) /
227 stride[0])));
228 output_hw->push_back(static_cast<int64_t>(std::floor(
229 1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) /
230 stride[1])));
231 }
232 }
233
CheckShape(const std::string & op_name,const ShapeVector & w_shape,const AbstractTensorPtr & input_w)234 void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const AbstractTensorPtr &input_w) {
235 ShapeVector w_min_shape = input_w->shape()->min_shape();
236 ShapeVector w_max_shape = input_w->shape()->max_shape();
237 CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
238 CheckShapeAnyAndPositive(op_name + " w_shape", w_shape);
239 CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape);
240 CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
241 }
242
InferImplConv2D(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)243 AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
244 const AbstractBasePtrList &args_spec_list) {
245 constexpr auto kConv2DInputNum = 2;
246 const std::string op_name = primitive->name();
247 CheckArgsSize(op_name, args_spec_list, kConv2DInputNum);
248 AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
249 MS_EXCEPTION_IF_NULL(input_x);
250 MS_EXCEPTION_IF_NULL(input_x->shape());
251 ShapeVector x_shape = input_x->shape()->shape();
252 ShapeVector x_min_shape = input_x->shape()->min_shape();
253 ShapeVector x_max_shape = input_x->shape()->max_shape();
254 CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
255 CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
256 CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape);
257 CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape);
258 AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
259 MS_EXCEPTION_IF_NULL(input_w);
260 MS_EXCEPTION_IF_NULL(input_w->shape());
261 ShapeVector w_shape = input_w->shape()->shape();
262 CheckShape(op_name, w_shape, input_w);
263 const uint64_t n_axis = 0;
264 uint64_t c_axis = 1;
265 uint64_t h_axis = 2;
266 uint64_t w_axis = 3;
267 int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format"));
268 if (data_format == Format::NHWC) {
269 c_axis = 3;
270 h_axis = 1;
271 w_axis = 2;
272 }
273 int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group");
274 if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) &&
275 ((x_shape[c_axis] / group) != w_shape[c_axis])) {
276 MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got "
277 << (x_shape[c_axis] / group);
278 }
279 int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel");
280 if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
281 MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must equal to = " << out_channel;
282 }
283 const size_t kernel_size_num_element = 2;
284 std::vector<int64_t> kernel_size =
285 CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, kernel_size_num_element);
286 if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) {
287 MS_LOG(EXCEPTION) << "weight height = " << w_shape[h_axis] << ", must equal to = " << kernel_size[0];
288 }
289 if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) {
290 MS_LOG(EXCEPTION) << "weight width = " << w_shape[w_axis] << ", must equal to = " << kernel_size[1];
291 }
292 std::vector<int64_t> stride =
293 CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), stride_start_idx, stride_num_element);
294 std::vector<int64_t> dilation =
295 CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), dilation_start_idx, dilation_num_element);
296 std::vector<int64_t> padding =
297 CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), padding_start_idx, padding_num_element);
298 int64_t pad_mode;
299 CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
300 std::vector<int64_t> output_hw;
301 std::vector<int64_t> pad_list;
302 std::vector<int64_t> output_hw_min;
303 std::vector<int64_t> pad_list_min;
304 std::vector<int64_t> output_hw_max;
305 std::vector<int64_t> pad_list_max;
306 Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
307 padding);
308 if (x_shape[h_axis] == Shape::SHP_ANY) {
309 output_hw[0] = Shape::SHP_ANY;
310 }
311 if (x_shape[w_axis] == Shape::SHP_ANY) {
312 output_hw[1] = Shape::SHP_ANY;
313 }
314 Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride,
315 dilation, pad_mode, padding);
316 Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
317 dilation, pad_mode, padding);
318 std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
319 MakeValue(pad_list[3])};
320 primitive->set_attr("pad_list", MakeValue(pad_list_val));
321 ShapeVector output_shape;
322 ShapeVector output_shape_min;
323 ShapeVector output_shape_max;
324 if (data_format == Format::NHWC) {
325 output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel};
326 output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel};
327 output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel};
328 } else {
329 output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]};
330 output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]};
331 output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]};
332 }
333 CheckShapeAnyAndPositive(op_name + " output_shape", output_shape);
334 CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min);
335 CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max);
336 TypePtr x_type = input_x->element()->GetTypeTrack();
337 if (x_type->type_id() == TypeId::kNumberTypeInt8) {
338 x_type = kInt32;
339 }
340 ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max);
341 return std::make_shared<AbstractTensor>(x_type, output_shape_ptr);
342 }
343
InferImplBiasAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)344 AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
345 const AbstractBasePtrList &args_spec_list) {
346 const std::string op_name = primitive->name();
347 constexpr size_t args_size = 2;
348 CheckArgsSize(op_name, args_spec_list, args_size);
349 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
350 auto bias = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
351 MS_EXCEPTION_IF_NULL(x);
352 MS_EXCEPTION_IF_NULL(x->shape());
353 ShapeVector x_shape = x->shape()->shape();
354 MS_EXCEPTION_IF_NULL(bias);
355 MS_EXCEPTION_IF_NULL(bias->shape());
356 ShapeVector bias_shape = bias->shape()->shape();
357 ShapeVector x_min_shape = x->shape()->min_shape();
358 ShapeVector x_max_shape = x->shape()->max_shape();
359 auto data_format_ptr = primitive->GetAttr("format");
360 int64_t data_format = Format::NCHW;
361 if (data_format_ptr != nullptr) {
362 data_format = GetAndCheckFormat(data_format_ptr);
363 }
364 auto x_channel = data_format == Format::NHWC ? x_shape[x_shape.size() - 1] : x_shape[1];
365 // Additional check for dynamic shape
366 // Last infer will be real shape values
367 bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
368 if (x_not_dyn && bias_shape[0] != x_channel) {
369 MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format
370 << ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
371 }
372 CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
373 return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape));
374 }
375
InferImplBiasAddGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)376 AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
377 const AbstractBasePtrList &args_spec_list) {
378 // Inputs: at least one tensor(y_backprop)
379 // Outputs: dbias
380 if (args_spec_list.empty()) {
381 MS_LOG(EXCEPTION) << primitive->name() << " evaluator at least has 1 parameters, while the input size is "
382 << args_spec_list.size() << ".";
383 }
384
385 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
386 ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
387 MS_EXCEPTION_IF_NULL(shape_y);
388 ShapeVector y_dims = shape_y->shape();
389 if (y_dims.size() < 2) {
390 MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << ".";
391 }
392 ShapeVector bias_dims = {y_dims[1]};
393 ShapePtr ret_shape = std::make_shared<Shape>(bias_dims);
394 AbstractBasePtr ret = args_spec_list[0]->Broaden();
395 ret->set_shape(ret_shape);
396 return ret;
397 }
398
InferImplHSigmoid(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)399 AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
400 const AbstractBasePtrList &args_spec_list) {
401 // Inputs: a tensor.
402 CheckArgsSize(primitive->name(), args_spec_list, 1);
403 // add check, types other than half and float are from cpu
404 auto tensor = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0);
405 (void)CheckTensorDType(tensor, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "Input of HSigmoid should be %s");
406 return args_spec_list[0]->Broaden();
407 }
408
InferImplHSigmoidGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)409 AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
410 const AbstractBasePtrList &args_spec_list) {
411 // Inputs: a tensor.
412 CheckArgsSize(primitive->name(), args_spec_list, 2);
413 // add check, types other than half and float are from cpu
414 auto dout = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0);
415 auto x = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 1);
416 (void)CheckTensorDType(dout, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32},
417 "Dout of HSigmoidGrad should be %s");
418 (void)CheckTensorDType(x, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "X of HSigmoidGrad should be %s");
419 return args_spec_list[1]->Broaden();
420 }
421
InferImplBpropCut(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)422 AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
423 const AbstractBasePtrList &args_spec_list) {
424 // Inputs: a tensor.
425 AbstractBasePtrList args_list;
426 constexpr size_t out_and_dout_size = 2;
427 for (size_t i = 0; i < args_spec_list.size() - out_and_dout_size; i++) {
428 args_list.push_back(args_spec_list[i]->Broaden());
429 }
430 return std::make_shared<AbstractTuple>(args_list);
431 }
432
InferImplDropout(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)433 AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
434 const AbstractBasePtrList &args_spec_list) {
435 const std::string op_name = primitive->name();
436 CheckArgsSize(op_name, args_spec_list, 1);
437 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
438 MS_EXCEPTION_IF_NULL(x);
439 MS_EXCEPTION_IF_NULL(x->shape());
440 ShapeVector shape = x->shape()->shape();
441 ShapeVector min_shape = x->shape()->min_shape();
442 ShapeVector max_shape = x->shape()->max_shape();
443 CheckMinMaxShape(shape, &min_shape, &max_shape);
444 auto output_shape =
445 std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
446 AbstractBasePtrList ret = {output_shape, output_shape};
447 return std::make_shared<AbstractTuple>(ret);
448 }
449
InferImplSparseApplyFtrl(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)450 AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
451 const AbstractBasePtrList &args_spec_list) {
452 CheckRequiredArgsSize(primitive->name(), args_spec_list, 5);
453 AbstractBasePtrList elements;
454 for (size_t i = 0; i < 3; ++i) {
455 elements.push_back(args_spec_list[i]->Clone()->Broaden());
456 }
457 return std::make_shared<AbstractTuple>(elements);
458 }
459
InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)460 AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
461 const AbstractBasePtrList &args_spec_list) {
462 CheckRequiredArgsSize(primitive->name(), args_spec_list, 7);
463 AbstractBasePtrList elements;
464 const size_t args_size = 2;
465 for (size_t i = 0; i < args_size; ++i) {
466 elements.push_back(args_spec_list[i]->Clone()->Broaden());
467 }
468 return std::make_shared<AbstractTuple>(elements);
469 }
470
InferImplSGD(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)471 AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
472 const AbstractBasePtrList &args_spec_list) {
473 CheckRequiredArgsSize(primitive->name(), args_spec_list, 6);
474 AbstractBasePtrList elements;
475 elements.push_back(args_spec_list[0]->Clone()->Broaden());
476 return std::make_shared<AbstractTuple>(elements);
477 }
478
InferImplCTCGreedyDecoder(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)479 AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
480 const AbstractBasePtrList &args_spec_list) {
481 // inputs: inputs, sequence_length
482 const std::string op_name = primitive->name();
483 CheckArgsSize(op_name, args_spec_list, 2);
484 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
485
486 constexpr size_t size_expected = 3;
487 auto shape = input->shape();
488 if (shape->shape().size() != size_expected) {
489 MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 3.";
490 }
491
492 ShapeVector indices_shape = {Shape::SHP_ANY, 2};
493 ShapeVector min_shape = {1, 2};
494 ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1], 2};
495 auto decoded_indices =
496 std::make_shared<AbstractTensor>(kInt64, std::make_shared<Shape>(indices_shape, min_shape, max_shape));
497
498 ShapeVector values_shape = {Shape::SHP_ANY};
499 ShapeVector values_min_shape = {1};
500 ShapeVector values_max_shape = {shape->shape()[0] * shape->shape()[1]};
501 ShapePtr values_shapes = std::make_shared<Shape>(values_shape, values_min_shape, values_max_shape);
502 auto decoded_values = std::make_shared<AbstractTensor>(kInt64, values_shapes);
503
504 ShapeVector decoded_shape_shape = {2};
505 auto decoded_shape = std::make_shared<AbstractTensor>(kInt64, decoded_shape_shape);
506
507 ShapeVector log_probability_shape = {shape->shape()[1], 1};
508 auto log_probability =
509 std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(log_probability_shape));
510
511 // outputs: decoded_indices, decoded_values, decoded_shape, log_probability
512 AbstractBasePtrList elements = {decoded_indices, decoded_values, decoded_shape, log_probability};
513 return std::make_shared<AbstractTuple>(elements);
514 }
515
InferImplPad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)516 AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
517 const AbstractBasePtrList &args_spec_list) {
518 const std::string op_name = primitive->name();
519 CheckArgsSize(op_name, args_spec_list, 1);
520 auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
521 auto input_shp = arg->shape()->shape();
522 MS_EXCEPTION_IF_NULL(primitive);
523 auto padding_attr = primitive->GetAttr("paddings");
524 MS_EXCEPTION_IF_NULL(padding_attr);
525 if (!padding_attr->isa<ValueTuple>()) {
526 MS_LOG(EXCEPTION) << "Paddings is not a ValueTuple";
527 }
528 std::vector<ValuePtr> paddings = padding_attr->cast<ValueTuplePtr>()->value();
529 std::vector<std::vector<int64_t>> paddings_vec;
530 for (ValuePtr paddings_elements : paddings) {
531 std::vector<ValuePtr> paddings_elements_tuple = paddings_elements->cast<ValueTuplePtr>()->value();
532 std::vector<int64_t> paddings_vec_item;
533 (void)std::transform(std::begin(paddings_elements_tuple), std::end(paddings_elements_tuple),
534 std::back_inserter(paddings_vec_item),
535 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
536 paddings_vec.push_back(paddings_vec_item);
537 }
538
539 ShapeVector result_shp;
540 size_t length = paddings_vec.size();
541 for (size_t i = 0; i < length; ++i) {
542 if (paddings_vec[i].size() != 2) {
543 MS_LOG(EXCEPTION) << "Paddings 's second dim size is not 2";
544 }
545 result_shp.push_back(input_shp[i] + paddings_vec[i][0] + paddings_vec[i][1]);
546 }
547 return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp));
548 }
549
InferImplComputeAccidentalHits(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)550 AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
551 const AbstractBasePtrList &args_spec_list) {
552 // inputs: true_classes, sampled_candidates
553 const std::string op_name = primitive->name();
554 constexpr size_t size_expected = 2;
555 CheckArgsSize(op_name, args_spec_list, size_expected);
556 AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
557
558 auto shape = input->shape();
559 if (shape->shape().size() != size_expected) {
560 MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 2.";
561 }
562 ShapeVector indices_shape = {Shape::SHP_ANY};
563 ShapeVector min_shape = {1};
564 ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1]};
565
566 auto indices =
567 std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(indices_shape, min_shape, max_shape));
568
569 auto weights = std::make_shared<AbstractTensor>(kFloat32, indices_shape);
570 weights->set_shape(std::make_shared<Shape>(indices_shape, min_shape, max_shape));
571 // outputs: indices, ids, weights
572 AbstractBasePtrList elements = {indices, indices, weights};
573 return std::make_shared<AbstractTuple>(elements);
574 }
575 } // namespace abstract
576 } // namespace mindspore
577