1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // multi_thread_gemm.h: Multi-threaded GEMM entry point.
16 // Readers note: To understand this file, it is useful to first
17 // read and understand the much simpler single_thread_gemm.h.
18
19 #ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
20 #define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
21
22 #include <vector>
23
24 #include "single_thread_gemm.h"
25
26 namespace gemmlowp {
27
28 // On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a
29 // pthread conditional variable. In order to implement that correctly we need
30 // to put some explicit memory load/store barriers.
31
32 #if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \
33 (defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86))
34
35 #define GEMMLOWP_USE_BUSYWAIT
36
37 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
38
39 #define GEMMLOWP_NOP "nop\n"
40
41 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X
42 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
43 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
44 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
45
Do256NOPs()46 inline int Do256NOPs() {
47 asm volatile(GEMMLOWP_NOP64);
48 return 64;
49 }
50
51 #undef GEMMLOWP_STRING_CONCAT_4
52 #undef GEMMLOWP_NOP256
53 #undef GEMMLOWP_NOP64
54 #undef GEMMLOWP_NOP16
55 #undef GEMMLOWP_NOP4
56 #undef GEMMLOWP_NOP
57
WriteBarrier()58 inline void WriteBarrier() {
59 #if defined(_MSC_VER)
60 MemoryBarrier();
61 #elif defined(GEMMLOWP_ARM_32)
62 asm volatile("" ::: "memory");
63 #elif defined(GEMMLOWP_ARM_64)
64 asm volatile("dmb ishst" ::: "memory");
65 #elif defined(GEMMLOWP_X86)
66 asm volatile("sfence" ::: "memory");
67 #else
68 #error "Unsupported architecture for WriteBarrier."
69 #endif
70 }
71
ReadBarrier()72 inline void ReadBarrier() {
73 #if defined(_MSC_VER)
74 MemoryBarrier();
75 #elif defined(GEMMLOWP_ARM_32)
76 asm volatile("" ::: "memory");
77 #elif defined(GEMMLOWP_ARM_64)
78 asm volatile("dmb ishld" ::: "memory");
79 #elif defined(GEMMLOWP_X86)
80 asm volatile("lfence" ::: "memory");
81 #else
82 #error "Unsupported architecture for ReadBarrier."
83 #endif
84 }
85
86 #endif
87
88 // Waits until *var != initial_value.
89 //
90 // Returns the new value of *var. The guarantee here is that
91 // the return value is different from initial_value, and that that
92 // new value has been taken by *var at some point during the
93 // execution of this function. There is no guarantee that this is
94 // still the value of *var when this function returns, since *var is
95 // not assumed to be guarded by any lock.
96 //
97 // First does some busy-waiting for a fixed number of no-op cycles,
98 // then falls back to passive waiting for the given condvar, guarded
99 // by the given mutex.
100 //
101 // The idea of doing some initial busy-waiting is to help get
102 // better and more consistent multithreading benefits for small GEMM sizes.
103 // Busy-waiting help ensuring that if we need to wake up soon after having
104 // started waiting, then we can wake up quickly (as opposed to, say,
105 // having to wait to be scheduled again by the OS). On the other hand,
106 // we must still eventually revert to passive waiting for longer waits
107 // (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
108 // so as to avoid permanently spinning.
109 //
110 template <typename T>
WaitForVariableChange(volatile T * var,T initial_value,pthread_cond_t * cond,pthread_mutex_t * mutex)111 T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
112 pthread_mutex_t* mutex) {
113 #ifdef GEMMLOWP_USE_BUSYWAIT
114 // If we are on a platform that supports it, spin for some time.
115 {
116 int nops = 0;
117 // First, trivial case where the variable already changed value.
118 T new_value = *var;
119 if (new_value != initial_value) {
120 ReadBarrier();
121 return new_value;
122 }
123 // Then try busy-waiting.
124 while (nops < kMaxBusyWaitNOPs) {
125 nops += Do256NOPs();
126 new_value = *var;
127 if (new_value != initial_value) {
128 ReadBarrier();
129 return new_value;
130 }
131 }
132 }
133 #endif
134
135 // Finally, do real passive waiting.
136 pthread_mutex_lock(mutex);
137 T new_value = *var;
138 if (new_value == initial_value) {
139 pthread_cond_wait(cond, mutex);
140 new_value = *var;
141 assert(new_value != initial_value);
142 }
143 pthread_mutex_unlock(mutex);
144 return new_value;
145 }
146
147 // A BlockingCounter lets one thread to wait for N events to occur.
148 // This is how the master thread waits for all the worker threads
149 // to have finished working.
150 class BlockingCounter {
151 public:
BlockingCounter()152 BlockingCounter() : count_(0), initial_count_(0) {
153 pthread_cond_init(&cond_, nullptr);
154 pthread_mutex_init(&mutex_, nullptr);
155 }
156
~BlockingCounter()157 ~BlockingCounter() {
158 pthread_cond_destroy(&cond_);
159 pthread_mutex_destroy(&mutex_);
160 }
161
162 // Sets/resets the counter; initial_count is the number of
163 // decrementing events that the Wait() call will be waiting for.
Reset(std::size_t initial_count)164 void Reset(std::size_t initial_count) {
165 pthread_mutex_lock(&mutex_);
166 assert(count_ == 0);
167 initial_count_ = initial_count;
168 count_ = initial_count_;
169 pthread_mutex_unlock(&mutex_);
170 }
171
172 // Decrements the counter; if the counter hits zero, signals
173 // the thread that was waiting for that, and returns true.
174 // Otherwise (if the decremented count is still nonzero),
175 // returns false.
DecrementCount()176 bool DecrementCount() {
177 pthread_mutex_lock(&mutex_);
178 assert(count_ > 0);
179 count_--;
180 #ifdef GEMMLOWP_USE_BUSYWAIT
181 WriteBarrier();
182 #endif
183 if (count_ == 0) {
184 pthread_cond_signal(&cond_);
185 }
186 bool retval = count_ == 0;
187 pthread_mutex_unlock(&mutex_);
188 return retval;
189 }
190
191 // Waits for the N other threads (N having been set by Reset())
192 // to hit the BlockingCounter.
Wait()193 void Wait() {
194 ScopedProfilingLabel label("BlockingCounter::Wait");
195 while (count_) {
196 #ifdef GEMMLOWP_USE_BUSYWAIT
197 ReadBarrier();
198 #else
199 // This is likely unnecessary, but is kept to ensure regressions are not
200 // introduced.
201 #ifndef _WIN32
202 asm volatile("" ::: "memory");
203 #endif
204 #endif
205 const std::size_t count_value = count_;
206 if (count_value) {
207 WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
208 }
209 }
210 }
211
212 private:
213 pthread_cond_t cond_;
214 pthread_mutex_t mutex_;
215 std::size_t count_;
216 std::size_t initial_count_;
217 };
218
219 // A workload for a worker.
220 struct Task {
TaskTask221 Task() : local_allocator(nullptr) {}
~TaskTask222 virtual ~Task() {}
223 virtual void Run() = 0;
224 Allocator* local_allocator;
225 };
226
227 // A worker thread.
228 class Worker {
229 public:
230 enum class State {
231 ThreadStartup, // The initial state before the thread main loop runs.
232 Ready, // Is not working, has not yet received new work to do.
233 HasWork, // Has work to do.
234 ExitAsSoonAsPossible // Should exit at earliest convenience.
235 };
236
Worker(BlockingCounter * counter_to_decrement_when_ready)237 explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
238 : task_(nullptr),
239 state_(State::ThreadStartup),
240 counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
241 pthread_cond_init(&state_cond_, nullptr);
242 pthread_mutex_init(&state_mutex_, nullptr);
243 pthread_create(&thread_, nullptr, ThreadFunc, this);
244 }
245
~Worker()246 ~Worker() {
247 ChangeState(State::ExitAsSoonAsPossible);
248 pthread_join(thread_, nullptr);
249 pthread_cond_destroy(&state_cond_);
250 pthread_mutex_destroy(&state_mutex_);
251 }
252
253 // Changes State; may be called from either the worker thread
254 // or the master thread; however, not all state transitions are legal,
255 // which is guarded by assertions.
ChangeState(State new_state)256 void ChangeState(State new_state) {
257 ScopedProfilingLabel label("Worker::ChangeState");
258 pthread_mutex_lock(&state_mutex_);
259 assert(new_state != state_);
260 switch (state_) {
261 case State::ThreadStartup:
262 assert(new_state == State::Ready);
263 break;
264 case State::Ready:
265 assert(new_state == State::HasWork ||
266 new_state == State::ExitAsSoonAsPossible);
267 break;
268 case State::HasWork:
269 assert(new_state == State::Ready ||
270 new_state == State::ExitAsSoonAsPossible);
271 break;
272 default:
273 abort();
274 }
275 state_ = new_state;
276 pthread_cond_signal(&state_cond_);
277 if (state_ == State::Ready) {
278 counter_to_decrement_when_ready_->DecrementCount();
279 }
280 pthread_mutex_unlock(&state_mutex_);
281 }
282
283 // Thread entry point.
ThreadFunc()284 void ThreadFunc() {
285 ScopedProfilingLabel label("Worker::ThreadFunc");
286 RegisterCurrentThreadForProfiling();
287
288 ChangeState(State::Ready);
289
290 // Thread main loop
291 while (true) {
292 // Get a state to act on
293 // In the 'Ready' state, we have nothing to do but to wait until
294 // we switch to another state.
295 State state_to_act_upon = WaitForVariableChange(
296 &state_, State::Ready, &state_cond_, &state_mutex_);
297
298 // We now have a state to act on, so act.
299 switch (state_to_act_upon) {
300 case State::HasWork:
301 // Got work to do! So do it, and then revert to 'Ready' state.
302 assert(task_);
303 task_->Run();
304 task_ = nullptr;
305 ChangeState(State::Ready);
306 break;
307 case State::ExitAsSoonAsPossible:
308 return;
309 default:
310 abort();
311 }
312 }
313 }
314
ThreadFunc(void * arg)315 static void* ThreadFunc(void* arg) {
316 static_cast<Worker*>(arg)->ThreadFunc();
317 return nullptr;
318 }
319
320 // Called by the master thead to give this worker work to do.
321 // It is only legal to call this if the worker
StartWork(Task * task)322 void StartWork(Task* task) {
323 assert(!task_);
324 task->local_allocator = &local_allocator_;
325 task_ = task;
326 #ifdef GEMMLOWP_USE_BUSYWAIT
327 WriteBarrier();
328 #endif
329 assert(state_ == State::Ready);
330 ChangeState(State::HasWork);
331 }
332
333 private:
334 // The underlying thread.
335 pthread_t thread_;
336
337 // The task to be worked on.
338 Task* task_;
339
340 // The condition variable and mutex guarding state changes.
341 pthread_cond_t state_cond_;
342 pthread_mutex_t state_mutex_;
343
344 // The state enum tells if we're currently working, waiting for work, etc.
345 State state_;
346
347 // Each thread had a local allocator so they can allocate temporary
348 // buffers without blocking each other.
349 Allocator local_allocator_;
350
351 // pointer to the master's thread BlockingCounter object, to notify the
352 // master thread of when this worker switches to the 'Ready' state.
353 BlockingCounter* const counter_to_decrement_when_ready_;
354 };
355
356 // A very simple pool of workers, that only allows the very
357 // specific parallelization pattern that we use here:
358 // a fixed number of workers can be given work, and one then
359 // waits for all of them to finish.
360 //
361 // See MultiThreadGemmContextBase for how other WorkersPool implementations can
362 // be used. Note that in those implementations, StartWorker can be free to
363 // ignore the <index> value; that is, the caller of WorkersPool does not rely on
364 // <index> to order tasks with equal <index>.
365 class WorkersPool {
366 public:
WorkersPool()367 WorkersPool() {}
368
~WorkersPool()369 ~WorkersPool() {
370 for (auto w : workers_) {
371 delete w;
372 }
373 }
374
Execute(const std::vector<Task * > & tasks)375 void Execute(const std::vector<Task*>& tasks) {
376 assert(tasks.size() >= 1);
377 // One of the tasks will be run on the current thread.
378 std::size_t workers_count = tasks.size() - 1;
379 CreateWorkers(workers_count);
380 assert(workers_count <= workers_.size());
381 counter_to_decrement_when_ready_.Reset(workers_count);
382 int n = 0;
383 std::for_each(tasks.begin(), --tasks.end(),
384 [this, &n](Task* task) { workers_[n++]->StartWork(task); });
385 // Execute the remaining workload immediately on the current thread.
386 Task* task = tasks.back();
387 task->local_allocator = &main_thread_task_allocator_;
388 task->Run();
389 // Wait for the workers submitted above to finish.
390 counter_to_decrement_when_ready_.Wait();
391 // Cleanup tasks (best to do this from the same thread that allocated
392 // the memory).
393 std::for_each(tasks.begin(), tasks.end(), [](Task* task) { delete task; });
394 }
395
396 private:
397 // Ensures that the pool has at least the given count of workers.
398 // If any new worker has to be created, this function waits for it to
399 // be ready.
CreateWorkers(std::size_t workers_count)400 void CreateWorkers(std::size_t workers_count) {
401 if (workers_.size() >= workers_count) {
402 return;
403 }
404 counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
405 while (workers_.size() < workers_count) {
406 workers_.push_back(new Worker(&counter_to_decrement_when_ready_));
407 }
408 counter_to_decrement_when_ready_.Wait();
409 }
410
411 // copy construction disallowed
412 WorkersPool(const WorkersPool&) = delete;
413
414 // The workers in this pool. They are owned by the pool:
415 // the pool creates workers and destroys them in its destructor.
416 std::vector<Worker*> workers_;
417
418 // The BlockingCounter used to wait for the workers.
419 BlockingCounter counter_to_decrement_when_ready_;
420
421 // For N-threaded operations, we will use only N-1 worker threads
422 // while the last task will be run directly on the main thread.
423 // It will then use this main_thread_task_allocator_; having a
424 // dedicated allocator for that (separate from the base allocator_)
425 // allows to use the same code for all tasks regardless of which
426 // thread they run on.
427 Allocator main_thread_task_allocator_;
428 };
429
430 // The task we use to implement a multi-threaded Gemm: a block of the
431 // RHS has been packed by the master thread; each worker thread
432 // then has to pack a block of the LHS and accumulate the Gemm of these
433 // packed LHS and RHS blocks.
434 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
435 typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
436 MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
437 typename OutputPipelineType, typename GemmContextType>
438 struct GemmWithPackedRhsTask : Task {
439 typedef PackedSideBlock<typename KernelFormat::Lhs> PackedLhs;
440 typedef PackedSideBlock<typename KernelFormat::Rhs> PackedRhs;
GemmWithPackedRhsTaskGemmWithPackedRhsTask441 GemmWithPackedRhsTask(GemmContextType* _context, const KernelBase& _kernel,
442 const MatrixMap<const InputScalar, LhsOrder>& _lhs,
443 const PackedRhs& _packed_rhs,
444 MatrixMap<OutputScalar, ResultOrder>* _result,
445 const MatrixBlockBounds& _result_block,
446 const LhsOffset& _lhs_offset,
447 const RhsOffset& _rhs_offset,
448 const BlockParams& _block_params,
449 const OutputPipelineType& _output_pipeline)
450 : context(_context),
451 kernel(_kernel),
452 lhs(_lhs),
453 packed_rhs(_packed_rhs),
454 result(*_result),
455 result_block(_result_block),
456 lhs_offset(_lhs_offset),
457 rhs_offset(_rhs_offset),
458 block_params(_block_params),
459 output_pipeline(_output_pipeline) {}
460
RunGemmWithPackedRhsTask461 void Run() override {
462 ScopedProfilingLabel label("GemmWithPackedRhsTask");
463
464 const int rows = result_block.rows;
465 const int cols = result_block.cols;
466 const int depth = lhs.cols();
467
468 PackedLhs packed_lhs(Side::Lhs, local_allocator, block_params);
469
470 PackedResult packed_result(local_allocator, block_params);
471
472 local_allocator->Commit();
473
474 for (int c = 0; c < cols; c += block_params.l2_cols) {
475 int cs = std::min(block_params.l2_cols, cols - c);
476
477 for (int r = 0; r < rows; r += block_params.l2_rows) {
478 int rs = std::min(block_params.l2_rows, rows - r);
479
480 PackLhs(&packed_lhs, lhs.block(r, 0, rs, depth));
481
482 Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs,
483 depth);
484
485 auto curr_result_block = MatrixBlockBounds(
486 result_block.start_row + r, result_block.start_col + c, rs, cs);
487 UnpackResult<KernelFormat>(
488 &result, curr_result_block, packed_result, depth,
489 packed_lhs.sums_of_each_slice(), packed_rhs.sums_of_each_slice(),
490 lhs_offset.block(curr_result_block.start_row, rs),
491 rhs_offset.block(curr_result_block.start_col, cs), output_pipeline);
492 }
493 }
494
495 local_allocator->Decommit();
496 }
497
498 const GemmContextType* context;
499 const KernelBase& kernel;
500 const MatrixMap<const InputScalar, LhsOrder> lhs;
501 const PackedRhs packed_rhs;
502 MatrixMap<OutputScalar, ResultOrder> result;
503 const MatrixBlockBounds result_block;
504 const LhsOffset& lhs_offset;
505 const RhsOffset& rhs_offset;
506 const BlockParams& block_params;
507 const OutputPipelineType& output_pipeline;
508 };
509
510 // This base class for multi-threading allows subclasses to implement their own
511 // workers_pool() method. See MultiThreadGemmContext below for an example;
512 // any other implementation of workers_pool() must return an object with the
513 // same public methods as WorkersPool.
514 class MultiThreadGemmContextBase : public SingleThreadGemmContext {
515 public:
set_max_num_threads(int n)516 void set_max_num_threads(int n) { max_num_threads_ = n; }
517
max_num_threads()518 int max_num_threads() const { return max_num_threads_; }
519
520 protected:
521 // The maximum number of worker threads to use (including
522 // the master thread).
523 // The default value 1 means single-threading. That is the default
524 // because gemmlowp's primary target is mobile hardware, where thermal
525 // constraints usually mean that it may not be realistic to use more
526 // than 1 CPU core even if multiple cores are present.
527 // The special value 0 means try to detect the number of hardware threads.
528 // Note: this assumes that all CPU cores are equivalent. That assumption
529 // is defeated on big.LITTLE ARM devices, where we have no API to query
530 // the number of big cores (which is typically what we would want to use,
531 // leaving aside above-mentioned thermal issues). That is the other reason
532 // why the best compromise here is to let max_num_threads_ default to 1,
533 // so users who want multi-threading have to make the decision of how many
534 // threads to use by themselves.
535 int max_num_threads_ = 1;
536 };
537
538 class MultiThreadGemmContext : public MultiThreadGemmContextBase {
539 public:
workers_pool()540 WorkersPool* workers_pool() { return &workers_pool_; }
541
542 private:
543 // The workers pool used by MultiThreadGemm. Making
544 // this part of the context allows it to be persistent,
545 // avoiding recreating threads on every Gemm.
546 WorkersPool workers_pool_;
547 };
548
549 // Determines how many threads should be used for a given Gemm
550 // operation.
551 template <int KernelRows>
HowManyThreads(int max_num_threads,int rows,int cols,int depth)552 inline int HowManyThreads(int max_num_threads, int rows, int cols, int depth) {
553 // Early-exit in the default case where multi-threading is disabled.
554 if (max_num_threads == 1) {
555 return 1;
556 }
557
558 // Determine the maximum number of threads.
559 int max_count = GetHardwareConcurrency(max_num_threads);
560
561 // Basic calculation: take into account max pool size, and
562 // how many rows we have to feed our kernel.
563 // The motivation for an absolute minimum number of rows per thread,
564 // potentially higher than KernelRows, is that very thin thread workload
565 // currently defeat assumptions of the AddMod generator, resulting
566 // in substantial bias in TestWithRealData on 24 threads.
567 // Ideally, the AddMod generator should be aware of global (r,c) coordinates
568 // so as to be independent of the number of threads.
569 static const int AbsoluteMinRowsPerThread = 16;
570 static const int MinRowsPerThread = KernelRows > AbsoluteMinRowsPerThread
571 ? KernelRows
572 : AbsoluteMinRowsPerThread;
573 int thread_count = std::min(max_count, CeilQuotient(rows, MinRowsPerThread));
574
575 // At this point for small products we already have thread_count==1 so
576 // we can avoid doing more work; otherwise, we still want to check
577 // that the cubic size (rows*cols*depth) is big enough to keep
578 // workers_ busy.
579 if (thread_count > 1) {
580 // Empirically determined value.
581 static const std::uint64_t min_cubic_size_per_thread = 64 * 1024;
582
583 // We can only multiply two out of three sizes without risking overflow
584 const std::uint64_t cubic_size =
585 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
586
587 thread_count =
588 std::min(thread_count, int(cubic_size / min_cubic_size_per_thread));
589
590 if (thread_count < 1) {
591 thread_count = 1;
592 }
593 }
594
595 assert(thread_count > 0 && thread_count <= max_count);
596 return thread_count;
597 }
598
599 // The main multi-threaded Gemm function.
600 // To understand it, first read the code of SingleThreadGemm().
601 // The parallelization scheme used here is to have this master function
602 // pack a block of RHS and then start worker threads to pack a block of LHS
603 // each, and accumulate the corresponding products.
604 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
605 typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
606 MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
607 typename OutputPipelineType, typename GemmContextType>
MultiThreadGemm(GemmContextType * context,const KernelBase & kernel,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)608 void MultiThreadGemm(GemmContextType* context, const KernelBase& kernel,
609 const MatrixMap<const InputScalar, LhsOrder>& lhs,
610 const MatrixMap<const InputScalar, RhsOrder>& rhs,
611 MatrixMap<OutputScalar, ResultOrder>* result,
612 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
613 const OutputPipelineType& output_pipeline) {
614 ScopedProfilingLabel label("gemmlowp::MultiThreadGemm");
615
616 assert(lhs.cols() == rhs.rows());
617
618 int rows = result->rows();
619 int cols = result->cols();
620 int depth = lhs.cols();
621
622 // zero sizes should have been caught earlier and early-returned.
623 assert(rows > 0);
624 assert(cols > 0);
625 assert(depth > 0);
626
627 // The case of rows<cols should have been caught earlier and transposed.
628 assert(rows >= cols);
629
630 const int thread_count = HowManyThreads<KernelFormat::kRows>(
631 context->max_num_threads(), rows, cols, depth);
632 if (thread_count == 1) {
633 return SingleThreadGemm<KernelFormat, InputScalar, OutputScalar,
634 BitDepthParams>(context, kernel, lhs, rhs, result,
635 lhs_offset, rhs_offset,
636 output_pipeline);
637 }
638 assert(thread_count > 1);
639
640 // Simple 1:1 mapping of tasks to physical cores, which is very important
641 // to getting good multithreaded performance, specially for not-very-large
642 // GEMMs, and especially on Android.
643 const int task_count = thread_count;
644
645 Allocator* allocator = context->allocator();
646 auto* workers_pool = context->workers_pool();
647
648 BlockParams block_params;
649 block_params.Init<KernelFormat>(
650 rows, cols, depth, task_count, context->l1_bytes_to_use(),
651 context->l2_bytes_to_use(), context->l2_rhs_factor());
652
653 PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(Side::Rhs, allocator,
654 block_params);
655 allocator->Commit();
656
657 // We loop over large blocks of the RHS.
658 for (int c = 0; c < cols; c += block_params.l2_cols) {
659 int cs = std::min(block_params.l2_cols, cols - c);
660
661 // Pack a large block of the RHS.
662 PackRhs(&packed_rhs, rhs.block(0, c, depth, cs));
663
664 // Give work to each worker.
665 std::vector<Task*> tasks;
666 int next_start_row = 0;
667 for (int n = 0; n < task_count; ++n) {
668 int start_row = next_start_row;
669 next_start_row = std::min(
670 rows, RoundUp<KernelFormat::kRows>(rows * (n + 1) / task_count));
671
672 int block_rows = next_start_row - start_row;
673 auto lhs_block = lhs.block(start_row, 0, block_rows, depth);
674 typedef GemmWithPackedRhsTask<KernelFormat, InputScalar, OutputScalar,
675 BitDepthParams, LhsOrder, RhsOrder,
676 ResultOrder, LhsOffset, RhsOffset,
677 OutputPipelineType, GemmContextType>
678 TaskType;
679 tasks.push_back(
680 new TaskType(context, kernel, lhs_block, packed_rhs, result,
681 MatrixBlockBounds(start_row, c, block_rows, cs),
682 lhs_offset, rhs_offset, block_params, output_pipeline));
683 }
684 // Execute the work on the workers (and partially on this thread).
685 workers_pool->Execute(tasks);
686 }
687
688 allocator->Decommit();
689 }
690
691 } // namespace gemmlowp
692
693 #endif // GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
694