• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/delegate/tensorrt/op/softmax_tensorrt.h"
18 
19 namespace mindspore::lite {
IsSupport(const schema::Primitive * primitive,const std::vector<mindspore::MSTensor> & in_tensors,const std::vector<mindspore::MSTensor> & out_tensors)20 int SoftMaxTensorRT::IsSupport(const schema::Primitive *primitive, const std::vector<mindspore::MSTensor> &in_tensors,
21                                const std::vector<mindspore::MSTensor> &out_tensors) {
22   if (!IsShapeKnown()) {
23     MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_;
24     return RET_ERROR;
25   }
26   softmax_op_ = primitive->value_as_Softmax();
27   if (softmax_op_ == nullptr) {
28     MS_LOG(ERROR) << "convert failed";
29     return RET_ERROR;
30   }
31 
32   if (in_tensors.size() != 1) {
33     MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
34     return RET_ERROR;
35   }
36   if (out_tensors.size() != 1) {
37     MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size();
38     return RET_ERROR;
39   }
40   return RET_OK;
41 }
AddInnerOp(nvinfer1::INetworkDefinition * network)42 int SoftMaxTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
43   if (network == nullptr) {
44     MS_LOG(ERROR) << "network is invalid";
45     return RET_ERROR;
46   }
47   nvinfer1::ISoftMaxLayer *softmax_layer_ = AddSoftMaxOp(network);
48   if (softmax_layer_ == nullptr) {
49     MS_LOG(ERROR) << "add softmax op failed for TensorRT.";
50     return RET_ERROR;
51   }
52   softmax_layer_->setName((op_name_ + "_softmax").c_str());
53 
54   nvinfer1::ITensor *out_tensor = softmax_layer_->getOutput(0);
55   if (out_tensor == nullptr) {
56     MS_LOG(ERROR) << "softmax output tensor create failed for TensorRT.";
57     return RET_ERROR;
58   }
59   out_tensor->setName((op_name_ + "_output").c_str());
60   this->AddInnerOutTensors(ITensorHelper{out_tensor, tensorrt_in_tensors_[0].format_});
61   return RET_OK;
62 }
63 
AddSoftMaxOp(nvinfer1::INetworkDefinition * network)64 nvinfer1::ISoftMaxLayer *SoftMaxTensorRT::AddSoftMaxOp(nvinfer1::INetworkDefinition *network) {
65   nvinfer1::ISoftMaxLayer *current_layer_ = network->addSoftMax(*tensorrt_in_tensors_[0].trt_tensor_);
66   if (current_layer_ == nullptr) {
67     MS_LOG(ERROR) << "add softmax op failed for TensorRT.";
68     return nullptr;
69   }
70   auto axis = softmax_op_->axis();
71   auto axis_val = std::vector<int64_t>(axis->begin(), axis->end());
72 
73   if (axis_val.size() != 1) {
74     MS_LOG(WARNING) << "axis needs check";
75   }
76 
77   if (axis_val[0] >= this->tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims) {
78     MS_LOG(ERROR) << "axis is larger than input tensor dims.";
79     return nullptr;
80   }
81   int64_t axis_format_value = axis_val[0];
82   if (tensorrt_in_tensors_[0].trt_tensor_->getDimensions().nbDims == DIMENSION_4D &&
83       tensorrt_in_tensors_[0].format_ == Format::NCHW) {
84     // transpose axis to NCHW
85     axis_format_value = ConvertAxisFromNHWC2NCHW(axis_val[0]);
86   }
87   uint32_t axis_bit = 1 << axis_format_value;
88   MS_LOG(DEBUG) << op_name_ << " set axis to " << axis_bit;
89   current_layer_->setAxes(axis_bit);
90   return current_layer_;
91 }
92 }  // namespace mindspore::lite
93