1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/trt_pass/trt_converter_context.h"
18
19 #include "runtime/device/gpu/trt_loader.h"
20 #include "backend/optimizer/trt_pass/trt_op_factory.h"
21 #include "backend/kernel_compiler/gpu/trt/trt_utils.h"
22 #include "utils/convert_utils.h"
23 #include "utils/utils.h"
24 #include "utils/singleton.h"
25 #include "utils/ms_context.h"
26
27 namespace mindspore::opt {
Init()28 bool TrtConverterContext::Init() {
29 auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
30 builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
31 MS_EXCEPTION_IF_NULL(builder_);
32
33 auto batch_type = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
34 network_ = TrtPtr(builder_->createNetworkV2(batch_type));
35 MS_EXCEPTION_IF_NULL(network_);
36
37 config_ = TrtPtr(builder_->createBuilderConfig());
38 MS_EXCEPTION_IF_NULL(config_);
39
40 InitInputTable();
41 InitValueNodeTable();
42 return true;
43 }
44
Parser()45 bool TrtConverterContext::Parser() {
46 std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return());
47 const auto &converter_factory = TrtOpFactory::GetInstance();
48 for (auto node : node_list) {
49 if (!node->isa<CNode>()) {
50 continue;
51 }
52
53 // Transform AnfNode To Trt layer.
54 // Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
55 std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
56 if (!AnfAlgo::IsRealKernel(node) && op_name != "Return") {
57 continue;
58 }
59
60 ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name);
61 auto result = convert_func(node, this->shared_from_this());
62 if (!result.first) {
63 MS_LOG(WARNING) << op_name << " converter failed.";
64 return false;
65 }
66 auto ret = StoreLayerOutput(node, result.second);
67 if (!ret) {
68 MS_LOG(WARNING) << op_name << " converter failed.";
69 return false;
70 }
71 }
72
73 return true;
74 }
75
Serialize(std::string * model)76 bool TrtConverterContext::Serialize(std::string *model) {
77 MS_EXCEPTION_IF_NULL(model);
78 builder_->setMaxBatchSize(batch_size_);
79 config_->setMaxWorkspaceSize(workspace_size_);
80
81 // Set precision mode
82 const auto &context = MsContext::GetInstance();
83 const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
84 if (precision_mode == "fp16") {
85 MS_LOG(INFO) << "Inference with mixed precision mode";
86 config_->setFlag(nvinfer1::BuilderFlag::kFP16);
87 }
88
89 MS_LOG(WARNING) << "It will take few minutes for operators selection.";
90 engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
91 MS_EXCEPTION_IF_NULL(engine_);
92
93 std::shared_ptr<nvinfer1::IHostMemory> model_data = TrtPtr(engine_->serialize());
94 *model = string(static_cast<const char *>(model_data->data()), model_data->size());
95 return true;
96 }
97
InitInputTable()98 bool TrtConverterContext::InitInputTable() {
99 const std::vector<AnfNodePtr> graph_inputs = func_graph_->parameters();
100 for (auto input_node : graph_inputs) {
101 if (!input_node->isa<Parameter>()) {
102 continue;
103 }
104
105 auto input = input_node->cast<ParameterPtr>();
106 if (AnfAlgo::IsParameterWeight(input)) {
107 const auto ¶m_value = input->default_param();
108 MS_EXCEPTION_IF_NULL(param_value);
109 auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
110 MS_EXCEPTION_IF_NULL(tensor);
111
112 nvinfer1::Weights weight;
113 weight.values = tensor->data_c();
114 std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
115 TRT_VARIANT_CHECK(type, 1UL, false);
116 weight.type = std::get<nvinfer1::DataType>(type);
117 weight.count = tensor->DataSize();
118 output_map_[input_node][0] = LayerInput(weight, tensor->shape());
119 }
120 }
121 return true;
122 }
123
InitValueNodeTable()124 bool TrtConverterContext::InitValueNodeTable() {
125 MS_EXCEPTION_IF_NULL(func_graph_);
126 const std::vector<AnfNodePtr> &node_list = TopoSort(func_graph_->get_return());
127 for (const auto &node : node_list) {
128 MS_EXCEPTION_IF_NULL(node);
129 if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
130 auto value_node = node->cast<ValueNodePtr>();
131 auto &node_value = value_node->value();
132 MS_EXCEPTION_IF_NULL(node_value);
133
134 if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
135 std::vector<tensor::TensorPtr> tensors;
136 TensorValueToTensor(node_value, &tensors);
137 for (size_t i = 0; i < tensors.size(); i++) {
138 const auto &tensor = tensors[i];
139 nvinfer1::Weights weight;
140 weight.values = tensor->data_c();
141 std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type());
142 TRT_VARIANT_CHECK(type, 1UL, false);
143 weight.type = std::get<nvinfer1::DataType>(type);
144 weight.count = tensor->DataSize();
145 output_map_[value_node][i] = LayerInput(weight, tensor->shape());
146 }
147 }
148 }
149 }
150 return true;
151 }
152
StoreLayerOutput(const AnfNodePtr & node,const std::vector<nvinfer1::ITensor * > & nv_tensors)153 bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &nv_tensors) {
154 if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) {
155 MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node)
156 << ", while got: " << nv_tensors.size();
157 }
158
159 for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) {
160 if (nv_tensors[tensor_index] != nullptr) {
161 const nvinfer1::Dims &dim = nv_tensors[tensor_index]->getDimensions();
162 const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(dim);
163 output_map_[node][tensor_index] = LayerInput(nv_tensors[tensor_index], shape);
164
165 std::ostringstream oss;
166 oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ ";
167 for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) {
168 oss << dim.d[dim_index] << " ";
169 }
170 oss << "]";
171 MS_LOG(INFO) << oss.str();
172 }
173 }
174 return true;
175 }
176
LoadInputOnDemand(const AnfNodePtr & node)177 LayerInput *TrtConverterContext::LoadInputOnDemand(const AnfNodePtr &node) {
178 MS_EXCEPTION_IF_NULL(node);
179 auto input = node->cast<ParameterPtr>();
180 std::variant<bool, nvinfer1::DataType> type = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(node, 0));
181 TRT_VARIANT_CHECK(type, 1UL, nullptr);
182 const auto &trt_dtype = std::get<nvinfer1::DataType>(type);
183 const nvinfer1::Dims &trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(node, 0), false);
184 nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims);
185 const std::vector<int64_t> &shape = TrtUtils::TrtDimsToMsDims(trt_dims);
186 output_map_[node][0] = LayerInput(tensor, shape);
187 return &output_map_[node][0];
188 }
189
LoadLayerInput(const AnfNodePtr & node,std::vector<LayerInput> * inputs)190 bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) {
191 std::vector<session::KernelWithIndex> real_inputs;
192 AnfAlgo::GetRealInputs(node, &real_inputs);
193 for (auto item : real_inputs) {
194 auto node_iter = output_map_.find(item.first);
195 if (node_iter == output_map_.end()) {
196 if (item.first->isa<Parameter>()) {
197 LayerInput *input = LoadInputOnDemand(item.first);
198 if (input == nullptr) {
199 MS_LOG(WARNING) << "LoadLayerInput failed.";
200 return false;
201 }
202 inputs->push_back(*input);
203 continue;
204 }
205 MS_LOG(WARNING) << "node: " << node->DebugString() << " not found.";
206 return false;
207 }
208
209 auto out_iter = node_iter->second.find(item.second);
210 if (out_iter == node_iter->second.end()) {
211 MS_LOG(WARNING) << "node: " << node->DebugString() << "output index: " << item.second << " not found.";
212 return false;
213 }
214
215 inputs->push_back(out_iter->second);
216 }
217 return true;
218 }
219
GetGraphInputs() const220 std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() const {
221 // Get Anf-graph inputs without weights. All weights were binded to Trt-graph.
222 std::unordered_map<std::string, AnfNodePtr> graph_inputs;
223 for (const auto &input_node : func_graph_->parameters()) {
224 if (!input_node->isa<Parameter>()) {
225 continue;
226 }
227
228 auto input = input_node->cast<ParameterPtr>();
229 if (!AnfAlgo::IsParameterWeight(input)) {
230 graph_inputs.insert(std::make_pair(input->name(), input_node));
231 }
232 }
233
234 // Keep the graph inputs in order of the binding name.
235 std::vector<AnfNodePtr> trt_inputs;
236 for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
237 if (!engine_->bindingIsInput(i)) {
238 continue;
239 }
240 auto iter = graph_inputs.find(engine_->getBindingName(i));
241 if (iter == graph_inputs.end()) {
242 MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i);
243 }
244 trt_inputs.push_back(iter->second);
245 }
246 return trt_inputs;
247 }
248
GetGraphOutputs() const249 std::tuple<std::map<size_t, size_t>, std::vector<session::KernelWithIndex>> TrtConverterContext::GetGraphOutputs()
250 const {
251 std::vector<session::KernelWithIndex> anf_output_list;
252 AnfAlgo::GetRealInputs(func_graph_->get_return(), &anf_output_list);
253
254 std::map<size_t, size_t> anf_trt_index_map;
255 std::vector<session::KernelWithIndex> trt_output_list(anf_output_list.size());
256 size_t trt_index = 0;
257 for (int32_t i = 0; i < engine_->getNbBindings(); ++i) {
258 if (!engine_->bindingIsInput(i)) {
259 const std::string &name = engine_->getBindingName(i);
260 size_t pos = name.find_first_not_of("return_output_");
261 size_t anf_index = atoi(name.substr(pos).c_str());
262
263 anf_trt_index_map.insert(std::make_pair(anf_index, trt_index));
264 trt_output_list[trt_index] = anf_output_list[anf_index];
265 trt_index++;
266 }
267 }
268
269 return std::make_tuple(anf_trt_index_map, trt_output_list);
270 }
271
CreateTempWeight(const TypeId & type,const std::vector<size_t> & shape)272 std::shared_ptr<tensor::Tensor> TrtConverterContext::CreateTempWeight(const TypeId &type,
273 const std::vector<size_t> &shape) {
274 ShapeVector shape_int;
275 std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), SizeToLong);
276 auto tensor = std::make_shared<tensor::Tensor>(type, shape_int);
277 temp_weights_.push_back(tensor);
278 return tensor;
279 }
280 } // namespace mindspore::opt
281