• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 /**
3  * Copyright 2024 Huawei Technologies Co., Ltd
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include "plugin/device/ascend/optimizer/ir_fusion/inference_weight_preprocess_utils.h"
19 #include <string>
20 #include <memory>
21 #include <algorithm>
22 
23 namespace mindspore {
24 namespace opt {
25 
GetParamFromLoad(const CNodePtr & load,const bool unused)26 tensor::TensorPtr GetParamFromLoad(const CNodePtr &load, const bool unused) {
27   if (IsPrimitiveCNode(load, prim::kPrimLoad)) {
28     auto anf_node = common::AnfAlgo::GetInputNode(load, kIndex0);
29     MS_EXCEPTION_IF_NULL(anf_node);
30     if (anf_node->isa<Parameter>()) {
31       auto para = anf_node->cast<ParameterPtr>();
32       MS_EXCEPTION_IF_NULL(para);
33       if (para->has_default()) {
34         auto value = para->default_param();
35         MS_EXCEPTION_IF_NULL(value);
36         auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
37         MS_EXCEPTION_IF_NULL(tensor);
38         if (unused) {
39           auto param_info = para->param_info();
40           param_info->set_ignore_device_addr(true);
41         }
42         return tensor;
43       }
44     }
45   }
46   return nullptr;
47 }
48 
CheckFusionValid(const CNodePtr & matmul,int64_t * k,const int trans_a_pos,const int trans_b_pos,const std::vector<TypeId> & valid_dtypes)49 bool CheckFusionValid(const CNodePtr &matmul, int64_t *k, const int trans_a_pos, const int trans_b_pos,
50                       const std::vector<TypeId> &valid_dtypes) {
51   auto inputs = matmul->inputs();
52   auto trans_a_node = GetValueNode(inputs[trans_a_pos]);
53   auto trans_b_node = GetValueNode(inputs[trans_b_pos]);
54   MS_EXCEPTION_IF_NULL(trans_a_node);
55   MS_EXCEPTION_IF_NULL(trans_b_node);
56   bool trans_a = GetValue<bool>(trans_a_node);
57   bool trans_b = GetValue<bool>(trans_b_node);
58   if (trans_a != false) {
59     return false;
60   }
61   if (trans_b != true) {
62     return false;
63   }
64   auto weight_node = inputs[kIndex2]->cast<CNodePtr>();
65   auto w_param = GetParamFromLoad(weight_node, false);
66   if (!w_param) {
67     return false;
68   }
69   auto w_type_id = static_cast<TypeId>(w_param->data_type_c());
70   if (std::find(valid_dtypes.begin(), valid_dtypes.end(), w_type_id) == valid_dtypes.end()) {
71     return false;
72   }
73   std::vector<int64_t> origin_shape = w_param->shape();
74   auto parallel_shape = common::AnfAlgo::GetOutputInferShape(weight_node, kIndex0);
75   // when param is not parallel tiled, it is not safe to use and concat, skip this pass
76   if (parallel_shape.size() != origin_shape.size()) {
77     return false;
78   }
79   for (int i = 0; i < static_cast<int>(parallel_shape.size()); i++) {
80     if (parallel_shape[i] != origin_shape[i]) {
81       return false;
82     }
83   }
84   const int shape_num_two = 2;
85   if (origin_shape.size() != shape_num_two) {
86     return false;
87   }
88   if (*k == -1) {
89     *k = origin_shape[1];
90   } else if (*k != origin_shape[1]) {
91     return false;
92   }
93   return true;
94 }
95 
96 template <typename T>
ConcatWeightsToNewTensor(void * data_ptr,const std::vector<void * > & data_c_list,const int64_t & k_len,const std::vector<int64_t> & n_len_list,const bool & need_rank_offset,const uint32_t & global_rank_id)97 void ConcatWeightsToNewTensor(void *data_ptr, const std::vector<void *> &data_c_list, const int64_t &k_len,
98                               const std::vector<int64_t> &n_len_list, const bool &need_rank_offset,
99                               const uint32_t &global_rank_id) {
100   const auto data_size = sizeof(T);
101   int64_t offset = 0;
102   for (int idx = 0; idx < static_cast<int>(data_c_list.size()); idx++) {
103     auto count = k_len * n_len_list[idx];
104     auto rank_offset = need_rank_offset ? global_rank_id * count : 0;
105     auto byte_size = count * data_size;
106     memcpy_s(reinterpret_cast<T *>(data_ptr) + offset, byte_size, reinterpret_cast<T *>(data_c_list[idx]) + rank_offset,
107              byte_size);
108     offset += count;
109   }
110 }
111 
CreateWeightTensor(TypeId type_id,const std::vector<int64_t> & weight_shape,const std::vector<void * > & data_c_list,const std::vector<int64_t> & n_len_list,const int64_t & k_len,const std::shared_ptr<Type> & w_dtype,const bool & need_rank_offset,const uint32_t & global_rank_id)112 std::shared_ptr<ValueNode> CreateWeightTensor(TypeId type_id, const std::vector<int64_t> &weight_shape,
113                                               const std::vector<void *> &data_c_list,
114                                               const std::vector<int64_t> &n_len_list, const int64_t &k_len,
115                                               const std::shared_ptr<Type> &w_dtype, const bool &need_rank_offset,
116                                               const uint32_t &global_rank_id) {
117   tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(type_id, weight_shape);
118   auto data_ptr = assist_tensor->data_c();
119   if (type_id == TypeId::kNumberTypeBFloat16) {
120     ConcatWeightsToNewTensor<bfloat16>(data_ptr, data_c_list, k_len, n_len_list, need_rank_offset, global_rank_id);
121   } else if (type_id == TypeId::kNumberTypeFloat16) {
122     ConcatWeightsToNewTensor<float16>(data_ptr, data_c_list, k_len, n_len_list, need_rank_offset, global_rank_id);
123   } else if (type_id == TypeId::kNumberTypeInt8) {
124     ConcatWeightsToNewTensor<int8_t>(data_ptr, data_c_list, k_len, n_len_list, need_rank_offset, global_rank_id);
125   }
126 
127   TensorTypePtr tensor_type = std::make_shared<TensorType>(w_dtype);
128   tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
129   assist_tensor->set_device_info(device_info);
130   MS_EXCEPTION_IF_NULL(assist_tensor);
131 
132   auto assist_const = std::make_shared<ValueNode>(assist_tensor);
133   auto assist_abstract = assist_tensor->ToAbstract();
134   assist_const->set_abstract(assist_abstract);
135   auto assist_kernel_info = std::make_shared<device::KernelInfo>();
136   assist_const->set_kernel_info(assist_kernel_info);
137   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
138   builder.SetOutputsFormat({kOpFormat_DEFAULT});
139   builder.SetOutputsDeviceType({common::AnfAlgo::GetOutputInferDataType(assist_const, 0)});
140   builder.SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
141   AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), assist_const.get());
142   return assist_const;
143 }
144 
SortWeightNodeList(AnfNodePtrList * node_list)145 void SortWeightNodeList(AnfNodePtrList *node_list) {
146   std::sort(node_list->begin(), node_list->end(), [](const AnfNodePtr &a, const AnfNodePtr &b) {
147     auto para_a =
148       common::AnfAlgo::GetInputNode(a->cast<CNodePtr>()->inputs()[2]->cast<CNodePtr>(), kIndex0)->cast<ParameterPtr>();
149     auto para_b =
150       common::AnfAlgo::GetInputNode(b->cast<CNodePtr>()->inputs()[2]->cast<CNodePtr>(), kIndex0)->cast<ParameterPtr>();
151     return para_a->name() < para_b->name();
152   });
153 }
154 
155 }  // namespace opt
156 }  // namespace mindspore
157