• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/infer_functions.h"
18 #include "abstract/utils.h"
19 #include "abstract/param_validator.h"
20 #include "utils/ms_utils.h"
21 #include "utils/check_convert_utils.h"
22 
23 namespace mindspore {
24 namespace abstract {
InferImplMinOrMaxGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)25 AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
26                                       const AbstractBasePtrList &args_spec_list) {
27   // Inputs: three tensors.
28   constexpr auto kMinMaxGradInputNum = 3;
29   const size_t dout_index = 2;
30   const std::string op_name = primitive->name();
31   CheckArgsSize(op_name, args_spec_list, kMinMaxGradInputNum);
32   auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
33   auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
34   auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, dout_index);
35   (void)CheckTensorsDTypeSame({input_x, input_y, dout}, {kInt, kUInt, kFloat},
36                               op_name + "evaluator three inputs should be %s");
37 
38   AbstractBasePtr dx = input_x->Broaden();
39   AbstractBasePtr dy = input_y->Broaden();
40 
41   return std::make_shared<AbstractTuple>(AbstractBasePtrList({dx, dy}));
42 }
43 
InferImplSqrt(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)44 AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
45                               const AbstractBasePtrList &args_spec_list) {
46   // Inputs: three tensors.
47   constexpr auto kSqrtInputNum = 1;
48   const std::string op_name = primitive->name();
49   CheckArgsSize(op_name, args_spec_list, kSqrtInputNum);
50   auto inp = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
51   return inp->Clone()->Broaden();
52 }
53 
InferImplSqrtGrad(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)54 AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
55                                   const AbstractBasePtrList &args_spec_list) {
56   // Inputs: two tensors.
57   constexpr auto kSqrtGradInputNum = 2;
58   const std::string op_name = primitive->name();
59   CheckArgsSize(op_name, args_spec_list, kSqrtGradInputNum);
60   auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
61   auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
62   (void)CheckDtypeSame(op_name, out, dout);
63   (void)CheckShapeSame(op_name, out, dout);
64 
65   return out->Broaden();
66 }
67 
InferImplReduceFuncCheckAxis(const int64_t & axis,const size_t dim)68 int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) {
69   int64_t dim_ = static_cast<int64_t>(dim);
70   if (axis < -dim_ || axis >= dim_) {
71     MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
72   }
73   int64_t ret_axis = axis;
74   if (axis >= -dim_ && axis < 0) {
75     ret_axis += dim_;
76   }
77   return ret_axis;
78 }
79 
InferImplReduceFuncCalShape(ShapeVector * shape,const ShapeVector & x_shape,const ValuePtr & axis,bool keep_dims_value)80 void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis,
81                                  bool keep_dims_value) {
82   MS_EXCEPTION_IF_NULL(axis);
83   if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
84     auto axis_ptr_list =
85       axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
86     if (!axis_ptr_list.size()) {
87       if (keep_dims_value) (void)shape->insert(shape->end(), x_shape.size(), 1);
88     } else {
89       (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
90       ValuePtrList axis_items = axis_ptr_list;
91       ValuePtrList::iterator it;
92       ValuePtrList::reverse_iterator it_re;
93       int64_t axis_value;
94       if (keep_dims_value) {
95         for (it = axis_items.begin(); it != axis_items.end(); ++it) {
96           axis_value = GetValue<int64_t>(*it);
97           axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
98           shape->at(LongToSize(axis_value)) = 1;
99         }
100       } else {
101         std::sort(axis_items.begin(), axis_items.end());
102         for (it_re = axis_items.rbegin(); it_re != axis_items.rend(); ++it_re) {
103           axis_value = GetValue<int64_t>(*it_re);
104           axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
105           (void)shape->erase(shape->begin() + axis_value);
106         }
107       }
108     }
109   } else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
110     (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
111     auto axis_value = GetValue<int64_t>(axis);
112     axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
113     if (keep_dims_value) {
114       shape->at(LongToSize(axis_value)) = 1;
115     } else {
116       (void)shape->erase(shape->begin() + axis_value);
117     }
118   } else {
119     MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
120   }
121   return;
122 }
123 
124 // To reduce code repeat, use InferImplReduceFunc. Currently registered with ReduceMean, ReduceSum,
125 // ReduceAll, ReduceAny, ReduceMax, ReduceMin.
InferImplReduceFunc(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)126 AbstractBasePtr InferImplReduceFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
127                                     const AbstractBasePtrList &args_spec_list) {
128   const auto kReduceInputNum = 1;
129   const std::string op_name = primitive->name();
130   CheckArgsSize(op_name, args_spec_list, kReduceInputNum);
131   auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
132   MS_EXCEPTION_IF_NULL(input_x);
133   MS_EXCEPTION_IF_NULL(input_x->element());
134 
135   ValuePtr keep_dims = primitive->GetAttr("keep_dims");
136   MS_EXCEPTION_IF_NULL(keep_dims);
137   if (!keep_dims->isa<BoolImm>()) {
138     MS_LOG(EXCEPTION) << "Keep_dims should be Bool.";
139   }
140   bool keep_dims_value = GetValue<bool>(keep_dims);
141 
142   ValuePtr axis = primitive->GetAttr("axis");
143   MS_EXCEPTION_IF_NULL(axis);
144 
145   ShapeVector shape = {};
146   ShapeVector x_shape = input_x->shape()->shape();
147   InferImplReduceFuncCalShape(&shape, x_shape, axis, keep_dims_value);
148 
149   bool x_is_dyn = (!input_x->shape()->min_shape().empty() && !input_x->shape()->max_shape().empty());
150   if (x_is_dyn) {
151     ShapeVector shape_min = {};
152     ShapeVector shape_max = {};
153     ShapeVector x_shape_min = input_x->shape()->min_shape();
154     ShapeVector x_shape_max = input_x->shape()->max_shape();
155     InferImplReduceFuncCalShape(&shape_min, x_shape_min, axis, keep_dims_value);
156     InferImplReduceFuncCalShape(&shape_max, x_shape_max, axis, keep_dims_value);
157     return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape, shape_min, shape_max));
158   }
159   return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
160 }
161 
InferImplBinaryBase(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)162 AbstractBasePtr InferImplBinaryBase(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
163                                     const AbstractBasePtrList &args_spec_list) {
164   constexpr auto kBinaryBaseInputNum = 2;
165   const std::string op_name = primitive->name();
166   CheckArgsSize(op_name, args_spec_list, kBinaryBaseInputNum);
167   auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
168   MS_EXCEPTION_IF_NULL(input_x);
169   MS_EXCEPTION_IF_NULL(input_x->shape());
170 
171   auto input_y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
172   MS_EXCEPTION_IF_NULL(input_y);
173   MS_EXCEPTION_IF_NULL(input_y->shape());
174 
175   auto x_shape = input_x->shape()->shape();
176   auto y_shape = input_y->shape()->shape();
177   auto output_shape = BroadcastShape(x_shape, y_shape);
178 
179   auto x_type = input_x->BuildType();
180   MS_EXCEPTION_IF_NULL(x_type);
181   MS_EXCEPTION_IF_NULL(x_type->cast<TensorTypePtr>());
182   auto y_type = input_y->BuildType();
183   MS_EXCEPTION_IF_NULL(y_type);
184   MS_EXCEPTION_IF_NULL(y_type->cast<TensorTypePtr>());
185 
186   auto x_element = x_type->cast<TensorTypePtr>()->element();
187   MS_EXCEPTION_IF_NULL(x_element);
188   auto y_element = y_type->cast<TensorTypePtr>()->element();
189   MS_EXCEPTION_IF_NULL(y_element);
190 
191   auto x_element_type = x_element->number_type();
192   auto y_element_type = y_element->number_type();
193 
194   auto x_priority = type_priority_map.find(x_element_type);
195   if (x_priority == type_priority_map.end()) {
196     MS_LOG(EXCEPTION) << "input_x type is " << x_element_type << ", it's not number type.";
197   }
198   auto y_priority = type_priority_map.find(y_element_type);
199   if (y_priority == type_priority_map.end()) {
200     MS_LOG(EXCEPTION) << "input_y type is " << y_element_type << ", it's not number type.";
201   }
202 
203   if (x_priority->second >= y_priority->second) {
204     return std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(output_shape));
205   } else {
206     return std::make_shared<AbstractTensor>(input_y->element(), std::make_shared<Shape>(output_shape));
207   }
208 }
209 
InferImplMinimum(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)210 AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
211                                  const AbstractBasePtrList &args_spec_list) {
212   return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
213 }
214 
InferImplDivNoNan(const AnalysisEnginePtr & engine_ptr,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)215 AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &engine_ptr, const PrimitivePtr &primitive,
216                                   const AbstractBasePtrList &args_spec_list) {
217   return InferImplBinaryBase(engine_ptr, primitive, args_spec_list);
218 }
219 
InferImplLinSpace(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)220 AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
221                                   const AbstractBasePtrList &args_spec_list) {
222   constexpr auto kLinSpaceInputNum = 3;
223   const std::string op_name = primitive->name();
224   CheckArgsSize(op_name, args_spec_list, kLinSpaceInputNum);
225   auto start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
226   MS_EXCEPTION_IF_NULL(start);
227   MS_EXCEPTION_IF_NULL(start->shape());
228   auto stop = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
229   MS_EXCEPTION_IF_NULL(stop);
230   MS_EXCEPTION_IF_NULL(stop->shape());
231   (void)CheckTensorDType(start, {kFloat32}, "Input 0 (start) for LinSpace should be %s");
232   (void)CheckTensorDType(stop, {kFloat32}, "Input 1 (stop) for LinSpace should be %s");
233   ShapeVector shape;
234   ShapeVector max_shape;
235   ShapeVector min_shape;
236   int64_t num_val = 0;
237   // 3rd input is a Tensor when LinSpace is a dynamic shape operator
238   const size_t tensor_index = 2;
239   auto abs_num = args_spec_list[tensor_index];
240   if (abs_num->isa<AbstractTensor>()) {
241     auto num = abs_num->cast<AbstractTensorPtr>();
242     MS_EXCEPTION_IF_NULL(num);
243     auto num_value_ptr = num->BuildValue();
244     MS_EXCEPTION_IF_NULL(num_value_ptr);
245     auto num_tensor = num_value_ptr->cast<tensor::TensorPtr>();
246     MS_EXCEPTION_IF_NULL(num_tensor);
247     num_val = *static_cast<int64_t *>(num_tensor->data_c());
248   } else if (abs_num->isa<AbstractScalar>()) {
249     auto num = abs_num->cast<AbstractScalarPtr>();
250     num_val = GetValue<int64_t>(num->BuildValue());
251   } else {
252     MS_LOG(EXCEPTION) << "Invalid abstract type:" << abs_num->type_name();
253   }
254   shape.emplace_back(num_val);
255   if (shape[0] < 0) {
256     MS_LOG(EXCEPTION) << "num must be >= 0 in LinSpace";
257   }
258   max_shape.emplace_back(num_val);
259   min_shape.emplace_back(num_val);
260   AbstractTensorPtr ret =
261     std::make_shared<AbstractTensor>(start->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
262   return ret;
263 }
264 
InferImplMatMul(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)265 AbstractBasePtr InferImplMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
266                                 const AbstractBasePtrList &args_spec_list) {
267   constexpr auto kMatMulInputNum = 2;
268   const std::string op_name = primitive->name();
269   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(args_spec_list.size()), kGreaterEqual,
270                                            kMatMulInputNum, op_name);
271   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
272   MS_EXCEPTION_IF_NULL(x);
273   MS_EXCEPTION_IF_NULL(x->shape());
274   auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
275   MS_EXCEPTION_IF_NULL(y);
276   MS_EXCEPTION_IF_NULL(y->shape());
277   auto x_shp = x->shape()->shape();
278   auto y_shp = y->shape()->shape();
279   const size_t SHAPE_SIZE = 2;
280   if (x_shp.size() != SHAPE_SIZE || y_shp.size() != SHAPE_SIZE) {
281     MS_LOG(EXCEPTION) << "MatMul inputs should have the same dimension size and equal to 2.";
282   }
283   ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
284   ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
285   bool transpose_a = GetValue<bool>(transpose_a_ptr);
286   bool transpose_b = GetValue<bool>(transpose_b_ptr);
287   ShapeVector x_min_shape = x->shape()->min_shape();
288   ShapeVector x_max_shape = x->shape()->max_shape();
289   ShapeVector y_min_shape = y->shape()->min_shape();
290   ShapeVector y_max_shape = y->shape()->max_shape();
291   CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape);
292   CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape);
293   // Additional check for dynamic shape
294   // Last infer will be real shape values
295   bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
296   bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
297   if (x_not_dyn && y_not_dyn) {
298     auto x_col = x_shp[(transpose_a ? 0 : 1)];
299     auto y_row = y_shp[(transpose_b ? 1 : 0)];
300     if (x_col != y_row) {
301       MS_LOG(EXCEPTION) << "MatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
302                         << ". In MatMul x_col and y_row should be equal.";
303     }
304   }
305   ShapeVector ret_shape;
306   ShapeVector ret_min_shape;
307   ShapeVector ret_max_shape;
308   auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
309                                                  const ShapeVector yshp) -> void {
310     output.push_back(xshp[(transpose_a ? 1 : 0)]);
311     output.push_back(yshp[(transpose_b ? 0 : 1)]);
312     return;
313   };
314   make_shape(ret_shape, x_shp, y_shp);
315   make_shape(ret_min_shape, x_min_shape, y_min_shape);
316   make_shape(ret_max_shape, x_max_shape, y_max_shape);
317   TypePtr x_type = x->element()->GetTypeTrack();
318   if (x_type->type_id() == TypeId::kNumberTypeInt8) {
319     x_type = kInt32;
320   }
321   if (primitive->HasAttr("cast_type")) {
322     auto out_type = primitive->GetAttr("cast_type");
323     MS_EXCEPTION_IF_NULL(out_type);
324     if (!out_type->isa<Type>()) {
325       MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
326     }
327     x_type = out_type->cast<TypePtr>();
328   }
329   return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
330 }
331 
InferImplBatchMatMul(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)332 AbstractBasePtr InferImplBatchMatMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
333                                      const AbstractBasePtrList &args_spec_list) {
334   constexpr auto kBatchMatMulInputNum = 2;
335   const std::string op_name = primitive->name();
336   CheckArgsSize(op_name, args_spec_list, kBatchMatMulInputNum);
337   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
338   MS_EXCEPTION_IF_NULL(x);
339   MS_EXCEPTION_IF_NULL(x->shape());
340   auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
341   MS_EXCEPTION_IF_NULL(y);
342   MS_EXCEPTION_IF_NULL(y->shape());
343   auto x_shp = x->shape()->shape();
344   auto y_shp = y->shape()->shape();
345   constexpr size_t minimum_shape = 3;
346   if (x_shp.size() != y_shp.size() || x_shp.size() < minimum_shape) {
347     MS_LOG(EXCEPTION)
348       << "BatchMatMul input x, y should have the same dimension size and should be greater or equal to 3.";
349   }
350   ValuePtr transpose_a_ptr = primitive->GetAttr("transpose_a");
351   ValuePtr transpose_b_ptr = primitive->GetAttr("transpose_b");
352   bool transpose_a = GetValue<bool>(transpose_a_ptr);
353   bool transpose_b = GetValue<bool>(transpose_b_ptr);
354   ShapeVector x_min_shape = x->shape()->min_shape();
355   ShapeVector x_max_shape = x->shape()->max_shape();
356   ShapeVector y_min_shape = y->shape()->min_shape();
357   ShapeVector y_max_shape = y->shape()->max_shape();
358   CheckMinMaxShape(x_shp, &x_min_shape, &x_max_shape);
359   CheckMinMaxShape(y_shp, &y_min_shape, &y_max_shape);
360   // Additional check for dynamic shape
361   // Last infer will be real shape values
362   bool x_not_dyn = std::all_of(x_shp.begin(), x_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
363   bool y_not_dyn = std::all_of(y_shp.begin(), y_shp.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
364   if (x_not_dyn && y_not_dyn) {
365     size_t offset = x_shp.size() - 2;
366     auto x_col = x_shp[offset + (transpose_a ? 0 : 1)];
367     auto y_row = y_shp[offset + (transpose_b ? 1 : 0)];
368     if (x_col != y_row) {
369       MS_LOG(EXCEPTION) << "BatchMatMul shape error, got x_col: " << x_col << ", y_row: " << y_row
370                         << ". In BatchMatMul x_col and y_row should be equal.";
371     }
372   }
373   ShapeVector ret_shape;
374   ShapeVector ret_min_shape;
375   ShapeVector ret_max_shape;
376   auto make_shape = [&transpose_a, &transpose_b](ShapeVector &output, const ShapeVector xshp,
377                                                  const ShapeVector yshp) -> void {
378     for (size_t i = 0; i < xshp.size() - 2; i++) {
379       if (xshp[i] != yshp[i]) {
380         if (xshp[i] > 0 && yshp[i] > 0) {
381           MS_LOG(EXCEPTION) << "BatchMatMul input x, y are different at index " << i << ".";
382         }
383         output.push_back(Shape::SHP_ANY);
384       } else {
385         output.push_back(xshp[i]);
386       }
387     }
388     const size_t bias = 2;
389     size_t offset = xshp.size() - bias;
390     output.push_back(xshp[offset + (transpose_a ? 1 : 0)]);
391     output.push_back(yshp[offset + (transpose_b ? 0 : 1)]);
392     return;
393   };
394   make_shape(ret_shape, x_shp, y_shp);
395   make_shape(ret_min_shape, x_min_shape, y_min_shape);
396   make_shape(ret_max_shape, x_max_shape, y_max_shape);
397   TypePtr x_type = x->element()->GetTypeTrack();
398   if (x_type->type_id() == TypeId::kNumberTypeInt8) {
399     x_type = kInt32;
400   }
401   if (primitive->HasAttr("cast_type")) {
402     auto out_type = primitive->GetAttr("cast_type");
403     MS_EXCEPTION_IF_NULL(out_type);
404     if (!out_type->isa<Type>()) {
405       MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
406     }
407     x_type = out_type->cast<TypePtr>();
408   }
409   return std::make_shared<AbstractTensor>(x_type, std::make_shared<Shape>(ret_shape, ret_min_shape, ret_max_shape));
410 }
411 
InferImplLess(const AnalysisEnginePtr &,const PrimitivePtr & primitive,const AbstractBasePtrList & args_spec_list)412 AbstractBasePtr InferImplLess(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
413                               const AbstractBasePtrList &args_spec_list) {
414   constexpr auto kLessInputNum = 2;
415   const std::string op_name = primitive->name();
416   CheckArgsSize(op_name, args_spec_list, kLessInputNum);
417   auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
418   MS_EXCEPTION_IF_NULL(x);
419   MS_EXCEPTION_IF_NULL(x->shape());
420   ShapeVector x_shape = x->shape()->shape();
421   ShapeVector x_shape_min = x->shape()->min_shape().empty() ? x_shape : x->shape()->min_shape();
422   ShapeVector x_shape_max = x->shape()->max_shape().empty() ? x_shape : x->shape()->max_shape();
423 
424   auto y = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
425   MS_EXCEPTION_IF_NULL(y);
426   MS_EXCEPTION_IF_NULL(y->shape());
427   ShapeVector y_shape = y->shape()->shape();
428   ShapeVector y_shape_min = y->shape()->min_shape().empty() ? y_shape : y->shape()->min_shape();
429   ShapeVector y_shape_max = y->shape()->max_shape().empty() ? y_shape : y->shape()->max_shape();
430 
431   auto out_shape = BroadcastShape(x_shape, y_shape);
432   auto out_shape_min = BroadcastShape(x_shape_min, y_shape_min);
433   auto out_shape_max = BroadcastShape(x_shape_max, y_shape_max);
434   auto output_type = std::make_shared<Bool>();
435   return std::make_shared<AbstractTensor>(output_type,
436                                           std::make_shared<Shape>(out_shape, out_shape_min, out_shape_max));
437 }
438 }  // namespace abstract
439 }  // namespace mindspore
440