1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H 18 #define MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include "include/backend/optimizer/optimizer.h" 24 25 namespace mindspore { 26 namespace opt { 27 class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { 28 public: 29 explicit SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR( 30 const std::string &name = "sparse_softmax_cross_entropy_with_logits_unify_mindir", bool multigraph = true) PatternProcessPass(name,multigraph)31 : PatternProcessPass(name, multigraph) {} 32 ~SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 33 const BaseRef DefinePattern() const override; 34 const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; 35 36 private: 37 std::vector<std::string> MustExistPrimitiveName() const override; 38 }; 39 40 class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { 41 public: 42 explicit GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true) 43 : PatternProcessPass("grad_sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {} 44 ~GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 45 const BaseRef DefinePattern() const override; 46 const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; 47 }; 48 49 class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public PatternProcessPass { 50 public: 51 explicit GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2(bool multigraph = true) 52 : PatternProcessPass("grad_sparse_softmax_cross_entropy_with_logits_unify_mindir_v2", multigraph) {} 53 ~GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2() override = default; 54 const BaseRef DefinePattern() const override; 55 const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; 56 }; 57 58 class PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR { 59 public: 60 explicit PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true) 61 : SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR("pynative_sparse_softmax_cross_entropy_with_logits_unify_mindir", 62 multigraph) {} 63 ~PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 64 const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; 65 }; 66 67 class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { 68 public: 69 explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR(bool multigraph = true) 70 : PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir", multigraph) {} 71 ~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 72 const BaseRef DefinePattern() const override; 73 const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; 74 75 private: 76 std::vector<std::string> MustExistPrimitiveName() const override; 77 }; 78 79 class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public PatternProcessPass { 80 public: 81 explicit PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2(bool multigraph = true) 82 : PatternProcessPass("pynative_grad_sparse_softmax_cross_entropy_with_logits_unify_mindir_v2", multigraph) {} 83 ~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2() override = default; 84 const BaseRef DefinePattern() const override; 85 const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; 86 87 private: 88 std::vector<std::string> MustExistPrimitiveName() const override; 89 }; 90 91 class GeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { 92 public: 93 explicit GeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR( 94 const std::string &name = "ge_sparse_softmax_cross_entropy_with_logits_unify_mindir", bool multigraph = true) PatternProcessPass(name,multigraph)95 : PatternProcessPass(name, multigraph) {} 96 ~GeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; 97 const BaseRef DefinePattern() const override; 98 const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; 99 100 private: 101 std::vector<std::string> MustExistPrimitiveName() const override; 102 }; 103 } // namespace opt 104 } // namespace mindspore 105 #endif // MINDSPORE_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_UNIFY_MINDIR_H 106