• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "DelegateUtils.hpp"
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
ValidateSoftmaxOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::SoftmaxDescriptor & descriptor)18 TfLiteStatus ValidateSoftmaxOperator(DelegateData& delegateData,
19                                      TfLiteContext* tfLiteContext,
20                                      const armnn::TensorInfo& inputInfo,
21                                      const armnn::TensorInfo& outputTensorInfo,
22                                      const armnn::SoftmaxDescriptor& descriptor)
23 {
24     bool isSupported = false;
25     FORWARD_LAYER_SUPPORT_FUNC(__func__,
26                                tfLiteContext,
27                                IsSoftmaxSupported,
28                                delegateData.m_Backends,
29                                isSupported,
30                                inputInfo,
31                                outputTensorInfo,
32                                descriptor);
33     return isSupported ? kTfLiteOk : kTfLiteError;
34 }
35 
36 
ValidateLogSoftmaxOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputTensorInfo,const armnn::LogSoftmaxDescriptor & descriptor)37 TfLiteStatus ValidateLogSoftmaxOperator(DelegateData& delegateData,
38                                         TfLiteContext* tfLiteContext,
39                                         const armnn::TensorInfo& inputInfo,
40                                         const armnn::TensorInfo& outputTensorInfo,
41                                         const armnn::LogSoftmaxDescriptor& descriptor)
42 {
43     bool isSupported = false;
44     FORWARD_LAYER_SUPPORT_FUNC(__func__,
45                                tfLiteContext,
46                                IsLogSoftmaxSupported,
47                                delegateData.m_Backends,
48                                isSupported,
49                                inputInfo,
50                                outputTensorInfo,
51                                descriptor);
52     return isSupported ? kTfLiteOk : kTfLiteError;
53 }
54 
VisitSoftmaxOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t softmaxOperatorCode)55 TfLiteStatus VisitSoftmaxOperator(DelegateData& delegateData,
56                                   TfLiteContext* tfLiteContext,
57                                   TfLiteNode* tfLiteNode,
58                                   int nodeIndex,
59                                   int32_t softmaxOperatorCode)
60 {
61     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
62     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
63 
64     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
65     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
66     if (IsDynamicTensor(tfLiteInputTensor))
67     {
68         TF_LITE_MAYBE_KERNEL_LOG(
69             tfLiteContext,
70             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in node #%d: ",
71             nodeIndex);
72         return kTfLiteError;
73     }
74     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
75     if (IsDynamicTensor(tfLiteOutputTensor))
76     {
77         TF_LITE_MAYBE_KERNEL_LOG(
78             tfLiteContext,
79             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in node #%d: ",
80             nodeIndex);
81         return kTfLiteError;
82     }
83 
84     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
85     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
86 
87 
88     if (!delegateData.m_Network)
89     {
90         switch(softmaxOperatorCode)
91         {
92             case kTfLiteBuiltinSoftmax:
93             {
94                 armnn::SoftmaxDescriptor descriptor;
95                 auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
96                 descriptor.m_Beta = params->beta;
97                 return ValidateSoftmaxOperator(delegateData,
98                                                tfLiteContext,
99                                                inputTensorInfo,
100                                                outputTensorInfo,
101                                                descriptor);
102             }
103             case kTfLiteBuiltinLogSoftmax:
104             {
105                 armnn::LogSoftmaxDescriptor descriptor;
106                 return ValidateLogSoftmaxOperator(delegateData,
107                                                   tfLiteContext,
108                                                   inputTensorInfo,
109                                                   outputTensorInfo,
110                                                   descriptor);
111             }
112             default:
113                 return kTfLiteError;
114         }
115     }
116 
117     armnn::IConnectableLayer* softmaxLayer = nullptr;
118 
119     switch(softmaxOperatorCode)
120     {
121         case kTfLiteBuiltinSoftmax:
122         {
123             armnn::SoftmaxDescriptor descriptor;
124             auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(tfLiteNode->builtin_data);
125             descriptor.m_Beta = params->beta;
126             softmaxLayer = delegateData.m_Network->AddSoftmaxLayer(descriptor);
127             break;
128         }
129         case kTfLiteBuiltinLogSoftmax:
130         {
131             armnn::LogSoftmaxDescriptor descriptor;
132             softmaxLayer = delegateData.m_Network->AddLogSoftmaxLayer(descriptor);
133             break;
134         }
135         default:
136             return kTfLiteError;
137     }
138     ARMNN_ASSERT(softmaxLayer != nullptr);
139 
140     armnn::IOutputSlot& outputSlot = softmaxLayer->GetOutputSlot(0);
141     outputSlot.SetTensorInfo(outputTensorInfo);
142 
143     // Connect
144     return Connect(softmaxLayer, tfLiteNode, delegateData);
145 }
146 
147 } // namespace armnnDelegate
148