1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_CONVERT_BFLOAT_H_ 17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_CONVERT_BFLOAT_H_ 18 19 #include <string> 20 #include <utility> 21 #include <vector> 22 #include "include/backend/optimizer/optimizer.h" 23 24 namespace mindspore::graphkernel { 25 /** 26 * @brief Add Cast for op's inputs if the input data type is bfloat16 27 * @example 28 * sub_graph(p0: bfloat16, p1: bfloat16) { 29 * %0 = Op(p0, p1) 30 * return %0 31 * } 32 * ----------> 33 * sub_graph(p0: bfloat16, p1: bfloat16) { 34 * %0 = Cast(p0, float32) 35 * %1 = Cast(p1, float32) 36 * %2 = Op(%0, %1) 37 * %3 = Cast(%2, bfloat16) 38 * return %3 39 * } 40 */ 41 class ConvertBFloat16 : public opt::Pass { 42 public: ConvertBFloat16()43 ConvertBFloat16() : Pass("convert_bfloat16") {} 44 ~ConvertBFloat16() override = default; 45 bool Run(const FuncGraphPtr &func_graph) override; 46 47 private: 48 AnfNodePtr GetCastedInput(const AnfNodePtr &input_node, TypeId dst_type, const FuncGraphPtr &func_graph); 49 AnfNodePtr CastTensor(const ValueNodePtr &value_node); 50 void CastInput(const CNodePtr &cnode, size_t input_idx, const FuncGraphPtr &func_graph); 51 void GetKeepBF16Nodes(const FuncGraphPtr &func_graph); 52 bool Process(const FuncGraphPtr &func_graph); 53 HashMap<AnfNodePtr, AnfNodePtr> cast_nodes_; 54 // (keep_bf16_node, {node_user, input_idx}), node_user's input[input_idx] is keep_bf16_node 55 HashMap<AnfNodePtr, std::vector<std::pair<CNodePtr, size_t>>> keep_bf16_nodes_; 56 CNodePtr last_node_; 57 }; 58 } // namespace mindspore::graphkernel 59 #endif // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_CONVERT_BFLOAT_H_ 60