1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INPLACE_ASSIGN_BUILDER_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INPLACE_ASSIGN_BUILDER_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <string> 23 #include "include/backend/optimizer/optimizer.h" 24 25 namespace mindspore::graphkernel { 26 struct InplaceAssignerInfo { 27 // inplace-assigner node, which result will be written to inplace-assignee(an input of the func graph) 28 CNodePtr op_node{nullptr}; 29 // inplace-assigner's index among all the func graph's outputs(inplace-assigner must be an output of func graph) 30 size_t real_output_index{0}; 31 // num of inputs of inplace-assigner's func graph 32 size_t real_output_num{0}; 33 // inplace-assignee's index among all the inputs; if inplace-assignee is a new additional input, set it to -1 34 int inplace_to_origin_input{-1}; 35 }; 36 37 struct InplaceAssignUserInfo { 38 AnfNodePtr inplace_assignee_addr{nullptr}; 39 AnfNodePtr work_node{nullptr}; 40 AnfNodePtr user_node{nullptr}; 41 size_t user_input_idx{0}; 42 }; 43 44 class InplaceAssignBuilder : public opt::Pass { 45 public: Pass(name)46 explicit InplaceAssignBuilder(const std::string &name = "inplace_assign_builder") : Pass(name) {} 47 ~InplaceAssignBuilder() override = default; 48 49 protected: 50 virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, 51 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &inplace_infos); 52 virtual CNodePtr CreateCleanCompositeNode(const InplaceAssignerInfo &op_info, const FuncGraphPtr &main_graph, 53 TypeId dst_type); 54 void CreateAssignNodeAndCorrectReturn( 55 const FuncGraphPtr &sub_graph, 56 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> ¶meters_infos) const; 57 virtual void ProcessOriginCNode( 58 const AnfNodePtr &composite_node, 59 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr); 60 virtual void ProcessOriginCNodeUser( 61 const FuncGraphPtr &main_graph, const AnfNodePtr &composite_node, 62 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr, 63 const FuncGraphManagerPtr &mng) const; SetTargetAttrs(const CNodePtr &)64 virtual void SetTargetAttrs(const CNodePtr &) {} 65 66 private: 67 std::vector<InplaceAssignUserInfo> FindOriginCNodeUsers( 68 const AnfNodePtr &composite_node, 69 const std::vector<std::pair<InplaceAssignerInfo, AnfNodePtr>> &info_and_inplace_assignee_addr, 70 const FuncGraphManagerPtr &mng) const; 71 }; 72 } // namespace mindspore::graphkernel 73 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_INPLACE_ASSIGN_BUILDER_H_ 74