1
2 /**
3 * Copyright 2024 Huawei Technologies Co., Ltd
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 #include "plugin/device/ascend/optimizer/ir_fusion/inference_weight_preprocess_utils.h"
19 #include <string>
20 #include <memory>
21 #include <algorithm>
22
23 namespace mindspore {
24 namespace opt {
25
GetParamFromLoad(const CNodePtr & load,const bool unused)26 tensor::TensorPtr GetParamFromLoad(const CNodePtr &load, const bool unused) {
27 if (IsPrimitiveCNode(load, prim::kPrimLoad)) {
28 auto anf_node = common::AnfAlgo::GetInputNode(load, kIndex0);
29 MS_EXCEPTION_IF_NULL(anf_node);
30 if (anf_node->isa<Parameter>()) {
31 auto para = anf_node->cast<ParameterPtr>();
32 MS_EXCEPTION_IF_NULL(para);
33 if (para->has_default()) {
34 auto value = para->default_param();
35 MS_EXCEPTION_IF_NULL(value);
36 auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
37 MS_EXCEPTION_IF_NULL(tensor);
38 if (unused) {
39 auto param_info = para->param_info();
40 param_info->set_ignore_device_addr(true);
41 }
42 return tensor;
43 }
44 }
45 }
46 return nullptr;
47 }
48
CheckFusionValid(const CNodePtr & matmul,int64_t * k,const int trans_a_pos,const int trans_b_pos,const std::vector<TypeId> & valid_dtypes)49 bool CheckFusionValid(const CNodePtr &matmul, int64_t *k, const int trans_a_pos, const int trans_b_pos,
50 const std::vector<TypeId> &valid_dtypes) {
51 auto inputs = matmul->inputs();
52 auto trans_a_node = GetValueNode(inputs[trans_a_pos]);
53 auto trans_b_node = GetValueNode(inputs[trans_b_pos]);
54 MS_EXCEPTION_IF_NULL(trans_a_node);
55 MS_EXCEPTION_IF_NULL(trans_b_node);
56 bool trans_a = GetValue<bool>(trans_a_node);
57 bool trans_b = GetValue<bool>(trans_b_node);
58 if (trans_a != false) {
59 return false;
60 }
61 if (trans_b != true) {
62 return false;
63 }
64 auto weight_node = inputs[kIndex2]->cast<CNodePtr>();
65 auto w_param = GetParamFromLoad(weight_node, false);
66 if (!w_param) {
67 return false;
68 }
69 auto w_type_id = static_cast<TypeId>(w_param->data_type_c());
70 if (std::find(valid_dtypes.begin(), valid_dtypes.end(), w_type_id) == valid_dtypes.end()) {
71 return false;
72 }
73 std::vector<int64_t> origin_shape = w_param->shape();
74 auto parallel_shape = common::AnfAlgo::GetOutputInferShape(weight_node, kIndex0);
75 // when param is not parallel tiled, it is not safe to use and concat, skip this pass
76 if (parallel_shape.size() != origin_shape.size()) {
77 return false;
78 }
79 for (int i = 0; i < static_cast<int>(parallel_shape.size()); i++) {
80 if (parallel_shape[i] != origin_shape[i]) {
81 return false;
82 }
83 }
84 const int shape_num_two = 2;
85 if (origin_shape.size() != shape_num_two) {
86 return false;
87 }
88 if (*k == -1) {
89 *k = origin_shape[1];
90 } else if (*k != origin_shape[1]) {
91 return false;
92 }
93 return true;
94 }
95
96 template <typename T>
ConcatWeightsToNewTensor(void * data_ptr,const std::vector<void * > & data_c_list,const int64_t & k_len,const std::vector<int64_t> & n_len_list,const bool & need_rank_offset,const uint32_t & global_rank_id)97 void ConcatWeightsToNewTensor(void *data_ptr, const std::vector<void *> &data_c_list, const int64_t &k_len,
98 const std::vector<int64_t> &n_len_list, const bool &need_rank_offset,
99 const uint32_t &global_rank_id) {
100 const auto data_size = sizeof(T);
101 int64_t offset = 0;
102 for (int idx = 0; idx < static_cast<int>(data_c_list.size()); idx++) {
103 auto count = k_len * n_len_list[idx];
104 auto rank_offset = need_rank_offset ? global_rank_id * count : 0;
105 auto byte_size = count * data_size;
106 memcpy_s(reinterpret_cast<T *>(data_ptr) + offset, byte_size, reinterpret_cast<T *>(data_c_list[idx]) + rank_offset,
107 byte_size);
108 offset += count;
109 }
110 }
111
CreateWeightTensor(TypeId type_id,const std::vector<int64_t> & weight_shape,const std::vector<void * > & data_c_list,const std::vector<int64_t> & n_len_list,const int64_t & k_len,const std::shared_ptr<Type> & w_dtype,const bool & need_rank_offset,const uint32_t & global_rank_id)112 std::shared_ptr<ValueNode> CreateWeightTensor(TypeId type_id, const std::vector<int64_t> &weight_shape,
113 const std::vector<void *> &data_c_list,
114 const std::vector<int64_t> &n_len_list, const int64_t &k_len,
115 const std::shared_ptr<Type> &w_dtype, const bool &need_rank_offset,
116 const uint32_t &global_rank_id) {
117 tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(type_id, weight_shape);
118 auto data_ptr = assist_tensor->data_c();
119 if (type_id == TypeId::kNumberTypeBFloat16) {
120 ConcatWeightsToNewTensor<bfloat16>(data_ptr, data_c_list, k_len, n_len_list, need_rank_offset, global_rank_id);
121 } else if (type_id == TypeId::kNumberTypeFloat16) {
122 ConcatWeightsToNewTensor<float16>(data_ptr, data_c_list, k_len, n_len_list, need_rank_offset, global_rank_id);
123 } else if (type_id == TypeId::kNumberTypeInt8) {
124 ConcatWeightsToNewTensor<int8_t>(data_ptr, data_c_list, k_len, n_len_list, need_rank_offset, global_rank_id);
125 }
126
127 TensorTypePtr tensor_type = std::make_shared<TensorType>(w_dtype);
128 tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
129 assist_tensor->set_device_info(device_info);
130 MS_EXCEPTION_IF_NULL(assist_tensor);
131
132 auto assist_const = std::make_shared<ValueNode>(assist_tensor);
133 auto assist_abstract = assist_tensor->ToAbstract();
134 assist_const->set_abstract(assist_abstract);
135 auto assist_kernel_info = std::make_shared<device::KernelInfo>();
136 assist_const->set_kernel_info(assist_kernel_info);
137 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
138 builder.SetOutputsFormat({kOpFormat_DEFAULT});
139 builder.SetOutputsDeviceType({common::AnfAlgo::GetOutputInferDataType(assist_const, 0)});
140 builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
141 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), assist_const.get());
142 return assist_const;
143 }
144
SortWeightNodeList(AnfNodePtrList * node_list)145 void SortWeightNodeList(AnfNodePtrList *node_list) {
146 std::sort(node_list->begin(), node_list->end(), [](const AnfNodePtr &a, const AnfNodePtr &b) {
147 auto para_a =
148 common::AnfAlgo::GetInputNode(a->cast<CNodePtr>()->inputs()[2]->cast<CNodePtr>(), kIndex0)->cast<ParameterPtr>();
149 auto para_b =
150 common::AnfAlgo::GetInputNode(b->cast<CNodePtr>()->inputs()[2]->cast<CNodePtr>(), kIndex0)->cast<ParameterPtr>();
151 return para_a->name() < para_b->name();
152 });
153 }
154
155 } // namespace opt
156 } // namespace mindspore
157