• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/meta_support.h"
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/kernels/quantization_utils.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 #if (defined(GEMMLOWP_NEON_32) || defined(GEMMLOWP_NEON_64)) && \
27     !defined(TENSORFLOW_DISABLE_META) && !defined(__APPLE__)
28 #define TENSORFLOW_USE_META (1)
29 #endif
30 
31 namespace tensorflow {
32 namespace meta {
33 
34 namespace {
35 
36 int g_num_threads = 0;
37 bool g_enabled = true;
38 bool g_use_local_context = false;
39 
40 #ifdef TENSORFLOW_USE_META
41 
42 const int kAlignment = 32;
43 const int kScratchSize = 2048 * 1024 + kAlignment;
44 
45 class Scratch : public ResourceBase {
46  public:
Scratch()47   Scratch() : scratch_(new uint8_t[kScratchSize]) {
48     // Make sure scratch is aligned to 32 bytes. Scratch object owns the
49     // scratch buffer.
50     scratch_32_aligned_ =
51         scratch_.get() + kAlignment -
52         (reinterpret_cast<uintptr_t>(scratch_.get()) % kAlignment);
53   }
54 
buffer()55   uint8_t* buffer() { return scratch_32_aligned_; }
56 
DebugString() const57   string DebugString() const override { return "MetaGemmScratchResource"; }
58 
59  private:
60   std::unique_ptr<uint8_t> scratch_;
61   uint8_t* scratch_32_aligned_;
62 };
63 
GetScratch(OpKernelContext * context)64 uint8_t* GetScratch(OpKernelContext* context) {
65   Scratch* scratch = nullptr;
66   std::function<Status(Scratch**)> creator = [](Scratch** resource) {
67     *resource = new Scratch();
68     return Status::OK();
69   };
70   Status s = context->resource_manager()->LookupOrCreate(
71       "MetaGemm", "ScratchBuffer", &scratch, creator);
72   if (!s.ok()) {
73     context->CtxFailureWithWarning(s);
74     return nullptr;
75   }
76   return scratch->buffer();
77 }
78 
GetWorkersPool()79 gemmlowp::WorkersPool* GetWorkersPool() {
80   static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
81   return pool;
82 }
83 
GetMutex()84 mutex& GetMutex() {
85   static mutex mu(LINKER_INITIALIZED);
86   return mu;
87 }
88 
GetWorkersCount(OpKernelContext * tf_context)89 int GetWorkersCount(OpKernelContext* tf_context) {
90   if (g_num_threads == 0) {
91     return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
92   }
93   return g_num_threads;
94 }
95 
96 typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
97 
98 template <typename Context, typename Params>
MultiThreadGemm(Context * context,const Params & params)99 void MultiThreadGemm(Context* context, const Params& params) {
100   if (params.m <= 4) {
101     gemmlowp::meta::MultiThreadGemm<
102         Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1,
103         8, 8>(context, params);
104   } else {
105     if (params.m >= params.n) {
106       gemmlowp::meta::MultiThreadGemm<
107           Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
108           2, 4, 8>(context, params);
109     } else {
110       gemmlowp::meta::MultiThreadGemm<
111           Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
112           2, 4, 8>(context, params);
113     }
114   }
115 }
116 
117 template <typename LeftStream, typename RightStream>
QuantizedGemmImpl(OpKernelContext * tf_context,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)118 void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
119                        const quint8* b_data, qint32* c_data, int m, int n,
120                        int k, int offset_a, int offset_b, int lda, int ldb,
121                        int ldc) {
122   typedef gemmlowp::meta::GemmParams<
123       uint8_t, int32_t, LeftStream, RightStream,
124       gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
125       gemmlowp::meta::RowMajor>
126       Params;
127   Params params;
128 
129   params.m = m;
130   params.n = n;
131   params.k = k;
132 
133   params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
134   params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
135   params.result = reinterpret_cast<int32_t*>(&(c_data->value));
136   params.scratch = CHECK_NOTNULL(GetScratch(tf_context));
137 
138   params.left_stream.count = k;
139   params.left_stream.stride = lda;
140   params.left_stream.multiplicative_sum_offset = offset_b;
141   params.left_stream.additive_sum_offset = k * offset_a * offset_b;
142 
143   params.right_stream.count = k;
144   params.right_stream.stride = ldb;
145   params.right_stream.multiplicative_sum_offset = offset_a;
146   params.right_stream.additive_sum_offset = 0;
147 
148   params.fused_kernel.kernel.count = k;
149   params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
150 
151   if (g_use_local_context) {
152     LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
153     MultiThreadGemm<LocalContext, Params>(&local_context, params);
154   } else {
155     auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
156     TensorflowGemmContext context(workers.num_threads, workers.workers);
157     MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
158   }
159 }
160 
161 template <typename Params, int kernel_size>
MultiThreadTransform1D(OpKernelContext * tf_context,const Params & params)162 void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
163   if (g_use_local_context) {
164     LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
165     gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
166         &local_context, params);
167   } else {
168     auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
169     TensorflowGemmContext context(workers.num_threads, workers.workers);
170     gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
171                                            kernel_size>(&context, params);
172   }
173 }
174 
175 template <typename QuantizedType>
CalculateRangeScale(float min,float max)176 double CalculateRangeScale(float min, float max) {
177   const int bits = sizeof(QuantizedType) * 8;
178   return static_cast<double>(max - min) /
179          ((static_cast<int64_t>(1) << bits) - 1);
180 }
181 
182 template <typename QuantizedType>
CalculateOneOverRangeScale(float min,float max)183 double CalculateOneOverRangeScale(float min, float max) {
184   if (min == max) {
185     return 0.0;
186   }
187   const int bits = sizeof(QuantizedType) * 8;
188   return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
189          (max - min);
190 }
191 
192 #endif  // TENSORFLOW_USE_META
193 
194 }  // namespace
195 
SetNumThreads(int num_threads)196 void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
197 
GetNumThreads()198 int GetNumThreads() { return g_num_threads; }
199 
SetUseLocalContext(bool use_local_context)200 void SetUseLocalContext(bool use_local_context) {
201   g_use_local_context = use_local_context;
202 }
203 
GetUseLocalContext()204 bool GetUseLocalContext() { return g_use_local_context; }
205 
IsSupported()206 bool IsSupported() {
207 #if defined(TENSORFLOW_USE_META)
208   return true;
209 #else
210   return false;
211 #endif
212 }
213 
IsEnabled()214 bool IsEnabled() { return g_enabled; }
215 
SetEnabled(bool enabled)216 void SetEnabled(bool enabled) { g_enabled = enabled; }
217 
IsSupportedAndEnabled()218 bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
219 
QuantizedGemm(OpKernelContext * tf_context,bool transpose_a,bool transpose_b,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)220 void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
221                    bool transpose_b, const quint8* a_data, const quint8* b_data,
222                    qint32* c_data, int m, int n, int k, int offset_a,
223                    int offset_b, int lda, int ldb, int ldc) {
224 #ifdef TENSORFLOW_USE_META
225   mutex_lock library_lock(GetMutex());
226   if (transpose_a) {
227     if (transpose_b) {
228       QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
229                         gemmlowp::meta::RowMajorWithSum>(
230           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
231           ldb, ldc);
232     } else {
233       QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
234                         gemmlowp::meta::ColumnMajorWithSum>(
235           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
236           ldb, ldc);
237     }
238   } else {
239     if (transpose_b) {
240       QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
241                         gemmlowp::meta::RowMajorWithSum>(
242           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
243           ldb, ldc);
244     } else {
245       QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
246                         gemmlowp::meta::ColumnMajorWithSum>(
247           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
248           ldb, ldc);
249     }
250   }
251 #else
252   LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
253 #endif
254 }
255 
Requantize(OpKernelContext * tf_context,const qint32 * input,int count,float input_min,float input_max,float output_min,float output_max,quint8 * output)256 void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
257                 float input_min, float input_max, float output_min,
258                 float output_max, quint8* output) {
259 #ifdef TENSORFLOW_USE_META
260   mutex_lock library_lock(GetMutex());
261   typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
262                                             gemmlowp::meta::Requantize>
263       Params;
264 
265   Params params;
266   params.input = reinterpret_cast<const int32_t*>(input);
267   params.output = reinterpret_cast<uint8_t*>(output);
268   params.kernel.count = count;
269   params.kernel.input_range_min = input_min;
270   params.kernel.output_range_min = output_min;
271   params.kernel.input_range_scale =
272       CalculateRangeScale<int32_t>(input_min, input_max);
273   params.kernel.one_over_output_range_scale =
274       CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
275   params.kernel.input_range_offset =
276       static_cast<float>(std::numeric_limits<int32_t>::lowest());
277 
278   // After adding the output_range_offset the value is cast from float to uint.
279   // The float to int/uint cast in NEON uses round toward 0. To keep the
280   // rounding consistent with Eigen, which uses round toward closest, we can
281   // add 0.5f and exploit the fact that we only operate on non negative values.
282   // TODO(maciekc): fix the actual kernel in gemmlowp/meta
283   params.kernel.output_range_offset =
284       static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
285 
286   MultiThreadTransform1D<Params, 16>(tf_context, params);
287 #else
288   LOG(FATAL) << "Requantize: Meta fastpath not supported.";
289 #endif
290 }
291 
Dequantize(OpKernelContext * tf_context,const quint8 * input,int count,float range_min,float range_max,float * output)292 void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
293                 float range_min, float range_max, float* output) {
294 #ifdef TENSORFLOW_USE_META
295   mutex_lock library_lock(GetMutex());
296   typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
297                                             gemmlowp::meta::Dequantize>
298       Params;
299 
300   Params params;
301   params.input = reinterpret_cast<const uint8_t*>(input);
302   params.output = reinterpret_cast<float*>(output);
303   params.kernel.count = count;
304   params.kernel.range_min = range_min;
305   params.kernel.range_scale =
306       CalculateRangeScale<uint8_t>(range_min, range_max);
307   params.kernel.range_offset =
308       static_cast<float>(std::numeric_limits<uint8_t>::lowest());
309 
310   MultiThreadTransform1D<Params, 16>(tf_context, params);
311 #else
312   LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
313 #endif
314 }
315 
Quantize(OpKernelContext * tf_context,const float * input,int count,float range_min,float range_max,quint8 * output)316 void Quantize(OpKernelContext* tf_context, const float* input, int count,
317               float range_min, float range_max, quint8* output) {
318 #ifdef TENSORFLOW_USE_META
319   mutex_lock library_lock(GetMutex());
320   typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
321                                             gemmlowp::meta::Quantize>
322       Params;
323 
324   Params params;
325   params.input = reinterpret_cast<const float*>(input);
326   params.output = reinterpret_cast<uint8_t*>(output);
327   params.kernel.count = count;
328   params.kernel.range_min = range_min;
329   params.kernel.range_scale =
330       CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
331 
332   // After adding the range_offset the value is cast from float to uint.
333   // The float to int/uint cast in NEON uses round toward 0. To keep the
334   // rounding consistent with Eigen, which uses round toward closest, we can
335   // add 0.5f and exploit the fact that we only operate on non negative values.
336   // TODO(maciekc): fix the actual kernel in gemmlowp/meta
337   params.kernel.range_offset =
338       static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
339 
340   MultiThreadTransform1D<Params, 16>(tf_context, params);
341 #else
342   LOG(FATAL) << "Quantize: Meta fastpath not supported.";
343 #endif
344 }
345 
QuantizedBiasAdd(OpKernelContext * tf_context,const quint8 * input,int input_count,const quint8 * bias,int bias_count,float input_min,float input_max,float bias_min,float bias_max,float output_min,float output_max,qint32 * output)346 void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
347                       int input_count, const quint8* bias, int bias_count,
348                       float input_min, float input_max, float bias_min,
349                       float bias_max, float output_min, float output_max,
350                       qint32* output) {
351 #ifdef TENSORFLOW_USE_META
352   mutex_lock library_lock(GetMutex());
353   typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
354                                             gemmlowp::meta::BiasAdd<uint8_t>>
355       Params;
356 
357   Params params;
358   params.input = reinterpret_cast<const uint8_t*>(input);
359   params.output = reinterpret_cast<int32_t*>(output);
360   params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
361   params.kernel.count = bias_count;
362   params.kernel.rows = input_count / bias_count;
363   params.kernel.input_range_min = input_min;
364   params.kernel.bias_range_min = bias_min;
365   params.kernel.input_range_scale =
366       CalculateRangeScale<uint8_t>(input_min, input_max);
367   params.kernel.bias_range_scale =
368       CalculateRangeScale<uint8_t>(bias_min, bias_max);
369   params.kernel.input_range_offset = 0;
370   params.kernel.bias_range_offset = 0;
371   params.kernel.output_range_min = output_min;
372   params.kernel.one_over_output_range_scale =
373       CalculateOneOverRangeScale<int32_t>(output_min, output_max);
374   params.kernel.output_range_offset =
375       static_cast<float>(std::numeric_limits<int32_t>::lowest());
376 
377   // TODO(maciekc): add multithreading to bias add.
378   // Right now this kernel does not support multi threaded execution.
379   gemmlowp::meta::Transform1D<Params, 16>(params);
380 #else
381   LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
382 #endif
383 }
384 
Clamp(OpKernelContext * tf_context,const quint8 * input,int count,quint8 clamp_min,quint8 clamp_max,quint8 * output)385 void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
386            quint8 clamp_min, quint8 clamp_max, quint8* output) {
387 #ifdef TENSORFLOW_USE_META
388   mutex_lock library_lock(GetMutex());
389   typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
390                                             gemmlowp::meta::MinMax<uint8_t>>
391       Params;
392 
393   Params params;
394   params.input = reinterpret_cast<const uint8_t*>(input);
395   params.output = reinterpret_cast<uint8_t*>(output);
396   params.kernel.count = count;
397   params.kernel.min = clamp_min;
398   params.kernel.max = clamp_max;
399 
400   MultiThreadTransform1D<Params, 16>(tf_context, params);
401 #else
402   LOG(FATAL) << "Clamp: Meta fastpath not supported.";
403 #endif
404 }
405 
406 }  // namespace meta
407 }  // namespace tensorflow
408