1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <tuple> 24 #include <map> 25 #include <NvInfer.h> 26 #include "utils/hash_map.h" 27 #include "base/base.h" 28 #include "ir/anf.h" 29 #include "include/backend/anf_runtime_algorithm.h" 30 #include "include/common/utils/anfalgo.h" 31 #include "plugin/device/gpu/optimizer/trt_pass/layer_input.h" 32 33 namespace mindspore { 34 namespace opt { 35 // The const number 4GB in bytes. 36 constexpr size_t kFourGBytes = 4UL << 30; 37 38 // Class transform ANF graph to Tensor-RT network. 39 // It converts the operators in ANF graph to Tensor-RT layer according to the topological order. 40 // During the conversion, the cache keep the map between ANF node outputs and Tensor-RT layer outputs. 41 // Before starting the operator conversion, it first caches the weights and constant node int the Anf graph. 42 // During performing operator transformation, it obtains the inputs of the operator from the cache. 43 // After conversion is completed, it store the outputs of the operator to the cache. 44 class TrtConverterContext : public std::enable_shared_from_this<TrtConverterContext> { 45 public: TrtConverterContext(FuncGraphPtr fg)46 explicit TrtConverterContext(FuncGraphPtr fg) 47 : func_graph_(fg), 48 batch_size_(1), 49 workspace_size_(kFourGBytes), 50 builder_(nullptr), 51 network_(nullptr), 52 config_(nullptr), 53 engine_(nullptr) {} 54 ~TrtConverterContext() = default; 55 56 // Create Tensor-RT object and cache the ANF graph inputs and constant node. 57 bool Init(); 58 59 // Parser KernelGraph to trt graph 60 bool Parser(); 61 62 // Serialize trt models. 63 bool Serialize(std::string *model); 64 65 // Get trt graph inputs without weights. The inputs keep same order as binding name. 66 std::vector<AnfNodePtr> GetGraphInputs() const; 67 68 // Get trt graph outputs. All outputs are flatten to vector with concret shape. 69 std::tuple<std::map<size_t, size_t>, std::vector<session::KernelWithIndex>> GetGraphOutputs() const; 70 71 // Store trt layer outputs to the cache. 72 bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<nvinfer1::ITensor *> &inputs); 73 74 // Get trt layer inputs from the cache. 75 bool LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs); 76 77 // Create and keep temporary weight, as constant folding demanding new weight excluded in graph, 78 // which should release until building finish. 79 std::shared_ptr<tensor::Tensor> CreateTempWeight(const TypeId &type, const ShapeVector &shape); 80 network()81 std::shared_ptr<nvinfer1::INetworkDefinition> network() const { return network_; } 82 83 private: 84 bool InitInputTable(); 85 bool InitValueNodeTable(); 86 LayerInput *LoadInputOnDemand(const AnfNodePtr &node); 87 88 FuncGraphPtr func_graph_; 89 uint32_t batch_size_; 90 size_t workspace_size_; 91 std::shared_ptr<nvinfer1::IBuilder> builder_; 92 std::shared_ptr<nvinfer1::INetworkDefinition> network_; 93 std::shared_ptr<nvinfer1::IBuilderConfig> config_; 94 std::shared_ptr<nvinfer1::ICudaEngine> engine_; 95 96 // Cache (AnfNode + output_index : ILayer output). 97 mindspore::HashMap<AnfNodePtr, mindspore::HashMap<size_t, LayerInput>> output_map_; 98 std::vector<std::shared_ptr<tensor::Tensor>> temp_weights_; 99 }; 100 } // namespace opt 101 } // namespace mindspore 102 103 #endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_HELPER_H_ 104