1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/gpu/optimizer/trt_pass/trt_converter_context.h"
18
19 #include <utility>
20 #include <algorithm>
21 #include "plugin/device/gpu/hal/device/trt_loader.h"
22 #include "plugin/device/gpu/optimizer/trt_pass/trt_op_factory.h"
23 #include "plugin/device/gpu/kernel/trt/trt_utils.h"
24 #include "include/common/utils/convert_utils.h"
25 #include "include/common/utils/utils.h"
26 #include "utils/singleton.h"
27 #include "utils/ms_context.h"
28
29 namespace mindspore::opt {
Init()30 bool TrtConverterContext::Init() {
31 auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
32 builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
33 MS_EXCEPTION_IF_NULL(builder_);
34
35 auto batch_type = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
36 network_ = TrtPtr(builder_->createNetworkV2(batch_type));
37 MS_EXCEPTION_IF_NULL(network_);
38
39 config_ = TrtPtr(builder_->createBuilderConfig());
40 MS_EXCEPTION_IF_NULL(config_);
41
42 InitInputTable();
43 InitValueNodeTable();
44 return true;
45 }
46
Parser()47 bool TrtConverterContext::Parser() {
48 std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
49 const auto &converter_factory = TrtOpFactory::GetInstance();
50 for (auto node : node_list) {
51 if (!node->isa<CNode>()) {
52 continue;
53 }
54
55 // Transform AnfNode To Trt layer.
56 // Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
57 std::string op_name = common::AnfAlgo::GetCNodePrimitive(node)->name();
58 if (!AnfUtils::IsRealKernel(node) && op_name != "Return") {
59 continue;
60 }
61
62 ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name);
63 auto result = convert_func(node, this->shared_from_this());
64 if (!result.first) {
65 MS_LOG(WARNING) << op_name << " converter failed.";
66 return false;
67 }
68 auto ret = StoreLayerOutput(node, result.second);
69 if (!ret) {
70 MS_LOG(WARNING) << op_name << " converter failed.";
71 return false;
72 }
73 }
74
75 return true;
76 }
77
Serialize(std::string * model)78 bool TrtConverterContext::Serialize(std::string *model) {
79 MS_EXCEPTION_IF_NULL(model);
80 builder_->setMaxBatchSize(batch_size_);
81 config_->setMaxWorkspaceSize(workspace_size_);
82
83 // Set precision mode
84 const auto &context = MsContext::GetInstance();
85 const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
86 if (precision_mode == "fp16") {
87 MS_LOG(INFO) << "Inference with mixed precision mode";
88 config_->setFlag(nvinfer1::BuilderFlag::kFP16);
89 }
90
91 MS_LOG(WARNING) << "It will take few minutes for operators selection.";
92 engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
93 MS_EXCEPTION_IF_NULL(engine_);
94
95 std::shared_ptr<nvinfer1::IHostMemory> model_data = TrtPtr(engine_->serialize());
96 *model = string(static_cast<const char *>(model_data->data()), model_data->size());
97 return true;
98 }
99
InitInputTable()100 bool TrtConverterContext::InitInputTable() {
101 const std::vector<AnfNodePtr> graph_inputs = func_graph_->parameters();
102 for (auto input_node : graph_inputs) {
103 if (!input_node->isa<Parameter>()) {
104 continue;
105 }
106
107 auto input = input_node->cast<ParameterPtr>();
108 if (common::AnfAlgo::IsParameterWeight(input)) {
109 const auto ¶m_value = input->default_param();
110 MS_EXCEPTION_IF_NULL(param_value);
111 auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
112 MS_EXCEPTION_IF_NULL(tensor);
113
114 nvinfer1::Weights weight;
115 weight.values = tensor->data_c();
116 std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
117 TRT_VARIANT_CHECK(type, 1UL, false);
118 weight.type = std::get<nvinfer1::DataType>(type);
119 weight.count = tensor->DataSize();
120 output_map_[input_node][0] = LayerInput(weight, tensor->shape());
121 }
122 }
123 return true;
124 }
125
InitValueNodeTable()126 bool TrtConverterContext::InitValueNodeTable() {
127 MS_EXCEPTION_IF_NULL(func_graph_);
128 const std::vector<AnfNodePtr> &node_list = TopoSort(func_graph_->get_return());
129 for (const auto &node : node_list) {
130 MS_EXCEPTION_IF_NULL(node);
131 if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
132 auto value_node = node->cast<ValueNodePtr>();
133 auto &node_value = value_node->value();
134 MS_EXCEPTION_IF_NULL(node_value);
135
136 if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
137 std::vector<tensor::BaseTensorPtr> tensors;
138 TensorValueToTensor(node_value, &tensors);
139 for (size_t i = 0; i < tensors.size(); i++) {
140 const auto &tensor = tensors[i];
141 nvinfer1::Weights weight;
142 weight.values = tensor->data_c();
143 std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
144 TRT_VARIANT_CHECK(type, 1UL, false);
145 weight.type = std::get<nvinfer1::DataType>(type);
146 weight.count = tensor->DataSize();
147 output_map_[value_node][i] = LayerInput(weight, tensor->shape());
148 }
149 }
150 }
151 }
152 return true;
153 }
154
StoreLayerOutput(const AnfNodePtr & node,const std::vector<nvinfer1::ITensor * > & nv_tensors)155 bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &nv_tensors) {
156 if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
157 MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
158 << ", while got: " << nv_tensors.size();
159 }
160
161 for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
162 if (nv_tensors[tensor_index] != nullptr) {
163 const nvinfer1::Dims &dim = nv_tensors[tensor_index]->getDimensions();
164 const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(dim);
165 output_map_[node][tensor_index] = LayerInput(nv_tensors[tensor_index], shape);
166
167 std::ostringstream oss;
168 oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
169 for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
170 oss << dim.d[dim_index] << " ";
171 }
172 oss << "]";
173 MS_LOG(INFO) << oss.str();
174 }
175 }
176 return true;
177 }
178
LoadInputOnDemand(const AnfNodePtr & node)179 LayerInput *TrtConverterContext::LoadInputOnDemand(const AnfNodePtr &node) {
180 MS_EXCEPTION_IF_NULL(node);
181 auto input = node->cast<ParameterPtr>();
182 std::variant<bool, nvinfer1::DataType> type =
183 TrtUtils::MsDtypeToTrtDtype(common::AnfAlgo::GetOutputInferDataType(node, 0));
184 TRT_VARIANT_CHECK(type, 1UL, nullptr);
185 const auto &trt_dtype = std::get<nvinfer1::DataType>(type);
186 const nvinfer1::Dims &trt_dims = TrtUtils::MsDimsToTrtDims(common::AnfAlgo::GetOutputInferShape(node, 0), false);
187 nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
188 const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
189 output_map_[node][0] = LayerInput(tensor, shape);
190 return &output_map_[node][0];
191 }
192
LoadLayerInput(const AnfNodePtr & node,std::vector<LayerInput> * inputs)193 bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
194 std::vector<session::KernelWithIndex> real_inputs;
195 common::AnfAlgo::GetRealInputs(node, &real_inputs);
196 for (auto item : real_inputs) {
197 auto node_iter = output_map_.find(item.first);
198 if (node_iter == output_map_.end()) {
199 if (item.first->isa<Parameter>()) {
200 LayerInput *input = LoadInputOnDemand(item.first);
201 if (input == nullptr) {
202 MS_LOG(WARNING) << "LoadLayerInput failed.";
203 return false;
204 }
205 inputs->push_back(*input);
206 continue;
207 }
208 MS_LOG(WARNING) << "node: " << node->DebugString() << " not found.";
209 return false;
210 }
211
212 auto out_iter = node_iter->second.find(item.second);
213 if (out_iter == node_iter->second.end()) {
214 MS_LOG(WARNING) << "node: " << node->DebugString() << "output index: " << item.second << " not found.";
215 return false;
216 }
217
218 inputs->push_back(out_iter->second);
219 }
220 return true;
221 }
222
GetGraphInputs() const223 std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() const {
224 // Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
225 mindspore::HashMap<std::string, AnfNodePtr> graph_inputs;
226 for (const auto &input_node : func_graph_->parameters()) {
227 if (!input_node->isa<Parameter>()) {
228 continue;
229 }
230
231 auto input = input_node->cast<ParameterPtr>();
232 if (!common::AnfAlgo::IsParameterWeight(input)) {
233 (void)graph_inputs.emplace(input->name(), input_node);
234 }
235 }
236
237 // Keep the graph inputs in order of the binding name.
238 std::vector<AnfNodePtr> trt_inputs;
239 for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
240 if (!engine_->bindingIsInput(i)) {
241 continue;
242 }
243 auto iter = graph_inputs.find(engine_->getBindingName(i));
244 if (iter == graph_inputs.end()) {
245 MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i);
246 }
247 trt_inputs.push_back(iter->second);
248 }
249 return trt_inputs;
250 }
251
GetGraphOutputs() const252 std::tuple<std::map<size_t, size_t>, std::vector<session::KernelWithIndex>> TrtConverterContext::GetGraphOutputs()
253 const {
254 std::vector<session::KernelWithIndex> anf_output_list;
255 common::AnfAlgo::GetRealInputs(func_graph_->get_return(), &anf_output_list);
256
257 std::map<size_t, size_t> anf_trt_index_map;
258 std::vector<session::KernelWithIndex> trt_output_list(anf_output_list.size());
259 size_t trt_index = 0;
260 for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
261 if (!engine_->bindingIsInput(i)) {
262 const std::string &name = engine_->getBindingName(i);
263 size_t pos = name.find_first_not_of("return_output_");
264 size_t anf_index = atoi(name.substr(pos).c_str());
265
266 (void)anf_trt_index_map.emplace(anf_index, trt_index);
267 trt_output_list[trt_index] = anf_output_list[anf_index];
268 trt_index++;
269 }
270 }
271
272 return std::make_tuple(anf_trt_index_map, trt_output_list);
273 }
274
CreateTempWeight(const TypeId & type,const ShapeVector & shape)275 std::shared_ptr<tensor::Tensor> TrtConverterContext::CreateTempWeight(const TypeId &type, const ShapeVector &shape) {
276 auto tensor = std::make_shared<tensor::Tensor>(type, shape);
277 temp_weights_.push_back(tensor);
278 return tensor;
279 }
280 } // namespace mindspore::opt
281