• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <string>
19 
20 #include "tools/graph_kernel/common/utils.h"
21 #include "src/tensor.h"
22 
23 namespace mindspore::graphkernel {
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
SplitString(const std::string & raw_str,char delimiter)26 std::vector<std::string> SplitString(const std::string &raw_str, char delimiter) {
27   std::vector<std::string> res;
28   std::string::size_type last_pos = 0;
29   auto cur_pos = raw_str.find(delimiter);
30   while (cur_pos != std::string::npos) {
31     (void)res.emplace_back(raw_str.substr(last_pos, cur_pos - last_pos));
32     cur_pos++;
33     last_pos = cur_pos;
34     cur_pos = raw_str.find(delimiter, cur_pos);
35   }
36   if (last_pos < raw_str.size()) {
37     (void)res.emplace_back(raw_str.substr(last_pos, raw_str.size() - last_pos + 1));
38   }
39   return res;
40 }
41 
GetCustomShape(const std::string & attr,std::vector<std::vector<int>> * shapes)42 int GetCustomShape(const std::string &attr, std::vector<std::vector<int>> *shapes) {
43   auto split_shape_str = SplitString(attr, ',');
44   for (size_t i = 0; i < split_shape_str.size(); i++) {
45     size_t dim = std::stoul(split_shape_str[i]);
46     if (i + dim >= split_shape_str.size()) {
47       MS_LOG(ERROR) << "Shape string is invalid. The shape dim is " << dim << ", but only "
48                     << split_shape_str.size() - i << " values follow.";
49       return RET_ERROR;
50     }
51     std::vector<int> shape;
52     for (size_t j = i + 1; j <= i + dim; j++) {
53       shape.push_back(std::stoi(split_shape_str[j]));
54     }
55     i += dim;
56     shapes->push_back(shape);
57   }
58   return RET_OK;
59 }
60 
GetCustomIndex(const std::string & dynamic_input_index,std::vector<size_t> * index)61 void GetCustomIndex(const std::string &dynamic_input_index, std::vector<size_t> *index) {
62   auto split_index_str = SplitString(dynamic_input_index, ',');
63   for (size_t i = 0; i < split_index_str.size(); i++) {
64     index->push_back(std::stoul(split_index_str[i]));
65   }
66 }
67 
CalculateDynamicBatchSize(const TensorC * const * inputs,size_t inputs_size,const std::vector<std::vector<int>> & shapes,const std::vector<size_t> & index,int * batch)68 int CalculateDynamicBatchSize(const TensorC *const *inputs, size_t inputs_size,
69                               const std::vector<std::vector<int>> &shapes, const std::vector<size_t> &index,
70                               int *batch) {
71   if (shapes.size() != inputs_size) {
72     MS_LOG(ERROR) << "The saved inputs is not equal to the inputs_size: " << shapes.size() << " vs " << inputs_size;
73     return RET_ERROR;
74   }
75   bool changed = false;
76   for (auto i : index) {
77     if (i >= shapes.size()) {
78       MS_LOG(ERROR) << "The input num is " << shapes.size() << ", but want query index " << i;
79       return RET_ERROR;
80     }
81     if (shapes[i].size() > MAX_SHAPE_SIZE) {
82       MS_LOG(ERROR) << "The input shape size " << shapes[i].size() << " is greater than max size " << MAX_SHAPE_SIZE;
83       return RET_ERROR;
84     }
85     for (size_t j = 0; j < shapes[i].size(); j++) {
86       if (j == 0) {
87         int bs = inputs[i]->shape_[0] / shapes[i][0];
88         if (bs < 0) {
89           MS_LOG(ERROR) << "AKG doesn't support batch size smaller than 1";
90           return RET_ERROR;
91         }
92         if (bs != (*batch)) {
93           if (!changed) {
94             *batch = bs;
95             changed = true;
96           } else {
97             MS_LOG(ERROR) << "AKG doesn't support inputs with different batch size";
98             return RET_ERROR;
99           }
100         }
101       } else if (inputs[i]->shape_[j] != shapes[i][j]) {
102         MS_LOG(ERROR) << "AKG only support dynamic shape on axis 0";
103         return RET_ERROR;
104       }
105     }
106   }
107   return RET_OK;
108 }
109 
SetAnfKernelInfoFormatFromAToB(const AnfNodePtr & node_a,const CNodePtr & node_b,const std::vector<std::string> & formats)110 void SetAnfKernelInfoFormatFromAToB(const AnfNodePtr &node_a, const CNodePtr &node_b,
111                                     const std::vector<std::string> &formats) {
112   std::shared_ptr<device::KernelInfo> kernel_info = nullptr;
113   auto kernel_info_builder = kernel::KernelBuildInfo::KernelBuildInfoBuilder();
114   kernel_info_builder.SetOutputsFormat(formats);
115   if (node_a->kernel_info_ptr() != nullptr) {
116     kernel_info = std::make_shared<device::KernelInfo>();
117   } else {
118     kernel_info = std::dynamic_pointer_cast<device::KernelInfo>(node_a->kernel_info_ptr());
119   }
120   kernel_info->set_select_kernel_build_info(kernel_info_builder.Build());
121   node_b->set_kernel_info(kernel_info);
122 }
123 
SetKernelInfoWithFormatToAnfNode(const AnfNodePtr & node,const std::vector<std::string> & format)124 void SetKernelInfoWithFormatToAnfNode(const AnfNodePtr &node, const std::vector<std::string> &format) {
125   auto kernel_info_builder = kernel::KernelBuildInfo::KernelBuildInfoBuilder();
126   kernel_info_builder.SetOutputsFormat(format);
127   auto kernel_build_info = kernel_info_builder.Build();
128   auto kernel_info = std::make_shared<device::KernelInfo>();
129   kernel_info->set_select_kernel_build_info(kernel_build_info);
130   node->set_kernel_info(kernel_info);
131 }
132 
GetKernelInfo(const AnfNodePtr & node)133 kernel::KernelBuildInfoPtr GetKernelInfo(const AnfNodePtr &node) {
134   if (!node->has_user_data("kernel_info")) {
135     return nullptr;
136   }
137   auto kernel_info_ptr = node->kernel_info_ptr();
138   if (kernel_info_ptr == nullptr) {
139     return nullptr;
140   }
141   auto kernel_info = std::dynamic_pointer_cast<device::KernelInfo>(kernel_info_ptr);
142   if (kernel_info == nullptr) {
143     MS_LOG(ERROR) << "kernel info from " << node->fullname_with_scope() << " is nullptr.";
144     return nullptr;
145   }
146   auto kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo();
147   if (kernel_build_info == nullptr) {
148     MS_LOG(ERROR) << "kernel build info from " << node->fullname_with_scope() << " is nullptr.";
149     return nullptr;
150   }
151   return kernel_build_info;
152 }
153 
GetOutputFormatFromAnfNode(const AnfNodePtr & node,size_t output_idx)154 std::string GetOutputFormatFromAnfNode(const AnfNodePtr &node, size_t output_idx) {
155   auto kernel_build_info = GetKernelInfo(node);
156   if (kernel_build_info == nullptr) {
157     MS_LOG(EXCEPTION) << "kernel build info from " << node->fullname_with_scope() << " is empty.";
158   }
159   auto vec_size = kernel_build_info->GetOutputNum();
160   if (output_idx >= vec_size) {
161     MS_LOG(EXCEPTION) << "Index " << output_idx << " is out of the range of node output vector, output size is "
162                       << kernel_build_info->GetOutputNum() << ". node is " << node->fullname_with_scope();
163   }
164   return kernel_build_info->GetOutputFormat(output_idx);
165 }
166 }  // namespace mindspore::graphkernel
167