• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/delegate/tensorrt/op/activation_tensorrt.h"
18 
19 namespace mindspore::lite {
IsSupport(const schema::Primitive * primitive,const std::vector<mindspore::MSTensor> & in_tensors,const std::vector<mindspore::MSTensor> & out_tensors)20 int ActivationTensorRT::IsSupport(const schema::Primitive *primitive,
21                                   const std::vector<mindspore::MSTensor> &in_tensors,
22                                   const std::vector<mindspore::MSTensor> &out_tensors) {
23   if (!IsShapeKnown()) {
24     MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
25     return RET_ERROR;
26   }
27   if (in_tensors.size() != 1) {
28     MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
29     return RET_ERROR;
30   }
31   if (out_tensors.size() != 1) {
32     MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
33     return RET_ERROR;
34   }
35   auto activation_op = this->op_primitive_->value_as_Activation();
36   if (activation_op == nullptr) {
37     MS_LOG(ERROR) << "op convert failed";
38     return RET_ERROR;
39   }
40   this->action_code_ = ConvertActivationType(activation_op->activation_type()).activation_type;
41   if (this->action_code_ == nvinfer1::ActivationType::kRELU &&
42       activation_op->activation_type() != schema::ActivationType_RELU) {
43     MS_LOG(ERROR) << "Unsupported op action type for TensorRT: " << activation_op->activation_type();
44     return RET_ERROR;
45   }
46   return RET_OK;
47 }
AddInnerOp(nvinfer1::INetworkDefinition * network)48 int ActivationTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
49   if (network == nullptr) {
50     MS_LOG(ERROR) << "network is invalid";
51     return RET_ERROR;
52   }
53   auto activation_op = this->op_primitive_->value_as_Activation();
54   if (activation_op == nullptr) {
55     MS_LOG(ERROR) << "op convert failed";
56     return RET_ERROR;
57   }
58   float alpha = activation_op->alpha();
59 
60   nvinfer1::IActivationLayer *activation_layer = ActivationTensorRT::AddActivation(
61     network, activation_op->activation_type(), alpha, tensorrt_in_tensors_[0].trt_tensor_);
62   if (activation_layer == nullptr) {
63     MS_LOG(ERROR) << "add activation op failed for TensorRT.";
64     return RET_ERROR;
65   }
66 
67   activation_layer->setName(op_name_.c_str());
68   activation_layer->getOutput(0)->setName((op_name_ + "_output").c_str());
69   this->AddInnerOutTensors(ITensorHelper{activation_layer->getOutput(0), tensorrt_in_tensors_[0].format_});
70 
71   return RET_OK;
72 }
AddActivation(nvinfer1::INetworkDefinition * network,schema::ActivationType activation_type,float alpha,nvinfer1::ITensor * trt_in_tensor)73 nvinfer1::IActivationLayer *ActivationTensorRT::AddActivation(nvinfer1::INetworkDefinition *network,
74                                                               schema::ActivationType activation_type, float alpha,
75                                                               nvinfer1::ITensor *trt_in_tensor) {
76   // Just some action_code correct, unfind code is set to default relu. need double check.
77   lite::ActivationParams action_param = ConvertActivationType(activation_type);
78   if (action_param.activation_type == nvinfer1::ActivationType::kRELU &&
79       activation_type != schema::ActivationType_RELU) {
80     MS_LOG(ERROR) << "Unsupported op action type for TensorRT: " << activation_type;
81     return nullptr;
82   }
83   nvinfer1::IActivationLayer *activation_layer = network->addActivation(*trt_in_tensor, action_param.activation_type);
84   if (activation_layer == nullptr) {
85     MS_LOG(ERROR) << "add activation op failed for TensorRT.";
86     return nullptr;
87   }
88 
89   if (action_param.has_alpha) {
90     activation_layer->setAlpha(alpha);
91   }
92 
93   if (action_param.has_beta) {
94     activation_layer->setBeta(action_param.beta);
95   }
96 
97   return activation_layer;
98 }
99 }  // namespace mindspore::lite
100