• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/common_utils.h"
18 #include <unordered_map>
19 #include <map>
20 #include <iostream>
21 #include <utility>
22 #include <fstream>
23 #include <algorithm>
24 #include <thread>
25 #include "nlohmann/json.hpp"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "utils/ms_utils.h"
28 #include "ir/manager.h"
29 #include "ir/meta_tensor.h"
30 #include "base/core_ops.h"
31 #include "ir/graph_utils.h"
32 #include "utils/ms_context.h"
33 #include "mindspore/ccsrc/debug/common.h"
34 
35 namespace mindspore {
36 namespace kernel {
37 constexpr char kAxis[] = "axis";
38 constexpr char kTypeInt32[] = "Int32";
39 const std::unordered_map<std::string, TypeId> type_id_maps = {{"float", TypeId::kNumberTypeFloat32},
40                                                               {"float16", TypeId::kNumberTypeFloat16},
41                                                               {"float32", TypeId::kNumberTypeFloat32},
42                                                               {"float64", TypeId::kNumberTypeFloat64},
43                                                               {"int", TypeId::kNumberTypeInt},
44                                                               {"int8", TypeId::kNumberTypeInt8},
45                                                               {"int16", TypeId::kNumberTypeInt16},
46                                                               {"int32", TypeId::kNumberTypeInt32},
47                                                               {"int64", TypeId::kNumberTypeInt64},
48                                                               {"uint", TypeId::kNumberTypeUInt},
49                                                               {"uint8", TypeId::kNumberTypeUInt8},
50                                                               {"uint16", TypeId::kNumberTypeUInt16},
51                                                               {"uint32", TypeId::kNumberTypeUInt32},
52                                                               {"uint64", TypeId::kNumberTypeUInt64},
53                                                               {"bool", TypeId::kNumberTypeBool},
54                                                               {"complex64", TypeId::kNumberTypeComplex64},
55                                                               {"complex128", TypeId::kNumberTypeComplex128}};
56 
57 const std::map<TypeId, std::string> type_id_str_map = {{TypeId::kNumberTypeFloat32, "float32"},
58                                                        {TypeId::kNumberTypeFloat16, "float16"},
59                                                        {TypeId::kNumberTypeFloat, "float"},
60                                                        {TypeId::kNumberTypeFloat64, "float64"},
61                                                        {TypeId::kNumberTypeInt, "int"},
62                                                        {TypeId::kNumberTypeInt8, "int8"},
63                                                        {TypeId::kNumberTypeInt16, "int16"},
64                                                        {TypeId::kNumberTypeInt32, "int32"},
65                                                        {TypeId::kNumberTypeInt64, "int64"},
66                                                        {TypeId::kNumberTypeUInt, "uint"},
67                                                        {TypeId::kNumberTypeUInt8, "uint8"},
68                                                        {TypeId::kNumberTypeUInt16, "uint16"},
69                                                        {TypeId::kNumberTypeUInt32, "uint32"},
70                                                        {TypeId::kNumberTypeUInt64, "uint64"},
71                                                        {TypeId::kNumberTypeBool, "bool"},
72                                                        {TypeId::kNumberTypeComplex64, "complex64"},
73                                                        {TypeId::kNumberTypeComplex128, "complex128"}};
74 
75 const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
76   {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"},    {"int16", "i16"},  {"int32", "i32"},
77   {"int64", "i64"},   {"uint8", "u8"},    {"uint16", "u16"},  {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
78 };
79 
80 const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
81   {"float16", sizeof(float) / 2},  {"float32", sizeof(float)},  {"float64", sizeof(float) * 2},
82   {"int8", sizeof(int) / 4},       {"int16", sizeof(int) / 2},  {"int32", sizeof(int)},
83   {"int64", sizeof(int) * 2},      {"uint8", sizeof(int) / 4},  {"uint16", sizeof(int) / 2},
84   {"uint32", sizeof(int)},         {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
85   {"complex64", sizeof(float) * 2}};
86 
87 // Define all patterns here for different schedule
88 const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
89   {FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
90   {FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
91   {FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
92   {FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
93   {FusionType::ELEMWISE, "ElemWise"},
94   {FusionType::PURE_BROADCAST, "PureBroadcast"},
95   {FusionType::COMMREDUCE, "CommReduce"},
96   {FusionType::SEGMENT, "Segment"},
97   {FusionType::INPLACE, "Inplace"},
98   {FusionType::MATMUL, "Matmul"},
99   {FusionType::MATMUL_V2, "Matmul_v2"},
100   {FusionType::GEMM, "GEMM"},
101   {FusionType::CONV, "Convolution"},
102   {FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
103   {FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
104   {FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
105   {FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
106   {FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
107   {FusionType::OPAQUE, "Opaque"},
108   {FusionType::BN_REDUCE, "bn_reduce"},
109   {FusionType::BN_UPDATE, "bn_update"},
110   {FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
111   {FusionType::L2_NORMALIZE, "l2_normalize"},
112   {FusionType::SOFTMAX, "softmax_pattern"},
113   {FusionType::L2_LOSS, "l2_loss"},
114   {FusionType::ASCEND_QUANT, "quant"},
115   {FusionType::ASCEND_DEQUANT, "dequant"},
116   {FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
117   {FusionType::STRIDED_READ, "strided_read"},
118   {FusionType::STRIDED_WRITE, "strided_write"},
119   {FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
120   {FusionType::ASCEND_REQUANT, "requant"},
121   {FusionType::ASCEND_REQUANT_S16, "requant_s16"},
122   {FusionType::MAX_POOL, "MaxPool"},
123   {FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
124   {FusionType::CONV3D, "Conv3d"},
125   {FusionType::POOL2D, "Pool2d"},
126   {FusionType::POOL3D, "Pool3d"},
127   {FusionType::READ_SELECT, "read_select"},
128   {FusionType::WRITE_SELECT, "write_select"},
129   {FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
130   {FusionType::DILATION_PATTERN, "dilation"},
131   {FusionType::BROAD_CAST, "Broadcast"},
132   {FusionType::BATCH_MATMUL, "BatchMatmul"},
133   {FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
134   {FusionType::UNKNOWN_FUSION_TYPE, ""}};
135 
GetFusionNameByType(const kernel::FusionType & type)136 std::string GetFusionNameByType(const kernel::FusionType &type) {
137   auto iter = fusion_type_name_maps.find(type);
138   if (iter == fusion_type_name_maps.end()) {
139     MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
140   }
141   return iter->second;
142 }
143 
GetFusionTypeByName(const std::string & name)144 FusionType GetFusionTypeByName(const std::string &name) {
145   std::string fusion_name_upper = name;
146   transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
147   auto iter =
148     std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
149       std::string name_upper = it.second;
150       transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
151       return fusion_name_upper == name_upper;
152     });
153   if (iter == fusion_type_name_maps.end()) {
154     MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
155   }
156   return iter->first;
157 }
158 
Initialize()159 void KernelMeta::Initialize() {
160   kernel_meta_path_ = std::string(kGpuKernelMeta) + "/";
161 
162 #if defined(_WIN32) || defined(_WIN64)
163   auto ret = mkdir(kernel_meta_path_.c_str());
164 #else
165   auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
166 #endif
167   if (ret != 0) {
168     MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
169   }
170   initialized_ = true;
171 }
172 
Search(const std::string & kernel_name) const173 std::string KernelMeta::Search(const std::string &kernel_name) const {
174   if (!initialized_) {
175     return "";
176   }
177 
178   auto iter = kernel_meta_map_.find(kernel_name);
179   if (iter == kernel_meta_map_.end()) {
180     return "";
181   } else {
182     return iter->second;
183   }
184 }
185 
Insert(const std::string & kernel_name,const std::string & kernel_json)186 bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
187   if (!initialized_) {
188     return false;
189   }
190   kernel_meta_map_[kernel_name] = kernel_json;
191   return true;
192 }
193 
CheckCache(const std::string & kernel_name)194 bool CheckCache(const std::string &kernel_name) {
195   // check cache.
196   KernelMeta *bin_map = KernelMeta::GetInstance();
197   if (bin_map == nullptr) {
198     MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
199     return false;
200   }
201   std::string kernel_json = bin_map->Search(kernel_name);
202   bool ret = (!kernel_json.empty());
203   if (ret) {
204     MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
205   } else {
206     MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
207   }
208   return ret;
209 }
210 
SearchCache(const std::string & kernel_name,const std::string & processor)211 KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
212   // search cache.
213   KernelMeta *bin_map = KernelMeta::GetInstance();
214   if (bin_map == nullptr) {
215     MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
216     return nullptr;
217   }
218 
219   std::string kernel_json = bin_map->Search(kernel_name);
220   if (!kernel_json.empty()) {
221     KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
222     // just a tmp solution.
223     if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
224       MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
225       return nullptr;
226     } else {
227       return kernel_pack;
228     }
229   } else {
230     MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
231     return nullptr;
232   }
233 }
234 
InsertCache(const std::string & kernel_name,const std::string & processor)235 KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
236   MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
237   KernelMeta *bin_map = KernelMeta::GetInstance();
238   std::string kernel_json;
239   if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
240     kernel_json = kCceKernelMeta;
241   } else {
242     kernel_json = bin_map->kernel_meta_path();
243   }
244   (void)kernel_json.append(kernel_name).append(kJsonSuffix);
245   KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
246   if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
247     MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
248     return nullptr;
249   }
250 
251   if (bin_map == nullptr) {
252     MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
253     return nullptr;
254   }
255   if (bin_map->Insert(kernel_name, kernel_json)) {
256     MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
257   }
258   return kernel_pack;
259 }
260 
DtypeToTypeId(const std::string & dtypes)261 TypeId DtypeToTypeId(const std::string &dtypes) {
262   auto iter = type_id_maps.find(dtypes);
263   if (iter != type_id_maps.end()) {
264     return iter->second;
265   } else {
266     MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
267   }
268 }
269 
TypeId2String(TypeId type_id,bool unknown_as_default)270 std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
271   auto iter = type_id_str_map.find(type_id);
272   if (iter == type_id_str_map.end()) {
273     if (!unknown_as_default) {
274       MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
275     }
276     MS_LOG(INFO) << "Using default dtype: float32";
277     return "float32";
278   }
279   return iter->second;
280 }
281 
Dtype2ShortType(const std::string & dtype)282 std::string Dtype2ShortType(const std::string &dtype) {
283   auto iter = dtype_shortdtype_map_.find(dtype);
284   if (iter != dtype_shortdtype_map_.end()) {
285     return iter->second;
286   } else {
287     MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
288   }
289 }
290 
GetDtypeNbyte(const std::string & dtype)291 size_t GetDtypeNbyte(const std::string &dtype) {
292   auto iter = dtype_nbyte_map.find(dtype);
293   if (iter != dtype_nbyte_map.end()) {
294     return iter->second;
295   } else {
296     MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
297   }
298 }
299 
SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> & inputs,size_t real_input_num,size_t builder_idex,const std::vector<int64_t> & dyn_input_sizes,const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder)300 bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
301                                size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
302                                const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
303   MS_EXCEPTION_IF_NULL(builder);
304 
305   std::vector<TypeId> inputs_device_type;
306   std::vector<std::string> inputs_format;
307   size_t dyn_input_idx = 0;
308   size_t kernel_info_index = 0;
309   MS_EXCEPTION_IF_NULL(inputs[0]);
310   size_t kernel_info_cnt = inputs[0]->dtypes().size();
311 
312   for (const auto &input : inputs) {
313     MS_EXCEPTION_IF_NULL(input);
314     std::string param_type = input->param_type();
315     std::vector<std::string> dtypes = input->dtypes();
316     std::vector<std::string> formats = input->formats();
317     if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
318       MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size != formats size. dtypes size: "
319                     << dtypes.size() << ", formats size : " << formats.size();
320       return false;
321     }
322 
323     if (param_type == "dynamic") {
324       if (dyn_input_sizes.empty()) {
325         MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
326         return false;
327       }
328 
329       for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
330         kernel_info_index++;
331         auto type_id = DtypeToTypeId(dtypes[builder_idex]);
332         inputs_device_type.push_back(type_id);
333         inputs_format.push_back(formats[builder_idex]);
334       }
335       dyn_input_idx++;
336     } else if (param_type == "required") {
337       kernel_info_index++;
338       auto type_id = DtypeToTypeId(dtypes[builder_idex]);
339       inputs_device_type.push_back(type_id);
340       inputs_format.push_back(formats[builder_idex]);
341     } else {
342       if (kernel_info_index < real_input_num) {
343         MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
344         kernel_info_index++;
345         auto type_id = DtypeToTypeId(dtypes[builder_idex]);
346         inputs_device_type.push_back(type_id);
347         inputs_format.push_back(formats[builder_idex]);
348       }
349     }
350   }
351 
352   builder->SetInputsDeviceType(inputs_device_type);
353   builder->SetInputsFormat(inputs_format);
354   return true;
355 }
356 
SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> & outputs,size_t builder_idex,const size_t & real_output_num,const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder)357 bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
358                                 const size_t &real_output_num,
359                                 const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
360   // not now but in the next we need to support dynamic output case
361   MS_EXCEPTION_IF_NULL(builder);
362 
363   size_t output_idx = 0;
364   std::vector<TypeId> outputs_device_type;
365   std::vector<std::string> outputs_format;
366   MS_EXCEPTION_IF_NULL(outputs[0]);
367   size_t kernel_info_cnt = outputs[0]->dtypes().size();
368 
369   for (const auto &output : outputs) {
370     MS_EXCEPTION_IF_NULL(output);
371     if (output_idx >= real_output_num) {
372       MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
373       continue;
374     }
375     size_t output_num = 0;
376     if (output->param_type() == "dynamic") {
377       if (outputs.size() > 1) {
378         MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
379       }
380       output_num = real_output_num;
381     } else if (output->param_type() == "required") {
382       output_num = 1;
383     } else {
384       if (output_idx < real_output_num) {
385         MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
386         output_num = 1;
387       }
388     }
389 
390     for (size_t i = 0; i < output_num; i++) {
391       std::vector<std::string> dtypes = output->dtypes();
392       std::vector<std::string> formats = output->formats();
393       if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
394         MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
395         return false;
396       }
397       auto type_id = DtypeToTypeId(dtypes[builder_idex]);
398       outputs_device_type.push_back(type_id);
399       outputs_format.push_back(formats[builder_idex]);
400       output_idx++;
401     }
402   }
403 
404   builder->SetOutputsFormat(outputs_format);
405   builder->SetOutputsDeviceType(outputs_device_type);
406   return true;
407 }
408 
SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> & builder,Processor processor,const std::shared_ptr<const OpInfo> & op_info_ptr)409 void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
410                         const std::shared_ptr<const OpInfo> &op_info_ptr) {
411   MS_EXCEPTION_IF_NULL(builder);
412   MS_EXCEPTION_IF_NULL(op_info_ptr);
413 
414   auto imply_type = op_info_ptr->imply_type();
415   builder->SetProcessor(processor);
416   std::string fusion_name = op_info_ptr->fusion_type();
417   auto fusion_type = GetFusionTypeByName(fusion_name);
418   builder->SetFusionType(fusion_type);
419 
420   if (imply_type == kAKG) {
421     builder->SetKernelType(AKG_KERNEL);
422   } else if (imply_type == kAICPU) {
423     builder->SetKernelType(AICPU_KERNEL);
424   } else {
425     builder->SetKernelType(TBE_KERNEL);
426   }
427 }
428 
ParseMetadata(const CNodePtr & kernel_node,const std::shared_ptr<const OpInfo> & op_info_ptr,Processor processor,std::vector<std::shared_ptr<KernelBuildInfo>> * const kernel_info_list)429 bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
430                    std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
431   MS_EXCEPTION_IF_NULL(kernel_node);
432   MS_EXCEPTION_IF_NULL(kernel_info_list);
433   size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
434   size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
435   std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
436   std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
437   std::vector<int64_t> dyn_input_sizes;
438   auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
439   MS_EXCEPTION_IF_NULL(primitive);
440   auto op_name = AnfAlgo::GetCNodeName(kernel_node);
441   if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
442     dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
443   }
444   if (inputs.size() > 0) {
445     if (inputs[0] == nullptr) {
446       MS_LOG(EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
447     }
448     size_t kernel_info_cnt = inputs[0]->dtypes().size();
449     for (size_t j = 0; j < kernel_info_cnt; j++) {
450       auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
451       MS_EXCEPTION_IF_NULL(builder);
452       SetKernelBuildInfo(builder, processor, op_info_ptr);
453 
454       if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
455         MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
456         return false;
457       }
458 
459       if (outputs.size() > 0) {
460         if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
461           MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
462           return false;
463         }
464       }
465 
466       kernel_info_list->push_back(builder->Build());
467     }
468   } else if (outputs.size() > 0) {
469     if (outputs[0] == nullptr) {
470       MS_LOG(EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
471     }
472     size_t kernel_info_cnt = outputs[0]->dtypes().size();
473     for (size_t j = 0; j < kernel_info_cnt; j++) {
474       auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
475       MS_EXCEPTION_IF_NULL(builder);
476       SetKernelBuildInfo(builder, processor, op_info_ptr);
477 
478       if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
479         MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
480         return false;
481       }
482 
483       kernel_info_list->push_back(builder->Build());
484     }
485   } else {
486     if (processor == AICPU) {
487       auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
488       MS_EXCEPTION_IF_NULL(builder);
489       SetKernelBuildInfo(builder, processor, op_info_ptr);
490       kernel_info_list->push_back(builder->Build());
491     }
492   }
493   return true;
494 }
495 
SaveJsonInfo(const std::string & json_name,const std::string & info,const std::string & base_path)496 void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
497   std::string path = base_path + json_name + kInfoSuffix;
498   auto realpath = Common::CreatePrefixPath(path);
499   if (!realpath.has_value()) {
500     MS_LOG(ERROR) << "Get real path failed, path=" << path;
501     return;
502   }
503   ChangeFileMode(realpath.value(), S_IWUSR);
504   std::ofstream filewrite(realpath.value());
505   if (!filewrite.is_open()) {
506     MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
507     return;
508   }
509   filewrite << info << std::endl;
510   filewrite.close();
511   ChangeFileMode(realpath.value(), S_IRUSR);
512 }
513 
GetProcessor(const string & processor)514 Processor GetProcessor(const string &processor) {
515   if (processor == kProcessorAiCore) return Processor::AICORE;
516   if (processor == kProcessorAiCpu) return Processor::AICPU;
517   if (processor == kProcessorCuda) return Processor::CUDA;
518   MS_LOG(DEBUG) << "Unknown processor type.";
519   return Processor::UNKNOWN;
520 }
521 
GetProcessor(const AnfNodePtr & anf_node)522 std::string GetProcessor(const AnfNodePtr &anf_node) {
523   MS_EXCEPTION_IF_NULL(anf_node);
524   std::string device;
525   switch (AnfAlgo::GetProcessor(anf_node)) {
526     case Processor::AICORE:
527       device = kProcessorAiCore;
528       break;
529 
530     case Processor::AICPU:
531       device = kProcessorAiCpu;
532       break;
533 
534     case Processor::CUDA:
535       device = kProcessorCuda;
536       break;
537 
538     default:
539       MS_LOG(DEBUG) << "Unknown processor type.";
540       break;
541   }
542   return device;
543 }
544 
IsSameShape(const std::vector<size_t> & shape_a,const std::vector<size_t> & shape_b)545 bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
546   if (shape_a.size() != shape_b.size()) {
547     return false;
548   }
549   for (size_t i = 0; i < shape_a.size(); ++i) {
550     if (shape_a[i] != shape_b[i]) {
551       return false;
552     }
553   }
554   return true;
555 }
556 
Sign(float x)557 int Sign(float x) {
558   if (x > 0) {
559     return 1;
560   }
561   if (x < 0) {
562     return -1;
563   }
564   return 0;
565 }
566 
GetKernelInput(const AnfNodePtr & anf_node,size_t index)567 std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
568   MS_EXCEPTION_IF_NULL(anf_node);
569 
570   if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
571     MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs. Node info : ["
572                                 << anf_node->DebugString() << "]";
573   }
574 
575   auto cnode = anf_node->cast<CNodePtr>();
576   if (cnode == nullptr) {
577     return AnfAlgo::VisitKernel(anf_node, 0);
578   } else {
579     return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
580   }
581 }
582 
GetInputIndex(const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list)583 std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
584                                                                             const std::vector<AnfNodePtr> &input_list) {
585   std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
586   for (size_t i = 0; i < input_list.size(); ++i) {
587     auto const &input = input_list[i];
588     MS_EXCEPTION_IF_NULL(input);
589     bool found = false;
590     auto mng = input->func_graph()->manager();
591     MS_EXCEPTION_IF_NULL(mng);
592     const NodeUsersMap &users = mng->node_users();
593     auto input_users = users.find(input);
594     if (input_users == users.end() || input_users->second.empty()) {
595       MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
596                                   << input->func_graph()->ToString() << "] has no users.";
597     }
598 
599     for (auto const &input_user : input_users->second) {
600       for (auto const &anf_node : node_list) {
601         if (anf_node != input_user.first) {
602           continue;
603         }
604 
605         std::vector<int64_t> dyn_input_sizes;
606         auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
607         MS_EXCEPTION_IF_NULL(prim);
608         if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
609           dyn_input_sizes = GetValue<const std::vector<int64_t>>(prim->GetAttr(kAttrDynInputSizes));
610         }
611 
612         if (dyn_input_sizes.empty()) {
613           (void)input_index.emplace_back(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0));
614           found = true;
615           break;
616         }
617         int used_as_idx = input_user.second - 1;
618         int accum_idx = 0;
619         size_t dyn_i = 0;
620         for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
621           accum_idx += LongToInt(dyn_input_sizes[dyn_i]);
622           if (used_as_idx < accum_idx) {
623             (void)input_index.emplace_back(
624               anf_node,
625               std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - LongToInt(dyn_input_sizes[dyn_i])))));
626             break;
627           }
628           if (dyn_i != dyn_input_sizes.size()) {
629             found = true;
630             break;
631           }
632         }
633       }
634       if (found) {
635         break;
636       }
637     }
638 
639     if (!found) {
640       MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
641                                   << input->func_graph()->ToString() << "] found no related kernel info.";
642     }
643   }
644   return input_index;
645 }
646 
GetOutputIndex(const std::vector<AnfNodePtr> & node_list,const std::vector<AnfNodePtr> & input_list,const std::vector<AnfNodePtr> & output_list)647 std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
648                                                           const std::vector<AnfNodePtr> &input_list,
649                                                           const std::vector<AnfNodePtr> &output_list) {
650   std::vector<std::pair<AnfNodePtr, size_t>> output_index;
651   for (size_t i = 0; i < output_list.size(); ++i) {
652     auto const &output = output_list[i];
653     MS_EXCEPTION_IF_NULL(output);
654     bool found = false;
655     auto pree_node = AnfAlgo::VisitKernel(output, 0);
656     auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
657     if (pos != std::end(node_list)) {
658       output_index.push_back(pree_node);
659       continue;
660     }
661     auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
662     if (ret != std::end(input_list)) {
663       output_index.push_back(std::make_pair(pree_node.first, 0));
664       found = true;
665     }
666     if (!found) {
667       MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
668                                   << output->func_graph()->ToString() << "] found no related kernel info.";
669     }
670   }
671   return output_index;
672 }
673 
GetValidKernelNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * node_list)674 void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
675   MS_EXCEPTION_IF_NULL(node_list);
676   MS_EXCEPTION_IF_NULL(func_graph);
677   std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
678   for (auto const &node : node_lists) {
679     if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
680       continue;
681     }
682     auto cnode = node->cast<CNodePtr>();
683     MS_EXCEPTION_IF_NULL(cnode);
684     if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
685       node_list->push_back(node);
686     }
687   }
688 }
689 
GetValidKernelNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * node_list,std::vector<AnfNodePtr> * input_list,std::vector<AnfNodePtr> * output_list)690 void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
691                          std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
692   MS_EXCEPTION_IF_NULL(func_graph);
693   MS_EXCEPTION_IF_NULL(node_list);
694   MS_EXCEPTION_IF_NULL(input_list);
695 
696   GetValidKernelNodes(func_graph, node_list);
697 
698   auto parameters = func_graph->parameters();
699   input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
700 
701   GetFuncGraphOutputNodes(func_graph, output_list);
702 }
703 
GetFuncGraphOutputNodes(const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * output_list)704 void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
705   MS_EXCEPTION_IF_NULL(func_graph);
706   MS_EXCEPTION_IF_NULL(output_list);
707   auto func_output = func_graph->output();
708   MS_EXCEPTION_IF_NULL(func_output);
709   if (func_output->isa<CNode>()) {
710     // multi output.
711     auto cnode = func_output->cast<CNodePtr>();
712     MS_EXCEPTION_IF_NULL(cnode);
713     auto input0 = cnode->input(kAnfPrimitiveIndex);
714     MS_EXCEPTION_IF_NULL(input0);
715     if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
716       for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
717         auto input_node = cnode->input(input_idx);
718         MS_EXCEPTION_IF_NULL(input_node);
719         if (input_node->isa<CNode>() && AnfAlgo::GetInputTensorNum(input_node) == 0) {
720           continue;
721         }
722         output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
723       }
724     } else {
725       // single output.
726       output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
727     }
728   } else {
729     // single output.
730     output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
731   }
732 }
733 
GetInputTensorValue(const AnfNodePtr & anf_node,size_t input_idx,nlohmann::json * const node_json)734 bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
735   MS_EXCEPTION_IF_NULL(anf_node);
736   MS_EXCEPTION_IF_NULL(node_json);
737   auto cnode = anf_node->cast<CNodePtr>();
738   MS_EXCEPTION_IF_NULL(cnode);
739   if (input_idx + 1 >= cnode->size()) {
740     MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
741                                 << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
742   }
743 
744   auto input_node = cnode->input(input_idx + 1);
745   if (!IsValueNode<tensor::Tensor>(input_node)) {
746     return false;
747   }
748 
749   auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
750   if (tensor == nullptr) {
751     MS_LOG(DEBUG) << "Value of input node is nullptr, op: [" << input_node->DebugString() << "]";
752     return false;
753   }
754 
755   auto type_id = tensor->data_type();
756   auto *data = tensor->data_c();
757   MS_EXCEPTION_IF_NULL(data);
758   if (tensor->DataSize() > 1) {
759     // not const tensor.
760     MS_LOG(WARNING) << "Not take value of tensor whose datasize greater than 1, [" << input_node->DebugString(2) << "]";
761     return false;
762   }
763 
764   if (type_id == kFloat64->type_id()) {
765     (*node_json)["value"] = static_cast<double *>(data)[0];
766   } else if (type_id == kFloat32->type_id()) {
767     (*node_json)["value"] = static_cast<float *>(data)[0];
768   } else if (type_id == kFloat16->type_id()) {
769     float16 *val = static_cast<float16 *>(data);
770     (*node_json)["value"] = static_cast<float>(val[0]);
771   } else if (type_id == kUInt64->type_id()) {
772     (*node_json)["value"] = static_cast<uint64_t *>(data)[0];
773   } else if (type_id == kUInt32->type_id()) {
774     (*node_json)["value"] = static_cast<uint32_t *>(data)[0];
775   } else if (type_id == kUInt16->type_id()) {
776     (*node_json)["value"] = static_cast<uint16_t *>(data)[0];
777   } else if (type_id == kUInt8->type_id()) {
778     (*node_json)["value"] = static_cast<uint8_t *>(data)[0];
779   } else if (type_id == kInt64->type_id()) {
780     (*node_json)["value"] = static_cast<int64_t *>(data)[0];
781   } else if (type_id == kInt32->type_id()) {
782     (*node_json)["value"] = static_cast<int32_t *>(data)[0];
783   } else if (type_id == kInt16->type_id()) {
784     (*node_json)["value"] = static_cast<int16_t *>(data)[0];
785   } else if (type_id == kInt8->type_id()) {
786     (*node_json)["value"] = static_cast<int8_t *>(data)[0];
787   } else if (type_id == kBool->type_id()) {
788     (*node_json)["value"] = static_cast<bool *>(data)[0];
789   } else {
790     MS_LOG(EXCEPTION) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
791   }
792   return true;
793 }
794 
IsWeightBoundary(const AnfNodePtr & node)795 bool IsWeightBoundary(const AnfNodePtr &node) {
796   if (node->isa<ValueNode>()) {
797     return true;
798   }
799   if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
800     return true;
801   }
802   return false;
803 }
804 
GetReduceAttrAxis(const CNodePtr & cnode)805 std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
806   if (AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputTensorNum(cnode) != 1) {
807     MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output.";
808   }
809   std::vector<int64_t> axis;
810   auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
811   auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
812   MS_EXCEPTION_IF_NULL(primitive);
813   auto axis_attr = primitive->GetAttr(kAxis);
814   if (axis_attr == nullptr) {
815     MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
816     return std::vector<int64_t>();
817   }
818   std::vector<int64_t> axis_list;
819   if (axis_attr->isa<Int64Imm>()) {
820     (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
821   } else {
822     axis_list = GetValue<std::vector<int64_t>>(axis_attr);
823   }
824   for (const auto &elem : axis_list) {
825     if (elem < 0) {
826       (void)axis.emplace_back(input_shape.size() + elem);
827     } else {
828       (void)axis.emplace_back(elem);
829     }
830   }
831   AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
832   return axis;
833 }
834 
GetProcessorStr(const AnfNodePtr & anf_node)835 std::string GetProcessorStr(const AnfNodePtr &anf_node) {
836   MS_EXCEPTION_IF_NULL(anf_node);
837   std::string processor = kProcessorUnknown;
838   auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
839   MS_EXCEPTION_IF_NULL(kernel_info);
840   auto build_info = kernel_info->select_kernel_build_info();
841   // we may call this before kernel select.
842   if (build_info == nullptr) {
843     return processor;
844   }
845   switch (build_info->processor()) {
846     case Processor::AICORE:
847       processor = kProcessorAiCore;
848       break;
849 
850     case Processor::AICPU:
851       processor = kProcessorAiCpu;
852       break;
853 
854     case Processor::CUDA:
855       processor = kProcessorCuda;
856       break;
857 
858     default:
859       MS_LOG(ERROR) << "Unknown processor type.";
860       break;
861   }
862 
863   return processor;
864 }
865 
GetProcessorFromContext()866 Processor GetProcessorFromContext() {
867   kernel::Processor processor = kernel::Processor::UNKNOWN;
868   auto context_ptr = MsContext::GetInstance();
869   MS_EXCEPTION_IF_NULL(context_ptr);
870   auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
871   if (device_info == kGPUDevice) {
872     processor = kernel::Processor::CUDA;
873   } else if (device_info == kAscendDevice) {
874     processor = kernel::Processor::AICORE;
875   }
876   return processor;
877 }
878 
GetStrProcessorFromContext()879 std::string GetStrProcessorFromContext() {
880   auto processor = GetProcessorFromContext();
881   string str_processor = kernel::kProcessorUnknown;
882   if (processor == kernel::Processor::CUDA) {
883     str_processor = kernel::kProcessorCuda;
884   } else if (processor == kernel::Processor::AICORE) {
885     str_processor = kernel::kProcessorAiCore;
886   }
887   return str_processor;
888 }
889 
Scaling(size_t in_size,size_t out_size,bool align_corners)890 float Scaling(size_t in_size, size_t out_size, bool align_corners) {
891   return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
892                                          : in_size / static_cast<float>(out_size);
893 }
894 
ScaleGrid(const int x,const float scale)895 float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
896 
ComputeInterpolationWeights(const size_t out_size,const size_t in_size,const float scale,CachedInterpolation * interpolation)897 void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
898                                  CachedInterpolation *interpolation) {
899   interpolation[out_size].lower = 0;
900   interpolation[out_size].upper = 0;
901   for (size_t i = 0; i <= out_size - 1; ++i) {
902     const float in = ScaleGrid(i, scale);
903     const float in_f = std::floor(in);
904     interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
905     interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
906     interpolation[i].lerp = in - in_f;
907   }
908 }
909 
GetShapeSize(const std::vector<size_t> & shape,const TypePtr & type_ptr,int64_t * size_i)910 bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int64_t *size_i) {
911   MS_EXCEPTION_IF_NULL(type_ptr);
912   size_t type_byte = GetTypeByte(type_ptr);
913   if (type_byte == 0) {
914     return false;
915   }
916   for (size_t j = 0; j < shape.size(); j++) {
917     size_i[0] = LongMulWithOverflowCheck(size_i[0], static_cast<int>(shape[j]));
918   }
919   size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
920   return true;
921 }
922 
CastShapeSizeToLong(const std::vector<size_t> & shape,std::vector<int64_t> * long_shape)923 void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape) {
924   MS_EXCEPTION_IF_NULL(long_shape);
925   (void)std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong);
926 }
927 
CheckSliceValid(const std::vector<int64_t> & start,const std::vector<int64_t> & stop,const std::vector<int64_t> & step,const std::vector<int64_t> & input_shape)928 void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
929                      const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape) {
930   if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) {
931     MS_LOG(EXCEPTION)
932       << "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension.";
933   }
934 
935   size_t size = start.size();
936   for (size_t i = 0; i < size; ++i) {
937     if (stop[i] <= start[i]) {
938       MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")";
939     }
940     // Operator need to be generalized in the future. Only support to copy continuous memory now.
941     if (step[i] != 1) {
942       MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step;
943     }
944   }
945 
946   size_t slice_pos = size;
947   for (size_t i = 0; i < size; ++i) {
948     if (stop[i] - start[i] > 1) {
949       slice_pos = i;
950       break;
951     }
952   }
953 
954   for (size_t i = slice_pos + 1; i < size; ++i) {
955     if (stop[i] - start[i] != input_shape[i]) {
956       MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, "
957                            "but tensor[0:100, 0] is not supported.";
958     }
959   }
960 }
961 
GetCopySize(const std::vector<int64_t> & dim_offset,const std::vector<int64_t> & start,const std::vector<int64_t> & stop)962 size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
963                    const std::vector<int64_t> &stop) {
964   for (size_t i = 0; i < start.size(); ++i) {
965     if (stop[i] - start[i] != 1) {
966       return SizetMulWithOverflowCheck(LongToSize(stop[i] - start[i]), LongToSize(dim_offset[i]));
967     }
968   }
969   return LongToSize(dim_offset[start.size() - 1]);
970 }
971 
CalDimOffset(const std::vector<int64_t> & input_shape)972 std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape) {
973   std::vector<int64_t> dim_offset;
974   int64_t offset = 1;
975   for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) {
976     dim_offset.push_back(offset);
977     offset = offset * (*iter);
978   }
979   std::reverse(dim_offset.begin(), dim_offset.end());
980   return dim_offset;
981 }
982 
CalOffset(const std::vector<int64_t> & start,const std::vector<int64_t> & stop,const std::vector<int64_t> & dim_offset)983 size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
984                  const std::vector<int64_t> &dim_offset) {
985   size_t size = start.size();
986   size_t offset = 0;
987   for (size_t i = 0; i < size; ++i) {
988     offset += SizetMulWithOverflowCheck(LongToSize(dim_offset[i]), LongToSize(start[i]));
989     if (stop[i] - start[i] != 1) {
990       break;
991     }
992   }
993   return offset;
994 }
995 
UnitSizeInBytes(const mindspore::TypeId & t)996 size_t UnitSizeInBytes(const mindspore::TypeId &t) {
997   size_t bytes = 0;
998   switch (t) {
999     case kNumberTypeBool:
1000     case kNumberTypeInt8:
1001     case kNumberTypeUInt8:
1002       bytes = sizeof(int8_t);
1003       break;
1004     case kNumberTypeInt16:
1005     case kNumberTypeUInt16:
1006     case kNumberTypeFloat16:
1007       bytes = sizeof(int16_t);
1008       break;
1009     case kNumberTypeInt:
1010     case kNumberTypeUInt:
1011     case kNumberTypeInt32:
1012     case kNumberTypeUInt32:
1013     case kNumberTypeFloat:
1014     case kNumberTypeFloat32:
1015       bytes = sizeof(int32_t);
1016       break;
1017     case kNumberTypeUInt64:
1018     case kNumberTypeInt64:
1019     case kNumberTypeFloat64:
1020       bytes = sizeof(int64_t);
1021       break;
1022     default:
1023       MS_LOG(EXCEPTION) << "Invalid types " << t;
1024       break;
1025   }
1026 
1027   return bytes;
1028 }
1029 }  // namespace kernel
1030 }  // namespace mindspore
1031