• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "abstract/abstract_value.h"
23 #include "abstract/dshape.h"
24 #include "abstract/ops/op_infer.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 #include "abstract/utils.h"
27 #include "base/base.h"
28 #include "ir/anf.h"
29 #include "ir/dtype/number.h"
30 #include "ir/dtype/type.h"
31 #include "ir/primitive.h"
32 #include "mindapi/base/type_id.h"
33 #include "mindapi/src/helper.h"
34 #include "mindspore/core/ops/math_ops.h"
35 #include "ops/mod.h"
36 #include "ops/op_utils.h"
37 #include "ops/primitive_c.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/convert_utils_base.h"
40 #include "utils/log_adapter.h"
41 
42 namespace mindspore {
43 namespace ops {
44 namespace {
ModInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)45 abstract::ShapePtr ModInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
46   MS_EXCEPTION_IF_NULL(primitive);
47   auto prim_name = primitive->name();
48   return BroadCastInferShape(prim_name, input_args);
49 }
50 
ModInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)51 TypePtr ModInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
52   for (const auto &item : input_args) {
53     MS_EXCEPTION_IF_NULL(item);
54   }
55   auto op_name = prim->name();
56   const int64_t kInputNum = 2;
57   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
58                                            op_name);
59   std::map<std::string, TypePtr> types;
60   (void)types.emplace("x", input_args[0]->GetType());
61   (void)types.emplace("y", input_args[1]->GetType());
62 
63   auto type_x = input_args[0]->GetType();
64   auto type_y = input_args[1]->GetType();
65   MS_EXCEPTION_IF_NULL(type_x);
66   MS_EXCEPTION_IF_NULL(type_y);
67   if (type_x->isa<Complex>() || type_y->isa<Complex>()) {
68     if (type_x->type_id() == kNumberTypeComplex64 && type_y->type_id() == kNumberTypeComplex64) {
69       return type_x;
70     } else if (type_x->type_id() == kNumberTypeComplex64 && type_y->type_id() == kNumberTypeFloat32) {
71       return type_x;
72     } else if (type_x->type_id() == kNumberTypeComplex128 && type_y->type_id() == kNumberTypeComplex128) {
73       return type_x;
74     } else if (type_x->type_id() == kNumberTypeComplex128 && type_y->type_id() == kNumberTypeFloat64) {
75       return type_x;
76     } else if (type_x->type_id() == kNumberTypeFloat32 && type_y->type_id() == kNumberTypeComplex64) {
77       return type_y;
78     } else if (type_x->type_id() == kNumberTypeFloat64 && type_y->type_id() == kNumberTypeComplex128) {
79       return type_y;
80     } else {
81       MS_EXCEPTION(TypeError)
82         << "For '" << op_name
83         << "', complex math binary op expecting Tensor [complex64, complex64],[complex64, float32], [float32, "
84            "complex64], [complex128, complex128], [complex128, float64] or [float64, complex128], but got ["
85         << type_x->ToString() << ", " << type_y->ToString() << "].";
86     }
87   }
88   (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name());
89   return type_x;
90 }
91 }  // namespace
ModInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)92 AbstractBasePtr ModInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
93                          const std::vector<AbstractBasePtr> &input_args) {
94   auto infer_type = ModInferType(primitive, input_args);
95   auto infer_shape = ModInferShape(primitive, input_args);
96   return abstract::MakeAbstract(infer_shape, infer_type);
97 }
98 MIND_API_OPERATOR_IMPL(Mod, BaseOperator);
99 
100 // AG means auto generated
101 class MIND_API AGModInfer : public abstract::OpInferBase {
102  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const103   BaseShapePtr InferShape(const PrimitivePtr &primitive,
104                           const std::vector<AbstractBasePtr> &input_args) const override {
105     return ModInferShape(primitive, input_args);
106   }
107 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const108   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
109     return ModInferType(primitive, input_args);
110   }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const111   AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
112                                     const std::vector<AbstractBasePtr> &input_args) const override {
113     return ModInfer(engine, primitive, input_args);
114   }
115 };
116 
117 REGISTER_PRIMITIVE_OP_INFER_IMPL(Mod, prim::kPrimMod, AGModInfer, false);
118 }  // namespace ops
119 }  // namespace mindspore
120