1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include "abstract/abstract_value.h"
23 #include "abstract/dshape.h"
24 #include "abstract/ops/op_infer.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 #include "abstract/utils.h"
27 #include "base/base.h"
28 #include "ir/anf.h"
29 #include "ir/dtype/number.h"
30 #include "ir/dtype/type.h"
31 #include "ir/primitive.h"
32 #include "mindapi/base/type_id.h"
33 #include "mindapi/src/helper.h"
34 #include "mindspore/core/ops/math_ops.h"
35 #include "ops/mod.h"
36 #include "ops/op_utils.h"
37 #include "ops/primitive_c.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/convert_utils_base.h"
40 #include "utils/log_adapter.h"
41
42 namespace mindspore {
43 namespace ops {
44 namespace {
ModInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)45 abstract::ShapePtr ModInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
46 MS_EXCEPTION_IF_NULL(primitive);
47 auto prim_name = primitive->name();
48 return BroadCastInferShape(prim_name, input_args);
49 }
50
ModInferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)51 TypePtr ModInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
52 for (const auto &item : input_args) {
53 MS_EXCEPTION_IF_NULL(item);
54 }
55 auto op_name = prim->name();
56 const int64_t kInputNum = 2;
57 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, kInputNum,
58 op_name);
59 std::map<std::string, TypePtr> types;
60 (void)types.emplace("x", input_args[0]->GetType());
61 (void)types.emplace("y", input_args[1]->GetType());
62
63 auto type_x = input_args[0]->GetType();
64 auto type_y = input_args[1]->GetType();
65 MS_EXCEPTION_IF_NULL(type_x);
66 MS_EXCEPTION_IF_NULL(type_y);
67 if (type_x->isa<Complex>() || type_y->isa<Complex>()) {
68 if (type_x->type_id() == kNumberTypeComplex64 && type_y->type_id() == kNumberTypeComplex64) {
69 return type_x;
70 } else if (type_x->type_id() == kNumberTypeComplex64 && type_y->type_id() == kNumberTypeFloat32) {
71 return type_x;
72 } else if (type_x->type_id() == kNumberTypeComplex128 && type_y->type_id() == kNumberTypeComplex128) {
73 return type_x;
74 } else if (type_x->type_id() == kNumberTypeComplex128 && type_y->type_id() == kNumberTypeFloat64) {
75 return type_x;
76 } else if (type_x->type_id() == kNumberTypeFloat32 && type_y->type_id() == kNumberTypeComplex64) {
77 return type_y;
78 } else if (type_x->type_id() == kNumberTypeFloat64 && type_y->type_id() == kNumberTypeComplex128) {
79 return type_y;
80 } else {
81 MS_EXCEPTION(TypeError)
82 << "For '" << op_name
83 << "', complex math binary op expecting Tensor [complex64, complex64],[complex64, float32], [float32, "
84 "complex64], [complex128, complex128], [complex128, float64] or [float64, complex128], but got ["
85 << type_x->ToString() << ", " << type_y->ToString() << "].";
86 }
87 }
88 (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim->name());
89 return type_x;
90 }
91 } // namespace
ModInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)92 AbstractBasePtr ModInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
93 const std::vector<AbstractBasePtr> &input_args) {
94 auto infer_type = ModInferType(primitive, input_args);
95 auto infer_shape = ModInferShape(primitive, input_args);
96 return abstract::MakeAbstract(infer_shape, infer_type);
97 }
98 MIND_API_OPERATOR_IMPL(Mod, BaseOperator);
99
100 // AG means auto generated
101 class MIND_API AGModInfer : public abstract::OpInferBase {
102 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const103 BaseShapePtr InferShape(const PrimitivePtr &primitive,
104 const std::vector<AbstractBasePtr> &input_args) const override {
105 return ModInferShape(primitive, input_args);
106 }
107
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const108 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
109 return ModInferType(primitive, input_args);
110 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const111 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
112 const std::vector<AbstractBasePtr> &input_args) const override {
113 return ModInfer(engine, primitive, input_args);
114 }
115 };
116
117 REGISTER_PRIMITIVE_OP_INFER_IMPL(Mod, prim::kPrimMod, AGModInfer, false);
118 } // namespace ops
119 } // namespace mindspore
120