• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
19 
20 #include <memory>
21 #include <tuple>
22 #include <utility>
23 #include <vector>
24 #include "include/backend/optimizer/optimizer.h"
25 #include "backend/common/graph_kernel/add_atomic_clean.h"
26 #include "include/backend/kernel_graph.h"
27 
28 namespace mindspore::graphkernel {
29 class StitchAtomicCleanInserter : public AtomicCleanInserter {
30  public:
StitchAtomicCleanInserter()31   StitchAtomicCleanInserter() : AtomicCleanInserter("stitch_atomic_clean") {}
32   ~StitchAtomicCleanInserter() override = default;
33   bool Run(const FuncGraphPtr &func_graph) override;
34 
35  protected:
36   void CorrectKernelBuildInfo(const AnfNodePtr &composite_node,
37                               const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &inplace_infos) override;
38   void ProcessOriginCNode(
39     const AnfNodePtr &composite_node,
40     const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr) override;
41 
42  private:
43   CNodePtr CreateAssignNode(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter,
44                             const InplaceAssignerInfo &info) const;
45   std::vector<std::pair<AnfNodePtr, int>> FindInnerCNodeUsers(const AnfNodePtr &inner_node,
46                                                               const CNodePtr &target) const;
47   std::pair<bool, InplaceAssignerInfo> IsStitchWithAtomic(const AnfNodePtr &anf_node);
48 
49   void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
50                  const AnfNodePtr &user_node, int index) const;
51 
52   AnfNodePtr stitch_node_{nullptr};
53 };
54 using StitchAtomicCleanInserterPtr = std::shared_ptr<StitchAtomicCleanInserter>;
55 }  // namespace mindspore::graphkernel
56 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_STITCH_ATOMIC_CLEAN_GPU_H_
57