1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/convert_bfloat16.h"
17 #include <vector>
18 #include <string>
19 #include <memory>
20 #include <utility>
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
24 #include "backend/common/graph_kernel/graph_kernel_helper.h"
25 #include "kernel/common_utils.h"
26
27 namespace mindspore::graphkernel {
28 namespace {
UpdateAbstractDataType(const AbstractBasePtr & orig_abs,TypePtr data_type)29 AbstractBasePtr UpdateAbstractDataType(const AbstractBasePtr &orig_abs, TypePtr data_type) {
30 if (orig_abs == nullptr) {
31 return orig_abs;
32 }
33 if (orig_abs->isa<abstract::AbstractTensor>()) {
34 return std::make_shared<abstract::AbstractTensor>(data_type, orig_abs->GetShape());
35 }
36 if (orig_abs->isa<abstract::AbstractScalar>()) {
37 auto new_abs = orig_abs->Clone();
38 new_abs->set_type(data_type);
39 return new_abs;
40 }
41 if (orig_abs->isa<abstract::AbstractTuple>()) {
42 auto abs_tuple = orig_abs->cast<abstract::AbstractTuplePtr>()->elements();
43 AbstractBasePtrList abstracts(abs_tuple.size());
44 for (size_t i = 0; i < abs_tuple.size(); ++i) {
45 abstracts[i] = UpdateAbstractDataType(abs_tuple[i], data_type);
46 }
47 return std::make_shared<abstract::AbstractTuple>(abstracts);
48 }
49 return orig_abs;
50 }
51
UpdateBuildInfoInputDataType(const AnfNodePtr & node,TypeId orig_type,TypeId new_type)52 void UpdateBuildInfoInputDataType(const AnfNodePtr &node, TypeId orig_type, TypeId new_type) {
53 if (node->kernel_info() == nullptr) {
54 return;
55 }
56 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
57 if (build_info != nullptr) {
58 auto inputs_type = build_info->GetAllInputDeviceTypes();
59 std::replace_if(
60 inputs_type.begin(), inputs_type.end(), [orig_type](TypeId type_id) { return type_id == orig_type; }, new_type);
61 build_info->SetInputsDeviceType(inputs_type);
62 }
63 }
64
UpdateBuildInfoOutputDataType(const AnfNodePtr & node,TypeId orig_type,TypeId new_type)65 void UpdateBuildInfoOutputDataType(const AnfNodePtr &node, TypeId orig_type, TypeId new_type) {
66 if (node->kernel_info() == nullptr) {
67 return;
68 }
69 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
70 if (build_info != nullptr) {
71 auto outputs_type = build_info->GetAllOutputDeviceTypes();
72 std::replace_if(
73 outputs_type.begin(), outputs_type.end(), [orig_type](TypeId type_id) { return type_id == orig_type; }, new_type);
74 build_info->SetOutputsDeviceType(outputs_type);
75 }
76 }
77
NewCastNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,TypeId dst_type)78 AnfNodePtr NewCastNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, TypeId dst_type) {
79 auto cb = Callback::Instance();
80 MS_EXCEPTION_IF_NULL(cb);
81 auto src_type = cb->GetOutputType(input_node, 0);
82 if (dst_type == src_type) {
83 return input_node;
84 }
85 auto type_value = std::make_shared<Int64Imm>(static_cast<int64_t>(dst_type));
86 auto type_node = NewValueNode(type_value);
87 type_node->set_abstract(type_value->ToAbstract());
88 auto cast_node =
89 func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), input_node, type_node});
90 auto input_abstract = input_node->abstract();
91 auto cast_abstract = UpdateAbstractDataType(input_abstract, TypeIdToType(dst_type));
92 cast_node->set_abstract(cast_abstract);
93 if (cb->IsUseDeviceInfo()) {
94 auto input_format = cb->GetOutputFormat(input_node, 0);
95 auto input_type = cb->GetOutputType(input_node, 0);
96 auto input_object_type = kernel::TypeIdToKernelObjectType(AnfAlgo::GetAbstractObjectType(input_abstract));
97 std::string type_node_format = kOpFormat_DEFAULT;
98 auto type_node_type = kNumberTypeInt64;
99 auto type_node_object_type = kernel::KernelObjectType::SCALAR;
100 // set build info for type node
101 auto type_kernel_info = std::make_shared<device::KernelInfo>();
102 type_node->set_kernel_info(type_kernel_info);
103 auto type_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
104 type_info_builder->SetOutputsFormat(std::vector<std::string>{type_node_format});
105 type_info_builder->SetOutputsDeviceType(std::vector<TypeId>{type_node_type});
106 type_info_builder->SetOutputsKernelObjectType(std::vector<kernel::KernelObjectType>{type_node_object_type});
107 AnfAlgo::SetSelectKernelBuildInfo(type_info_builder->Build(), type_node.get());
108 // set build info for cast node
109 auto kernel_info = std::make_shared<device::KernelInfo>();
110 cast_node->set_kernel_info(kernel_info);
111 auto info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
112 info_builder->SetInputsFormat(std::vector<std::string>{input_format, type_node_format});
113 info_builder->SetInputsDeviceType(std::vector<TypeId>{input_type, type_node_type});
114 info_builder->SetInputsKernelObjectType(
115 std::vector<kernel::KernelObjectType>{input_object_type, type_node_object_type});
116 info_builder->SetOutputsFormat(std::vector<std::string>{input_format});
117 info_builder->SetOutputsDeviceType(std::vector<TypeId>{dst_type});
118 info_builder->SetOutputsKernelObjectType(std::vector<kernel::KernelObjectType>{input_object_type});
119 AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
120 }
121 return cast_node;
122 }
123
124 // {prim_name, {inputs_keep_index}}
125 const HashMap<std::string, std::vector<size_t>> kNeedKeepBF16Ops = {
126 {ops::kNameAssign, {kIndex2}}, {ops::kNameMatMul, {kIndex1, kIndex2}}, {ops::kNameBatchMatMul, {kIndex1, kIndex2}}};
127
NeedKeepBF16(const CNodePtr & cnode)128 inline bool NeedKeepBF16(const CNodePtr &cnode) {
129 const auto &prim = GetCNodePrimitive(cnode);
130 return prim != nullptr && kNeedKeepBF16Ops.find(prim->name()) != kNeedKeepBF16Ops.end();
131 }
132 } // namespace
133
GetCastedInput(const AnfNodePtr & input_node,TypeId dst_type,const FuncGraphPtr & func_graph)134 AnfNodePtr ConvertBFloat16::GetCastedInput(const AnfNodePtr &input_node, TypeId dst_type,
135 const FuncGraphPtr &func_graph) {
136 auto iter = cast_nodes_.find(input_node);
137 if (iter != cast_nodes_.end()) {
138 return iter->second;
139 }
140 cast_nodes_[input_node] = NewCastNode(func_graph, input_node, dst_type);
141 return cast_nodes_[input_node];
142 }
143
CastTensor(const ValueNodePtr & value_node)144 AnfNodePtr ConvertBFloat16::CastTensor(const ValueNodePtr &value_node) {
145 MS_EXCEPTION_IF_NULL(value_node);
146 auto value = value_node->value();
147 MS_EXCEPTION_IF_NULL(value);
148 auto tensor = value->cast<tensor::TensorPtr>();
149 MS_EXCEPTION_IF_NULL(tensor);
150 auto *src_data = reinterpret_cast<bfloat16 *>(tensor->data_c());
151 MS_EXCEPTION_IF_NULL(src_data);
152 // create float32 tensor
153 auto new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, tensor->shape());
154 MS_EXCEPTION_IF_NULL(new_tensor);
155 auto *dst_data = reinterpret_cast<float *>(new_tensor->data_c());
156 MS_EXCEPTION_IF_NULL(dst_data);
157 for (size_t i = 0; i < tensor->DataSize(); ++i) {
158 dst_data[i] = static_cast<float>(src_data[i]);
159 }
160 // create new value node
161 auto new_value_node = NewValueNode(new_tensor);
162 new_value_node->set_abstract(new_tensor->ToAbstract());
163 if (value_node->kernel_info() != nullptr) {
164 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(value_node);
165 if (build_info != nullptr) {
166 // set build info for new value node
167 new_value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
168 auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(build_info);
169 builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32});
170 AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), new_value_node.get());
171 }
172 }
173 return new_value_node;
174 }
175
CastInput(const CNodePtr & cnode,size_t input_idx,const FuncGraphPtr & func_graph)176 void ConvertBFloat16::CastInput(const CNodePtr &cnode, size_t input_idx, const FuncGraphPtr &func_graph) {
177 input_idx += 1;
178 auto input_node = cnode->input(input_idx);
179 TypeId target_input_type = kNumberTypeFloat32;
180 if (input_node->isa<ValueNode>()) {
181 auto value_node = input_node->cast<ValueNodePtr>();
182 auto new_input = CastTensor(value_node);
183 cnode->set_input(input_idx, new_input);
184 } else if (IsPrimitiveCNode(input_node, prim::kPrimCast)) {
185 // directly link cast's input to current node(because cast bf16 to fp32 needs more intermediate cast)
186 auto cast_node = input_node->cast<CNodePtr>();
187 MS_EXCEPTION_IF_NULL(cast_node);
188 auto new_input = GetCastedInput(cast_node->input(1), target_input_type, func_graph);
189 cnode->set_input(input_idx, new_input);
190 } else {
191 auto new_input = GetCastedInput(input_node, target_input_type, func_graph);
192 cnode->set_input(input_idx, new_input);
193 }
194 }
195
GetKeepBF16Nodes(const FuncGraphPtr & func_graph)196 void ConvertBFloat16::GetKeepBF16Nodes(const FuncGraphPtr &func_graph) {
197 keep_bf16_nodes_.clear();
198 auto nodes = TopoSort(func_graph->get_return());
199 for (auto node : nodes) {
200 auto cnode = node->cast<CNodePtr>();
201 if (cnode == nullptr) {
202 continue;
203 }
204 if (NeedKeepBF16(cnode)) {
205 // As NeedKeepBF16(cnode), value of GetCNodePrimitive(cnode) is not a nullptr
206 auto prim_name = GetCNodePrimitive(cnode)->name();
207 for (const auto &input_index : kNeedKeepBF16Ops.at(prim_name)) {
208 (void)keep_bf16_nodes_[cnode->input(input_index)].emplace_back(std::make_pair(cnode, input_index));
209 }
210 } else if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
211 auto ret_input = cnode->input(1);
212 MS_EXCEPTION_IF_NULL(ret_input);
213 if (IsPrimitiveCNode(ret_input, prim::kPrimMakeTuple)) {
214 // multiple output
215 last_node_ = ret_input->cast<CNodePtr>();
216 MS_EXCEPTION_IF_NULL(last_node_);
217 for (size_t i = 1; i < last_node_->size(); ++i) {
218 (void)keep_bf16_nodes_[last_node_->input(i)].emplace_back(std::make_pair(last_node_, i));
219 }
220 } else {
221 // single output
222 last_node_ = cnode;
223 (void)keep_bf16_nodes_[ret_input].emplace_back(std::make_pair(last_node_, 1));
224 }
225 }
226 }
227 }
228
Process(const FuncGraphPtr & func_graph)229 bool ConvertBFloat16::Process(const FuncGraphPtr &func_graph) {
230 cast_nodes_.clear();
231 GetKeepBF16Nodes(func_graph);
232 auto mng = func_graph->manager();
233 if (mng == nullptr) {
234 mng = Manage(func_graph, true);
235 func_graph->set_manager(mng);
236 }
237 bool changed = false;
238 auto cb = Callback::Instance();
239 MS_EXCEPTION_IF_NULL(cb);
240 auto nodes = TopoSort(func_graph->get_return());
241 for (auto node : nodes) {
242 auto cnode = node->cast<CNodePtr>();
243 if (cnode == nullptr || NeedKeepBF16(cnode)) {
244 continue;
245 }
246 if (cnode == last_node_) {
247 break;
248 }
249 // For cast node, directly update its input data type
250 if (IsPrimitiveCNode(node, prim::kPrimCast)) {
251 auto orig_input_type = cb->GetInputType(node, 0);
252 auto cur_input_type = cb->GetOutputType(cnode->input(1), 0);
253 if (cur_input_type != orig_input_type) {
254 UpdateBuildInfoInputDataType(node, orig_input_type, cur_input_type);
255 }
256 continue;
257 }
258 // For other nodes, add cast for its input and update its abstract and build info
259 // add cast for node's output if node is sub-graph's output
260 bool need_update = false;
261 for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(cnode); ++i) {
262 auto orig_input_type = cb->GetInputType(cnode, i);
263 if (orig_input_type == kNumberTypeBFloat16) {
264 need_update = true;
265 changed = true;
266 CastInput(cnode, i, func_graph);
267 }
268 }
269 if (!need_update) {
270 continue;
271 }
272 auto orig_output_type = cb->GetOutputType(node, 0);
273 // update node abstract
274 auto new_abstract = UpdateAbstractDataType(node->abstract(), kFloat32);
275 node->set_abstract(new_abstract);
276 // update node build info
277 UpdateBuildInfoInputDataType(node, kNumberTypeBFloat16, kNumberTypeFloat32);
278 UpdateBuildInfoOutputDataType(node, kNumberTypeBFloat16, kNumberTypeFloat32);
279 // add cast for current node if it is output node
280 auto cur_output_type = cb->GetOutputType(node, 0);
281 auto iter = keep_bf16_nodes_.find(node);
282 if (iter != keep_bf16_nodes_.end() && cur_output_type != orig_output_type) {
283 auto new_cast_node = NewCastNode(func_graph, node, orig_output_type);
284 for (auto &[user_node, idx] : iter->second) {
285 user_node->set_input(idx, new_cast_node);
286 }
287 }
288 }
289 if (changed) {
290 mng->RemoveRoots();
291 mng->KeepRoots({func_graph});
292 }
293 return changed;
294 }
295
Run(const FuncGraphPtr & func_graph)296 bool ConvertBFloat16::Run(const FuncGraphPtr &func_graph) {
297 MS_EXCEPTION_IF_NULL(func_graph);
298 auto mng = func_graph->manager();
299 if (mng == nullptr) {
300 mng = Manage(func_graph, true);
301 func_graph->set_manager(mng);
302 }
303 bool changed = false;
304 auto nodes = TopoSort(func_graph->get_return());
305 for (const auto &node : nodes) {
306 if (common::AnfAlgo::IsGraphKernel(node)) {
307 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
308 MS_EXCEPTION_IF_NULL(sub_graph);
309 changed = Process(sub_graph) || changed;
310 }
311 }
312 if (changed) {
313 GkUtils::UpdateFuncGraphManager(mng, func_graph);
314 }
315 return changed;
316 }
317 } // namespace mindspore::graphkernel
318