• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
18 #include <unordered_map>
19 #include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h"
20 
21 namespace mindspore::lite {
TensorRTOp(const BaseOperatorPtr & base_operator,const std::vector<TensorInfo> & in_tensors,const std::vector<TensorInfo> & out_tensors,std::string name)22 TensorRTOp::TensorRTOp(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
23                        const std::vector<TensorInfo> &out_tensors, std::string name)
24     : base_operator_(base_operator), in_tensors_(in_tensors), out_tensors_(out_tensors), op_name_(std::move(name)) {
25   MS_EXCEPTION_IF_NULL(base_operator);
26 
27   this->type_ = base_operator->name();
28   auto primitive_c = base_operator->GetPrim();
29   if (primitive_c != nullptr) {
30     return;
31   }
32 }
33 
GetBaseOperator()34 const BaseOperatorPtr &TensorRTOp::GetBaseOperator() { return this->base_operator_; }
35 
GetOpName()36 std::string TensorRTOp::GetOpName() { return this->op_name_; }
37 
inputs()38 std::vector<TensorInfo> &TensorRTOp::inputs() { return this->in_tensors_; }
39 
outputs()40 std::vector<TensorInfo> &TensorRTOp::outputs() { return this->out_tensors_; }
41 
input(TensorRTContext * ctx,size_t i)42 ITensorHelper TensorRTOp::input(TensorRTContext *ctx, size_t i) {
43   auto in_ms_tensor = in_tensors_[i];
44   ITensorHelper in_trt_tensor = ctx->MsName2Tensor(in_ms_tensor.Name());
45 
46   if (!GetSupportInputBool() && in_ms_tensor.DataType() == DataType::kNumberTypeBool) {
47     ITensorHelper in_trt_tensor_cast = ctx->MsName2Tensor(in_ms_tensor.Name() + "_to_int32");
48     if (in_trt_tensor_cast.trt_tensor_ == nullptr) {
49       auto cast_trt_tensor =
50         TRTTensorCast(ctx, in_trt_tensor.trt_tensor_, nvinfer1::DataType::kINT32, in_ms_tensor.Name() + "_cast_int32");
51       in_trt_tensor_cast = ITensorHelper{cast_trt_tensor, in_ms_tensor.format(), true};
52       ctx->RegisterTensor(in_trt_tensor_cast, in_ms_tensor.Name() + "_to_int32");
53     }
54     return in_trt_tensor_cast;
55   }
56   return in_trt_tensor;
57 }
58 
output(TensorRTContext * ctx,size_t i)59 ITensorHelper TensorRTOp::output(TensorRTContext *ctx, size_t i) { return ctx->MsName2Tensor(out_tensors_[i].Name()); }
60 
type() const61 const std::string &TensorRTOp::type() const { return this->type_; }
62 
GetQuantType() const63 schema::QuantType TensorRTOp::GetQuantType() const { return this->quant_type_; }
64 
set_in_ops(const std::vector<TensorRTOp * > & in_ops)65 void TensorRTOp::set_in_ops(const std::vector<TensorRTOp *> &in_ops) { this->in_ops_ = in_ops; }
66 
set_out_ops(const std::vector<TensorRTOp * > & out_ops)67 void TensorRTOp::set_out_ops(const std::vector<TensorRTOp *> &out_ops) { this->out_ops_ = out_ops; }
68 
in_ops() const69 const std::vector<TensorRTOp *> &TensorRTOp::in_ops() const { return this->in_ops_; }
70 
out_ops() const71 const std::vector<TensorRTOp *> &TensorRTOp::out_ops() const { return this->out_ops_; }
72 
SetRuntime(TensorRTRuntime * runtime)73 void TensorRTOp::SetRuntime(TensorRTRuntime *runtime) {
74   this->runtime_ = runtime;
75   device_id_ = runtime_->GetDeviceID();
76 }
77 
HasConst() const78 bool TensorRTOp::HasConst() const {
79   return std::any_of(in_tensors_.begin(), in_tensors_.end(),
80                      [](const TensorInfo &tensor) { return tensor.Data() != nullptr && tensor.IsConst(); });
81 }
82 
ReadyInputsNumber(TensorRTContext * ctx) const83 int TensorRTOp::ReadyInputsNumber(TensorRTContext *ctx) const {
84   return std::count_if(in_tensors_.begin(), in_tensors_.end(),
85                        [&](const TensorInfo &tensor) { return ctx->HasTensor(tensor.Name()); });
86 }
87 
IsShapeKnown()88 bool TensorRTOp::IsShapeKnown() { return true; }
89 
IsDynamicInput(TensorRTContext * ctx,size_t k)90 bool TensorRTOp::IsDynamicInput(TensorRTContext *ctx, size_t k) {
91   nvinfer1::Dims dims = input(ctx, k).trt_tensor_->getDimensions();
92   return std::any_of(dims.d, dims.d + dims.nbDims, [](int d) { return d == -1; });
93 }
94 
Prepare(void ** network_tensor_bindings,nvinfer1::ICudaEngine * engine)95 int TensorRTOp::Prepare(void **network_tensor_bindings, nvinfer1::ICudaEngine *engine) {
96   if (op_binding_tensor_.size() != 0) {
97     MS_LOG(ERROR) << "need special op Prepare for " << op_name_;
98     return RET_ERROR;
99   }
100   return RET_OK;
101 }
102 
GetDynamicShapeParams() const103 DynamicShapeParams TensorRTOp::GetDynamicShapeParams() const { return this->dynamic_shape_params_; }
104 
SetInt8DynamicRange(TensorRTContext * ctx)105 int TensorRTOp::SetInt8DynamicRange(TensorRTContext *ctx) {
106   // setting param layer_ forcely
107   if (this->layer_ == nullptr) {
108     MS_LOG(WARNING) << op_name_ << " layer is nullptr.";
109     return RET_OK;
110   }
111   if (in_tensors_.empty() || out_tensors_.empty()) {
112     MS_LOG(ERROR) << "input or output tensor empty.";
113     return RET_ERROR;
114   }
115   return RET_OK;
116 }
117 
SetTransposeDynamicRange()118 int TensorRTOp::SetTransposeDynamicRange() {
119   if (this->transpose_layer_ == nullptr) {
120     MS_LOG(INFO) << op_name_ << " transpose_layer is nullptr.";
121     return RET_OK;
122   }
123   return RET_OK;
124 }
125 
GetSupportInputBool()126 bool TensorRTOp::GetSupportInputBool() { return this->support_input_bool_; }
127 
SetSupportInputBool(bool support_input_bool)128 void TensorRTOp::SetSupportInputBool(bool support_input_bool) { this->support_input_bool_ = support_input_bool; }
129 
PrintTrtInputs(TensorRTContext * ctx)130 void TensorRTOp::PrintTrtInputs(TensorRTContext *ctx) {
131   MS_LOG(DEBUG) << "Op " << op_name_ << " type: " << type_;
132   for (size_t i = 0; i < in_tensors_.size(); i++) {
133     if (in_tensors_[i].IsConst()) {
134       MS_LOG(DEBUG) << "-input " << i << "  " << in_tensors_[i].Shape() << " " << in_tensors_[i].DataType();
135     } else {
136       auto tensor = input(ctx, i);
137       if (tensor.trt_tensor_) {
138         MS_LOG(DEBUG) << "-input " << i << "  " << CudaDimsAsString(tensor.trt_tensor_->getDimensions()) << " "
139                       << in_tensors_[i].DataType();
140       }
141     }
142   }
143 }
144 
PrintTrtOutputs(TensorRTContext * ctx)145 void TensorRTOp::PrintTrtOutputs(TensorRTContext *ctx) {
146   MS_LOG(DEBUG) << "Op " << op_name_ << " type: " << type_;
147   for (size_t i = 0; i < out_tensors_.size(); i++) {
148     auto tensor = output(ctx, i);
149     if (tensor.trt_tensor_) {
150       MS_LOG(DEBUG) << "-output " << i << "  " << CudaDimsAsString(tensor.trt_tensor_->getDimensions()) << " "
151                     << out_tensors_[i].DataType();
152     }
153   }
154 }
155 }  // namespace mindspore::lite
156