1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/grad/sigmoid_cross_entropy_with_logits_grad.h"
18
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24
25 #include "abstract/abstract_value.h"
26 #include "abstract/dshape.h"
27 #include "abstract/ops/op_infer.h"
28 #include "abstract/ops/primitive_infer_map.h"
29 #include "abstract/param_validator.h"
30 #include "abstract/utils.h"
31 #include "base/base.h"
32 #include "ir/anf.h"
33 #include "ir/dtype/number.h"
34 #include "ir/primitive.h"
35 #include "mindapi/src/helper.h"
36 #include "mindspore/core/ops/math_ops.h"
37 #include "mindspore/core/ops/nn_ops.h"
38 #include "ops/op_name.h"
39 #include "ops/primitive_c.h"
40 #include "utils/check_convert_utils.h"
41 #include "utils/convert_utils_base.h"
42 #include "utils/log_adapter.h"
43
44 namespace mindspore {
45 namespace ops {
46 namespace {
SigmoidCrossEntropyWithLogitsGradInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)47 abstract::ShapePtr SigmoidCrossEntropyWithLogitsGradInferShape(const PrimitivePtr &primitive,
48 const std::vector<AbstractBasePtr> &input_args) {
49 MS_EXCEPTION_IF_NULL(primitive);
50 auto prim_name = primitive->name();
51 const int64_t kInputNum = 3;
52 (void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer_shape",
53 SizeToLong(input_args.size()), kGreaterEqual, kInputNum, prim_name);
54 auto x = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, kInputIndex0, kObjectTypeTensorType);
55 auto y = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, kInputIndex1, kObjectTypeTensorType);
56 auto dout = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, kInputIndex2, kObjectTypeTensorType);
57 auto x_ptr = x->GetShape()->cast<abstract::ShapePtr>();
58 abstract::CheckShapeSame(prim_name, x, y);
59 abstract::CheckShapeSame(prim_name, x, dout);
60 MS_EXCEPTION_IF_NULL(x_ptr);
61 return x_ptr;
62 }
63
SigmoidCrossEntropyWithLogitsGradInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)64 TypePtr SigmoidCrossEntropyWithLogitsGradInferType(const PrimitivePtr &primitive,
65 const std::vector<AbstractBasePtr> &input_args) {
66 MS_EXCEPTION_IF_NULL(primitive);
67 auto prim_name = primitive->name();
68 const int64_t kInputNum = 3;
69 (void)CheckAndConvertUtils::CheckInteger("sigmoid_cross_extropy_with_logits_infer_type",
70 SizeToLong(input_args.size()), kGreaterEqual, kInputNum, prim_name);
71 auto x_type = input_args[0]->GetType();
72 auto y_type = input_args[1]->GetType();
73 auto dout_type = input_args[2]->GetType();
74 const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
75 kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
76 std::map<std::string, TypePtr> args;
77 (void)args.emplace("x_type", x_type);
78 (void)args.emplace("y_type", y_type);
79 (void)args.emplace("dout_type", dout_type);
80 (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, primitive->name());
81 return dout_type;
82 }
83 } // namespace
84
85 MIND_API_OPERATOR_IMPL(SigmoidCrossEntropyWithLogitsGrad, BaseOperator);
SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)86 AbstractBasePtr SigmoidCrossEntropyWithLogitsGradInfer(const abstract::AnalysisEnginePtr &,
87 const PrimitivePtr &primitive,
88 const std::vector<AbstractBasePtr> &input_args) {
89 auto infer_type = SigmoidCrossEntropyWithLogitsGradInferType(primitive, input_args);
90 auto infer_shape = SigmoidCrossEntropyWithLogitsGradInferShape(primitive, input_args);
91 return abstract::MakeAbstract(infer_shape, infer_type);
92 }
93
94 // AG means auto generated
95 class MIND_API AGSigmoidCrossEntropyWithLogitsGradInfer : public abstract::OpInferBase {
96 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const97 BaseShapePtr InferShape(const PrimitivePtr &primitive,
98 const std::vector<AbstractBasePtr> &input_args) const override {
99 return SigmoidCrossEntropyWithLogitsGradInferShape(primitive, input_args);
100 }
101
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const102 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
103 return SigmoidCrossEntropyWithLogitsGradInferType(primitive, input_args);
104 }
InferShapeAndType(const abstract::AnalysisEnginePtr & engine,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const105 AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
106 const std::vector<AbstractBasePtr> &input_args) const override {
107 return SigmoidCrossEntropyWithLogitsGradInfer(engine, primitive, input_args);
108 }
109 };
110
111 REGISTER_PRIMITIVE_OP_INFER_IMPL(SigmoidCrossEntropyWithLogitsGrad, prim::kPrimSigmoidCrossEntropyWithLogitsGrad,
112 AGSigmoidCrossEntropyWithLogitsGradInfer, false);
113 } // namespace ops
114 } // namespace mindspore
115