• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/gpu/optimizer/trt_pass/trt_converter_context.h"
18 
19 #include <utility>
20 #include <algorithm>
21 #include "plugin/device/gpu/hal/device/trt_loader.h"
22 #include "plugin/device/gpu/optimizer/trt_pass/trt_op_factory.h"
23 #include "plugin/device/gpu/kernel/trt/trt_utils.h"
24 #include "include/common/utils/convert_utils.h"
25 #include "include/common/utils/utils.h"
26 #include "utils/singleton.h"
27 #include "utils/ms_context.h"
28 
29 namespace mindspore::opt {
Init()30 bool TrtConverterContext::Init() {
31   auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
32   builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
33   MS_EXCEPTION_IF_NULL(builder_);
34 
35   auto batch_type = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
36   network_ = TrtPtr(builder_->createNetworkV2(batch_type));
37   MS_EXCEPTION_IF_NULL(network_);
38 
39   config_ = TrtPtr(builder_->createBuilderConfig());
40   MS_EXCEPTION_IF_NULL(config_);
41 
42   InitInputTable();
43   InitValueNodeTable();
44   return true;
45 }
46 
Parser()47 bool TrtConverterContext::Parser() {
48   std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
49   const auto &converter_factory = TrtOpFactory::GetInstance();
50   for (auto node : node_list) {
51     if (!node->isa<CNode>()) {
52       continue;
53     }
54 
55     // Transform AnfNode To Trt layer.
56     // Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
57     std::string op_name = common::AnfAlgo::GetCNodePrimitive(node)->name();
58     if (!AnfUtils::IsRealKernel(node) && op_name != "Return") {
59       continue;
60     }
61 
62     ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name);
63     auto result = convert_func(node, this->shared_from_this());
64     if (!result.first) {
65       MS_LOG(WARNING) << op_name << " converter failed.";
66       return false;
67     }
68     auto ret = StoreLayerOutput(node, result.second);
69     if (!ret) {
70       MS_LOG(WARNING) << op_name << " converter failed.";
71       return false;
72     }
73   }
74 
75   return true;
76 }
77 
Serialize(std::string * model)78 bool TrtConverterContext::Serialize(std::string *model) {
79   MS_EXCEPTION_IF_NULL(model);
80   builder_->setMaxBatchSize(batch_size_);
81   config_->setMaxWorkspaceSize(workspace_size_);
82 
83   // Set precision mode
84   const auto &context = MsContext::GetInstance();
85   const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
86   if (precision_mode == "fp16") {
87     MS_LOG(INFO) << "Inference with mixed precision mode";
88     config_->setFlag(nvinfer1::BuilderFlag::kFP16);
89   }
90 
91   MS_LOG(WARNING) << "It will take few minutes for operators selection.";
92   engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
93   MS_EXCEPTION_IF_NULL(engine_);
94 
95   std::shared_ptr<nvinfer1::IHostMemory> model_data = TrtPtr(engine_->serialize());
96   *model = string(static_cast<const char *>(model_data->data()), model_data->size());
97   return true;
98 }
99 
InitInputTable()100 bool TrtConverterContext::InitInputTable() {
101   const std::vector<AnfNodePtr> graph_inputs = func_graph_->parameters();
102   for (auto input_node : graph_inputs) {
103     if (!input_node->isa<Parameter>()) {
104       continue;
105     }
106 
107     auto input = input_node->cast<ParameterPtr>();
108     if (common::AnfAlgo::IsParameterWeight(input)) {
109       const auto &param_value = input->default_param();
110       MS_EXCEPTION_IF_NULL(param_value);
111       auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
112       MS_EXCEPTION_IF_NULL(tensor);
113 
114       nvinfer1::Weights weight;
115       weight.values = tensor->data_c();
116       std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
117       TRT_VARIANT_CHECK(type, 1UL, false);
118       weight.type = std::get<nvinfer1::DataType>(type);
119       weight.count = tensor->DataSize();
120       output_map_[input_node][0] = LayerInput(weight, tensor->shape());
121     }
122   }
123   return true;
124 }
125 
InitValueNodeTable()126 bool TrtConverterContext::InitValueNodeTable() {
127   MS_EXCEPTION_IF_NULL(func_graph_);
128   const std::vector<AnfNodePtr> &node_list = TopoSort(func_graph_->get_return());
129   for (const auto &node : node_list) {
130     MS_EXCEPTION_IF_NULL(node);
131     if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
132       auto value_node = node->cast<ValueNodePtr>();
133       auto &node_value = value_node->value();
134       MS_EXCEPTION_IF_NULL(node_value);
135 
136       if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
137         std::vector<tensor::BaseTensorPtr> tensors;
138         TensorValueToTensor(node_value, &tensors);
139         for (size_t i = 0; i < tensors.size(); i++) {
140           const auto &tensor = tensors[i];
141           nvinfer1::Weights weight;
142           weight.values = tensor->data_c();
143           std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
144           TRT_VARIANT_CHECK(type, 1UL, false);
145           weight.type = std::get<nvinfer1::DataType>(type);
146           weight.count = tensor->DataSize();
147           output_map_[value_node][i] = LayerInput(weight, tensor->shape());
148         }
149       }
150     }
151   }
152   return true;
153 }
154 
StoreLayerOutput(const AnfNodePtr & node,const std::vector<nvinfer1::ITensor * > & nv_tensors)155 bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &nv_tensors) {
156   if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
157     MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
158                  << ", while got: " << nv_tensors.size();
159   }
160 
161   for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
162     if (nv_tensors[tensor_index] != nullptr) {
163       const nvinfer1::Dims &dim = nv_tensors[tensor_index]->getDimensions();
164       const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(dim);
165       output_map_[node][tensor_index] = LayerInput(nv_tensors[tensor_index], shape);
166 
167       std::ostringstream oss;
168       oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
169       for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
170         oss << dim.d[dim_index] << " ";
171       }
172       oss << "]";
173       MS_LOG(INFO) << oss.str();
174     }
175   }
176   return true;
177 }
178 
LoadInputOnDemand(const AnfNodePtr & node)179 LayerInput *TrtConverterContext::LoadInputOnDemand(const AnfNodePtr &node) {
180   MS_EXCEPTION_IF_NULL(node);
181   auto input = node->cast<ParameterPtr>();
182   std::variant<bool, nvinfer1::DataType> type =
183     TrtUtils::MsDtypeToTrtDtype(common::AnfAlgo::GetOutputInferDataType(node, 0));
184   TRT_VARIANT_CHECK(type, 1UL, nullptr);
185   const auto &trt_dtype = std::get<nvinfer1::DataType>(type);
186   const nvinfer1::Dims &trt_dims = TrtUtils::MsDimsToTrtDims(common::AnfAlgo::GetOutputInferShape(node, 0), false);
187   nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
188   const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
189   output_map_[node][0] = LayerInput(tensor, shape);
190   return &output_map_[node][0];
191 }
192 
LoadLayerInput(const AnfNodePtr & node,std::vector<LayerInput> * inputs)193 bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
194   std::vector<session::KernelWithIndex> real_inputs;
195   common::AnfAlgo::GetRealInputs(node, &real_inputs);
196   for (auto item : real_inputs) {
197     auto node_iter = output_map_.find(item.first);
198     if (node_iter == output_map_.end()) {
199       if (item.first->isa<Parameter>()) {
200         LayerInput *input = LoadInputOnDemand(item.first);
201         if (input == nullptr) {
202           MS_LOG(WARNING) << "LoadLayerInput failed.";
203           return false;
204         }
205         inputs->push_back(*input);
206         continue;
207       }
208       MS_LOG(WARNING) << "node: " << node->DebugString() << " not found.";
209       return false;
210     }
211 
212     auto out_iter = node_iter->second.find(item.second);
213     if (out_iter == node_iter->second.end()) {
214       MS_LOG(WARNING) << "node: " << node->DebugString() << "output index: " << item.second << " not found.";
215       return false;
216     }
217 
218     inputs->push_back(out_iter->second);
219   }
220   return true;
221 }
222 
GetGraphInputs() const223 std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() const {
224   // Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
225   mindspore::HashMap<std::string, AnfNodePtr> graph_inputs;
226   for (const auto &input_node : func_graph_->parameters()) {
227     if (!input_node->isa<Parameter>()) {
228       continue;
229     }
230 
231     auto input = input_node->cast<ParameterPtr>();
232     if (!common::AnfAlgo::IsParameterWeight(input)) {
233       (void)graph_inputs.emplace(input->name(), input_node);
234     }
235   }
236 
237   // Keep the graph inputs in order of the binding name.
238   std::vector<AnfNodePtr> trt_inputs;
239   for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
240     if (!engine_->bindingIsInput(i)) {
241       continue;
242     }
243     auto iter = graph_inputs.find(engine_->getBindingName(i));
244     if (iter == graph_inputs.end()) {
245       MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i);
246     }
247     trt_inputs.push_back(iter->second);
248   }
249   return trt_inputs;
250 }
251 
GetGraphOutputs() const252 std::tuple<std::map<size_t, size_t>, std::vector<session::KernelWithIndex>> TrtConverterContext::GetGraphOutputs()
253   const {
254   std::vector<session::KernelWithIndex> anf_output_list;
255   common::AnfAlgo::GetRealInputs(func_graph_->get_return(), &anf_output_list);
256 
257   std::map<size_t, size_t> anf_trt_index_map;
258   std::vector<session::KernelWithIndex> trt_output_list(anf_output_list.size());
259   size_t trt_index = 0;
260   for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
261     if (!engine_->bindingIsInput(i)) {
262       const std::string &name = engine_->getBindingName(i);
263       size_t pos = name.find_first_not_of("return_output_");
264       size_t anf_index = atoi(name.substr(pos).c_str());
265 
266       (void)anf_trt_index_map.emplace(anf_index, trt_index);
267       trt_output_list[trt_index] = anf_output_list[anf_index];
268       trt_index++;
269     }
270   }
271 
272   return std::make_tuple(anf_trt_index_map, trt_output_list);
273 }
274 
CreateTempWeight(const TypeId & type,const ShapeVector & shape)275 std::shared_ptr<tensor::Tensor> TrtConverterContext::CreateTempWeight(const TypeId &type, const ShapeVector &shape) {
276   auto tensor = std::make_shared<tensor::Tensor>(type, shape);
277   temp_weights_.push_back(tensor);
278   return tensor;
279 }
280 }  // namespace mindspore::opt
281