1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_EXPANDER_H_ 17 #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_EXPANDER_H_ 18 #include <memory> 19 #include "backend/common/graph_kernel/core/expander.h" 20 #include "ir/func_graph.h" 21 #include "include/backend/visible.h" 22 #include <nlohmann/json.hpp> 23 24 namespace mindspore::graphkernel { 25 class ComplexOpDecorator : public ExpanderDecorator { 26 public: ComplexOpDecorator(const ExpanderPtr & decorated)27 explicit ComplexOpDecorator(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} 28 ~ComplexOpDecorator() override = default; Creator(const ExpanderPtr & decorated)29 static ExpanderPtr Creator(const ExpanderPtr &decorated) { 30 return std::static_pointer_cast<Expander>(std::make_shared<ComplexOpDecorator>(decorated)); 31 } 32 AnfNodePtr Run(const AnfNodePtr &node) override; 33 }; 34 35 class ArgWithValueDeco : public ExpanderDecorator { 36 public: ArgWithValueDeco(const ExpanderPtr & decorated)37 explicit ArgWithValueDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} 38 ~ArgWithValueDeco() override = default; Creator(const ExpanderPtr & decorated)39 static ExpanderPtr Creator(const ExpanderPtr &decorated) { 40 return std::static_pointer_cast<Expander>(std::make_shared<ArgWithValueDeco>(decorated)); 41 } 42 AnfNodePtr Run(const AnfNodePtr &node) override; 43 }; 44 45 class UnfoldMakeTupleDeco : public ExpanderDecorator { 46 public: UnfoldMakeTupleDeco(const ExpanderPtr & decorated)47 explicit UnfoldMakeTupleDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} 48 ~UnfoldMakeTupleDeco() override = default; Creator(const ExpanderPtr & decorated)49 static ExpanderPtr Creator(const ExpanderPtr &decorated) { 50 return std::static_pointer_cast<Expander>(std::make_shared<UnfoldMakeTupleDeco>(decorated)); 51 } 52 AnfNodePtr Run(const AnfNodePtr &node) override; 53 }; 54 55 class ProcessCustomOpDeco : public ExpanderDecorator { 56 public: ProcessCustomOpDeco(const ExpanderPtr & decorated)57 explicit ProcessCustomOpDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} 58 ~ProcessCustomOpDeco() override = default; Creator(const ExpanderPtr & decorated)59 static ExpanderPtr Creator(const ExpanderPtr &decorated) { 60 return std::static_pointer_cast<Expander>(std::make_shared<ProcessCustomOpDeco>(decorated)); 61 } 62 AnfNodePtr Run(const AnfNodePtr &node) override; 63 }; 64 65 class SetDynamicShapeAttrDeco : public ExpanderDecorator { 66 public: SetDynamicShapeAttrDeco(const ExpanderPtr & decorated)67 explicit SetDynamicShapeAttrDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} 68 ~SetDynamicShapeAttrDeco() override = default; Creator(const ExpanderPtr & decorated)69 static ExpanderPtr Creator(const ExpanderPtr &decorated) { 70 return std::static_pointer_cast<Expander>(std::make_shared<SetDynamicShapeAttrDeco>(decorated)); 71 } 72 AnfNodePtr Run(const AnfNodePtr &node) override; 73 }; 74 75 /** 76 * Get the Expander which is used to expand a cnode to a funcgraph which composite same function with core ops. 77 */ 78 BACKEND_EXPORT ExpanderPtr GetExpander(const AnfNodePtr &node, const ExpanderPtr &init); 79 80 /** 81 * Get the Expander which is used to expand a cnode to a funcgraph which composite same function with core ops. 82 */ 83 BACKEND_EXPORT ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract = true); 84 85 /** 86 * Inline the expanded func graph to main graph. 87 */ 88 BACKEND_EXPORT void InlineExpandFuncGraph(const AnfNodePtr &expanding_node, const FuncGraphPtr &expanded_graph); 89 90 /** 91 * Try Expand cnode with check func. 92 */ 93 BACKEND_EXPORT AnfNodePtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const CNodePtr &)> &func); 94 95 /** 96 * Check if node can be expanded fallback. 97 */ 98 BACKEND_EXPORT bool CanExpandFallback(const AnfNodePtr &node); 99 100 bool IsComplexOp(const AnfNodePtr &node); 101 } // namespace mindspore::graphkernel 102 #endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_ADAPTER_EXPANDER_H_ 103