1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/infer_functions.h"
18 #include "abstract/utils.h"
19 #include "abstract/param_validator.h"
20 #include "utils/ms_utils.h"
21 #include "utils/check_convert_utils.h"
22
23 namespace mindspore {
24 namespace abstract {
InferImplMinOrMaxGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)25 AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
26 const AbstractBasePtrList &args_spec_list) {
27 // Inputs: three tensors.
28 constexpr auto kMinMaxGradInputNum = 3;
29 const size_t dout_index = 2;
30 const std::string op_name = primitive->name();
31 CheckArgsSize(op_name, args_spec_list, kMinMaxGradInputNum);
32 auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
33 auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
34 auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, dout_index);
35 (void)CheckTensorsDTypeSame({input_x, input_y, dout}, {kInt, kUInt, kFloat},
36 op_name + "evaluator three inputs should be %s");
37
38 AbstractBasePtr dx = input_x->Broaden();
39 AbstractBasePtr dy = input_y->Broaden();
40
41 return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy}));
42 }
43
InferImplSqrt(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)44 AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
45 const AbstractBasePtrList &args_spec_list) {
46 // Inputs: three tensors.
47 constexpr auto kSqrtInputNum = 1;
48 const std::string op_name = primitive->name();
49 CheckArgsSize(op_name, args_spec_list, kSqrtInputNum);
50 auto inp = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
51 return inp->Clone()->Broaden();
52 }
53
InferImplSqrtGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)54 AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
55 const AbstractBasePtrList &args_spec_list) {
56 // Inputs: two tensors.
57 constexpr auto kSqrtGradInputNum = 2;
58 const std::string op_name = primitive->name();
59 CheckArgsSize(op_name, args_spec_list, kSqrtGradInputNum);
60 auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
61 auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
62 (void)CheckDtypeSame(op_name, out, dout);
63 (void)CheckShapeSame(op_name, out, dout);
64
65 return out->Broaden();
66 }
67
InferImplReduceFuncCheckAxis(const int64_t & axis,const size_t dim)68 int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) {
69 int64_t dim_ = static_cast<int64_t>(dim);
70 if (axis < -dim_ || axis >= dim_) {
71 MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
72 }
73 int64_t ret_axis = axis;
74 if (axis >= -dim_ && axis < 0) {
75 ret_axis += dim_;
76 }
77 return ret_axis;
78 }
79
InferImplReduceFuncCalShape(ShapeVector * shape,const ShapeVector & x_shape,const ValuePtr & axis,bool keep_dims_value)80 void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis,
81 bool keep_dims_value) {
82 MS_EXCEPTION_IF_NULL(axis);
83 if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
84 auto axis_ptr_list =
85 axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
86 if (!axis_ptr_list.size()) {
87 if (keep_dims_value) (void)shape->insert(shape->end(), x_shape.size(), 1);
88 } else {
89 (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
90 ValuePtrList axis_items = axis_ptr_list;
91 ValuePtrList::iterator it;
92 ValuePtrList::reverse_iterator it_re;
93 int64_t axis_value;
94 if (keep_dims_value) {
95 for (it = axis_items.begin(); it != axis_items.end(); ++it) {
96 axis_value = GetValue<int64_t>(*it);
97 axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
98 shape->at(LongToSize(axis_value)) = 1;
99 }
100 } else {
101 std::sort(axis_items.begin(), axis_items.end());
102 for (it_re = axis_items.rbegin(); it_re != axis_items.rend(); ++it_re) {
103 axis_value = GetValue<int64_t>(*it_re);
104 axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
105 (void)shape->erase(shape->begin() + axis_value);
106 }
107 }
108 }
109 } else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
110 (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
111 auto axis_value = GetValue<int64_t>(axis);
112 axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
113 if (keep_dims_value) {
114 shape->at(LongToSize(axis_value)) = 1;
115 } else {
116 (void)shape->erase(shape->begin() + axis_value);
117 }
118 } else {
119 MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
120 }
121 return;
122 }
123
124 // To reduce code repeat, use InferImplReduceFunc. Currently registered with ReduceMean, ReduceSum,
125 // ReduceAll, ReduceAny, ReduceMax, ReduceMin.
InferImplReduceFunc(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)126 AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
127 const AbstractBasePtrList &args_spec_list) {
128 const auto kReduceInputNum = 1;
129 const std::string op_name = primitive->name();
130 CheckArgsSize(op_name, args_spec_list, kReduceInputNum);
131 auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
132 MS_EXCEPTION_IF_NULL(input_x);
133 MS_EXCEPTION_IF_NULL(input_x->element());
134
135 ValuePtr keep_dims = primitive->GetAttr("keep_dims");
136 MS_EXCEPTION_IF_NULL(keep_dims);
137 if (!keep_dims->isa<BoolImm>()) {
138 MS_LOG(EXCEPTION) << "Keep_dims should be Bool.";
139 }
140 bool keep_dims_value = GetValue<bool>(keep_dims);
141
142 ValuePtr axis = primitive->GetAttr("axis");
143 MS_EXCEPTION_IF_NULL(axis);
144
145 ShapeVector shape = {};
146 ShapeVector x_shape = input_x->shape()->shape();
147 InferImplReduceFuncCalShape(&shape, x_shape, axis, keep_dims_value);
148
149 bool x_is_dyn = (!input_x->shape()->min_shape().empty() && !input_x->shape()->max_shape().empty());
150 if (x_is_dyn) {
151 ShapeVector shape_min = {};
152 ShapeVector shape_max = {};
153 ShapeVector x_shape_min = input_x->shape()->min_shape();
154 ShapeVector x_shape_max = input_x->shape()->max_shape();
155 InferImplReduceFuncCalShape(&shape_min, x_shape_min, axis, keep_dims_value);
156 InferImplReduceFuncCalShape(&shape_max, x_shape_max, axis, keep_dims_value);
157 return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, shape_min, shape_max));
158 }
159 return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
160 }
161
InferImplBinaryBase(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)162 AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
163 const AbstractBasePtrList &args_spec_list) {
164 constexpr auto kBinaryBaseInputNum = 2;
165 const std::string op_name = primitive->name();
166 CheckArgsSize(op_name, args_spec_list, kBinaryBaseInputNum);
167 auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
168 MS_EXCEPTION_IF_NULL(input_x);
169 MS_EXCEPTION_IF_NULL(input_x->shape());
170
171 auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
172 MS_EXCEPTION_IF_NULL(input_y);
173 MS_EXCEPTION_IF_NULL(input_y->shape());
174
175 auto x_shape = input_x->shape()->shape();
176 auto y_shape = input_y->shape()->shape();
177 auto output_shape = BroadcastShape(x_shape, y_shape);
178
179 auto x_type = input_x->BuildType();
180 MS_EXCEPTION_IF_NULL(x_type);
181 MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>());
182 auto y_type = input_y->BuildType();
183 MS_EXCEPTION_IF_NULL(y_type);
184 MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>());
185
186 auto x_element = x_type->cast<TensorTypePtr>()->element();
187 MS_EXCEPTION_IF_NULL(x_element);
188 auto y_element = y_type->cast<TensorTypePtr>()->element();
189 MS_EXCEPTION_IF_NULL(y_element);
190
191 auto x_element_type = x_element->number_type();
192 auto y_element_type = y_element->number_type();
193
194 auto x_priority = type_priority_map.find(x_element_type);
195 if (x_priority == type_priority_map.end()) {
196 MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type.";
197 }
198 auto y_priority = type_priority_map.find(y_element_type);
199 if (y_priority == type_priority_map.end()) {
200 MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type.";
201 }
202
203 if (x_priority->second >= y_priority->second) {
204 return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape));
205 } else {
206 return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
207 }
208 }
209
InferImplMinimum(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)210 AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
211 const AbstractBasePtrList &args_spec_list) {
212 return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
213 }
214
InferImplDivNoNan(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)215 AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
216 const AbstractBasePtrList &args_spec_list) {
217 return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
218 }
219
InferImplLinSpace(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)220 AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
221 const AbstractBasePtrList &args_spec_list) {
222 constexpr auto kLinSpaceInputNum = 3;
223 const std::string op_name = primitive->name();
224 CheckArgsSize(op_name, args_spec_list, kLinSpaceInputNum);
225 auto start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
226 MS_EXCEPTION_IF_NULL(start);
227 MS_EXCEPTION_IF_NULL(start->shape());
228 auto stop = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
229 MS_EXCEPTION_IF_NULL(stop);
230 MS_EXCEPTION_IF_NULL(stop->shape());
231 (void)CheckTensorDType(start, {kFloat32}, "Input 0 (start) for LinSpace should be %s");
232 (void)CheckTensorDType(stop, {kFloat32}, "Input 1 (stop) for LinSpace should be %s");
233 ShapeVector shape;
234 ShapeVector max_shape;
235 ShapeVector min_shape;
236 int64_t num_val = 0;
237 // 3rd input is a Tensor when LinSpace is a dynamic shape operator
238 const size_t tensor_index = 2;
239 auto abs_num = args_spec_list[tensor_index];
240 if (abs_num->isa<AbstractTensor>()) {
241 auto num = abs_num->cast<AbstractTensorPtr>();
242 MS_EXCEPTION_IF_NULL(num);
243 auto num_value_ptr = num->BuildValue();
244 MS_EXCEPTION_IF_NULL(num_value_ptr);
245 auto num_tensor = num_value_ptr->cast<tensor::TensorPtr>();
246 MS_EXCEPTION_IF_NULL(num_tensor);
247 num_val = *static_cast<int64_t *>(num_tensor->data_c());
248 } else if (abs_num->isa<AbstractScalar>()) {
249 auto num = abs_num->cast<AbstractScalarPtr>();
250 num_val = GetValue<int64_t>(num->BuildValue());
251 } else {
252 MS_LOG(EXCEPTION) << "Invalid abstract type:" << abs_num->type_name();
253 }
254 shape.emplace_back(num_val);
255 if (shape[0] < 0) {
256 MS_LOG(EXCEPTION) << "num must be >= 0 in LinSpace";
257 }
258 max_shape.emplace_back(num_val);
259 min_shape.emplace_back(num_val);
260 AbstractTensorPtr ret =
261 std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
262 return ret;
263 }
264
InferImplMatMul(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)265 AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
266 const AbstractBasePtrList &args_spec_list) {
267 constexpr auto kMatMulInputNum = 2;
268 const std::string op_name = primitive->name();
269 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(args_spec_list.size()), kGreaterEqual,
270 kMatMulInputNum, op_name);
271 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
272 MS_EXCEPTION_IF_NULL(x);
273 MS_EXCEPTION_IF_NULL(x->shape());
274 auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
275 MS_EXCEPTION_IF_NULL(y);
276 MS_EXCEPTION_IF_NULL(y->shape());
277 auto x_shp = x->shape()->shape();
278 auto y_shp = y->shape()->shape();
279 const size_t SHAPE_SIZE = 2;
280 if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
281 MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
282 }
283 ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
284 ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
285 bool transpose_a = GetValue<bool>(transpose_a_ptr);
286 bool transpose_b = GetValue<bool>(transpose_b_ptr);
287 ShapeVector x_min_shape = x->shape()->min_shape();
288 ShapeVector x_max_shape = x->shape()->max_shape();
289 ShapeVector y_min_shape = y->shape()->min_shape();
290 ShapeVector y_max_shape = y->shape()->max_shape();
291 CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape);
292 CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape);
293 // Additional check for dynamic shape
294 // Last infer will be real shape values
295 bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
296 bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
297 if (x_not_dyn && y_not_dyn) {
298 auto x_col = x_shp[(transpose_a ? 0 : 1)];
299 auto y_row = y_shp[(transpose_b ? 1 : 0)];
300 if (x_col != y_row) {
301 MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
302 << ". In MatMul x_col and y_row should be equal.";
303 }
304 }
305 ShapeVector ret_shape;
306 ShapeVector ret_min_shape;
307 ShapeVector ret_max_shape;
308 auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
309 const ShapeVector yshp) -> void {
310 output.push_back(xshp[(transpose_a ? 1 : 0)]);
311 output.push_back(yshp[(transpose_b ? 0 : 1)]);
312 return;
313 };
314 make_shape(ret_shape, x_shp, y_shp);
315 make_shape(ret_min_shape, x_min_shape, y_min_shape);
316 make_shape(ret_max_shape, x_max_shape, y_max_shape);
317 TypePtr x_type = x->element()->GetTypeTrack();
318 if (x_type->type_id() == TypeId::kNumberTypeInt8) {
319 x_type = kInt32;
320 }
321 if (primitive->HasAttr("cast_type")) {
322 auto out_type = primitive->GetAttr("cast_type");
323 MS_EXCEPTION_IF_NULL(out_type);
324 if (!out_type->isa<Type>()) {
325 MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
326 }
327 x_type = out_type->cast<TypePtr>();
328 }
329 return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
330 }
331
InferImplBatchMatMul(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)332 AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
333 const AbstractBasePtrList &args_spec_list) {
334 constexpr auto kBatchMatMulInputNum = 2;
335 const std::string op_name = primitive->name();
336 CheckArgsSize(op_name, args_spec_list, kBatchMatMulInputNum);
337 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
338 MS_EXCEPTION_IF_NULL(x);
339 MS_EXCEPTION_IF_NULL(x->shape());
340 auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
341 MS_EXCEPTION_IF_NULL(y);
342 MS_EXCEPTION_IF_NULL(y->shape());
343 auto x_shp = x->shape()->shape();
344 auto y_shp = y->shape()->shape();
345 constexpr size_t minimum_shape = 3;
346 if (x_shp.size() != y_shp.size() || x_shp.size() < minimum_shape) {
347 MS_LOG(EXCEPTION)
348 << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3.";
349 }
350 ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
351 ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
352 bool transpose_a = GetValue<bool>(transpose_a_ptr);
353 bool transpose_b = GetValue<bool>(transpose_b_ptr);
354 ShapeVector x_min_shape = x->shape()->min_shape();
355 ShapeVector x_max_shape = x->shape()->max_shape();
356 ShapeVector y_min_shape = y->shape()->min_shape();
357 ShapeVector y_max_shape = y->shape()->max_shape();
358 CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape);
359 CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape);
360 // Additional check for dynamic shape
361 // Last infer will be real shape values
362 bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
363 bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
364 if (x_not_dyn && y_not_dyn) {
365 size_t offset = x_shp.size() - 2;
366 auto x_col = x_shp[offset + (transpose_a ? 0 : 1)];
367 auto y_row = y_shp[offset + (transpose_b ? 1 : 0)];
368 if (x_col != y_row) {
369 MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
370 << ". In BatchMatMul x_col and y_row should be equal.";
371 }
372 }
373 ShapeVector ret_shape;
374 ShapeVector ret_min_shape;
375 ShapeVector ret_max_shape;
376 auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
377 const ShapeVector yshp) -> void {
378 for (size_t i = 0; i < xshp.size() - 2; i++) {
379 if (xshp[i] != yshp[i]) {
380 if (xshp[i] > 0 && yshp[i] > 0) {
381 MS_LOG(EXCEPTION) << "BatchMatMul input x, y are different at index " << i << ".";
382 }
383 output.push_back(Shape::SHP_ANY);
384 } else {
385 output.push_back(xshp[i]);
386 }
387 }
388 const size_t bias = 2;
389 size_t offset = xshp.size() - bias;
390 output.push_back(xshp[offset + (transpose_a ? 1 : 0)]);
391 output.push_back(yshp[offset + (transpose_b ? 0 : 1)]);
392 return;
393 };
394 make_shape(ret_shape, x_shp, y_shp);
395 make_shape(ret_min_shape, x_min_shape, y_min_shape);
396 make_shape(ret_max_shape, x_max_shape, y_max_shape);
397 TypePtr x_type = x->element()->GetTypeTrack();
398 if (x_type->type_id() == TypeId::kNumberTypeInt8) {
399 x_type = kInt32;
400 }
401 if (primitive->HasAttr("cast_type")) {
402 auto out_type = primitive->GetAttr("cast_type");
403 MS_EXCEPTION_IF_NULL(out_type);
404 if (!out_type->isa<Type>()) {
405 MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
406 }
407 x_type = out_type->cast<TypePtr>();
408 }
409 return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
410 }
411
InferImplLess(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)412 AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
413 const AbstractBasePtrList &args_spec_list) {
414 constexpr auto kLessInputNum = 2;
415 const std::string op_name = primitive->name();
416 CheckArgsSize(op_name, args_spec_list, kLessInputNum);
417 auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
418 MS_EXCEPTION_IF_NULL(x);
419 MS_EXCEPTION_IF_NULL(x->shape());
420 ShapeVector x_shape = x->shape()->shape();
421 ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape();
422 ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape();
423
424 auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
425 MS_EXCEPTION_IF_NULL(y);
426 MS_EXCEPTION_IF_NULL(y->shape());
427 ShapeVector y_shape = y->shape()->shape();
428 ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape();
429 ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
430
431 auto out_shape = BroadcastShape(x_shape, y_shape);
432 auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
433 auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
434 auto output_type = std::make_shared<Bool>();
435 return std::make_shared<AbstractTensor>(output_type,
436 std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
437 }
438 } // namespace abstract
439 } // namespace mindspore
440