1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/kernels/meta_support.h"
19
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/kernels/quantization_utils.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 #if (defined(GEMMLOWP_NEON_32) || defined(GEMMLOWP_NEON_64)) && \
27 !defined(TENSORFLOW_DISABLE_META) && !defined(__APPLE__)
28 #define TENSORFLOW_USE_META (1)
29 #endif
30
31 namespace tensorflow {
32 namespace meta {
33
34 namespace {
35
36 int g_num_threads = 0;
37 bool g_enabled = true;
38 bool g_use_local_context = false;
39
40 #ifdef TENSORFLOW_USE_META
41
42 const int kAlignment = 32;
43 const int kScratchSize = 2048 * 1024 + kAlignment;
44
45 class Scratch : public ResourceBase {
46 public:
Scratch()47 Scratch() : scratch_(new uint8_t[kScratchSize]) {
48 // Make sure scratch is aligned to 32 bytes. Scratch object owns the
49 // scratch buffer.
50 scratch_32_aligned_ =
51 scratch_.get() + kAlignment -
52 (reinterpret_cast<uintptr_t>(scratch_.get()) % kAlignment);
53 }
54
buffer()55 uint8_t* buffer() { return scratch_32_aligned_; }
56
DebugString() const57 string DebugString() const override { return "MetaGemmScratchResource"; }
58
59 private:
60 std::unique_ptr<uint8_t> scratch_;
61 uint8_t* scratch_32_aligned_;
62 };
63
GetScratch(OpKernelContext * context)64 uint8_t* GetScratch(OpKernelContext* context) {
65 Scratch* scratch = nullptr;
66 std::function<Status(Scratch**)> creator = [](Scratch** resource) {
67 *resource = new Scratch();
68 return Status::OK();
69 };
70 Status s = context->resource_manager()->LookupOrCreate(
71 "MetaGemm", "ScratchBuffer", &scratch, creator);
72 if (!s.ok()) {
73 context->CtxFailureWithWarning(s);
74 return nullptr;
75 }
76 return scratch->buffer();
77 }
78
GetWorkersPool()79 gemmlowp::WorkersPool* GetWorkersPool() {
80 static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
81 return pool;
82 }
83
GetMutex()84 mutex& GetMutex() {
85 static mutex mu(LINKER_INITIALIZED);
86 return mu;
87 }
88
GetWorkersCount(OpKernelContext * tf_context)89 int GetWorkersCount(OpKernelContext* tf_context) {
90 if (g_num_threads == 0) {
91 return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
92 }
93 return g_num_threads;
94 }
95
96 typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
97
98 template <typename Context, typename Params>
MultiThreadGemm(Context * context,const Params & params)99 void MultiThreadGemm(Context* context, const Params& params) {
100 if (params.m <= 4) {
101 gemmlowp::meta::MultiThreadGemm<
102 Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1,
103 8, 8>(context, params);
104 } else {
105 if (params.m >= params.n) {
106 gemmlowp::meta::MultiThreadGemm<
107 Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
108 2, 4, 8>(context, params);
109 } else {
110 gemmlowp::meta::MultiThreadGemm<
111 Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
112 2, 4, 8>(context, params);
113 }
114 }
115 }
116
117 template <typename LeftStream, typename RightStream>
QuantizedGemmImpl(OpKernelContext * tf_context,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)118 void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
119 const quint8* b_data, qint32* c_data, int m, int n,
120 int k, int offset_a, int offset_b, int lda, int ldb,
121 int ldc) {
122 typedef gemmlowp::meta::GemmParams<
123 uint8_t, int32_t, LeftStream, RightStream,
124 gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
125 gemmlowp::meta::RowMajor>
126 Params;
127 Params params;
128
129 params.m = m;
130 params.n = n;
131 params.k = k;
132
133 params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
134 params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
135 params.result = reinterpret_cast<int32_t*>(&(c_data->value));
136 params.scratch = CHECK_NOTNULL(GetScratch(tf_context));
137
138 params.left_stream.count = k;
139 params.left_stream.stride = lda;
140 params.left_stream.multiplicative_sum_offset = offset_b;
141 params.left_stream.additive_sum_offset = k * offset_a * offset_b;
142
143 params.right_stream.count = k;
144 params.right_stream.stride = ldb;
145 params.right_stream.multiplicative_sum_offset = offset_a;
146 params.right_stream.additive_sum_offset = 0;
147
148 params.fused_kernel.kernel.count = k;
149 params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
150
151 if (g_use_local_context) {
152 LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
153 MultiThreadGemm<LocalContext, Params>(&local_context, params);
154 } else {
155 auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
156 TensorflowGemmContext context(workers.num_threads, workers.workers);
157 MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
158 }
159 }
160
161 template <typename Params, int kernel_size>
MultiThreadTransform1D(OpKernelContext * tf_context,const Params & params)162 void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
163 if (g_use_local_context) {
164 LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
165 gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
166 &local_context, params);
167 } else {
168 auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
169 TensorflowGemmContext context(workers.num_threads, workers.workers);
170 gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
171 kernel_size>(&context, params);
172 }
173 }
174
175 template <typename QuantizedType>
CalculateRangeScale(float min,float max)176 double CalculateRangeScale(float min, float max) {
177 const int bits = sizeof(QuantizedType) * 8;
178 return static_cast<double>(max - min) /
179 ((static_cast<int64_t>(1) << bits) - 1);
180 }
181
182 template <typename QuantizedType>
CalculateOneOverRangeScale(float min,float max)183 double CalculateOneOverRangeScale(float min, float max) {
184 if (min == max) {
185 return 0.0;
186 }
187 const int bits = sizeof(QuantizedType) * 8;
188 return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
189 (max - min);
190 }
191
192 #endif // TENSORFLOW_USE_META
193
194 } // namespace
195
SetNumThreads(int num_threads)196 void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
197
GetNumThreads()198 int GetNumThreads() { return g_num_threads; }
199
SetUseLocalContext(bool use_local_context)200 void SetUseLocalContext(bool use_local_context) {
201 g_use_local_context = use_local_context;
202 }
203
GetUseLocalContext()204 bool GetUseLocalContext() { return g_use_local_context; }
205
IsSupported()206 bool IsSupported() {
207 #if defined(TENSORFLOW_USE_META)
208 return true;
209 #else
210 return false;
211 #endif
212 }
213
IsEnabled()214 bool IsEnabled() { return g_enabled; }
215
SetEnabled(bool enabled)216 void SetEnabled(bool enabled) { g_enabled = enabled; }
217
IsSupportedAndEnabled()218 bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
219
QuantizedGemm(OpKernelContext * tf_context,bool transpose_a,bool transpose_b,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)220 void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
221 bool transpose_b, const quint8* a_data, const quint8* b_data,
222 qint32* c_data, int m, int n, int k, int offset_a,
223 int offset_b, int lda, int ldb, int ldc) {
224 #ifdef TENSORFLOW_USE_META
225 mutex_lock library_lock(GetMutex());
226 if (transpose_a) {
227 if (transpose_b) {
228 QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
229 gemmlowp::meta::RowMajorWithSum>(
230 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
231 ldb, ldc);
232 } else {
233 QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
234 gemmlowp::meta::ColumnMajorWithSum>(
235 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
236 ldb, ldc);
237 }
238 } else {
239 if (transpose_b) {
240 QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
241 gemmlowp::meta::RowMajorWithSum>(
242 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
243 ldb, ldc);
244 } else {
245 QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
246 gemmlowp::meta::ColumnMajorWithSum>(
247 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
248 ldb, ldc);
249 }
250 }
251 #else
252 LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
253 #endif
254 }
255
Requantize(OpKernelContext * tf_context,const qint32 * input,int count,float input_min,float input_max,float output_min,float output_max,quint8 * output)256 void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
257 float input_min, float input_max, float output_min,
258 float output_max, quint8* output) {
259 #ifdef TENSORFLOW_USE_META
260 mutex_lock library_lock(GetMutex());
261 typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
262 gemmlowp::meta::Requantize>
263 Params;
264
265 Params params;
266 params.input = reinterpret_cast<const int32_t*>(input);
267 params.output = reinterpret_cast<uint8_t*>(output);
268 params.kernel.count = count;
269 params.kernel.input_range_min = input_min;
270 params.kernel.output_range_min = output_min;
271 params.kernel.input_range_scale =
272 CalculateRangeScale<int32_t>(input_min, input_max);
273 params.kernel.one_over_output_range_scale =
274 CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
275 params.kernel.input_range_offset =
276 static_cast<float>(std::numeric_limits<int32_t>::lowest());
277
278 // After adding the output_range_offset the value is cast from float to uint.
279 // The float to int/uint cast in NEON uses round toward 0. To keep the
280 // rounding consistent with Eigen, which uses round toward closest, we can
281 // add 0.5f and exploit the fact that we only operate on non negative values.
282 // TODO(maciekc): fix the actual kernel in gemmlowp/meta
283 params.kernel.output_range_offset =
284 static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
285
286 MultiThreadTransform1D<Params, 16>(tf_context, params);
287 #else
288 LOG(FATAL) << "Requantize: Meta fastpath not supported.";
289 #endif
290 }
291
Dequantize(OpKernelContext * tf_context,const quint8 * input,int count,float range_min,float range_max,float * output)292 void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
293 float range_min, float range_max, float* output) {
294 #ifdef TENSORFLOW_USE_META
295 mutex_lock library_lock(GetMutex());
296 typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
297 gemmlowp::meta::Dequantize>
298 Params;
299
300 Params params;
301 params.input = reinterpret_cast<const uint8_t*>(input);
302 params.output = reinterpret_cast<float*>(output);
303 params.kernel.count = count;
304 params.kernel.range_min = range_min;
305 params.kernel.range_scale =
306 CalculateRangeScale<uint8_t>(range_min, range_max);
307 params.kernel.range_offset =
308 static_cast<float>(std::numeric_limits<uint8_t>::lowest());
309
310 MultiThreadTransform1D<Params, 16>(tf_context, params);
311 #else
312 LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
313 #endif
314 }
315
Quantize(OpKernelContext * tf_context,const float * input,int count,float range_min,float range_max,quint8 * output)316 void Quantize(OpKernelContext* tf_context, const float* input, int count,
317 float range_min, float range_max, quint8* output) {
318 #ifdef TENSORFLOW_USE_META
319 mutex_lock library_lock(GetMutex());
320 typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
321 gemmlowp::meta::Quantize>
322 Params;
323
324 Params params;
325 params.input = reinterpret_cast<const float*>(input);
326 params.output = reinterpret_cast<uint8_t*>(output);
327 params.kernel.count = count;
328 params.kernel.range_min = range_min;
329 params.kernel.range_scale =
330 CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
331
332 // After adding the range_offset the value is cast from float to uint.
333 // The float to int/uint cast in NEON uses round toward 0. To keep the
334 // rounding consistent with Eigen, which uses round toward closest, we can
335 // add 0.5f and exploit the fact that we only operate on non negative values.
336 // TODO(maciekc): fix the actual kernel in gemmlowp/meta
337 params.kernel.range_offset =
338 static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
339
340 MultiThreadTransform1D<Params, 16>(tf_context, params);
341 #else
342 LOG(FATAL) << "Quantize: Meta fastpath not supported.";
343 #endif
344 }
345
QuantizedBiasAdd(OpKernelContext * tf_context,const quint8 * input,int input_count,const quint8 * bias,int bias_count,float input_min,float input_max,float bias_min,float bias_max,float output_min,float output_max,qint32 * output)346 void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
347 int input_count, const quint8* bias, int bias_count,
348 float input_min, float input_max, float bias_min,
349 float bias_max, float output_min, float output_max,
350 qint32* output) {
351 #ifdef TENSORFLOW_USE_META
352 mutex_lock library_lock(GetMutex());
353 typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
354 gemmlowp::meta::BiasAdd<uint8_t>>
355 Params;
356
357 Params params;
358 params.input = reinterpret_cast<const uint8_t*>(input);
359 params.output = reinterpret_cast<int32_t*>(output);
360 params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
361 params.kernel.count = bias_count;
362 params.kernel.rows = input_count / bias_count;
363 params.kernel.input_range_min = input_min;
364 params.kernel.bias_range_min = bias_min;
365 params.kernel.input_range_scale =
366 CalculateRangeScale<uint8_t>(input_min, input_max);
367 params.kernel.bias_range_scale =
368 CalculateRangeScale<uint8_t>(bias_min, bias_max);
369 params.kernel.input_range_offset = 0;
370 params.kernel.bias_range_offset = 0;
371 params.kernel.output_range_min = output_min;
372 params.kernel.one_over_output_range_scale =
373 CalculateOneOverRangeScale<int32_t>(output_min, output_max);
374 params.kernel.output_range_offset =
375 static_cast<float>(std::numeric_limits<int32_t>::lowest());
376
377 // TODO(maciekc): add multithreading to bias add.
378 // Right now this kernel does not support multi threaded execution.
379 gemmlowp::meta::Transform1D<Params, 16>(params);
380 #else
381 LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
382 #endif
383 }
384
Clamp(OpKernelContext * tf_context,const quint8 * input,int count,quint8 clamp_min,quint8 clamp_max,quint8 * output)385 void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
386 quint8 clamp_min, quint8 clamp_max, quint8* output) {
387 #ifdef TENSORFLOW_USE_META
388 mutex_lock library_lock(GetMutex());
389 typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
390 gemmlowp::meta::MinMax<uint8_t>>
391 Params;
392
393 Params params;
394 params.input = reinterpret_cast<const uint8_t*>(input);
395 params.output = reinterpret_cast<uint8_t*>(output);
396 params.kernel.count = count;
397 params.kernel.min = clamp_min;
398 params.kernel.max = clamp_max;
399
400 MultiThreadTransform1D<Params, 16>(tf_context, params);
401 #else
402 LOG(FATAL) << "Clamp: Meta fastpath not supported.";
403 #endif
404 }
405
406 } // namespace meta
407 } // namespace tensorflow
408