• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Operations.h"
18 #include "CpuOperationUtils.h"
19 
20 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
21 
22 namespace android {
23 namespace nn {
24 
25 // If possible we will use this static buffer for the tensor.
26 static constexpr size_t kStaticBufferSize = 1605632;
27 static char static_scratch_buffer[kStaticBufferSize];
28 
29 // executionMutex is used to protect concurrent access of the static_scratch_buffer
30 // and other non-threadsafe resources like gemmlowp::GemmContext.
31 // std::mutex is safe for pthreads on Android.
32 static std::mutex executionMutex;
33 
34 #define ANDROID_NN_CONV_PARAMETERS(Type)                                        \
35     uint32_t height       = getSizeOfDimension(inputShape, 1);                  \
36     uint32_t width        = getSizeOfDimension(inputShape, 2);                  \
37     uint32_t filterHeight = getSizeOfDimension(filterShape, 1);                 \
38     uint32_t filterWidth  = getSizeOfDimension(filterShape, 2);                 \
39     uint32_t outHeight    = getSizeOfDimension(outputShape, 1);                 \
40     uint32_t outWidth     = getSizeOfDimension(outputShape, 2);                 \
41     uint32_t inDepth      = getSizeOfDimension(inputShape, 3);                  \
42                                                                                 \
43     uint32_t paddingHeight = (uint32_t)padding_top;                             \
44     uint32_t paddingWidth = (uint32_t)padding_left;                             \
45                                                                                 \
46     tflite::Dims<4> im2colDim;                                                  \
47     im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0);               \
48     im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1);               \
49     im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2);               \
50     im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth;             \
51                                                                                 \
52     im2colDim.strides[0] = 1;                                                   \
53     for (int i=1; i<4; i++) {                                                   \
54         im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1];   \
55     }                                                                           \
56                                                                                 \
57     Type* im2colData = nullptr;                                                 \
58     uint64_t im2colByteSize = sizeof(Type);                                     \
59     std::unique_ptr<Type[]> im2colGuard;                                        \
60     for (int i=0; i<4; i++) {                                                   \
61         im2colByteSize *= im2colDim.sizes[i];                                   \
62     }                                                                           \
63     /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */   \
64     if (im2colByteSize >= 0x7fffffff)  {                                        \
65         LOG(ERROR) << "Conv size is too large, not enough memory";              \
66         return false;                                                           \
67     }                                                                           \
68     if (im2colByteSize <= kStaticBufferSize) {                                  \
69         im2colData = reinterpret_cast<Type *>(static_scratch_buffer);           \
70     } else {                                                                    \
71         im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)];    \
72         if (im2colData == nullptr) {                                            \
73             LOG(ERROR) << "Conv size is too large, not enough memory";          \
74             return false;                                                       \
75         }                                                                       \
76         im2colGuard.reset(im2colData);                                          \
77     }
78 
convFloat32(const float * inputData,const Shape & inputShape,const float * filterData,const Shape & filterShape,const float * biasData,const Shape & biasShape,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t activation,float * outputData,const Shape & outputShape)79 bool convFloat32(const float* inputData, const Shape& inputShape,
80                  const float* filterData, const Shape& filterShape,
81                  const float* biasData, const Shape& biasShape,
82                  int32_t padding_left, int32_t padding_right,
83                  int32_t padding_top, int32_t padding_bottom,
84                  int32_t stride_width, int32_t stride_height,
85                  int32_t activation,
86                  float* outputData, const Shape& outputShape) {
87 
88     ANDROID_NN_CONV_PARAMETERS(float)
89 
90     float output_activation_min, output_activation_max;
91     CalculateActivationRangeFloat(activation, &output_activation_min,
92                                   &output_activation_max);
93 
94     // Prevent concurrent executions that may access the scratch buffer.
95     std::unique_lock<std::mutex> lock(executionMutex);
96     tflite::optimized_ops::Conv(
97             inputData, convertShapeToDims(inputShape),
98             filterData, convertShapeToDims(filterShape),
99             biasData, convertShapeToDims(biasShape),
100             stride_width, stride_height, paddingWidth, paddingHeight,
101             output_activation_min, output_activation_max,
102             outputData, convertShapeToDims(outputShape),
103             im2colData, im2colDim);
104     return true;
105 }
106 
convQuant8(const uint8_t * inputData,const Shape & inputShape,const uint8_t * filterData,const Shape & filterShape,const int32_t * biasData,const Shape & biasShape,int32_t padding_left,int32_t padding_right,int32_t padding_top,int32_t padding_bottom,int32_t stride_width,int32_t stride_height,int32_t activation,uint8_t * outputData,const Shape & outputShape)107 bool convQuant8(const uint8_t* inputData, const Shape& inputShape,
108                 const uint8_t* filterData, const Shape& filterShape,
109                 const int32_t* biasData, const Shape& biasShape,
110                 int32_t padding_left, int32_t padding_right,
111                 int32_t padding_top, int32_t padding_bottom,
112                 int32_t stride_width, int32_t stride_height,
113                 int32_t activation,
114                 uint8_t* outputData, const Shape& outputShape) {
115 
116     ANDROID_NN_CONV_PARAMETERS(uint8_t)
117 
118     int32_t inputOffset = -inputShape.offset;
119     int32_t filterOffset = -filterShape.offset;
120     int32_t outputOffset = outputShape.offset;
121 
122     float real_multiplier = 0.0;
123     int32_t output_multiplier = 0;
124     int32_t output_shift = 0;
125     int32_t output_activation_min = 0;
126     int32_t output_activation_max = 0;
127 
128     if (!GetQuantizedConvolutionMultipler(inputShape, filterShape, biasShape,
129                                           outputShape, &real_multiplier) ||
130             !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
131                                               &output_shift)){
132         return false;
133     }
134     CalculateActivationRangeUint8(activation, outputShape,
135                                   &output_activation_min,
136                                   &output_activation_max);
137 
138     static gemmlowp::GemmContext gemm_context;
139 
140     // Prevent concurrent executions that may access the scratch buffer and
141     // gemm_context.
142     std::unique_lock<std::mutex> lock(executionMutex);
143     // Alow gemmlowp automatically decide how many threads to use.
144     gemm_context.set_max_num_threads(0);
145     tflite::optimized_ops::Conv(
146             inputData, convertShapeToDims(inputShape), inputOffset,
147             filterData, convertShapeToDims(filterShape), filterOffset,
148             biasData, convertShapeToDims(biasShape),
149             stride_width, stride_height, paddingWidth, paddingHeight,
150             outputOffset, output_multiplier, output_shift,
151             output_activation_min, output_activation_max,
152             outputData, convertShapeToDims(outputShape),
153             im2colData, im2colDim, &gemm_context);
154     return true;
155 }
156 
157 #undef ANDROID_NN_CONV_PARAMETERS
158 }  // namespace nn
159 }  // namespace android
160