• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
17 
18 #include <cstdint>
19 
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/micro/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace micro {
29 namespace activations {
30 namespace {
31 
CalculatePreluParams(const TfLiteTensor * input,const TfLiteTensor * alpha,TfLiteTensor * output,PreluParams * params)32 TfLiteStatus CalculatePreluParams(const TfLiteTensor* input,
33                                   const TfLiteTensor* alpha,
34                                   TfLiteTensor* output, PreluParams* params) {
35   if (output->type == kTfLiteInt8 || output->type == kTfLiteUInt8 ||
36       output->type == kTfLiteInt16) {
37     double real_multiplier_1 = static_cast<double>(input->params.scale) /
38                                static_cast<double>(output->params.scale);
39     double real_multiplier_2 = static_cast<double>(input->params.scale) *
40                                static_cast<double>(alpha->params.scale) /
41                                static_cast<double>(output->params.scale);
42     QuantizeMultiplier(real_multiplier_1, &params->output_multiplier_1,
43                        &params->output_shift_1);
44     QuantizeMultiplier(real_multiplier_2, &params->output_multiplier_2,
45                        &params->output_shift_2);
46 
47     params->input_offset = -input->params.zero_point;
48     params->alpha_offset = -alpha->params.zero_point;
49     params->output_offset = output->params.zero_point;
50   }
51 
52   return kTfLiteOk;
53 }
54 
55 }  // namespace
56 
BroadcastPrelu4DSlowFloat(const RuntimeShape & unextended_input1_shape,const float * input1_data,const RuntimeShape & unextended_input2_shape,const float * input2_data,const RuntimeShape & unextended_output_shape,float * output_data)57 inline void BroadcastPrelu4DSlowFloat(
58     const RuntimeShape& unextended_input1_shape, const float* input1_data,
59     const RuntimeShape& unextended_input2_shape, const float* input2_data,
60     const RuntimeShape& unextended_output_shape, float* output_data) {
61   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
62   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
63   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
64   const RuntimeShape output_shape =
65       RuntimeShape::ExtendedShape(4, unextended_output_shape);
66 
67   NdArrayDesc<4> desc1;
68   NdArrayDesc<4> desc2;
69   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
70                                       unextended_input2_shape, &desc1, &desc2);
71 
72   for (int b = 0; b < output_shape.Dims(0); ++b) {
73     for (int y = 0; y < output_shape.Dims(1); ++y) {
74       for (int x = 0; x < output_shape.Dims(2); ++x) {
75         for (int c = 0; c < output_shape.Dims(3); ++c) {
76           auto out_idx = Offset(output_shape, b, y, x, c);
77           auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
78           auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
79           auto in1_val = input1_data[in1_idx];
80           auto in2_val = input2_data[in2_idx];
81           output_data[out_idx] = in1_val >= 0.0f ? in1_val : in1_val * in2_val;
82         }
83       }
84     }
85   }
86 }
87 
PreluInit(TfLiteContext * context,const char * buffer,size_t length)88 void* PreluInit(TfLiteContext* context, const char* buffer, size_t length) {
89   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
90   return context->AllocatePersistentBuffer(context, sizeof(PreluParams));
91 }
92 
PreluPrepare(TfLiteContext * context,TfLiteNode * node)93 TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
94   TFLITE_DCHECK(node->user_data != nullptr);
95   PreluParams* params = static_cast<PreluParams*>(node->user_data);
96 
97   const TfLiteTensor* input = GetInput(context, node, 0);
98   TF_LITE_ENSURE(context, input != nullptr);
99   const TfLiteTensor* alpha = GetInput(context, node, 1);
100   TF_LITE_ENSURE(context, alpha != nullptr);
101   TfLiteTensor* output = GetOutput(context, node, 0);
102   TF_LITE_ENSURE(context, output != nullptr);
103 
104   return CalculatePreluParams(input, alpha, output, params);
105 }
106 
PreluEval(TfLiteContext * context,TfLiteNode * node)107 TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
108   TFLITE_DCHECK(node->user_data != nullptr);
109   const PreluParams& params =
110       *(static_cast<const PreluParams*>(node->user_data));
111 
112   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
113   const TfLiteEvalTensor* alpha = tflite::micro::GetEvalInput(context, node, 1);
114   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
115 
116   switch (input->type) {
117     case kTfLiteFloat32: {
118       BroadcastPrelu4DSlowFloat(tflite::micro::GetTensorShape(input),
119                                 tflite::micro::GetTensorData<float>(input),
120                                 tflite::micro::GetTensorShape(alpha),
121                                 tflite::micro::GetTensorData<float>(alpha),
122                                 tflite::micro::GetTensorShape(output),
123                                 tflite::micro::GetTensorData<float>(output));
124       return kTfLiteOk;
125     } break;
126     case kTfLiteUInt8: {
127       reference_ops::BroadcastPrelu4DSlow(
128           params, tflite::micro::GetTensorShape(input),
129           tflite::micro::GetTensorData<uint8_t>(input),
130           tflite::micro::GetTensorShape(alpha),
131           tflite::micro::GetTensorData<uint8_t>(alpha),
132           tflite::micro::GetTensorShape(output),
133           tflite::micro::GetTensorData<uint8_t>(output));
134       return kTfLiteOk;
135     } break;
136     case kTfLiteInt8: {
137       reference_ops::BroadcastPrelu4DSlow(
138           params, tflite::micro::GetTensorShape(input),
139           tflite::micro::GetTensorData<int8_t>(input),
140           tflite::micro::GetTensorShape(alpha),
141           tflite::micro::GetTensorData<int8_t>(alpha),
142           tflite::micro::GetTensorShape(output),
143           tflite::micro::GetTensorData<int8_t>(output));
144       return kTfLiteOk;
145     } break;
146     default:
147       TF_LITE_KERNEL_LOG(
148           context, "Only float32 and uint8_t are supported currently, got %d.",
149           TfLiteTypeGetName(input->type));
150       return kTfLiteError;
151   }
152 }
153 
154 }  // namespace activations
155 
Register_PRELU()156 TfLiteRegistration Register_PRELU() {
157   return {/*init=*/activations::PreluInit,
158           /*free=*/nullptr,
159           /*prepare=*/activations::PreluPrepare,
160           /*invoke=*/activations::PreluEval,
161           /*profiling_string=*/nullptr,
162           /*builtin_code=*/0,
163           /*custom_name=*/nullptr,
164           /*version=*/0};
165 }
166 
167 }  // namespace micro
168 }  // namespace ops
169 }  // namespace tflite
170