• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_CONTEXT_H_
17 #define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_CONTEXT_H_
18 
19 #include <experimental/optional>
20 #include <NvInfer.h>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 #include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h"
25 
26 namespace mindspore::lite {
27 struct ITensorHelper {
28   nvinfer1::ITensor *trt_tensor_{nullptr};
29   mindspore::Format format_{Format::NCHW};
30   bool same_format_{true};
31   bool is_tensor{true};
32 };
33 class TensorRTContext {
34  public:
35   TensorRTContext() = default;
36   ~TensorRTContext();
37   bool Init();
38   void SetRuntime(TensorRTRuntime *runtime);
39   nvinfer1::INetworkDefinition *network();
40   void RegisterLayer(nvinfer1::ILayer *layer, const std::string &basename);
41   void RegisterTensor(ITensorHelper tensor, const std::string &basename);
42   void RegisterTensorWithSameName(ITensorHelper tensor, const std::string &basename);
43   bool HasTensor(const std::string &name) const;
44   ITensorHelper MsName2Tensor(const std::string &ms_name);
45 
46   template <typename T>
47   nvinfer1::ITensor *ConvertTo0DTensor(T value);
48   template <typename T>
49   nvinfer1::ITensor *ConvertTo1DTensor(T value);
50   template <typename T>
51   nvinfer1::ITensor *ConvertTo1DTensor(const std::vector<T> &values);
52 
53  private:
54   int counter_{0};
55   nvinfer1::INetworkDefinition *network_{nullptr};
56   std::unordered_map<std::string, ITensorHelper> ms_name2trt_tensor_;
57   TensorRTRuntime *runtime_{nullptr};
58   std::vector<void *> owner_memorys_;
59 };
60 }  // namespace mindspore::lite
61 #endif  // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_CONTEXT_H_
62