• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_
17 
18 #include "ruy/profiler/instrumentation.h"  // from @ruy
19 #include "tensorflow/lite/kernels/cpu_backend_context.h"
20 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
21 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/optimized/im2col_utils.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 
27 namespace tflite {
28 namespace optimized_integer_ops {
29 
30 // Fixed-point per-channel-quantization convolution reference kernel.
ConvPerChannel(const ConvParams & params,const int32 * output_multiplier,const int32 * output_shift,const RuntimeShape & input_shape,const int8 * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int8 * output_data,const RuntimeShape & im2col_shape,int8 * im2col_data,CpuBackendContext * cpu_backend_context)31 inline void ConvPerChannel(
32     const ConvParams& params, const int32* output_multiplier,
33     const int32* output_shift, const RuntimeShape& input_shape,
34     const int8* input_data, const RuntimeShape& filter_shape,
35     const int8* filter_data, const RuntimeShape& bias_shape,
36     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
37     const RuntimeShape& im2col_shape, int8* im2col_data,
38     CpuBackendContext* cpu_backend_context) {
39   ruy::profiler::ScopeLabel label("Conv/8bit");
40   const int stride_width = params.stride_width;
41   const int stride_height = params.stride_height;
42   const int dilation_width_factor = params.dilation_width_factor;
43   const int dilation_height_factor = params.dilation_height_factor;
44   const int32 input_offset = params.input_offset;
45   const int32 output_offset = params.output_offset;
46   // Set min and max value of the output.
47   const int32 output_activation_min = params.quantized_activation_min;
48   const int32 output_activation_max = params.quantized_activation_max;
49   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
50   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
51   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
52 
53   const int8* gemm_input_data = nullptr;
54   const RuntimeShape* gemm_input_shape = nullptr;
55   const int filter_width = filter_shape.Dims(2);
56   const int filter_height = filter_shape.Dims(1);
57   const bool need_dilated_im2col =
58       dilation_width_factor != 1 || dilation_height_factor != 1;
59   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
60                            filter_width != 1 || filter_height != 1;
61   const int8 input_zero_point = -input_offset;
62   const uint8 zero_point_byte =
63       *reinterpret_cast<const uint8*>(&input_zero_point);
64   if (need_dilated_im2col) {
65     TFLITE_DCHECK(im2col_data);
66     optimized_ops::DilatedIm2col(params, zero_point_byte, input_shape,
67                                  input_data, filter_shape, output_shape,
68                                  im2col_data);
69     gemm_input_data = im2col_data;
70     gemm_input_shape = &im2col_shape;
71   } else if (need_im2col) {
72     TFLITE_DCHECK(im2col_data);
73     optimized_ops::Im2col(params, filter_height, filter_width, zero_point_byte,
74                           input_shape, input_data, im2col_shape, im2col_data);
75     gemm_input_data = im2col_data;
76     gemm_input_shape = &im2col_shape;
77   } else {
78     TFLITE_DCHECK(!im2col_data);
79     gemm_input_data = input_data;
80     gemm_input_shape = &input_shape;
81   }
82 
83   const int gemm_input_rows = gemm_input_shape->Dims(3);
84   const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
85   const int filter_rows = filter_shape.Dims(0);
86   const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
87   const int output_rows = output_shape.Dims(3);
88   // See b/79927784.
89   // const int output_cols = FlatSizeSkipDim(output_shape, 3);
90   const int output_cols =
91       output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
92   TFLITE_DCHECK_EQ(output_rows, filter_rows);
93   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
94   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
95   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
96 
97   cpu_backend_gemm::MatrixParams<int8> lhs_params;
98   lhs_params.rows = filter_rows;
99   lhs_params.cols = filter_cols;
100   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
101   lhs_params.zero_point = 0;  // filter is symmetric-quantized
102   cpu_backend_gemm::MatrixParams<int8> rhs_params;
103   rhs_params.rows = gemm_input_rows;
104   rhs_params.cols = gemm_input_cols;
105   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
106   rhs_params.zero_point = -input_offset;
107   cpu_backend_gemm::MatrixParams<int8> dst_params;
108   dst_params.rows = output_rows;
109   dst_params.cols = output_cols;
110   dst_params.order = cpu_backend_gemm::Order::kColMajor;
111   dst_params.zero_point = output_offset;
112   cpu_backend_gemm::GemmParams<
113       int32, int8,
114       cpu_backend_gemm::QuantizationFlavor::kIntegerWithPerRowMultiplier>
115       gemm_params;
116   gemm_params.bias = bias_data;
117   gemm_params.clamp_min = output_activation_min;
118   gemm_params.clamp_max = output_activation_max;
119   gemm_params.multiplier_fixedpoint_perchannel = output_multiplier;
120   gemm_params.multiplier_exponent_perchannel = output_shift;
121   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
122                          dst_params, output_data, gemm_params,
123                          cpu_backend_context);
124 }
125 
126 }  // namespace optimized_integer_ops
127 }  // namespace tflite
128 
129 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_CONV_H_
130