1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
19
20 #include "fixedpoint.h"
21 #include "gemmlowp.h"
22 #include "../common.h"
23 #include "../types.h"
24
25 namespace android {
26 namespace nn {
27 namespace reference_ops {
28
29 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)30 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
31 int32 input_offset, const uint8* filter_data,
32 const Dims<4>& filter_dims, int32 filter_offset,
33 const int32* bias_data, const Dims<4>& bias_dims,
34 int stride_width, int stride_height,
35 int pad_width, int pad_height, int depth_multiplier,
36 int32 output_offset, int32 output_multiplier,
37 int output_shift, int32 output_activation_min,
38 int32 output_activation_max, uint8* output_data,
39 const Dims<4>& output_dims) {
40 static_assert(Ac == FusedActivationFunctionType::kNone ||
41 Ac == FusedActivationFunctionType::kRelu ||
42 Ac == FusedActivationFunctionType::kRelu6 ||
43 Ac == FusedActivationFunctionType::kRelu1,
44 "");
45 DCHECK_LE(output_activation_min, output_activation_max);
46 if (Ac == FusedActivationFunctionType::kNone) {
47 DCHECK_EQ(output_activation_min, 0);
48 DCHECK_EQ(output_activation_max, 255);
49 }
50 const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
51 const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
52 const int input_height = ArraySize(input_dims, 2);
53 const int input_width = ArraySize(input_dims, 1);
54 const int input_depth = ArraySize(input_dims, 0);
55 const int filter_height = ArraySize(filter_dims, 2);
56 const int filter_width = ArraySize(filter_dims, 1);
57 const int output_height = ArraySize(output_dims, 2);
58 const int output_width = ArraySize(output_dims, 1);
59 DCHECK(output_depth == input_depth * depth_multiplier);
60
61 for (int b = 0; b < batches; ++b) {
62 for (int out_y = 0; out_y < output_height; ++out_y) {
63 for (int out_x = 0; out_x < output_width; ++out_x) {
64 for (int ic = 0; ic < input_depth; ++ic) {
65 for (int m = 0; m < depth_multiplier; m++) {
66 const int oc = m + ic * depth_multiplier;
67 const int in_x_origin = (out_x * stride_width) - pad_width;
68 const int in_y_origin = (out_y * stride_height) - pad_height;
69 int32 acc = 0;
70 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
71 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
72 const int in_x = in_x_origin + filter_x;
73 const int in_y = in_y_origin + filter_y;
74 // If the location is outside the bounds of the input image,
75 // use zero as a default value.
76 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
77 (in_y < input_height)) {
78 int32 input_val =
79 input_data[Offset(input_dims, ic, in_x, in_y, b)];
80 int32 filter_val = filter_data[Offset(filter_dims, oc,
81 filter_x, filter_y, 0)];
82 acc +=
83 (filter_val + filter_offset) * (input_val + input_offset);
84 }
85 }
86 }
87 if (bias_data) {
88 acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
89 }
90 acc = MultiplyByQuantizedMultiplierSmallerThanOne(
91 acc, output_multiplier, output_shift);
92 acc += output_offset;
93 acc = std::max(acc, output_activation_min);
94 acc = std::min(acc, output_activation_max);
95 output_data[Offset(output_dims, oc, out_x, out_y, b)] =
96 static_cast<uint8>(acc);
97 }
98 }
99 }
100 }
101 }
102 }
103
104 } // end namespace reference_ops
105 } // namespace nn
106 } // namespace android
107
108 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
109