• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/trt_pass/trt_converter_context.h"
18 
19 #include "runtime/device/gpu/trt_loader.h"
20 #include "backend/optimizer/trt_pass/trt_op_factory.h"
21 #include "backend/kernel_compiler/gpu/trt/trt_utils.h"
22 #include "utils/convert_utils.h"
23 #include "utils/utils.h"
24 #include "utils/singleton.h"
25 #include "utils/ms_context.h"
26 
27 namespace mindspore::opt {
Init()28 bool TrtConverterContext::Init() {
29   auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
30   builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
31   MS_EXCEPTION_IF_NULL(builder_);
32 
33   auto batch_type = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
34   network_ = TrtPtr(builder_->createNetworkV2(batch_type));
35   MS_EXCEPTION_IF_NULL(network_);
36 
37   config_ = TrtPtr(builder_->createBuilderConfig());
38   MS_EXCEPTION_IF_NULL(config_);
39 
40   InitInputTable();
41   InitValueNodeTable();
42   return true;
43 }
44 
Parser()45 bool TrtConverterContext::Parser() {
46   std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
47   const auto &converter_factory = TrtOpFactory::GetInstance();
48   for (auto node : node_list) {
49     if (!node->isa<CNode>()) {
50       continue;
51     }
52 
53     // Transform AnfNode To Trt layer.
54     // Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
55     std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
56     if (!AnfAlgo::IsRealKernel(node) && op_name != "Return") {
57       continue;
58     }
59 
60     ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name);
61     auto result = convert_func(node, this->shared_from_this());
62     if (!result.first) {
63       MS_LOG(WARNING) << op_name << " converter failed.";
64       return false;
65     }
66     auto ret = StoreLayerOutput(node, result.second);
67     if (!ret) {
68       MS_LOG(WARNING) << op_name << " converter failed.";
69       return false;
70     }
71   }
72 
73   return true;
74 }
75 
Serialize(std::string * model)76 bool TrtConverterContext::Serialize(std::string *model) {
77   MS_EXCEPTION_IF_NULL(model);
78   builder_->setMaxBatchSize(batch_size_);
79   config_->setMaxWorkspaceSize(workspace_size_);
80 
81   // Set precision mode
82   const auto &context = MsContext::GetInstance();
83   const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
84   if (precision_mode == "fp16") {
85     MS_LOG(INFO) << "Inference with mixed precision mode";
86     config_->setFlag(nvinfer1::BuilderFlag::kFP16);
87   }
88 
89   MS_LOG(WARNING) << "It will take few minutes for operators selection.";
90   engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
91   MS_EXCEPTION_IF_NULL(engine_);
92 
93   std::shared_ptr<nvinfer1::IHostMemory> model_data = TrtPtr(engine_->serialize());
94   *model = string(static_cast<const char *>(model_data->data()), model_data->size());
95   return true;
96 }
97 
InitInputTable()98 bool TrtConverterContext::InitInputTable() {
99   const std::vector<AnfNodePtr> graph_inputs = func_graph_->parameters();
100   for (auto input_node : graph_inputs) {
101     if (!input_node->isa<Parameter>()) {
102       continue;
103     }
104 
105     auto input = input_node->cast<ParameterPtr>();
106     if (AnfAlgo::IsParameterWeight(input)) {
107       const auto &param_value = input->default_param();
108       MS_EXCEPTION_IF_NULL(param_value);
109       auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
110       MS_EXCEPTION_IF_NULL(tensor);
111 
112       nvinfer1::Weights weight;
113       weight.values = tensor->data_c();
114       std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
115       TRT_VARIANT_CHECK(type, 1UL, false);
116       weight.type = std::get<nvinfer1::DataType>(type);
117       weight.count = tensor->DataSize();
118       output_map_[input_node][0] = LayerInput(weight, tensor->shape());
119     }
120   }
121   return true;
122 }
123 
InitValueNodeTable()124 bool TrtConverterContext::InitValueNodeTable() {
125   MS_EXCEPTION_IF_NULL(func_graph_);
126   const std::vector<AnfNodePtr> &node_list = TopoSort(func_graph_->get_return());
127   for (const auto &node : node_list) {
128     MS_EXCEPTION_IF_NULL(node);
129     if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
130       auto value_node = node->cast<ValueNodePtr>();
131       auto &node_value = value_node->value();
132       MS_EXCEPTION_IF_NULL(node_value);
133 
134       if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
135         std::vector<tensor::TensorPtr> tensors;
136         TensorValueToTensor(node_value, &tensors);
137         for (size_t i = 0; i < tensors.size(); i++) {
138           const auto &tensor = tensors[i];
139           nvinfer1::Weights weight;
140           weight.values = tensor->data_c();
141           std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
142           TRT_VARIANT_CHECK(type, 1UL, false);
143           weight.type = std::get<nvinfer1::DataType>(type);
144           weight.count = tensor->DataSize();
145           output_map_[value_node][i] = LayerInput(weight, tensor->shape());
146         }
147       }
148     }
149   }
150   return true;
151 }
152 
StoreLayerOutput(const AnfNodePtr & node,const std::vector<nvinfer1::ITensor * > & nv_tensors)153 bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &nv_tensors) {
154   if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
155     MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
156                  << ", while got: " << nv_tensors.size();
157   }
158 
159   for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
160     if (nv_tensors[tensor_index] != nullptr) {
161       const nvinfer1::Dims &dim = nv_tensors[tensor_index]->getDimensions();
162       const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(dim);
163       output_map_[node][tensor_index] = LayerInput(nv_tensors[tensor_index], shape);
164 
165       std::ostringstream oss;
166       oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
167       for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
168         oss << dim.d[dim_index] << " ";
169       }
170       oss << "]";
171       MS_LOG(INFO) << oss.str();
172     }
173   }
174   return true;
175 }
176 
LoadInputOnDemand(const AnfNodePtr & node)177 LayerInput *TrtConverterContext::LoadInputOnDemand(const AnfNodePtr &node) {
178   MS_EXCEPTION_IF_NULL(node);
179   auto input = node->cast<ParameterPtr>();
180   std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(node, 0));
181   TRT_VARIANT_CHECK(type, 1UL, nullptr);
182   const auto &trt_dtype = std::get<nvinfer1::DataType>(type);
183   const nvinfer1::Dims &trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(node, 0), false);
184   nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
185   const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
186   output_map_[node][0] = LayerInput(tensor, shape);
187   return &output_map_[node][0];
188 }
189 
LoadLayerInput(const AnfNodePtr & node,std::vector<LayerInput> * inputs)190 bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
191   std::vector<session::KernelWithIndex> real_inputs;
192   AnfAlgo::GetRealInputs(node, &real_inputs);
193   for (auto item : real_inputs) {
194     auto node_iter = output_map_.find(item.first);
195     if (node_iter == output_map_.end()) {
196       if (item.first->isa<Parameter>()) {
197         LayerInput *input = LoadInputOnDemand(item.first);
198         if (input == nullptr) {
199           MS_LOG(WARNING) << "LoadLayerInput failed.";
200           return false;
201         }
202         inputs->push_back(*input);
203         continue;
204       }
205       MS_LOG(WARNING) << "node: " << node->DebugString() << " not found.";
206       return false;
207     }
208 
209     auto out_iter = node_iter->second.find(item.second);
210     if (out_iter == node_iter->second.end()) {
211       MS_LOG(WARNING) << "node: " << node->DebugString() << "output index: " << item.second << " not found.";
212       return false;
213     }
214 
215     inputs->push_back(out_iter->second);
216   }
217   return true;
218 }
219 
GetGraphInputs() const220 std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() const {
221   // Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
222   std::unordered_map<std::string, AnfNodePtr> graph_inputs;
223   for (const auto &input_node : func_graph_->parameters()) {
224     if (!input_node->isa<Parameter>()) {
225       continue;
226     }
227 
228     auto input = input_node->cast<ParameterPtr>();
229     if (!AnfAlgo::IsParameterWeight(input)) {
230       graph_inputs.insert(std::make_pair(input->name(), input_node));
231     }
232   }
233 
234   // Keep the graph inputs in order of the binding name.
235   std::vector<AnfNodePtr> trt_inputs;
236   for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
237     if (!engine_->bindingIsInput(i)) {
238       continue;
239     }
240     auto iter = graph_inputs.find(engine_->getBindingName(i));
241     if (iter == graph_inputs.end()) {
242       MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i);
243     }
244     trt_inputs.push_back(iter->second);
245   }
246   return trt_inputs;
247 }
248 
GetGraphOutputs() const249 std::tuple<std::map<size_t, size_t>, std::vector<session::KernelWithIndex>> TrtConverterContext::GetGraphOutputs()
250   const {
251   std::vector<session::KernelWithIndex> anf_output_list;
252   AnfAlgo::GetRealInputs(func_graph_->get_return(), &anf_output_list);
253 
254   std::map<size_t, size_t> anf_trt_index_map;
255   std::vector<session::KernelWithIndex> trt_output_list(anf_output_list.size());
256   size_t trt_index = 0;
257   for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
258     if (!engine_->bindingIsInput(i)) {
259       const std::string &name = engine_->getBindingName(i);
260       size_t pos = name.find_first_not_of("return_output_");
261       size_t anf_index = atoi(name.substr(pos).c_str());
262 
263       anf_trt_index_map.insert(std::make_pair(anf_index, trt_index));
264       trt_output_list[trt_index] = anf_output_list[anf_index];
265       trt_index++;
266     }
267   }
268 
269   return std::make_tuple(anf_trt_index_map, trt_output_list);
270 }
271 
CreateTempWeight(const TypeId & type,const std::vector<size_t> & shape)272 std::shared_ptr<tensor::Tensor> TrtConverterContext::CreateTempWeight(const TypeId &type,
273                                                                       const std::vector<size_t> &shape) {
274   ShapeVector shape_int;
275   std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), SizeToLong);
276   auto tensor = std::make_shared<tensor::Tensor>(type, shape_int);
277   temp_weights_.push_back(tensor);
278   return tensor;
279 }
280 }  // namespace mindspore::opt
281