• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
18 #include <cuda_runtime_api.h>
19 #include <map>
20 #include <unordered_set>
21 #include <numeric>
22 #include <functional>
23 #include <iomanip>
24 #include <algorithm>
25 #include "src/extendrt/delegate/tensorrt/op/cast_plugin.h"
26 #include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h"
27 
28 namespace mindspore::lite {
ConvertCudaDims(int data,size_t size)29 nvinfer1::Dims ConvertCudaDims(int data, size_t size) {
30   nvinfer1::Dims dims{};
31   dims.nbDims = -1;
32   if (size > static_cast<size_t>(dims.MAX_DIMS)) {
33     MS_LOG(ERROR) << "invalid shape size: " << size;
34     return dims;
35   }
36   dims.nbDims = size;
37   for (size_t i = 0; i < size; i++) {
38     dims.d[i] = data;
39   }
40   return dims;
41 }
42 
43 template <typename T>
ConvertCudaDimsWithType(const void * data,int64_t size)44 nvinfer1::Dims ConvertCudaDimsWithType(const void *data, int64_t size) {
45   nvinfer1::Dims dims{};
46   dims.nbDims = -1;
47   if (size > static_cast<int64_t>(dims.MAX_DIMS)) {
48     MS_LOG(ERROR) << "invalid shape size: " << size;
49     return dims;
50   }
51   dims.nbDims = size;
52 
53   auto *dims_data = static_cast<const T *>(data);
54   for (int i = 0; i < size; i++) {
55     dims.d[i] = static_cast<const int>(*(dims_data + i));
56   }
57   return dims;
58 }
59 
ConvertCudaDims(const std::vector<int> & data)60 nvinfer1::Dims ConvertCudaDims(const std::vector<int> &data) {
61   auto dims = ConvertCudaDimsWithType<int>(data.data(), data.size());
62   return dims;
63 }
64 
ConvertCudaDims(const TensorInfo & ms_tensor)65 nvinfer1::Dims ConvertCudaDims(const TensorInfo &ms_tensor) {
66   auto data = ms_tensor.Data();
67   auto size = ms_tensor.ElementNum();
68   auto ms_dtype = ms_tensor.DataType();
69 
70   nvinfer1::Dims dims{};
71   if (ms_dtype == DataType::kNumberTypeInt32) {
72     dims = ConvertCudaDimsWithType<int>(data, size);
73   } else if (ms_dtype == DataType::kNumberTypeInt64) {
74     dims = ConvertCudaDimsWithType<int64_t>(data, size);
75   } else {
76     MS_LOG(ERROR) << "invalid DataType: " << ms_dtype;
77   }
78   return dims;
79 }
80 
CudaDimsAsString(const nvinfer1::Dims & dims)81 std::string CudaDimsAsString(const nvinfer1::Dims &dims) {
82   std::stringstream str_stream;
83   str_stream << "[" << dims.nbDims << ":";
84   if (dims.nbDims > 0) {
85     for (int i = 0; i < dims.nbDims; i++) {
86       str_stream << dims.d[i];
87       if (i + 1 != dims.nbDims) {
88         str_stream << ",";
89       }
90     }
91   }
92   str_stream << "]";
93   return str_stream.str();
94 }
95 
ConvertTensorAsIntVector(const TensorInfo & ms_tensor)96 std::vector<int32_t> ConvertTensorAsIntVector(const TensorInfo &ms_tensor) {
97   if (!ms_tensor.IsConst()) {
98     MS_LOG(ERROR) << "Expect tensor to be const tensor, but got var tensor";
99     return {};
100   }
101   auto data = ms_tensor.Data();
102   if (data == nullptr) {
103     MS_LOG(ERROR) << "Const data cannot be nullptr";
104     return {};
105   }
106   std::vector<int32_t> vals;
107   auto ms_dtype = ms_tensor.DataType();
108   auto size = ms_tensor.ElementNum();
109   if (ms_dtype == DataType::kNumberTypeInt32 || static_cast<TypeId>(ms_dtype) == TypeId::kMetaTypeTypeType) {
110     auto int_data = reinterpret_cast<const int32_t *>(data);
111     for (int64_t i = 0; i < size; i++) {
112       vals.push_back(int_data[i]);
113     }
114   } else if (ms_dtype == DataType::kNumberTypeInt64) {
115     auto int_data = reinterpret_cast<const int64_t *>(data);
116     for (int64_t i = 0; i < size; i++) {
117       vals.push_back((int32_t)int_data[i]);
118     }
119   } else {
120     MS_LOG(ERROR) << "invalid DataType: " << ms_dtype;
121   }
122   return vals;
123 }
124 
SameDims(nvinfer1::Dims dims,const std::vector<int64_t> & shape)125 bool SameDims(nvinfer1::Dims dims, const std::vector<int64_t> &shape) {
126   if (dims.nbDims != static_cast<int>(shape.size())) {
127     return false;
128   }
129   // dynamic dim, only channel dim know
130   for (int i = 0; i < dims.nbDims; i++) {
131     if (dims.d[i] == -1) {
132       continue;
133     }
134     if (dims.d[i] != shape[i]) {
135       return false;
136     }
137   }
138   return true;
139 }
140 
ConvertMSShape(const nvinfer1::Dims dims)141 std::vector<int64_t> ConvertMSShape(const nvinfer1::Dims dims) {
142   std::vector<int64_t> shape;
143   for (int i = 0; i < dims.nbDims; i++) {
144     shape.push_back(dims.d[i]);
145   }
146   return shape;
147 }
148 
NHWC2NCHW(std::vector<int64_t> nhwc_shape)149 std::vector<int64_t> NHWC2NCHW(std::vector<int64_t> nhwc_shape) {
150   std::vector<int64_t> nchw_shape;
151   if (nhwc_shape.size() != DIMENSION_4D) {
152     return nhwc_shape;
153   }
154   nchw_shape.push_back(nhwc_shape[kNHWC_N]);
155   nchw_shape.push_back(nhwc_shape[kNHWC_C]);
156   nchw_shape.push_back(nhwc_shape[kNHWC_H]);
157   nchw_shape.push_back(nhwc_shape[kNHWC_W]);
158   return nchw_shape;
159 }
160 
SetTranspose(TensorRTContext * ctx,const nvinfer1::ITensor & input,nvinfer1::Permutation permutation)161 nvinfer1::IShuffleLayer *SetTranspose(TensorRTContext *ctx, const nvinfer1::ITensor &input,
162                                       nvinfer1::Permutation permutation) {
163   nvinfer1::IShuffleLayer *layer = ctx->network()->addShuffle(const_cast<nvinfer1::ITensor &>(input));
164   if (layer == nullptr) {
165     MS_LOG(ERROR) << "failed to create ShuffleLayer when create transpose op.";
166     return nullptr;
167   }
168   layer->setFirstTranspose(permutation);
169   return layer;
170 }
171 
ConvertDataType(DataType type_id)172 nvinfer1::DataType ConvertDataType(DataType type_id) {
173   std::map<DataType, nvinfer1::DataType> data_type_map = {
174 #if TRT_VERSION_GE(7, 2)
175     {DataType::kNumberTypeBool, nvinfer1::DataType::kBOOL},
176 #endif
177     {DataType::kNumberTypeInt8, nvinfer1::DataType::kINT8},
178     {DataType::kNumberTypeInt32, nvinfer1::DataType::kINT32},
179     {DataType::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
180     {DataType::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
181     {DataType::kNumberTypeInt64, nvinfer1::DataType::kINT32},
182   };
183   auto iter = data_type_map.find(type_id);
184   nvinfer1::DataType data_type;
185   if (iter != data_type_map.end()) {
186     data_type = iter->second;
187   } else {
188     data_type = nvinfer1::DataType::kFLOAT;
189     MS_LOG(INFO) << "invalid data_type for TensorRT, need check: " << static_cast<int>(type_id);
190   }
191   return data_type;
192 }
193 
ConvertDataType(nvinfer1::DataType type_id)194 cudaDataType ConvertDataType(nvinfer1::DataType type_id) {
195   std::map<nvinfer1::DataType, cudaDataType> data_type_map = {
196     {nvinfer1::DataType::kINT8, CUDA_R_8I},
197     {nvinfer1::DataType::kINT32, CUDA_R_32I},
198     {nvinfer1::DataType::kFLOAT, CUDA_R_32F},
199     {nvinfer1::DataType::kHALF, CUDA_R_16F},
200   };
201   auto iter = data_type_map.find(type_id);
202   cudaDataType data_type;
203   if (iter != data_type_map.end()) {
204     data_type = iter->second;
205   } else {
206     data_type = CUDA_R_32F;
207     MS_LOG(WARNING) << "invalid data_type for TensorRT, need check: " << static_cast<int>(type_id);
208   }
209   return data_type;
210 }
211 
NHWC2NCHW(TensorRTContext * ctx,const nvinfer1::ITensor & input)212 nvinfer1::IShuffleLayer *NHWC2NCHW(TensorRTContext *ctx, const nvinfer1::ITensor &input) {
213   // NHWC 0123 NCHW 0312
214   nvinfer1::Permutation perm{{0, 3, 1, 2}};
215   return SetTranspose(ctx, input, perm);
216 }
217 
NCHW2NHWC(TensorRTContext * ctx,const nvinfer1::ITensor & input)218 nvinfer1::IShuffleLayer *NCHW2NHWC(TensorRTContext *ctx, const nvinfer1::ITensor &input) {
219   // NCHW 0123 NHWC 0231
220   nvinfer1::Permutation perm{{0, 2, 3, 1}};
221   return SetTranspose(ctx, input, perm);
222 }
223 
ConvertConstantTensor(TensorRTContext * ctx,const TensorInfo & ms_tensor,const std::string & op_name)224 nvinfer1::ITensor *ConvertConstantTensor(TensorRTContext *ctx, const TensorInfo &ms_tensor,
225                                          const std::string &op_name) {
226   if (ctx == nullptr || ctx->network() == nullptr) {
227     MS_LOG(ERROR) << "context or network is null for ConvertConstantTensor";
228     return nullptr;
229   }
230   nvinfer1::Dims dims = ConvertCudaDims(ms_tensor.Shape());
231   if (dims.nbDims == -1) {
232     MS_LOG(INFO) << ms_tensor.Name() << " ConvertCudaDims failed, convert as scalar.";
233     dims.nbDims = 1;
234     dims.d[0] = 1;
235   }
236   nvinfer1::DataType data_type = ConvertDataType(ms_tensor.DataType());
237   if (!ms_tensor.IsConst()) {
238     MS_LOG(ERROR) << "ConvertConstantTensor from a MSTensor with nullptr data: " << ms_tensor.Name();
239     return nullptr;
240   }
241   nvinfer1::Weights weights{data_type, ms_tensor.Data(), ms_tensor.ElementNum()};
242   if (data_type == nvinfer1::DataType::kBOOL) {
243     weights.type = nvinfer1::DataType::kINT32;
244     void *data_int32 = malloc(ms_tensor.ElementNum() * sizeof(int32_t));
245     if (data_int32 == nullptr) {
246       MS_LOG(ERROR) << "Malloc buffer failed.";
247       return nullptr;
248     }
249     auto src = static_cast<const bool *>(ms_tensor.Data());
250     auto dst = static_cast<int32_t *>(data_int32);
251     for (int i = 0; i < ms_tensor.ElementNum(); i++) {
252       dst[i] = (int32_t)src[i];
253     }
254     weights.values = data_int32;
255   }
256   nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights);
257   if (constant_tensor == nullptr) {
258     MS_LOG(ERROR) << "create constant_tensor failed.";
259     return nullptr;
260   }
261   ctx->RegisterLayer(constant_tensor, ms_tensor.Name() + "_" + op_name);
262   auto tensor_ptr = constant_tensor->getOutput(0);
263   return tensor_ptr;
264 }
265 
ConvertScalarToITensor(TensorRTContext * ctx,size_t shape_size,const void * value,const DataType data_type,const std::string & op_name)266 nvinfer1::ITensor *ConvertScalarToITensor(TensorRTContext *ctx, size_t shape_size, const void *value,
267                                           const DataType data_type, const std::string &op_name) {
268   nvinfer1::Dims dims = ConvertCudaDims(1, shape_size);
269   if (dims.nbDims == -1) {
270     MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name;
271     return nullptr;
272   }
273   nvinfer1::Weights weights{ConvertDataType(data_type), value, 1};
274   nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights);
275   if (constant_tensor == nullptr) {
276     MS_LOG(ERROR) << "create constant_tensor failed.";
277     return nullptr;
278   }
279   ctx->RegisterLayer(constant_tensor, op_name + "_constant");
280   return constant_tensor->getOutput(0);
281 }
282 
ConvertScalarToITensor(TensorRTContext * ctx,size_t shape_size,const TensorInfo & ms_tensor,const DataType data_type,const std::string & op_name)283 nvinfer1::ITensor *ConvertScalarToITensor(TensorRTContext *ctx, size_t shape_size, const TensorInfo &ms_tensor,
284                                           const DataType data_type, const std::string &op_name) {
285   const void *value = ms_tensor.Data();
286   auto tensor_ptr = ConvertScalarToITensor(ctx, shape_size, value, data_type, op_name);
287   return tensor_ptr;
288 }
289 
TryConvertActivationType(ActivationType activation_type)290 std::experimental::optional<ActivationParams> TryConvertActivationType(ActivationType activation_type) {
291   std::map<ActivationType, ActivationParams> action_map = {
292     {ActivationType::RELU, ActivationParams{nvinfer1::ActivationType::kRELU, false, 0, false, 0}},
293     {ActivationType::SIGMOID, ActivationParams{nvinfer1::ActivationType::kSIGMOID, false, 0, false, 0}},
294     {ActivationType::TANH, ActivationParams{nvinfer1::ActivationType::kTANH, false, 0, false, 0}},
295     {ActivationType::LEAKY_RELU, ActivationParams{nvinfer1::ActivationType::kLEAKY_RELU, true, 0, false, 0}},
296     {ActivationType::ELU, ActivationParams{nvinfer1::ActivationType::kELU, true, 0, false, 0}},
297     {ActivationType::SELU, ActivationParams{nvinfer1::ActivationType::kSELU, true, 0, true, 0}},
298     {ActivationType::SOFTSIGN, ActivationParams{nvinfer1::ActivationType::kSOFTSIGN, false, 0, false, 0}},
299     {ActivationType::SOFTPLUS, ActivationParams{nvinfer1::ActivationType::kSOFTPLUS, false, 0, false, 0}},
300     {ActivationType::THRESHOLDRELU, ActivationParams{nvinfer1::ActivationType::kTHRESHOLDED_RELU, true, 0, false, 0}},
301     {ActivationType::RELU6, ActivationParams{nvinfer1::ActivationType::kCLIP, true, 0, true, 6}},
302     {ActivationType::RELU1, ActivationParams{nvinfer1::ActivationType::kCLIP, true, 0, true, 1}},
303     {ActivationType::HARD_TANH, ActivationParams{nvinfer1::ActivationType::kCLIP, true, -1, true, 1}},
304     {ActivationType::HSIGMOID, ActivationParams{nvinfer1::ActivationType::kHARD_SIGMOID, true, 1.f / 6, true, 0.5f}},
305     // using plugin
306     {ActivationType::GELU, ActivationParams{nvinfer1::ActivationType::kTHRESHOLDED_RELU, false, 0, false, 0}},
307     {ActivationType::SWISH, ActivationParams{nvinfer1::ActivationType::kSIGMOID, false, 0, false, 0}}};
308   return action_map.find(activation_type) != action_map.end()
309            ? std::experimental::optional<ActivationParams>(action_map[activation_type])
310            : std::experimental::nullopt;
311 }
312 
IsComfortableAlign(std::vector<int64_t> * in_shape_ptr,const std::vector<int64_t> & out_shape,int index)313 bool IsComfortableAlign(std::vector<int64_t> *in_shape_ptr, const std::vector<int64_t> &out_shape, int index) {
314   if (in_shape_ptr->size() > out_shape.size()) {
315     return false;
316   }
317   int out_index = index;
318   int in_index = in_shape_ptr->size() - 1;
319   while (in_index >= 0 && (in_shape_ptr->at(in_index) == out_shape[out_index] || in_shape_ptr->at(in_index) == 1)) {
320     in_index--;
321     out_index--;
322   }
323   return in_index < 0;
324 }
325 
BackComfortableAlign(std::vector<int64_t> * in_shape_ptr,const std::vector<int64_t> & out_shape)326 void BackComfortableAlign(std::vector<int64_t> *in_shape_ptr, const std::vector<int64_t> &out_shape) {
327   if (in_shape_ptr->size() >= out_shape.size()) {
328     return;
329   }
330   int out_index = out_shape.size() - 1;
331   bool is_comfortable = false;
332   while (out_index >= static_cast<int>(in_shape_ptr->size()) - 1) {
333     if (IsComfortableAlign(in_shape_ptr, out_shape, out_index)) {
334       is_comfortable = true;
335       break;
336     }
337     out_index--;
338   }
339   if (is_comfortable == false) {
340     MS_LOG(INFO) << "failed to align constant tensor";
341     return;
342   }
343   while (static_cast<int>(in_shape_ptr->size()) - 1 < out_index) {
344     in_shape_ptr->insert(in_shape_ptr->begin(), 1);
345   }
346   while (in_shape_ptr->size() < out_shape.size()) {
347     in_shape_ptr->insert(in_shape_ptr->end(), 1);
348   }
349   DebugDims("constant : ", ConvertCudaDims(*in_shape_ptr));
350   return;
351 }
352 
AlignShapeRank(std::vector<int64_t> * in_shape_ptr,const std::vector<int64_t> & out_shape)353 void AlignShapeRank(std::vector<int64_t> *in_shape_ptr, const std::vector<int64_t> &out_shape) {
354   const size_t last_dim = in_shape_ptr->size() - 1;
355   const int in_rank = in_shape_ptr->size();
356   int index = out_shape.size() - 1;
357   for (; index >= 0; index--) {
358     if (out_shape[index] == in_shape_ptr->at(last_dim)) {
359       break;
360     }
361   }
362   const int align_rank = index + 1;
363   if (index <= 0 || align_rank == in_rank) return;
364   for (int i = 0; i < index + 1 - in_rank; i++) {
365     in_shape_ptr->insert(in_shape_ptr->begin(), 1);
366   }
367 }
368 
ConvertTensorWithExpandDims(TensorRTContext * ctx,const TensorInfo & ms_tensor,const std::vector<int64_t> & expect_shape,const std::string & op_name)369 nvinfer1::ITensor *ConvertTensorWithExpandDims(TensorRTContext *ctx, const TensorInfo &ms_tensor,
370                                                const std::vector<int64_t> &expect_shape, const std::string &op_name) {
371   if (ctx == nullptr || ctx->network() == nullptr) {
372     MS_LOG(ERROR) << "network is null for ConvertTensorWithExpandDims";
373     return nullptr;
374   }
375   if (!ms_tensor.IsConst()) {
376     MS_LOG(ERROR) << "ConvertTensorWithExpandDims from a MSTensor with nullptr data";
377     return nullptr;
378   }
379   auto origin_shape = ms_tensor.Shape();
380   std::vector<int64_t> convert_shape(expect_shape);
381   BackComfortableAlign(&origin_shape, convert_shape);
382   if (ms_tensor.ElementNum() !=
383       std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies<int64_t>())) {
384     MS_LOG(ERROR) << "ExpandDims failed for " << op_name;
385     return nullptr;
386   }
387   nvinfer1::Dims dims = ConvertCudaDims(origin_shape);
388   if (dims.nbDims == -1) {
389     MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name;
390     return nullptr;
391   }
392   nvinfer1::DataType data_type = ConvertDataType(ms_tensor.DataType());
393   nvinfer1::Weights weights{data_type, ms_tensor.Data(), ms_tensor.ElementNum()};
394   nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights);
395   if (constant_tensor == nullptr) {
396     MS_LOG(ERROR) << "create constant_tensor failed.";
397     return nullptr;
398   }
399   ctx->RegisterLayer(constant_tensor, ms_tensor.Name() + "_" + op_name);
400   auto tensor_ptr = constant_tensor->getOutput(0);
401   return tensor_ptr;
402 }
403 
ConvertConstantTensor1D(TensorRTContext * ctx,int * weights_vec,nvinfer1::DataType data_type)404 nvinfer1::ITensor *ConvertConstantTensor1D(TensorRTContext *ctx, int *weights_vec, nvinfer1::DataType data_type) {
405   constexpr int nchw_dims_count = 4;
406   nvinfer1::Weights weights{data_type, weights_vec, nchw_dims_count};
407   nvinfer1::Dims dims;
408   dims.nbDims = 1;
409   dims.d[0] = nchw_dims_count;
410   nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights);
411   if (constant_tensor == nullptr) {
412     MS_LOG(ERROR) << "create constant_tensor failed.";
413     return nullptr;
414   }
415   return constant_tensor->getOutput(0);
416 }
417 
ConvertConstantTensorWithDims(TensorRTContext * ctx,const TensorInfo & ms_tensor,const std::vector<int64_t> & expect_shape,const std::string & op_name)418 nvinfer1::ITensor *ConvertConstantTensorWithDims(TensorRTContext *ctx, const TensorInfo &ms_tensor,
419                                                  const std::vector<int64_t> &expect_shape, const std::string &op_name) {
420   nvinfer1::ITensor *constant_input{nullptr};
421   std::string tensor_name = op_name + "_" + ms_tensor.Name();
422   if (ms_tensor.Shape().size() == 0 || ms_tensor.ElementNum() == 1) {
423     constant_input =
424       lite::ConvertScalarToITensor(ctx, expect_shape.size(), ms_tensor, ms_tensor.DataType(), tensor_name);
425     if (constant_input == nullptr) {
426       MS_LOG(ERROR) << "create Itensor from scalar tensor failed: " << tensor_name;
427       return nullptr;
428     }
429   } else if (ms_tensor.Shape().size() == expect_shape.size()) {
430     constant_input = lite::ConvertConstantTensor(ctx, ms_tensor, tensor_name);
431     if (constant_input == nullptr) {
432       MS_LOG(ERROR) << "create Itensor from constant tensor failed: " << tensor_name;
433       return nullptr;
434     }
435   } else if (ms_tensor.ElementNum() >= 1) {
436     constant_input = ConvertTensorWithExpandDims(ctx, ms_tensor, expect_shape, tensor_name);
437     if (constant_input == nullptr) {
438       MS_LOG(ERROR) << "create Itensor from ConvertTensorWithExpandDims failed: " << tensor_name;
439       return nullptr;
440     }
441   } else {
442     MS_LOG(ERROR) << "const tensor value needs check: " << tensor_name;
443   }
444   return constant_input;
445 }
446 
TransposeWeight2D(const TensorInfo & ms_tensor,void ** pack_weight)447 nvinfer1::Weights TransposeWeight2D(const TensorInfo &ms_tensor, void **pack_weight) {
448   // usage notice: malloc addr saved to pack_weight, save pack_weight ptr and free it when deconstruct
449   nvinfer1::Weights weights{};
450   weights.count = ms_tensor.ElementNum();
451   auto weight_shape = ms_tensor.Shape();
452   if (weight_shape.size() != DIMENSION_2D) {
453     MS_LOG(ERROR) << ms_tensor.Name() << " dims is " << weight_shape.size();
454     return weights;
455   }
456   if (!ms_tensor.IsConst()) {
457     MS_LOG(ERROR) << ms_tensor.Name() << " has null data";
458     return weights;
459   }
460   void *pack_weight_tmp = malloc(ms_tensor.DataSize());
461   if (pack_weight_tmp == nullptr) {
462     MS_LOG(ERROR) << "Malloc buffer failed.";
463     return weights;
464   }
465   *pack_weight = pack_weight_tmp;
466   weights.values = pack_weight_tmp;
467 
468   int row = weight_shape[0];
469   int col = weight_shape[1];
470 
471   switch (ms_tensor.DataType()) {
472     case DataType::kNumberTypeFloat16: {
473       weights.type = nvinfer1::DataType::kHALF;
474       auto src = static_cast<const uint16_t *>(ms_tensor.Data());
475       auto dst = static_cast<uint16_t *>(pack_weight_tmp);
476       for (int r = 0; r < row; ++r) {
477         for (int c = 0; c < col; ++c) {
478           dst[c * row + r] = src[r * col + c];
479         }
480       }
481       break;
482     }
483     case DataType::kNumberTypeFloat32: {
484       weights.type = nvinfer1::DataType::kFLOAT;
485       auto dst = static_cast<float *>(pack_weight_tmp);
486       auto src = static_cast<const float *>(ms_tensor.Data());
487       for (int r = 0; r < row; ++r) {
488         for (int c = 0; c < col; ++c) {
489           dst[c * row + r] = src[r * col + c];
490         }
491       }
492       break;
493     }
494     default: {
495       MS_LOG(ERROR) << ms_tensor.Name() << " has unsupported tensor datatype for transpose data : "
496                     << static_cast<int>(ms_tensor.DataType());
497     }
498   }
499   return weights;
500 }
501 
ConvertWeight(const TensorInfo & ms_tensor)502 nvinfer1::Weights ConvertWeight(const TensorInfo &ms_tensor) {
503   nvinfer1::Weights weights{};
504   weights.type = ConvertDataType(ms_tensor.DataType());
505   weights.values = ms_tensor.Data();
506   weights.count = ms_tensor.ElementNum();
507   if (weights.values == nullptr) {
508     MS_LOG(ERROR) << "ConvertWeight from a MSTensor with nullptr data";
509   }
510   return weights;
511 }
512 
TRTTensorCast(TensorRTContext * ctx,nvinfer1::ITensor * trt_tensor,nvinfer1::DataType data_type,const std::string & name)513 nvinfer1::ITensor *TRTTensorCast(TensorRTContext *ctx, nvinfer1::ITensor *trt_tensor, nvinfer1::DataType data_type,
514                                  const std::string &name) {
515 #if TRT_VERSION_GE(7, 2)
516   data_type = data_type == nvinfer1::DataType::kBOOL ? nvinfer1::DataType::kINT32 : data_type;
517   if (data_type == nvinfer1::DataType::kINT32 && trt_tensor->getType() == nvinfer1::DataType::kFLOAT) {
518     auto plugin = std::make_shared<CastPlugin>(name, data_type);
519     nvinfer1::ITensor *inputTensors[] = {trt_tensor};
520     nvinfer1::IPluginV2Layer *cast_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin);
521     cast_layer->setName(name.c_str());
522     nvinfer1::ITensor *cast_out = cast_layer->getOutput(0);
523     cast_out->setName((name + "_output").c_str());
524     return cast_out;
525   }
526   auto cast_layer = ctx->network()->addIdentity(*trt_tensor);
527 #else
528   auto plugin = std::make_shared<CastPlugin>(name, data_type);
529   nvinfer1::ITensor *inputTensors[] = {trt_tensor};
530   nvinfer1::IPluginV2Layer *cast_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin);
531 #endif
532   if (cast_layer == nullptr) {
533     MS_LOG(ERROR) << "create cast layer failed for: " << name;
534     return nullptr;
535   }
536 #if TRT_VERSION_GE(7, 2)
537   cast_layer->setOutputType(0, data_type);
538 #endif
539   cast_layer->setName(name.c_str());
540   nvinfer1::ITensor *cast_out = cast_layer->getOutput(0);
541   cast_out->setName((name + "_output").c_str());
542   return cast_out;
543 }
544 
SetCudaDevice(std::shared_ptr<GPUDeviceInfo> device_info_)545 int SetCudaDevice(std::shared_ptr<GPUDeviceInfo> device_info_) {
546   return SetCudaDevice(static_cast<int>(device_info_->GetDeviceID()));
547 }
548 
SetCudaDevice(int device_id)549 int SetCudaDevice(int device_id) {
550   int device = 0;
551   auto ret = cudaGetDevice(&device);
552   if (ret != cudaSuccess) {
553     MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable. error code: " << ret;
554     return RET_ERROR;
555   }
556   int set_device_id = device_id;
557   int deviceCnt = 0;
558 
559   ret = cudaGetDeviceCount(&deviceCnt);
560   if (ret != cudaSuccess) {
561     MS_LOG(ERROR) << "cudaGetDeviceCount failed.";
562     return RET_ERROR;
563   }
564 
565   if (set_device_id > deviceCnt - 1) {
566     MS_LOG(ERROR) << "invalid input device id as " << set_device_id << " for current device count " << deviceCnt;
567     return RET_ERROR;
568   }
569   if (device != set_device_id) {
570     ret = cudaSetDevice(set_device_id);
571     if (ret != cudaSuccess) {
572       MS_LOG(ERROR) << "cudaSetDevice failed, error code: " << ret;
573       return RET_ERROR;
574     }
575   }
576   if (cudaGetDevice(&device) != cudaSuccess) {
577     MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable.";
578     return RET_ERROR;
579   }
580   MS_LOG(DEBUG) << "cuda is running on device: " << device;
581   return RET_OK;
582 }
583 
GetOutputFormat(Format input_format,nvinfer1::Permutation perm)584 Format GetOutputFormat(Format input_format, nvinfer1::Permutation perm) {
585   if (input_format == Format::NHWC) {
586     if (perm.order[kNHWC_N] == kNHWC_N && perm.order[kNHWC_H] == kNHWC_C && perm.order[kNHWC_W] == kNHWC_W &&
587         perm.order[kNHWC_C] == kNHWC_H) {
588       return Format::NCHW;
589     }
590   } else if (input_format == Format::NCHW) {
591     if (perm.order[kNCHW_N] == kNCHW_N && perm.order[kNCHW_C] == kNCHW_H && perm.order[kNCHW_H] == kNCHW_W &&
592         perm.order[kNCHW_W] == kNCHW_C) {
593       return Format::NHWC;
594     }
595   }
596   MS_LOG(WARNING) << "transpose out format needs to check for " << input_format;
597   return input_format;
598 }
ConvertAxisFromNHWC2NCHW(int nhwc_axis)599 int ConvertAxisFromNHWC2NCHW(int nhwc_axis) {
600   return nhwc_axis;
601   // N0H1W2C3->N0C1H2W3
602   if (nhwc_axis > kNHWC_C) {
603     return nhwc_axis;
604   }
605   switch (nhwc_axis) {
606     case kNHWC_N:
607       return kNCHW_N;
608     case kNHWC_H:
609       return kNCHW_H;
610     case kNHWC_W:
611       return kNCHW_W;
612     case kNHWC_C:
613       return kNCHW_C;
614     default:
615       MS_LOG(ERROR) << "invalid input axis for nhwc: " << nhwc_axis;
616   }
617   return nhwc_axis;
618 }
619 
PackNHWCToNCHWFp16(const void * src,void * dst,size_t batches,size_t plane,size_t channel,size_t task_id,size_t thread_count)620 void PackNHWCToNCHWFp16(const void *src, void *dst, size_t batches, size_t plane, size_t channel, size_t task_id,
621                         size_t thread_count) {
622   size_t hw8 = plane / C8NUM;
623   size_t task_start = 0;
624   size_t task_end = plane;
625   if (thread_count > 0) {
626     size_t offset_hw = UP_DIV(hw8, thread_count) * C8NUM;
627     task_start = offset_hw * task_id;
628     size_t count = plane - task_start;
629     if (count == 0) {
630       return;
631     }
632     task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw);
633     hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0);
634   } else {
635     hw8 *= C8NUM;
636   }
637   size_t c8 = channel / C8NUM * C8NUM;
638   size_t batch = plane * channel;
639   for (size_t n = 0; n < batches; n++) {
640     const uint16_t *src_batch = static_cast<const uint16_t *>(src) + n * batch;
641     uint16_t *dst_batch = static_cast<uint16_t *>(dst) + n * batch;
642     size_t hw = task_start;
643     for (; hw < hw8; hw += C8NUM) {
644       size_t c = 0;
645       for (; c < c8; c += C8NUM) {
646         const uint16_t *src_ptr = src_batch + hw * channel + c;
647         uint16_t *dst_ptr = dst_batch + c * plane + hw;
648         for (size_t tr = 0; tr < C8NUM; tr++) {
649           for (size_t tc = 0; tc < C8NUM; tc++) {
650             dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
651           }
652         }
653       }
654       for (; c < channel; c++) {
655         const uint16_t *src_ptr = src_batch + hw * channel + c;
656         uint16_t *dst_ptr = dst_batch + c * plane + hw;
657         for (size_t i = 0; i < C8NUM; i++) {
658           dst_ptr[i] = src_ptr[i * channel];
659         }
660       }
661     }
662     for (; hw < task_end; hw++) {
663       const uint16_t *src_ptr = src_batch + hw * channel;
664       uint16_t *dst_ptr = dst_batch + hw;
665       for (size_t i = 0; i < channel; i++) {
666         dst_ptr[i * plane] = src_ptr[i];
667       }
668     }
669   }
670 }
GetTensorFormat(nvinfer1::ITensor * trt_tensor,mindspore::Format format,bool is_same)671 std::string GetTensorFormat(nvinfer1::ITensor *trt_tensor, mindspore::Format format, bool is_same) {
672   nvinfer1::Dims dims = trt_tensor->getDimensions();
673   std::string is_same_string = is_same ? " is same with ms tensor " : " is different from ms tensor ";
674   std::string out_string = "tensor " + std::string(trt_tensor->getName()) + ": format (NHWC:1, NCHW:0) is " +
675                            std::to_string(static_cast<int>(format)) + is_same_string + ", dims is ";
676   std::string dim_string = "[";
677   for (int i = 0; i < dims.nbDims; i++) {
678     dim_string += std::to_string(dims.d[i]);
679     if (i != dims.nbDims - 1) {
680       dim_string += ", ";
681     }
682   }
683   dim_string += "]";
684   out_string += dim_string;
685   return out_string;
686 }
687 
GetTensorFormat(ITensorHelper tensor_helper)688 std::string GetTensorFormat(ITensorHelper tensor_helper) {
689   return GetTensorFormat(tensor_helper.trt_tensor_, tensor_helper.format_, tensor_helper.same_format_);
690 }
691 
GetTensorFormat(nvinfer1::ITensor * trt_tensor)692 std::string GetTensorFormat(nvinfer1::ITensor *trt_tensor) { return GetTensorFormat(trt_tensor, Format::NHWC, true); }
693 
TryConvertTRTReduceMode(ReduceMode mode)694 std::experimental::optional<nvinfer1::ReduceOperation> TryConvertTRTReduceMode(ReduceMode mode) {
695   std::map<ReduceMode, nvinfer1::ReduceOperation> reduce_ops_ = {
696     {ReduceMode::Reduce_Mean, nvinfer1::ReduceOperation::kAVG},
697     {ReduceMode::Reduce_Max, nvinfer1::ReduceOperation::kMAX},
698     {ReduceMode::Reduce_Min, nvinfer1::ReduceOperation::kMIN},
699     {ReduceMode::Reduce_Prod, nvinfer1::ReduceOperation::kPROD},
700     {ReduceMode::Reduce_L2, nvinfer1::ReduceOperation::kSUM},
701     {ReduceMode::Reduce_Sum, nvinfer1::ReduceOperation::kSUM},
702   };
703   return reduce_ops_.find(mode) != reduce_ops_.end()
704            ? std::experimental::optional<nvinfer1::ReduceOperation>(reduce_ops_[mode])
705            : std::experimental::nullopt;
706 }
PreprocessInputs2SameDim(TensorRTContext * ctx,ITensorHelper input_tensor_helper,ITensorHelper * out_tensor_helper)707 int PreprocessInputs2SameDim(TensorRTContext *ctx, ITensorHelper input_tensor_helper,
708                              ITensorHelper *out_tensor_helper) {
709   if (input_tensor_helper.trt_tensor_ == nullptr) {
710     MS_LOG(ERROR) << "input trt tensor is nullptr";
711     return RET_ERROR;
712   }
713   out_tensor_helper->trt_tensor_ = input_tensor_helper.trt_tensor_;
714   out_tensor_helper->format_ = input_tensor_helper.format_;
715   out_tensor_helper->same_format_ = true;
716   if (input_tensor_helper.trt_tensor_->getDimensions().nbDims == DIMENSION_4D && !input_tensor_helper.same_format_) {
717     if (input_tensor_helper.format_ == Format::NCHW) {
718       // transpose: NCHW->NHWC
719       nvinfer1::IShuffleLayer *transpose_layer_in = NCHW2NHWC(ctx, *input_tensor_helper.trt_tensor_);
720       if (transpose_layer_in == nullptr) {
721         MS_LOG(ERROR) << "op action convert failed";
722         return RET_ERROR;
723       }
724       transpose_layer_in->setName(
725         (std::string(input_tensor_helper.trt_tensor_->getName()) + "_input_transpose2NHWC").c_str());
726       out_tensor_helper->trt_tensor_ = transpose_layer_in->getOutput(0);
727       out_tensor_helper->format_ = Format::NHWC;
728     } else {
729       // transpose: NHWC->NCHW
730       nvinfer1::IShuffleLayer *transpose_layer_in = NHWC2NCHW(ctx, *input_tensor_helper.trt_tensor_);
731       if (transpose_layer_in == nullptr) {
732         MS_LOG(ERROR) << "op action convert failed";
733         return RET_ERROR;
734       }
735       transpose_layer_in->setName(
736         (std::string(input_tensor_helper.trt_tensor_->getName()) + "_input_transpose2NCHW").c_str());
737       out_tensor_helper->trt_tensor_ = transpose_layer_in->getOutput(0);
738       out_tensor_helper->format_ = Format::NCHW;
739     }
740   }
741   return RET_OK;
742 }
743 
GetDimsVolume(const nvinfer1::Dims & dims)744 int GetDimsVolume(const nvinfer1::Dims &dims) {
745   if (dims.nbDims <= 0) {
746     return 0;
747   }
748   return std::accumulate(dims.d, dims.d + dims.nbDims, 1, std::multiplies<int64_t>());
749 }
750 
GetDimsVolume(const std::vector<int64_t> & shape)751 int GetDimsVolume(const std::vector<int64_t> &shape) {
752   if (shape.size() == 0) {
753     return 0;
754   }
755   return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
756 }
757 
SqueezeDims(const nvinfer1::Dims & in_dims,int pos)758 std::experimental::optional<nvinfer1::Dims> SqueezeDims(const nvinfer1::Dims &in_dims, int pos) {
759   if (in_dims.nbDims <= 1) {
760     MS_LOG(ERROR) << "invalid shape size: " << in_dims.nbDims << "for squeeze.";
761     return {};
762   }
763   nvinfer1::Dims out_dims;
764   int i = 0;
765   for (int j = 0; j <= in_dims.nbDims; ++j) {
766     if (j != pos) {
767       out_dims.d[i++] = in_dims.d[j];
768     }
769   }
770   out_dims.nbDims = in_dims.nbDims - 1;
771   return std::experimental::optional<nvinfer1::Dims>(out_dims);
772 }
773 
UnsqueezeDims(const nvinfer1::Dims & in_dims,int pos,int val)774 std::experimental::optional<nvinfer1::Dims> UnsqueezeDims(const nvinfer1::Dims &in_dims, int pos, int val) {
775   if (in_dims.nbDims >= in_dims.MAX_DIMS) {
776     MS_LOG(ERROR) << "invalid shape size: " << in_dims.nbDims << "for unsqueeze.";
777     return {};
778   }
779   nvinfer1::Dims out_dims;
780   int i = 0;
781   for (int j = 0; j <= in_dims.nbDims; ++j) {
782     if (j == pos) {
783       out_dims.d[j] = val;
784     } else {
785       out_dims.d[j] = in_dims.d[i++];
786     }
787   }
788   out_dims.nbDims = in_dims.nbDims + 1;
789   return std::experimental::optional<nvinfer1::Dims>(out_dims);
790 }
791 
ParseData2Vector(const TensorInfo & ms_tensor,std::vector<float> * dst)792 int ParseData2Vector(const TensorInfo &ms_tensor, std::vector<float> *dst) {
793   if (!ms_tensor.IsConst()) {
794     MS_LOG(ERROR) << "ignore tensor: " << ms_tensor.Name();
795     return RET_ERROR;
796   }
797   dst->clear();
798   dst->resize(ms_tensor.ElementNum());
799   switch (ms_tensor.DataType()) {
800     case DataType::kNumberTypeInt64: {
801       Data2Vector<int64_t>(dst, ms_tensor.Data());
802       break;
803     }
804     case DataType::kNumberTypeInt32: {
805       Data2Vector<int>(dst, ms_tensor.Data());
806       break;
807     }
808     default: {
809       MS_LOG(ERROR) << ms_tensor.Name() << " has more datatype to parse";
810       return RET_ERROR;
811     }
812   }
813   return RET_OK;
814 }
815 
ExpandDim(TensorRTContext * ctx,nvinfer1::ITensor * input_tensor,int axis)816 nvinfer1::ITensor *ExpandDim(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor, int axis) {
817   // input has to prepocess to nchw
818   auto input_dims = input_tensor->getDimensions();
819   nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*input_tensor);
820   // if expand dim not at last dim and shape is dynamic, change to expanddim at last dim and transpose
821   bool special_expand = false;
822   for (int i = 0; i < input_dims.nbDims; i++) {
823     special_expand = special_expand || input_dims.d[i] == -1;
824   }
825   special_expand = special_expand && (axis != -1 && axis != input_dims.nbDims);
826 
827   if (special_expand) {
828     std::vector<int64_t> new_shape;
829     for (int i = 0; i < input_dims.nbDims; i++) {
830       new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]);
831     }
832     new_shape.push_back(1);
833     nvinfer1::Dims new_dims = ConvertCudaDims(new_shape);
834     if (new_dims.nbDims == -1) {
835       return nullptr;
836     }
837 
838     shuffle_layer->setReshapeDimensions(new_dims);
839     // transpose
840     nvinfer1::Permutation perm{};
841     for (int i = 0; i < new_dims.nbDims; i++) {
842       if (i < axis) {
843         perm.order[i] = i;
844       } else if (i == axis) {
845         perm.order[i] = new_dims.nbDims - 1;
846       } else {
847         perm.order[i] = i - 1;
848       }
849     }
850     nvinfer1::IShuffleLayer *trans_layer = ctx->network()->addShuffle(*shuffle_layer->getOutput(0));
851     if (trans_layer == nullptr) {
852       MS_LOG(ERROR) << "add transpose layer failed for special expand dims op ";
853       return nullptr;
854     }
855     trans_layer->setFirstTranspose(perm);
856     return trans_layer->getOutput(0);
857   } else {
858     std::vector<int64_t> new_shape;
859     for (int i = 0; i < input_dims.nbDims; i++) {
860       if (axis == i) {
861         new_shape.push_back(1);
862       }
863       new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]);
864     }
865     if (axis == -1 || axis == input_dims.nbDims) {
866       new_shape.push_back(1);
867     }
868     nvinfer1::Dims new_dims = ConvertCudaDims(new_shape);
869     if (new_dims.nbDims == -1) {
870       return nullptr;
871     }
872     shuffle_layer->setReshapeDimensions(new_dims);
873     return shuffle_layer->getOutput(0);
874   }
875 }
876 
Broadcast(TensorRTContext * ctx,nvinfer1::ITensor * input,nvinfer1::ITensor * shape)877 nvinfer1::ITensor *Broadcast(TensorRTContext *ctx, nvinfer1::ITensor *input, nvinfer1::ITensor *shape) {
878   int rank = shape->getDimensions().d[0];
879 
880   nvinfer1::Dims starts{rank};
881   std::fill(starts.d, starts.d + rank, 0);
882   nvinfer1::Dims strides{rank};
883   std::fill(strides.d, strides.d + rank, 1);
884 
885   auto slice_layer = ctx->network()->addSlice(*input, starts, {}, strides);
886   slice_layer->setMode(nvinfer1::SliceMode::kWRAP);
887   const int INPUT2 = 2;
888   slice_layer->setInput(INPUT2, *shape);
889 
890   auto shuffler_output = slice_layer->getOutput(0);
891   if (shuffler_output == nullptr) {
892     MS_LOG(ERROR) << "add slice layer failed";
893   }
894   return shuffler_output;
895 }
896 
Reshape(TensorRTContext * ctx,nvinfer1::ITensor * input,const std::vector<int64_t> & shape)897 nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *input, const std::vector<int64_t> &shape) {
898   return Reshape(ctx, input, ConvertCudaDims(shape));
899 }
900 
Reshape(TensorRTContext * ctx,nvinfer1::ITensor * input,const nvinfer1::Dims & shape)901 nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *input, const nvinfer1::Dims &shape) {
902   auto reshape_layer = ctx->network()->addShuffle(*input);
903   if (reshape_layer == nullptr) {
904     MS_LOG(ERROR) << "add reshape_layer failed";
905     return nullptr;
906   }
907   reshape_layer->setReshapeDimensions(shape);
908   return reshape_layer->getOutput(0);
909 }
910 
DebugDims(const std::string & key,const nvinfer1::Dims & dims)911 void DebugDims(const std::string &key, const nvinfer1::Dims &dims) {
912   MS_LOG(DEBUG) << key << ":" << dims.nbDims;
913   for (int i = 0; i != dims.nbDims; ++i) {
914     MS_LOG(DEBUG) << dims.d[i];
915   }
916 }
917 
918 template <>
GetNvinferDataType()919 nvinfer1::DataType GetNvinferDataType<float>() {
920   return nvinfer1::DataType::kFLOAT;
921 }
922 
923 template <>
GetNvinferDataType()924 nvinfer1::DataType GetNvinferDataType<int>() {
925   return nvinfer1::DataType::kINT32;
926 }
927 
928 template nvinfer1::DataType GetNvinferDataType<float>();
929 template nvinfer1::DataType GetNvinferDataType<int>();
930 
931 #ifdef PROFILER_
reportLayerTime(const char * layerName,float ms)932 void SimpleProfiler::reportLayerTime(const char *layerName, float ms) noexcept {
933   mProfile_[layerName].count++;
934   mProfile_[layerName].time += ms;
935   if (std::find(mLayerNames_.begin(), mLayerNames_.end(), layerName) == mLayerNames_.end()) {
936     mLayerNames_.push_back(layerName);
937   }
938 }
939 
SimpleProfiler(const char * name,const std::vector<SimpleProfiler> & srcProfilers)940 SimpleProfiler::SimpleProfiler(const char *name, const std::vector<SimpleProfiler> &srcProfilers) : mName_(name) {
941   for (const auto &srcProfiler : srcProfilers) {
942     for (const auto &rec : srcProfiler.mProfile_) {
943       auto it = mProfile_.find(rec.first);
944       if (it == mProfile_.end()) {
945         mProfile_.insert(rec);
946       } else {
947         it->second.time += rec.second.time;
948         it->second.count += rec.second.count;
949       }
950     }
951   }
952 }
953 
operator <<(std::ostream & out,const SimpleProfiler & value)954 std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value) {
955   out << "========== " << value.mName_ << " profile ==========" << std::endl;
956   float totalTime = 0;
957   std::string layerNameStr = "TensorRT layer name";
958   int maxLayerNameLength = std::max(static_cast<int>(layerNameStr.size()), 70);
959   for (const auto &elem : value.mProfile_) {
960     totalTime += elem.second.time;
961     maxLayerNameLength = std::max(maxLayerNameLength, static_cast<int>(elem.first.size()));
962   }
963 
964   auto old_settings = out.flags();
965   auto old_precision = out.precision();
966   // Output header
967   {
968     out << std::setw(maxLayerNameLength) << layerNameStr << " ";
969     out << std::setw(C12NUM) << "Runtime, "
970         << "%"
971         << " ";
972     out << std::setw(C12NUM) << "Invocations"
973         << " ";
974     out << std::setw(C12NUM) << "Runtime, ms" << std::endl;
975   }
976   for (size_t i = 0; i < value.mLayerNames_.size(); i++) {
977     const std::string layerName = value.mLayerNames_[i];
978     auto elem = value.mProfile_.at(layerName);
979     out << std::setw(maxLayerNameLength) << layerName << " ";
980     out << std::setw(C12NUM) << std::fixed << std::setprecision(1) << (elem.time * 100.0F / totalTime) << "%"
981         << " ";
982     out << std::setw(C12NUM) << elem.count << " ";
983     out << std::setw(C12NUM) << std::fixed << std::setprecision(C2NUM) << elem.time << std::endl;
984   }
985   out.flags(old_settings);
986   out.precision(old_precision);
987   out << "========== " << value.mName_ << " total runtime = " << totalTime << " ms ==========" << std::endl;
988 
989   return out;
990 }
991 #endif  // PROFILER_
992 }  // namespace mindspore::lite
993