• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cmath>
18 #include "abstract/infer_functions.h"
19 #include "abstract/utils.h"
20 #include "abstract/param_validator.h"
21 #include "utils/check_convert_utils.h"
22 #include "utils/shape_utils.h"
23 
24 namespace mindspore {
25 namespace abstract {
26 const size_t stride_num_element = 2;
27 const size_t stride_start_idx = 2;
28 const size_t dilation_num_element = 2;
29 const size_t dilation_start_idx = 2;
30 const size_t padding_num_element = 4;
31 const size_t padding_start_idx = 0;
GetAndCheckFormat(const ValuePtr & value)32 int64_t GetAndCheckFormat(const ValuePtr &value) {
33   int64_t data_format;
34   bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
35   if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
36     MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW";
37   }
38   return data_format;
39 }
40 
InferImplPooling(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)41 AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
42                                  const AbstractBasePtrList &args_spec_list) {
43   // Inputs: a tensor.
44   const std::string op_name = primitive->name();
45   CheckArgsSize(op_name, args_spec_list, 1);
46   AbstractTensorPtr input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
47   (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input 0 of Pooling should be %s");
48 
49   ShapePtr input_shape = dyn_cast<Shape>(input_tensor->GetShapeTrack());  // NCHW
50   MS_EXCEPTION_IF_NULL(input_shape);
51   const size_t input_shape_size = 4;
52   if (input_shape->shape().size() != input_shape_size) {
53     MS_LOG(EXCEPTION) << "Pooling input should be a 4-D tensor.";
54   }
55   const size_t H_INDEX = 2;
56   const size_t W_INDEX = 3;
57   int64_t h_input = input_shape->shape()[H_INDEX];
58   int64_t w_input = input_shape->shape()[W_INDEX];
59 
60   int64_t window = GetValue<int64_t>(primitive->GetAttr("window"));
61   int64_t stride = GetValue<int64_t>(primitive->GetAttr("stride"));
62   int64_t padding = GetValue<int64_t>(primitive->GetAttr("pad"));
63   int64_t nan_opt = GetValue<int64_t>(primitive->GetAttr("nan_opt"));
64   int64_t data_mode = GetValue<int64_t>(primitive->GetAttr("data_mode"));
65   int64_t ceil_mode = GetValue<int64_t>(primitive->GetAttr("ceil_mode"));
66 
67   if (stride <= 0) {
68     MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0";
69   }
70   if (nan_opt != 0) {
71     MS_LOG(EXCEPTION) << "Invalid nan_opt value: " << nan_opt << ", should be 0";
72   }
73   if (data_mode != 1) {
74     MS_LOG(EXCEPTION) << "Invalid data_mode value: " << data_mode << ", should be 1";
75   }
76   if (ceil_mode != 0) {
77     MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0";
78   }
79 
80   auto pad_mode_ptr = primitive->GetAttr("pad_mode");
81   if (pad_mode_ptr != nullptr) {
82     int64_t pad_mode;
83     CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
84     if (pad_mode == PadMode::VALID) {
85       padding = 0;
86     } else if (pad_mode == PadMode::SAME) {
87       padding = (window - 1) / 2;
88     }
89   }
90   std::set<std::string> available_mode{"max", "avg"};
91   auto mode_ptr = primitive->GetAttr("mode");
92   if ((mode_ptr != nullptr) && mode_ptr->isa<StringImm>()) {
93     auto mode = mode_ptr->cast<StringImmPtr>()->value();
94     if (available_mode.find(mode) == available_mode.end()) {
95       MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << ".";
96     }
97   }
98 
99   int64_t h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1;
100   int64_t w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1;
101   ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out};
102   AbstractBasePtr ret = input_tensor->Broaden();
103   ret->set_shape(std::make_shared<Shape>(shape_out));
104   return ret;
105 }
106 
InferImplPoolingGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)107 AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
108                                      const AbstractBasePtrList &args_spec_list) {
109   // Inputs: three tensors(y, dy, x).
110   constexpr auto kPoolingGradInputNum = 3;
111   const std::string op_name = primitive->name();
112   CheckArgsSize(op_name, args_spec_list, kPoolingGradInputNum);
113   auto out_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
114   auto d_out = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
115   auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
116   (void)CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat},
117                               op_name + "evaluator three inputs should be %s");
118 
119   AbstractBasePtr ret = d_out->Broaden();
120   auto x_shape = dyn_cast<Shape>(args_spec_list[2]->GetShapeTrack());
121   MS_EXCEPTION_IF_NULL(x_shape);
122 
123   ret->set_shape(x_shape);
124   return ret;
125 }
126 
InferImplBatchNorm(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)127 AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
128                                    const AbstractBasePtrList &args_spec_list) {
129   // Inputs: five tensors(x, gamma, beta, mean, variance).
130   constexpr auto kBatchNormInputNum = 5;
131   const std::string op_name = primitive->name();
132   CheckArgsSize(op_name, args_spec_list, kBatchNormInputNum);
133   AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
134   MS_EXCEPTION_IF_NULL(input_x);
135   MS_EXCEPTION_IF_NULL(input_x->shape());
136   ShapeVector x_shape = input_x->shape()->shape();
137   ShapeVector x_min_shape = input_x->shape()->min_shape();
138   ShapeVector x_max_shape = input_x->shape()->max_shape();
139   CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
140 
141   auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
142   (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be");
143   AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
144   for (size_t i = 1; i < args_spec_list.size(); ++i) {
145     auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
146     tensorPtrList.push_back(param);
147   }
148   (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32},
149                               "param  gamma, beta, mean, variance of Batchnorm should be");
150 
151   auto data_format_ptr = primitive->GetAttr("format");
152   MS_EXCEPTION_IF_NULL(data_format_ptr);
153   int64_t data_format = GetAndCheckFormat(data_format_ptr);
154 
155   size_t c_axis = 1;
156   if (data_format == Format::NHWC) {
157     c_axis = 3;
158   }
159   for (size_t i = 1; i < args_spec_list.size(); ++i) {
160     AbstractTensorPtr arg_spec = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
161     MS_EXCEPTION_IF_NULL(arg_spec);
162     MS_EXCEPTION_IF_NULL(arg_spec->shape());
163     ShapeVector arg_shape = arg_spec->shape()->shape();
164     if (arg_shape.size() != 1) {
165       MS_LOG(EXCEPTION) << "Arg " << i << " rank should be 1, but got " << arg_shape.size();
166     }
167     if ((x_shape[c_axis] != Shape::SHP_ANY) && (arg_shape[0] != x_shape[c_axis])) {
168       MS_EXCEPTION(ValueError) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis
169                                << "]=" << x_shape[c_axis] << ", but got " << arg_shape[0];
170     }
171   }
172   AbstractTensorPtr input_gamma = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
173   ShapeVector gamma_shape = input_gamma->shape()->shape();
174   ShapeVector gamma_min_shape = input_gamma->shape()->min_shape();
175   ShapeVector gamma_max_shape = input_gamma->shape()->max_shape();
176   CheckMinMaxShape(gamma_shape, &gamma_min_shape, &gamma_max_shape);
177   ShapePtr output_shape_ptr = std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape);
178   AbstractTensorPtr output = std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr);
179   ShapePtr gamma_shape_ptr = std::make_shared<Shape>(gamma_shape, gamma_min_shape, gamma_max_shape);
180   AbstractTensorPtr output_gamma = std::make_shared<AbstractTensor>(input_gamma->element(), gamma_shape_ptr);
181   AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma};
182   return std::make_shared<AbstractTuple>(rets);
183 }
184 
InferImplFusedSparseAdam(const AnalysisEnginePtr &,const PrimitivePtr &,const AbstractBasePtrList & args_spec_list)185 AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &,
186                                          const AbstractBasePtrList &args_spec_list) {
187   // the output is useless, so we dont have to focus on the output shape
188   constexpr size_t dx_index = 1;
189   constexpr size_t dscale_index = 2;
190   constexpr size_t dbias_index = 3;
191   MS_EXCEPTION_IF_NULL(args_spec_list[dx_index]);
192   MS_EXCEPTION_IF_NULL(args_spec_list[dscale_index]);
193   MS_EXCEPTION_IF_NULL(args_spec_list[dbias_index]);
194 
195   auto dx = args_spec_list[dx_index]->Broaden();
196   auto dscale = args_spec_list[dscale_index]->Broaden();
197   auto dbias = args_spec_list[dbias_index]->Broaden();
198 
199   AbstractBasePtrList rets = {dx, dscale, dbias};
200   return std::make_shared<AbstractTuple>(rets);
201 }
202 
Conv2DPadFunction(std::vector<int64_t> * output_hw,std::vector<int64_t> * pad_list,const int64_t x_h,const int64_t x_w,const std::vector<int64_t> & kernel,const std::vector<int64_t> & stride,const std::vector<int64_t> & dilation,const int64_t & pad_mode,const std::vector<int64_t> & padding)203 void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
204                        const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
205                        const std::vector<int64_t> &dilation, const int64_t &pad_mode,
206                        const std::vector<int64_t> &padding) {
207   if (pad_mode == PadMode::VALID) {
208     output_hw->push_back(static_cast<int64_t>(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])));
209     output_hw->push_back(static_cast<int64_t>(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])));
210     const size_t nhwc = 4;
211     (void)pad_list->insert(pad_list->begin(), nhwc, 0);
212   } else if (pad_mode == PadMode::SAME) {
213     output_hw->push_back(static_cast<int64_t>(std::ceil((x_h * 1.0) / stride[0])));
214     output_hw->push_back(static_cast<int64_t>(std::ceil((x_w * 1.0) / stride[1])));
215     int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
216     pad_needed_h = std::max((int64_t)0, pad_needed_h);
217     pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_h / 2)));
218     pad_list->push_back(pad_needed_h - pad_list->at(0));
219     int64_t pad_needed_w = (output_hw->at(1) - 1) * stride[1] + dilation[1] * (kernel[1] - 1) + 1 - x_w;
220     pad_needed_w = std::max((int64_t)0, pad_needed_w);
221     pad_list->push_back(static_cast<int64_t>(std::floor(pad_needed_w / 2)));
222     pad_list->push_back(pad_needed_w - pad_list->at(2));
223   } else if (pad_mode == PadMode::PAD) {
224     (void)pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
225     output_hw->push_back(static_cast<int64_t>(std::floor(
226       1 + ((x_h * 1.0) + pad_list->at(0) + pad_list->at(1) - kernel[0] - (kernel[0] - 1) * (dilation[0] - 1)) /
227             stride[0])));
228     output_hw->push_back(static_cast<int64_t>(std::floor(
229       1 + ((x_w * 1.0) + pad_list->at(2) + pad_list->at(3) - kernel[1] - (kernel[1] - 1) * (dilation[1] - 1)) /
230             stride[1])));
231   }
232 }
233 
CheckShape(const std::string & op_name,const ShapeVector & w_shape,const AbstractTensorPtr & input_w)234 void CheckShape(const std::string &op_name, const ShapeVector &w_shape, const AbstractTensorPtr &input_w) {
235   ShapeVector w_min_shape = input_w->shape()->min_shape();
236   ShapeVector w_max_shape = input_w->shape()->max_shape();
237   CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape);
238   CheckShapeAnyAndPositive(op_name + " w_shape", w_shape);
239   CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape);
240   CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
241 }
242 
InferImplConv2D(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)243 AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
244                                 const AbstractBasePtrList &args_spec_list) {
245   constexpr auto kConv2DInputNum = 2;
246   const std::string op_name = primitive->name();
247   CheckArgsSize(op_name, args_spec_list, kConv2DInputNum);
248   AbstractTensorPtr input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
249   MS_EXCEPTION_IF_NULL(input_x);
250   MS_EXCEPTION_IF_NULL(input_x->shape());
251   ShapeVector x_shape = input_x->shape()->shape();
252   ShapeVector x_min_shape = input_x->shape()->min_shape();
253   ShapeVector x_max_shape = input_x->shape()->max_shape();
254   CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
255   CheckShapeAnyAndPositive(op_name + " x_shape", x_shape);
256   CheckShapeAllPositive(op_name + " x_min_shape", x_min_shape);
257   CheckShapeAllPositive(op_name + " x_max_shape", x_max_shape);
258   AbstractTensorPtr input_w = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
259   MS_EXCEPTION_IF_NULL(input_w);
260   MS_EXCEPTION_IF_NULL(input_w->shape());
261   ShapeVector w_shape = input_w->shape()->shape();
262   CheckShape(op_name, w_shape, input_w);
263   const uint64_t n_axis = 0;
264   uint64_t c_axis = 1;
265   uint64_t h_axis = 2;
266   uint64_t w_axis = 3;
267   int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format"));
268   if (data_format == Format::NHWC) {
269     c_axis = 3;
270     h_axis = 1;
271     w_axis = 2;
272   }
273   int64_t group = CheckAttrPositiveInt64(op_name, primitive->GetAttr("group"), "group");
274   if ((x_shape[c_axis] != Shape::SHP_ANY) && (w_shape[c_axis] != Shape::SHP_ANY) &&
275       ((x_shape[c_axis] / group) != w_shape[c_axis])) {
276     MS_LOG(EXCEPTION) << "x_shape[C_in] / group must equal to w_shape[C_in] = " << w_shape[c_axis] << ", but got "
277                       << (x_shape[c_axis] / group);
278   }
279   int64_t out_channel = CheckAttrPositiveInt64(op_name, primitive->GetAttr("out_channel"), "out_channel");
280   if ((w_shape[n_axis] != Shape::SHP_ANY) && (w_shape[n_axis] != out_channel)) {
281     MS_LOG(EXCEPTION) << "w_shape[" << n_axis << "] = " << w_shape[n_axis] << " must equal to = " << out_channel;
282   }
283   const size_t kernel_size_num_element = 2;
284   std::vector<int64_t> kernel_size =
285     CheckAttrIntOrTuple(op_name, primitive->GetAttr("kernel_size"), 0, kernel_size_num_element);
286   if ((w_shape[h_axis] != Shape::SHP_ANY) && (w_shape[h_axis] != kernel_size[0])) {
287     MS_LOG(EXCEPTION) << "weight height = " << w_shape[h_axis] << ", must equal to = " << kernel_size[0];
288   }
289   if ((w_shape[w_axis] != Shape::SHP_ANY) && (w_shape[w_axis] != kernel_size[1])) {
290     MS_LOG(EXCEPTION) << "weight width = " << w_shape[w_axis] << ", must equal to = " << kernel_size[1];
291   }
292   std::vector<int64_t> stride =
293     CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), stride_start_idx, stride_num_element);
294   std::vector<int64_t> dilation =
295     CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), dilation_start_idx, dilation_num_element);
296   std::vector<int64_t> padding =
297     CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), padding_start_idx, padding_num_element);
298   int64_t pad_mode;
299   CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
300   std::vector<int64_t> output_hw;
301   std::vector<int64_t> pad_list;
302   std::vector<int64_t> output_hw_min;
303   std::vector<int64_t> pad_list_min;
304   std::vector<int64_t> output_hw_max;
305   std::vector<int64_t> pad_list_max;
306   Conv2DPadFunction(&output_hw, &pad_list, x_shape[h_axis], x_shape[w_axis], kernel_size, stride, dilation, pad_mode,
307                     padding);
308   if (x_shape[h_axis] == Shape::SHP_ANY) {
309     output_hw[0] = Shape::SHP_ANY;
310   }
311   if (x_shape[w_axis] == Shape::SHP_ANY) {
312     output_hw[1] = Shape::SHP_ANY;
313   }
314   Conv2DPadFunction(&output_hw_min, &pad_list_min, x_min_shape[h_axis], x_min_shape[w_axis], kernel_size, stride,
315                     dilation, pad_mode, padding);
316   Conv2DPadFunction(&output_hw_max, &pad_list_max, x_max_shape[h_axis], x_max_shape[w_axis], kernel_size, stride,
317                     dilation, pad_mode, padding);
318   std::vector<ValuePtr> pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]),
319                                         MakeValue(pad_list[3])};
320   primitive->set_attr("pad_list", MakeValue(pad_list_val));
321   ShapeVector output_shape;
322   ShapeVector output_shape_min;
323   ShapeVector output_shape_max;
324   if (data_format == Format::NHWC) {
325     output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel};
326     output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel};
327     output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel};
328   } else {
329     output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]};
330     output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]};
331     output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]};
332   }
333   CheckShapeAnyAndPositive(op_name + " output_shape", output_shape);
334   CheckShapeAllPositive(op_name + " output_shape_min", output_shape_min);
335   CheckShapeAllPositive(op_name + " output_shape_max", output_shape_max);
336   TypePtr x_type = input_x->element()->GetTypeTrack();
337   if (x_type->type_id() == TypeId::kNumberTypeInt8) {
338     x_type = kInt32;
339   }
340   ShapePtr output_shape_ptr = std::make_shared<Shape>(output_shape, output_shape_min, output_shape_max);
341   return std::make_shared<AbstractTensor>(x_type, output_shape_ptr);
342 }
343 
InferImplBiasAdd(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)344 AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
345                                  const AbstractBasePtrList &args_spec_list) {
346   const std::string op_name = primitive->name();
347   constexpr size_t args_size = 2;
348   CheckArgsSize(op_name, args_spec_list, args_size);
349   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
350   auto bias = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
351   MS_EXCEPTION_IF_NULL(x);
352   MS_EXCEPTION_IF_NULL(x->shape());
353   ShapeVector x_shape = x->shape()->shape();
354   MS_EXCEPTION_IF_NULL(bias);
355   MS_EXCEPTION_IF_NULL(bias->shape());
356   ShapeVector bias_shape = bias->shape()->shape();
357   ShapeVector x_min_shape = x->shape()->min_shape();
358   ShapeVector x_max_shape = x->shape()->max_shape();
359   auto data_format_ptr = primitive->GetAttr("format");
360   int64_t data_format = Format::NCHW;
361   if (data_format_ptr != nullptr) {
362     data_format = GetAndCheckFormat(data_format_ptr);
363   }
364   auto x_channel = data_format == Format::NHWC ? x_shape[x_shape.size() - 1] : x_shape[1];
365   // Additional check for dynamic shape
366   // Last infer will be real shape values
367   bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
368   if (x_not_dyn && bias_shape[0] != x_channel) {
369     MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format
370                       << ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
371   }
372   CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
373   return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x_shape, x_min_shape, x_max_shape));
374 }
375 
InferImplBiasAddGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)376 AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
377                                      const AbstractBasePtrList &args_spec_list) {
378   // Inputs: at least one tensor(y_backprop)
379   // Outputs: dbias
380   if (args_spec_list.empty()) {
381     MS_LOG(EXCEPTION) << primitive->name() << " evaluator at least has 1 parameters, while the input size is "
382                       << args_spec_list.size() << ".";
383   }
384 
385   MS_EXCEPTION_IF_NULL(args_spec_list[0]);
386   ShapePtr shape_y = dyn_cast<Shape>(args_spec_list[0]->GetShapeTrack());
387   MS_EXCEPTION_IF_NULL(shape_y);
388   ShapeVector y_dims = shape_y->shape();
389   if (y_dims.size() < 2) {
390     MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << ".";
391   }
392   ShapeVector bias_dims = {y_dims[1]};
393   ShapePtr ret_shape = std::make_shared<Shape>(bias_dims);
394   AbstractBasePtr ret = args_spec_list[0]->Broaden();
395   ret->set_shape(ret_shape);
396   return ret;
397 }
398 
InferImplHSigmoid(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)399 AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
400                                   const AbstractBasePtrList &args_spec_list) {
401   // Inputs: a tensor.
402   CheckArgsSize(primitive->name(), args_spec_list, 1);
403   // add check, types other than half and float are from cpu
404   auto tensor = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0);
405   (void)CheckTensorDType(tensor, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "Input of HSigmoid should be %s");
406   return args_spec_list[0]->Broaden();
407 }
408 
InferImplHSigmoidGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)409 AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
410                                       const AbstractBasePtrList &args_spec_list) {
411   // Inputs: a tensor.
412   CheckArgsSize(primitive->name(), args_spec_list, 2);
413   // add check, types other than half and float are from cpu
414   auto dout = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0);
415   auto x = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 1);
416   (void)CheckTensorDType(dout, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32},
417                          "Dout of HSigmoidGrad should be %s");
418   (void)CheckTensorDType(x, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "X of HSigmoidGrad should be %s");
419   return args_spec_list[1]->Broaden();
420 }
421 
InferImplBpropCut(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)422 AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
423                                   const AbstractBasePtrList &args_spec_list) {
424   // Inputs: a tensor.
425   AbstractBasePtrList args_list;
426   constexpr size_t out_and_dout_size = 2;
427   for (size_t i = 0; i < args_spec_list.size() - out_and_dout_size; i++) {
428     args_list.push_back(args_spec_list[i]->Broaden());
429   }
430   return std::make_shared<AbstractTuple>(args_list);
431 }
432 
InferImplDropout(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)433 AbstractBasePtr InferImplDropout(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
434                                  const AbstractBasePtrList &args_spec_list) {
435   const std::string op_name = primitive->name();
436   CheckArgsSize(op_name, args_spec_list, 1);
437   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
438   MS_EXCEPTION_IF_NULL(x);
439   MS_EXCEPTION_IF_NULL(x->shape());
440   ShapeVector shape = x->shape()->shape();
441   ShapeVector min_shape = x->shape()->min_shape();
442   ShapeVector max_shape = x->shape()->max_shape();
443   CheckMinMaxShape(shape, &min_shape, &max_shape);
444   auto output_shape =
445     std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
446   AbstractBasePtrList ret = {output_shape, output_shape};
447   return std::make_shared<AbstractTuple>(ret);
448 }
449 
InferImplSparseApplyFtrl(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)450 AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
451                                          const AbstractBasePtrList &args_spec_list) {
452   CheckRequiredArgsSize(primitive->name(), args_spec_list, 5);
453   AbstractBasePtrList elements;
454   for (size_t i = 0; i < 3; ++i) {
455     elements.push_back(args_spec_list[i]->Clone()->Broaden());
456   }
457   return std::make_shared<AbstractTuple>(elements);
458 }
459 
InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)460 AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
461                                                     const AbstractBasePtrList &args_spec_list) {
462   CheckRequiredArgsSize(primitive->name(), args_spec_list, 7);
463   AbstractBasePtrList elements;
464   const size_t args_size = 2;
465   for (size_t i = 0; i < args_size; ++i) {
466     elements.push_back(args_spec_list[i]->Clone()->Broaden());
467   }
468   return std::make_shared<AbstractTuple>(elements);
469 }
470 
InferImplSGD(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)471 AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
472                              const AbstractBasePtrList &args_spec_list) {
473   CheckRequiredArgsSize(primitive->name(), args_spec_list, 6);
474   AbstractBasePtrList elements;
475   elements.push_back(args_spec_list[0]->Clone()->Broaden());
476   return std::make_shared<AbstractTuple>(elements);
477 }
478 
InferImplCTCGreedyDecoder(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)479 AbstractBasePtr InferImplCTCGreedyDecoder(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
480                                           const AbstractBasePtrList &args_spec_list) {
481   // inputs: inputs, sequence_length
482   const std::string op_name = primitive->name();
483   CheckArgsSize(op_name, args_spec_list, 2);
484   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
485 
486   constexpr size_t size_expected = 3;
487   auto shape = input->shape();
488   if (shape->shape().size() != size_expected) {
489     MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 3.";
490   }
491 
492   ShapeVector indices_shape = {Shape::SHP_ANY, 2};
493   ShapeVector min_shape = {1, 2};
494   ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1], 2};
495   auto decoded_indices =
496     std::make_shared<AbstractTensor>(kInt64, std::make_shared<Shape>(indices_shape, min_shape, max_shape));
497 
498   ShapeVector values_shape = {Shape::SHP_ANY};
499   ShapeVector values_min_shape = {1};
500   ShapeVector values_max_shape = {shape->shape()[0] * shape->shape()[1]};
501   ShapePtr values_shapes = std::make_shared<Shape>(values_shape, values_min_shape, values_max_shape);
502   auto decoded_values = std::make_shared<AbstractTensor>(kInt64, values_shapes);
503 
504   ShapeVector decoded_shape_shape = {2};
505   auto decoded_shape = std::make_shared<AbstractTensor>(kInt64, decoded_shape_shape);
506 
507   ShapeVector log_probability_shape = {shape->shape()[1], 1};
508   auto log_probability =
509     std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(log_probability_shape));
510 
511   // outputs: decoded_indices, decoded_values, decoded_shape, log_probability
512   AbstractBasePtrList elements = {decoded_indices, decoded_values, decoded_shape, log_probability};
513   return std::make_shared<AbstractTuple>(elements);
514 }
515 
InferImplPad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)516 AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
517                              const AbstractBasePtrList &args_spec_list) {
518   const std::string op_name = primitive->name();
519   CheckArgsSize(op_name, args_spec_list, 1);
520   auto arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
521   auto input_shp = arg->shape()->shape();
522   MS_EXCEPTION_IF_NULL(primitive);
523   auto padding_attr = primitive->GetAttr("paddings");
524   MS_EXCEPTION_IF_NULL(padding_attr);
525   if (!padding_attr->isa<ValueTuple>()) {
526     MS_LOG(EXCEPTION) << "Paddings is not a ValueTuple";
527   }
528   std::vector<ValuePtr> paddings = padding_attr->cast<ValueTuplePtr>()->value();
529   std::vector<std::vector<int64_t>> paddings_vec;
530   for (ValuePtr paddings_elements : paddings) {
531     std::vector<ValuePtr> paddings_elements_tuple = paddings_elements->cast<ValueTuplePtr>()->value();
532     std::vector<int64_t> paddings_vec_item;
533     (void)std::transform(std::begin(paddings_elements_tuple), std::end(paddings_elements_tuple),
534                          std::back_inserter(paddings_vec_item),
535                          [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
536     paddings_vec.push_back(paddings_vec_item);
537   }
538 
539   ShapeVector result_shp;
540   size_t length = paddings_vec.size();
541   for (size_t i = 0; i < length; ++i) {
542     if (paddings_vec[i].size() != 2) {
543       MS_LOG(EXCEPTION) << "Paddings 's second dim size is not 2";
544     }
545     result_shp.push_back(input_shp[i] + paddings_vec[i][0] + paddings_vec[i][1]);
546   }
547   return std::make_shared<AbstractTensor>(arg->element(), std::make_shared<Shape>(result_shp));
548 }
549 
InferImplComputeAccidentalHits(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)550 AbstractBasePtr InferImplComputeAccidentalHits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
551                                                const AbstractBasePtrList &args_spec_list) {
552   // inputs: true_classes, sampled_candidates
553   const std::string op_name = primitive->name();
554   constexpr size_t size_expected = 2;
555   CheckArgsSize(op_name, args_spec_list, size_expected);
556   AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
557 
558   auto shape = input->shape();
559   if (shape->shape().size() != size_expected) {
560     MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 2.";
561   }
562   ShapeVector indices_shape = {Shape::SHP_ANY};
563   ShapeVector min_shape = {1};
564   ShapeVector max_shape = {shape->shape()[0] * shape->shape()[1]};
565 
566   auto indices =
567     std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(indices_shape, min_shape, max_shape));
568 
569   auto weights = std::make_shared<AbstractTensor>(kFloat32, indices_shape);
570   weights->set_shape(std::make_shared<Shape>(indices_shape, min_shape, max_shape));
571   // outputs: indices, ids, weights
572   AbstractBasePtrList elements = {indices, indices, weights};
573   return std::make_shared<AbstractTuple>(elements);
574 }
575 }  // namespace abstract
576 }  // namespace mindspore
577