• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
18 #define ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
19 
20 #include "fixedpoint.h"
21 #include "gemmlowp.h"
22 #include "../common.h"
23 #include "../types.h"
24 
25 namespace android {
26 namespace nn {
27 namespace reference_ops {
28 
29 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)30 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
31                    int32 input_offset, const uint8* filter_data,
32                    const Dims<4>& filter_dims, int32 filter_offset,
33                    const int32* bias_data, const Dims<4>& bias_dims,
34                    int stride_width, int stride_height,
35                    int pad_width, int pad_height, int depth_multiplier,
36                    int32 output_offset, int32 output_multiplier,
37                    int output_shift, int32 output_activation_min,
38                    int32 output_activation_max, uint8* output_data,
39                    const Dims<4>& output_dims) {
40   static_assert(Ac == FusedActivationFunctionType::kNone ||
41                     Ac == FusedActivationFunctionType::kRelu ||
42                     Ac == FusedActivationFunctionType::kRelu6 ||
43                     Ac == FusedActivationFunctionType::kRelu1,
44                 "");
45   DCHECK_LE(output_activation_min, output_activation_max);
46   if (Ac == FusedActivationFunctionType::kNone) {
47     DCHECK_EQ(output_activation_min, 0);
48     DCHECK_EQ(output_activation_max, 255);
49   }
50   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
51   const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
52   const int input_height = ArraySize(input_dims, 2);
53   const int input_width = ArraySize(input_dims, 1);
54   const int input_depth = ArraySize(input_dims, 0);
55   const int filter_height = ArraySize(filter_dims, 2);
56   const int filter_width = ArraySize(filter_dims, 1);
57   const int output_height = ArraySize(output_dims, 2);
58   const int output_width = ArraySize(output_dims, 1);
59   DCHECK(output_depth == input_depth * depth_multiplier);
60 
61   for (int b = 0; b < batches; ++b) {
62     for (int out_y = 0; out_y < output_height; ++out_y) {
63       for (int out_x = 0; out_x < output_width; ++out_x) {
64         for (int ic = 0; ic < input_depth; ++ic) {
65           for (int m = 0; m < depth_multiplier; m++) {
66             const int oc = m + ic * depth_multiplier;
67             const int in_x_origin = (out_x * stride_width) - pad_width;
68             const int in_y_origin = (out_y * stride_height) - pad_height;
69             int32 acc = 0;
70             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
71               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
72                 const int in_x = in_x_origin + filter_x;
73                 const int in_y = in_y_origin + filter_y;
74                 // If the location is outside the bounds of the input image,
75                 // use zero as a default value.
76                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
77                     (in_y < input_height)) {
78                   int32 input_val =
79                       input_data[Offset(input_dims, ic, in_x, in_y, b)];
80                   int32 filter_val = filter_data[Offset(filter_dims, oc,
81                                                         filter_x, filter_y, 0)];
82                   acc +=
83                       (filter_val + filter_offset) * (input_val + input_offset);
84                 }
85               }
86             }
87             if (bias_data) {
88               acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
89             }
90             acc = MultiplyByQuantizedMultiplierSmallerThanOne(
91                 acc, output_multiplier, output_shift);
92             acc += output_offset;
93             acc = std::max(acc, output_activation_min);
94             acc = std::min(acc, output_activation_max);
95             output_data[Offset(output_dims, oc, out_x, out_y, b)] =
96                 static_cast<uint8>(acc);
97           }
98         }
99       }
100     }
101   }
102 }
103 
104 }  // end namespace reference_ops
105 }  // namespace nn
106 }  // namespace android
107 
108 #endif  // ANDROID_ML_NN_COMMON_OPERATIONS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
109