1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ 19 20 #include <map> 21 #include <vector> 22 #include <memory> 23 #include <utility> 24 #include <tuple> 25 26 #include "utils/hash_map.h" 27 #include "frontend/optimizer/irpass.h" 28 #include "frontend/optimizer/optimizer.h" 29 #include "frontend/optimizer/anf_visitor.h" 30 #include "ir/manager.h" 31 #include "ir/func_graph.h" 32 #include "ir/func_graph_cloner.h" 33 #include "frontend/operator/ops.h" 34 35 namespace mindspore { 36 namespace opt { 37 namespace irpass { 38 namespace internal { 39 class SpecializeTransform { 40 public: SpecializeTransform()41 SpecializeTransform() : cache_() {} 42 ~SpecializeTransform() = default; 43 operator()44 FuncGraphPtr operator()(const FuncGraphPtr &func_graph, const std::vector<ValuePtr> &need_eliminate_args) { 45 if (cache_.count(func_graph) == 0) { 46 cache_[func_graph] = {}; 47 } 48 auto &cache = cache_[func_graph]; 49 const auto &key = need_eliminate_args; 50 if (cache.count(key) == 0) { 51 auto mng = func_graph->manager(); 52 MS_EXCEPTION_IF_NULL(mng); 53 FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared<TraceTransform>("sp")); 54 mng->AddFuncGraph(new_fg); 55 std::vector<AnfNodePtr> params = new_fg->parameters(); 56 std::vector<AnfNodePtr> new_params; 57 for (size_t i = 0; i < need_eliminate_args.size(); i++) { 58 // keep the parameter 59 if (need_eliminate_args[i] == nullptr) { 60 new_params.push_back(params[i]); 61 continue; 62 } 63 // replace the parameter with arg in new_fg without changing origin func_graph. 64 (void)mng->Replace(params[i], NewReplaceValueNode(need_eliminate_args[i])); 65 } 66 mng->SetParameters(new_fg, new_params); 67 cache[key] = new_fg; 68 } 69 return cache[key]; 70 } 71 72 private: 73 mindspore::HashMap<FuncGraphPtr, std::map<std::vector<ValuePtr>, FuncGraphPtr>> cache_; NewReplaceValueNode(const ValuePtr & value)74 static ValueNodePtr NewReplaceValueNode(const ValuePtr &value) { 75 MS_EXCEPTION_IF_NULL(value); 76 if (value->isa<FuncGraph>() || value->isa<Primitive>() || value->isa<parse::NameSpace>()) { 77 return NewValueNode(value); 78 } 79 if (value->isa<tensor::Tensor>()) { 80 auto &const_tensor = *(value->cast<tensor::TensorPtr>()); 81 auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor); 82 return NewValueNode(const_tensor_ptr); 83 } 84 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected value:" << value->ToString(); 85 } 86 }; 87 } // namespace internal 88 89 // {G, Xs} 90 class SpecializeOnGraphArguments : public AnfVisitor { 91 public: SpecializeOnGraphArguments()92 SpecializeOnGraphArguments() : specialize_transform_() {} 93 ~SpecializeOnGraphArguments() override = default; 94 operator()95 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 96 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 97 return nullptr; 98 } 99 100 auto &inputs = node->cast<CNodePtr>()->inputs(); 101 if (!IsValueNode<FuncGraph>(inputs[0])) { 102 return nullptr; 103 } 104 105 auto inp0_fg = GetValueNode<FuncGraphPtr>(inputs[0]); 106 if (inp0_fg == nullptr || inp0_fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) || IsSetRecomputed(inp0_fg) || 107 inp0_fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || inp0_fg->recursive()) { 108 return nullptr; 109 } 110 std::vector<ValuePtr> need_eliminated_args; 111 std::vector<AnfNodePtr> new_xs; 112 bool hasVNode = false; 113 for (size_t i = 1; i < inputs.size(); i++) { 114 if (IsValueNode<FuncGraph>(inputs[i]) || IsValueNode<Primitive>(inputs[i]) || 115 IsValueNode<tensor::Tensor>(inputs[i]) || IsValueNode<parse::NameSpace>(inputs[i])) { 116 need_eliminated_args.push_back(GetValueNode(inputs[i])); 117 hasVNode = true; 118 } else { 119 (void)need_eliminated_args.emplace_back(nullptr); 120 new_xs.push_back(inputs[i]); 121 } 122 } 123 if (!hasVNode) { 124 return nullptr; 125 } 126 auto new_fg = specialize_transform_(inp0_fg, need_eliminated_args); 127 (void)new_xs.insert(new_xs.cbegin(), NewValueNode(new_fg)); 128 129 return node->func_graph()->NewCNode(new_xs); 130 } 131 132 private: IsSetRecomputed(const FuncGraphPtr & fg)133 static bool IsSetRecomputed(const FuncGraphPtr &fg) { 134 auto context = MsContext::GetInstance(); 135 MS_EXCEPTION_IF_NULL(context); 136 const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse; 137 return fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) || 138 (cell_reuse && 139 (fg->has_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH))); 140 } 141 142 internal::SpecializeTransform specialize_transform_; 143 }; 144 } // namespace irpass 145 } // namespace opt 146 } // namespace mindspore 147 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ 148