• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
17 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/lite/kernels/internal/types.h"
22 
23 namespace tflite {
24 
25 namespace reference_ops {
26 
27 // TFLite Pad supports activation tensors with up to 4 dimensions.
PadKernelMaxDimensionCount()28 constexpr int PadKernelMaxDimensionCount() { return 4; }
29 
30 // There are two versions of pad: Pad and PadV2.  In PadV2 there is a second
31 // scalar input that provides the padding value.  Therefore pad_value_ptr can be
32 // equivalent to a simple input1_data.  For Pad, it should point to a zero
33 // value.
34 //
35 // Note that two typenames are required, so that T=P=int32 is considered a
36 // specialization distinct from P=int32.
37 template <typename T, typename P>
PadImpl(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)38 inline void PadImpl(const tflite::PadParams& op_params,
39                     const RuntimeShape& input_shape, const T* input_data,
40                     const P* pad_value_ptr, const RuntimeShape& output_shape,
41                     T* output_data) {
42   const RuntimeShape ext_input_shape =
43       RuntimeShape::ExtendedShape(PadKernelMaxDimensionCount(), input_shape);
44   const RuntimeShape ext_output_shape =
45       RuntimeShape::ExtendedShape(PadKernelMaxDimensionCount(), output_shape);
46   TFLITE_DCHECK_LE(op_params.left_padding_count, PadKernelMaxDimensionCount());
47   TFLITE_DCHECK_LE(op_params.right_padding_count, PadKernelMaxDimensionCount());
48 
49   // Runtime calls are currently fixed at 4 dimensions. Copy inputs so we can
50   // pad them to 4 dims (yes, we are "padding the padding").
51   int left_padding_copy[PadKernelMaxDimensionCount()];
52   for (int i = 0; i < PadKernelMaxDimensionCount(); i++) {
53     left_padding_copy[i] = 0;
54   }
55   for (int i = 0; i < op_params.left_padding_count; ++i) {
56     left_padding_copy[i + PadKernelMaxDimensionCount() -
57                       op_params.left_padding_count] = op_params.left_padding[i];
58   }
59   int right_padding_copy[PadKernelMaxDimensionCount()];
60   for (int i = 0; i < PadKernelMaxDimensionCount(); i++) {
61     right_padding_copy[i] = 0;
62   }
63   for (int i = 0; i < op_params.right_padding_count; ++i) {
64     right_padding_copy[i + PadKernelMaxDimensionCount() -
65                        op_params.right_padding_count] =
66         op_params.right_padding[i];
67   }
68 
69   const int output_batch = ext_output_shape.Dims(0);
70   const int output_height = ext_output_shape.Dims(1);
71   const int output_width = ext_output_shape.Dims(2);
72   const int output_depth = ext_output_shape.Dims(3);
73 
74   const int left_b_padding = left_padding_copy[0];
75   const int left_h_padding = left_padding_copy[1];
76   const int left_w_padding = left_padding_copy[2];
77   const int left_d_padding = left_padding_copy[3];
78 
79   const int right_b_padding = right_padding_copy[0];
80   const int right_h_padding = right_padding_copy[1];
81   const int right_w_padding = right_padding_copy[2];
82   const int right_d_padding = right_padding_copy[3];
83 
84   const T pad_value = *pad_value_ptr;
85 
86   const T* in_ptr = input_data;
87   T* out_ptr = output_data;
88   for (int out_b = 0; out_b < output_batch; ++out_b) {
89     for (int out_h = 0; out_h < output_height; ++out_h) {
90       for (int out_w = 0; out_w < output_width; ++out_w) {
91         for (int out_d = 0; out_d < output_depth; ++out_d) {
92           if (out_b < left_b_padding ||
93               out_b >= output_batch - right_b_padding ||
94               out_h < left_h_padding ||
95               out_h >= output_height - right_h_padding ||
96               out_w < left_w_padding ||
97               out_w >= output_width - right_w_padding ||
98               out_d < left_d_padding ||
99               out_d >= output_depth - right_d_padding) {
100             *out_ptr++ = pad_value;
101           } else {
102             *out_ptr++ = *in_ptr++;
103           }
104         }
105       }
106     }
107   }
108 }
109 
110 template <typename T, typename P>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)111 inline void Pad(const tflite::PadParams& op_params,
112                 const RuntimeShape& input_shape, const T* input_data,
113                 const P* pad_value_ptr, const RuntimeShape& output_shape,
114                 T* output_data) {
115   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
116           output_data);
117 }
118 
119 // The second (pad-value) input can be int32 when, say, the first is uint8.
120 template <typename T>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)121 inline void Pad(const tflite::PadParams& op_params,
122                 const RuntimeShape& input_shape, const T* input_data,
123                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
124                 T* output_data) {
125   const T converted_pad_value = static_cast<T>(*pad_value_ptr);
126   PadImpl(op_params, input_shape, input_data, &converted_pad_value,
127           output_shape, output_data);
128 }
129 
130 // This version avoids conflicting template matching.
131 template <>
Pad(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int32 * input_data,const int32 * pad_value_ptr,const RuntimeShape & output_shape,int32 * output_data)132 inline void Pad(const tflite::PadParams& op_params,
133                 const RuntimeShape& input_shape, const int32* input_data,
134                 const int32* pad_value_ptr, const RuntimeShape& output_shape,
135                 int32* output_data) {
136   PadImpl(op_params, input_shape, input_data, pad_value_ptr, output_shape,
137           output_data);
138 }
139 
140 // One could make all PadImageStyle calls simply delegate the work to the
141 // ordinary Pad.  However, it is better that the reference code asserts false in
142 // similar cases.
143 template <typename T, typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const T * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,T * output_data)144 inline void PadImageStyle(const tflite::PadParams& op_params,
145                           const RuntimeShape& input_shape, const T* input_data,
146                           const P* pad_value_ptr,
147                           const RuntimeShape& output_shape, T* output_data) {
148   TFLITE_ASSERT_FALSE;
149 }
150 
151 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const uint8 * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,uint8 * output_data)152 inline void PadImageStyle(const tflite::PadParams& op_params,
153                           const RuntimeShape& input_shape,
154                           const uint8* input_data, const P* pad_value_ptr,
155                           const RuntimeShape& output_shape,
156                           uint8* output_data) {
157   Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
158       output_data);
159 }
160 
161 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const int8_t * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,int8_t * output_data)162 inline void PadImageStyle(const tflite::PadParams& op_params,
163                           const RuntimeShape& input_shape,
164                           const int8_t* input_data, const P* pad_value_ptr,
165                           const RuntimeShape& output_shape,
166                           int8_t* output_data) {
167   Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
168       output_data);
169 }
170 
171 template <typename P>
PadImageStyle(const tflite::PadParams & op_params,const RuntimeShape & input_shape,const float * input_data,const P * pad_value_ptr,const RuntimeShape & output_shape,float * output_data)172 inline void PadImageStyle(const tflite::PadParams& op_params,
173                           const RuntimeShape& input_shape,
174                           const float* input_data, const P* pad_value_ptr,
175                           const RuntimeShape& output_shape,
176                           float* output_data) {
177   Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape,
178       output_data);
179 }
180 
181 }  // namespace reference_ops
182 }  // namespace tflite
183 
184 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PAD_H_
185