1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_MULTITHREAD_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_MULTITHREAD_H_
17
18 #include "tensorflow/lite/kernels/cpu_backend_context.h"
19 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
20 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
21 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_float.h"
22 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
23
24 namespace tflite {
25 namespace optimized_ops {
26
27 // TODO(luwa): add multithread to per-channel depthwise_conv
28 // DepthwiseConv can run with multi threads on the dim specified by thread_dim.
29 // Each thread processes output elements on dim, thread_dim, in the range of
30 // [thread_start, thread_end).
31 // For example, assume thread_start = 2, thread_end = 6, and thread_dim = 1, it
32 // means that it will calculate DepthwiseConv for output_data[:, 2:5, :, :].
33 template <typename T, typename TS>
34 struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
DepthwiseConvWorkerTaskDepthwiseConvWorkerTask35 DepthwiseConvWorkerTask(const DepthwiseParams& params,
36 const RuntimeShape& input_shape, const T* input_data,
37 const RuntimeShape& filter_shape,
38 const T* filter_data, const RuntimeShape& bias_shape,
39 const TS* bias_data, const RuntimeShape& output_shape,
40 T* output_data, const CpuFlags& cpu_flags,
41 int thread_start, int thread_end, int thread_dim)
42 : params_(params),
43 input_shape_(input_shape),
44 input_data_(input_data),
45 filter_shape_(filter_shape),
46 filter_data_(filter_data),
47 bias_shape_(bias_shape),
48 bias_data_(bias_data),
49 output_shape_(output_shape),
50 output_data_(output_data),
51 cpu_flags_(cpu_flags),
52 thread_start_(thread_start),
53 thread_end_(thread_end),
54 thread_dim_(thread_dim) {}
55
RunDepthwiseConvWorkerTask56 void Run() override {
57 DepthwiseConvImpl(params_, input_shape_, input_data_, filter_shape_,
58 filter_data_, bias_shape_, bias_data_, output_shape_,
59 output_data_, cpu_flags_, thread_start_, thread_end_,
60 thread_dim_);
61 }
62
63 private:
64 const DepthwiseParams& params_;
65 const RuntimeShape& input_shape_;
66 const T* input_data_;
67 const RuntimeShape& filter_shape_;
68 const T* filter_data_;
69 const RuntimeShape& bias_shape_;
70 const TS* bias_data_;
71 const RuntimeShape& output_shape_;
72 T* output_data_;
73 const CpuFlags& cpu_flags_;
74 int thread_start_;
75 int thread_end_;
76 int thread_dim_;
77 };
78
HowManyConvThreads(const RuntimeShape & output_shape,const RuntimeShape & filter_shape)79 inline int HowManyConvThreads(const RuntimeShape& output_shape,
80 const RuntimeShape& filter_shape) {
81 // How many scalar multiplications are needed to make it worth using one
82 // more thread
83 static constexpr int kMinMulPerThread = 1 << 13; // 8k
84 const int filter_height = filter_shape.Dims(1);
85 const int filter_width = filter_shape.Dims(2);
86 const int num_muls = output_shape.FlatSize() * filter_height * filter_width;
87 // Try to avoid real runtime divisions if possible by dividing by a
88 // compile-time constant.
89 int thread_count = std::max(1, num_muls / kMinMulPerThread);
90 return thread_count;
91 }
92
MultithreadAlongBatches(int thread_count,int batches)93 inline bool MultithreadAlongBatches(int thread_count, int batches) {
94 TFLITE_DCHECK_GE(thread_count, 2);
95 // If there are fewer batch entries than the number of threads we want to use,
96 // then better do intra-batch-entry multithreading.
97 if (batches < thread_count) {
98 return false;
99 }
100 // If there are at least 2 batch entries to be handed to each thread, then
101 // it's safe to proceed with batch-wise multithreading: each thread will have
102 // approximately equal number of batch entries to handle, so the load
103 // balancing will be reasonable, and the amount to which the load is not
104 // perfectly balanced will be offset by the inherent advantages of
105 // batch-wise multithreading (each thread is more efficient thanks to working
106 // on larger buffers with less boundary-handling overhead).
107 if (batches >= 2 * thread_count) {
108 return true;
109 }
110 // In the limit case were there are at least 1 but not much more than 1
111 // batch entries per thread, it may be a good idea to do per-batch
112 // multithreading if the number of batch entries is a multiple of the number
113 // of threads, so that each thread will have the same number of batch entries
114 // to process.
115 return ((batches % thread_count) == 0);
116 }
117
118 template <typename T, typename TS>
DepthwiseConv(const DepthwiseParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const T * filter_data,const RuntimeShape & bias_shape,const TS * bias_data,const RuntimeShape & output_shape,T * output_data,CpuBackendContext * cpu_backend_context)119 inline void DepthwiseConv(const DepthwiseParams& params,
120 const RuntimeShape& input_shape, const T* input_data,
121 const RuntimeShape& filter_shape,
122 const T* filter_data, const RuntimeShape& bias_shape,
123 const TS* bias_data, const RuntimeShape& output_shape,
124 T* output_data,
125 CpuBackendContext* cpu_backend_context) {
126 ruy::profiler::ScopeLabel label("DepthwiseConv");
127
128 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
129 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
130 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
131
132 int thread_count = HowManyConvThreads(output_shape, filter_shape);
133 const int max_threads = cpu_backend_context->max_num_threads();
134 thread_count = std::max(1, std::min(thread_count, max_threads));
135 #ifndef TFLITE_WITH_RUY
136 // Cap the number of threads to 2 for float path to avoid regression in
137 // performance (b/132294857).
138 if (std::is_floating_point<T>::value) {
139 thread_count = std::min(thread_count, 2);
140 }
141 #endif
142
143 const int output_batches = output_shape.Dims(0);
144 const int output_height = output_shape.Dims(1);
145
146 CpuFlags cpu_flags;
147 GetCpuFlags(&cpu_flags);
148
149 if (thread_count == 1) {
150 DepthwiseConvImpl(params, input_shape, input_data, filter_shape,
151 filter_data, bias_shape, bias_data, output_shape,
152 output_data, cpu_flags, /*thread_start=*/0,
153 /*thread_end=*/output_height, /*thread_dim=*/1);
154 return;
155 }
156
157 int thread_dim, thread_dim_size;
158 if (MultithreadAlongBatches(thread_count, output_batches)) {
159 thread_dim = 0;
160 thread_dim_size = output_batches;
161 } else {
162 thread_dim = 1;
163 thread_dim_size = output_height;
164 }
165
166 std::vector<DepthwiseConvWorkerTask<T, TS>> tasks;
167 // TODO(b/131746020) don't create new heap allocations every time.
168 // At least we make it a single heap allocation by using reserve().
169 tasks.reserve(thread_count);
170 int thread_start = 0;
171 for (int i = 0; i < thread_count; ++i) {
172 int thread_end =
173 thread_start + (thread_dim_size - thread_start) / (thread_count - i);
174 tasks.emplace_back(params, input_shape, input_data, filter_shape,
175 filter_data, bias_shape, bias_data, output_shape,
176 output_data, cpu_flags, thread_start, thread_end,
177 thread_dim);
178 thread_start = thread_end;
179 }
180 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
181 cpu_backend_context);
182 }
183
184 } // namespace optimized_ops
185 } // namespace tflite
186
187 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_MULTITHREAD_H_
188