• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/convert_bfloat16.h"
17 #include <vector>
18 #include <string>
19 #include <memory>
20 #include <utility>
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
24 #include "backend/common/graph_kernel/graph_kernel_helper.h"
25 #include "kernel/common_utils.h"
26 
27 namespace mindspore::graphkernel {
28 namespace {
UpdateAbstractDataType(const AbstractBasePtr & orig_abs,TypePtr data_type)29 AbstractBasePtr UpdateAbstractDataType(const AbstractBasePtr &orig_abs, TypePtr data_type) {
30   if (orig_abs == nullptr) {
31     return orig_abs;
32   }
33   if (orig_abs->isa<abstract::AbstractTensor>()) {
34     return std::make_shared<abstract::AbstractTensor>(data_type, orig_abs->GetShape());
35   }
36   if (orig_abs->isa<abstract::AbstractScalar>()) {
37     auto new_abs = orig_abs->Clone();
38     new_abs->set_type(data_type);
39     return new_abs;
40   }
41   if (orig_abs->isa<abstract::AbstractTuple>()) {
42     auto abs_tuple = orig_abs->cast<abstract::AbstractTuplePtr>()->elements();
43     AbstractBasePtrList abstracts(abs_tuple.size());
44     for (size_t i = 0; i < abs_tuple.size(); ++i) {
45       abstracts[i] = UpdateAbstractDataType(abs_tuple[i], data_type);
46     }
47     return std::make_shared<abstract::AbstractTuple>(abstracts);
48   }
49   return orig_abs;
50 }
51 
UpdateBuildInfoInputDataType(const AnfNodePtr & node,TypeId orig_type,TypeId new_type)52 void UpdateBuildInfoInputDataType(const AnfNodePtr &node, TypeId orig_type, TypeId new_type) {
53   if (node->kernel_info() == nullptr) {
54     return;
55   }
56   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
57   if (build_info != nullptr) {
58     auto inputs_type = build_info->GetAllInputDeviceTypes();
59     std::replace_if(
60       inputs_type.begin(), inputs_type.end(), [orig_type](TypeId type_id) { return type_id == orig_type; }, new_type);
61     build_info->SetInputsDeviceType(inputs_type);
62   }
63 }
64 
UpdateBuildInfoOutputDataType(const AnfNodePtr & node,TypeId orig_type,TypeId new_type)65 void UpdateBuildInfoOutputDataType(const AnfNodePtr &node, TypeId orig_type, TypeId new_type) {
66   if (node->kernel_info() == nullptr) {
67     return;
68   }
69   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
70   if (build_info != nullptr) {
71     auto outputs_type = build_info->GetAllOutputDeviceTypes();
72     std::replace_if(
73       outputs_type.begin(), outputs_type.end(), [orig_type](TypeId type_id) { return type_id == orig_type; }, new_type);
74     build_info->SetOutputsDeviceType(outputs_type);
75   }
76 }
77 
NewCastNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,TypeId dst_type)78 AnfNodePtr NewCastNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, TypeId dst_type) {
79   auto cb = Callback::Instance();
80   MS_EXCEPTION_IF_NULL(cb);
81   auto src_type = cb->GetOutputType(input_node, 0);
82   if (dst_type == src_type) {
83     return input_node;
84   }
85   auto type_value = std::make_shared<Int64Imm>(static_cast<int64_t>(dst_type));
86   auto type_node = NewValueNode(type_value);
87   type_node->set_abstract(type_value->ToAbstract());
88   auto cast_node =
89     func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input_node, type_node});
90   auto input_abstract = input_node->abstract();
91   auto cast_abstract = UpdateAbstractDataType(input_abstract, TypeIdToType(dst_type));
92   cast_node->set_abstract(cast_abstract);
93   if (cb->IsUseDeviceInfo()) {
94     auto input_format = cb->GetOutputFormat(input_node, 0);
95     auto input_type = cb->GetOutputType(input_node, 0);
96     auto input_object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(input_abstract));
97     std::string type_node_format = kOpFormat_DEFAULT;
98     auto type_node_type = kNumberTypeInt64;
99     auto type_node_object_type = kernel::KernelObjectType::SCALAR;
100     // set build info for type node
101     auto type_kernel_info = std::make_shared<device::KernelInfo>();
102     type_node->set_kernel_info(type_kernel_info);
103     auto type_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
104     type_info_builder->SetOutputsFormat(std::vector<std::string>{type_node_format});
105     type_info_builder->SetOutputsDeviceType(std::vector<TypeId>{type_node_type});
106     type_info_builder->SetOutputsKernelObjectType(std::vector<kernel::KernelObjectType>{type_node_object_type});
107     AnfAlgo::SetSelectKernelBuildInfo(type_info_builder->Build(), type_node.get());
108     // set build info for cast node
109     auto kernel_info = std::make_shared<device::KernelInfo>();
110     cast_node->set_kernel_info(kernel_info);
111     auto info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
112     info_builder->SetInputsFormat(std::vector<std::string>{input_format, type_node_format});
113     info_builder->SetInputsDeviceType(std::vector<TypeId>{input_type, type_node_type});
114     info_builder->SetInputsKernelObjectType(
115       std::vector<kernel::KernelObjectType>{input_object_type, type_node_object_type});
116     info_builder->SetOutputsFormat(std::vector<std::string>{input_format});
117     info_builder->SetOutputsDeviceType(std::vector<TypeId>{dst_type});
118     info_builder->SetOutputsKernelObjectType(std::vector<kernel::KernelObjectType>{input_object_type});
119     AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
120   }
121   return cast_node;
122 }
123 
124 // {prim_name, {inputs_keep_index}}
125 const HashMap<std::string, std::vector<size_t>> kNeedKeepBF16Ops = {
126   {ops::kNameAssign, {kIndex2}}, {ops::kNameMatMul, {kIndex1, kIndex2}}, {ops::kNameBatchMatMul, {kIndex1, kIndex2}}};
127 
NeedKeepBF16(const CNodePtr & cnode)128 inline bool NeedKeepBF16(const CNodePtr &cnode) {
129   const auto &prim = GetCNodePrimitive(cnode);
130   return prim != nullptr && kNeedKeepBF16Ops.find(prim->name()) != kNeedKeepBF16Ops.end();
131 }
132 }  // namespace
133 
GetCastedInput(const AnfNodePtr & input_node,TypeId dst_type,const FuncGraphPtr & func_graph)134 AnfNodePtr ConvertBFloat16::GetCastedInput(const AnfNodePtr &input_node, TypeId dst_type,
135                                            const FuncGraphPtr &func_graph) {
136   auto iter = cast_nodes_.find(input_node);
137   if (iter != cast_nodes_.end()) {
138     return iter->second;
139   }
140   cast_nodes_[input_node] = NewCastNode(func_graph, input_node, dst_type);
141   return cast_nodes_[input_node];
142 }
143 
CastTensor(const ValueNodePtr & value_node)144 AnfNodePtr ConvertBFloat16::CastTensor(const ValueNodePtr &value_node) {
145   MS_EXCEPTION_IF_NULL(value_node);
146   auto value = value_node->value();
147   MS_EXCEPTION_IF_NULL(value);
148   auto tensor = value->cast<tensor::TensorPtr>();
149   MS_EXCEPTION_IF_NULL(tensor);
150   auto *src_data = reinterpret_cast<bfloat16 *>(tensor->data_c());
151   MS_EXCEPTION_IF_NULL(src_data);
152   // create float32 tensor
153   auto new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, tensor->shape());
154   MS_EXCEPTION_IF_NULL(new_tensor);
155   auto *dst_data = reinterpret_cast<float *>(new_tensor->data_c());
156   MS_EXCEPTION_IF_NULL(dst_data);
157   for (size_t i = 0; i < tensor->DataSize(); ++i) {
158     dst_data[i] = static_cast<float>(src_data[i]);
159   }
160   // create new value node
161   auto new_value_node = NewValueNode(new_tensor);
162   new_value_node->set_abstract(new_tensor->ToAbstract());
163   if (value_node->kernel_info() != nullptr) {
164     auto build_info = AnfAlgo::GetSelectKernelBuildInfo(value_node);
165     if (build_info != nullptr) {
166       // set build info for new value node
167       new_value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
168       auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(build_info);
169       builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32});
170       AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_value_node.get());
171     }
172   }
173   return new_value_node;
174 }
175 
CastInput(const CNodePtr & cnode,size_t input_idx,const FuncGraphPtr & func_graph)176 void ConvertBFloat16::CastInput(const CNodePtr &cnode, size_t input_idx, const FuncGraphPtr &func_graph) {
177   input_idx += 1;
178   auto input_node = cnode->input(input_idx);
179   TypeId target_input_type = kNumberTypeFloat32;
180   if (input_node->isa<ValueNode>()) {
181     auto value_node = input_node->cast<ValueNodePtr>();
182     auto new_input = CastTensor(value_node);
183     cnode->set_input(input_idx, new_input);
184   } else if (IsPrimitiveCNode(input_node, prim::kPrimCast)) {
185     // directly link cast's input to current node(because cast bf16 to fp32 needs more intermediate cast)
186     auto cast_node = input_node->cast<CNodePtr>();
187     MS_EXCEPTION_IF_NULL(cast_node);
188     auto new_input = GetCastedInput(cast_node->input(1), target_input_type, func_graph);
189     cnode->set_input(input_idx, new_input);
190   } else {
191     auto new_input = GetCastedInput(input_node, target_input_type, func_graph);
192     cnode->set_input(input_idx, new_input);
193   }
194 }
195 
GetKeepBF16Nodes(const FuncGraphPtr & func_graph)196 void ConvertBFloat16::GetKeepBF16Nodes(const FuncGraphPtr &func_graph) {
197   keep_bf16_nodes_.clear();
198   auto nodes = TopoSort(func_graph->get_return());
199   for (auto node : nodes) {
200     auto cnode = node->cast<CNodePtr>();
201     if (cnode == nullptr) {
202       continue;
203     }
204     if (NeedKeepBF16(cnode)) {
205       // As NeedKeepBF16(cnode), value of GetCNodePrimitive(cnode) is not a nullptr
206       auto prim_name = GetCNodePrimitive(cnode)->name();
207       for (const auto &input_index : kNeedKeepBF16Ops.at(prim_name)) {
208         (void)keep_bf16_nodes_[cnode->input(input_index)].emplace_back(std::make_pair(cnode, input_index));
209       }
210     } else if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
211       auto ret_input = cnode->input(1);
212       MS_EXCEPTION_IF_NULL(ret_input);
213       if (IsPrimitiveCNode(ret_input, prim::kPrimMakeTuple)) {
214         // multiple output
215         last_node_ = ret_input->cast<CNodePtr>();
216         MS_EXCEPTION_IF_NULL(last_node_);
217         for (size_t i = 1; i < last_node_->size(); ++i) {
218           (void)keep_bf16_nodes_[last_node_->input(i)].emplace_back(std::make_pair(last_node_, i));
219         }
220       } else {
221         // single output
222         last_node_ = cnode;
223         (void)keep_bf16_nodes_[ret_input].emplace_back(std::make_pair(last_node_, 1));
224       }
225     }
226   }
227 }
228 
Process(const FuncGraphPtr & func_graph)229 bool ConvertBFloat16::Process(const FuncGraphPtr &func_graph) {
230   cast_nodes_.clear();
231   GetKeepBF16Nodes(func_graph);
232   auto mng = func_graph->manager();
233   if (mng == nullptr) {
234     mng = Manage(func_graph, true);
235     func_graph->set_manager(mng);
236   }
237   bool changed = false;
238   auto cb = Callback::Instance();
239   MS_EXCEPTION_IF_NULL(cb);
240   auto nodes = TopoSort(func_graph->get_return());
241   for (auto node : nodes) {
242     auto cnode = node->cast<CNodePtr>();
243     if (cnode == nullptr || NeedKeepBF16(cnode)) {
244       continue;
245     }
246     if (cnode == last_node_) {
247       break;
248     }
249     // For cast node, directly update its input data type
250     if (IsPrimitiveCNode(node, prim::kPrimCast)) {
251       auto orig_input_type = cb->GetInputType(node, 0);
252       auto cur_input_type = cb->GetOutputType(cnode->input(1), 0);
253       if (cur_input_type != orig_input_type) {
254         UpdateBuildInfoInputDataType(node, orig_input_type, cur_input_type);
255       }
256       continue;
257     }
258     // For other nodes, add cast for its input and update its abstract and build info
259     //   add cast for node's output if node is sub-graph's output
260     bool need_update = false;
261     for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(cnode); ++i) {
262       auto orig_input_type = cb->GetInputType(cnode, i);
263       if (orig_input_type == kNumberTypeBFloat16) {
264         need_update = true;
265         changed = true;
266         CastInput(cnode, i, func_graph);
267       }
268     }
269     if (!need_update) {
270       continue;
271     }
272     auto orig_output_type = cb->GetOutputType(node, 0);
273     // update node abstract
274     auto new_abstract = UpdateAbstractDataType(node->abstract(), kFloat32);
275     node->set_abstract(new_abstract);
276     // update node build info
277     UpdateBuildInfoInputDataType(node, kNumberTypeBFloat16, kNumberTypeFloat32);
278     UpdateBuildInfoOutputDataType(node, kNumberTypeBFloat16, kNumberTypeFloat32);
279     // add cast for current node if it is output node
280     auto cur_output_type = cb->GetOutputType(node, 0);
281     auto iter = keep_bf16_nodes_.find(node);
282     if (iter != keep_bf16_nodes_.end() && cur_output_type != orig_output_type) {
283       auto new_cast_node = NewCastNode(func_graph, node, orig_output_type);
284       for (auto &[user_node, idx] : iter->second) {
285         user_node->set_input(idx, new_cast_node);
286       }
287     }
288   }
289   if (changed) {
290     mng->RemoveRoots();
291     mng->KeepRoots({func_graph});
292   }
293   return changed;
294 }
295 
Run(const FuncGraphPtr & func_graph)296 bool ConvertBFloat16::Run(const FuncGraphPtr &func_graph) {
297   MS_EXCEPTION_IF_NULL(func_graph);
298   auto mng = func_graph->manager();
299   if (mng == nullptr) {
300     mng = Manage(func_graph, true);
301     func_graph->set_manager(mng);
302   }
303   bool changed = false;
304   auto nodes = TopoSort(func_graph->get_return());
305   for (const auto &node : nodes) {
306     if (common::AnfAlgo::IsGraphKernel(node)) {
307       auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
308       MS_EXCEPTION_IF_NULL(sub_graph);
309       changed = Process(sub_graph) || changed;
310     }
311   }
312   if (changed) {
313     GkUtils::UpdateFuncGraphManager(mng, func_graph);
314   }
315   return changed;
316 }
317 }  // namespace mindspore::graphkernel
318