1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA && GOOGLE_TENSORRT 17 18 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" 19 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h" 20 #include "tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h" 21 22 namespace tensorflow { 23 namespace tensorrt { 24 namespace convert { 25 26 class ConvertLogSoftmax : public OpConverterBase<ConvertLogSoftmax> { 27 public: ConvertLogSoftmax(OpConverterParams * params)28 explicit ConvertLogSoftmax(OpConverterParams *params) 29 : OpConverterBase<ConvertLogSoftmax>(params) {} 30 AllowedDataTypes()31 static constexpr std::array<DataType, 3> AllowedDataTypes() { 32 return {DataType::DT_FLOAT, DataType::DT_HALF}; 33 } 34 InputSpec()35 static constexpr std::array<InputArgSpec, 1> InputSpec() { 36 return std::array<InputArgSpec, 1>{ 37 InputArgSpec::Create("logits", TrtInputArg::kTensor)}; 38 } 39 Validate()40 Status Validate() { 41 const auto ¶ms = *this->params_; 42 const auto &inputs = params.inputs; 43 44 ITensorProxyPtr logits_tensor = inputs.at(0).tensor(); 45 46 const int num_trt_dims = logits_tensor->getDimensions().nbDims; 47 if (!num_trt_dims && params.use_implicit_batch) { 48 return errors::InvalidArgument( 49 "TensorRT LogSoftmax cannot apply on the batch dimension"); 50 } 51 52 return Status::OK(); 53 } 54 Convert()55 Status Convert() { 56 const auto ¶ms = *this->params_; 57 const auto &inputs = params.inputs; 58 const auto &node_def = params.node_def; 59 60 // Perform LogSoftmax operation: 61 // `logsoftmax = logits - log(reduce_sum(exp(logits), axis))` 62 63 // Get the logits tensor. 64 ITensorProxyPtr logits_tensor = inputs.at(0).tensor(); 65 const int num_trt_dims = logits_tensor->getDimensions().nbDims; 66 67 // Exponent of logits. 68 nvinfer1::IUnaryLayer *exp = params.converter->network()->addUnary( 69 *logits_tensor->trt_tensor(), nvinfer1::UnaryOperation::kEXP); 70 TFTRT_RETURN_ERROR_IF_NULLPTR(exp, node_def.name()); 71 params.converter->SetLayerName(exp, node_def, "exp"); 72 73 // Reduce-sum operation across the final dimension. 74 nvinfer1::IReduceLayer *reduced_sum = 75 params.converter->network()->addReduce( 76 *exp->getOutput(0), nvinfer1::ReduceOperation::kSUM, 77 (1 << (num_trt_dims - 1)), /*Reduce across final dimension*/ 78 true /*Keep reduced dims*/); 79 params.converter->SetLayerName(reduced_sum, node_def, "reduced_sum"); 80 81 // Logarithm of reduced_sum. 82 nvinfer1::IUnaryLayer *log_reduced_sum = 83 params.converter->network()->addUnary(*reduced_sum->getOutput(0), 84 nvinfer1::UnaryOperation::kLOG); 85 TFTRT_RETURN_ERROR_IF_NULLPTR(log_reduced_sum, node_def.name()); 86 params.converter->SetLayerName(log_reduced_sum, node_def, 87 "log_reduced_sum"); 88 89 // Finally, get the output by subtracting log_reduced_sum from logits. 90 nvinfer1::IElementWiseLayer *sub = 91 params.converter->network()->addElementWise( 92 *logits_tensor->trt_tensor(), *log_reduced_sum->getOutput(0), 93 nvinfer1::ElementWiseOperation::kSUB); 94 TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name()); 95 params.converter->SetLayerName(sub, node_def, "sub"); 96 97 params.outputs->push_back(TRT_TensorOrWeights(sub->getOutput(0))); 98 return Status::OK(); 99 } 100 }; 101 102 REGISTER_DEFAULT_TRT_OP_CONVERTER(MakeConverterFunction<ConvertLogSoftmax>(), 103 "LogSoftmax"); 104 105 } // namespace convert 106 } // namespace tensorrt 107 } // namespace tensorflow 108 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 109