• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
19 
20 #include <vector>
21 #include <NvInfer.h>
22 
23 namespace mindspore::opt {
24 // Tensor-RT layer inputs include weight or tensor.
25 // Tensor: Anf-graph inputs or feature map which values change during inference.
26 // Weight: Anf-graph inputs or value node which remain unchanged during inference.
27 class LayerInput {
28  public:
LayerInput()29   LayerInput() : type_(InputType::kUnknown), weight_(), tensor_(nullptr) {}
LayerInput(nvinfer1::Weights & w,const std::vector<int64_t> & s)30   explicit LayerInput(nvinfer1::Weights &w, const std::vector<int64_t> &s)
31       : type_(InputType::kWeight), weight_(w), tensor_(nullptr), shape_(s) {}
LayerInput(nvinfer1::ITensor * t,const std::vector<int64_t> & s)32   explicit LayerInput(nvinfer1::ITensor *t, const std::vector<int64_t> &s)
33       : type_(InputType::kTensor), weight_(), tensor_(t), shape_(s) {}
34 
IsTensor()35   bool IsTensor() const { return type_ == InputType::kTensor; }
IsWeight()36   bool IsWeight() const { return type_ == InputType::kWeight; }
37 
weight()38   nvinfer1::Weights *weight() {
39     if (!IsWeight()) {
40       MS_LOG(WARNING) << "weight not initialized.";
41       return nullptr;
42     }
43     return &weight_;
44   }
45 
tensor()46   nvinfer1::ITensor *tensor() const {
47     if (!IsTensor()) {
48       MS_LOG(WARNING) << "tensor not initialized.";
49       return nullptr;
50     }
51     return tensor_;
52   }
53 
shape()54   const std::vector<int64_t> &shape() const { return shape_; }
55 
56  private:
57   enum class InputType : char { kUnknown = 0, kTensor, kWeight };
58   InputType type_;
59   // Keep the copy rather than point cause Weights created as a local variable.
60   nvinfer1::Weights weight_;
61   // Keep the point as ITensor created/held by nvinfer1::INetworkDefinition.
62   nvinfer1::ITensor *tensor_;
63   // Keep the shape of tensor or weight.
64   std::vector<int64_t> shape_;
65 };
66 }  // namespace mindspore::opt
67 
68 #endif  // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
69