1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H 18 #define MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H 19 20 #include <memory> 21 #include <string> 22 #include "backend/optimizer/common/optimizer.h" 23 24 namespace mindspore { 25 namespace opt { 26 class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { 27 public: 28 explicit SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR( 29 const std::string &name = "sparse_softmax_cross_entropy_with_logits_unify_mindir", bool multigraph = true) PatternProcessPass(name,multigraph)30 : PatternProcessPass(name, multigraph) {} 31 ~SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 32 const BaseRef DefinePattern() const override; 33 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 34 }; 35 36 class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { 37 public: 38 explicit GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true) 39 : PatternProcessPass("grad_sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {} 40 ~GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 41 const BaseRef DefinePattern() const override; 42 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 43 }; 44 45 class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public PatternProcessPass { 46 public: 47 explicit GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2(bool multigraph = true) 48 : PatternProcessPass("grad_sparse_softmax_cross_entropy_with_logits_unify_mindir_v2", multigraph) {} 49 ~GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2() override = default; 50 const BaseRef DefinePattern() const override; 51 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 52 }; 53 54 class PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR { 55 public: 56 explicit PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true) 57 : SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR("pynative_sparse_softmax_cross_entropy_with_logits_unify_mindir", 58 multigraph) {} 59 ~PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 60 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 61 }; 62 63 class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { 64 public: 65 explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true) 66 : PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {} 67 ~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 68 const BaseRef DefinePattern() const override; 69 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 70 }; 71 } // namespace opt 72 } // namespace mindspore 73 #endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H 74