1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // multi_thread_gemm.h: Multi-threaded GEMM entry point.
16 // Readers note: To understand this file, it is useful to first
17 // read and understand the much simpler single_thread_gemm.h.
18
19 #ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
20 #define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
21
22 #include <pthread.h>
23 #include <unistd.h>
24 #include <vector>
25
26 #include "single_thread_gemm.h"
27
28 namespace gemmlowp {
29
30 #ifdef GEMMLOWP_ALLOW_INLINE_ASM
31 // Where inline asm is allowed, we use some busy-waiting,
32 // preferably implemented using NOP instructions.
33 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
34
35 #define GEMMLOWP_NOP "nop\n"
36
37 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X
38 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
39 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
40 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
41 #define GEMMLOWP_NOP256 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP64)
42
Do256NOPs()43 inline int Do256NOPs() {
44 asm volatile(GEMMLOWP_NOP256);
45 return 256;
46 }
47
48 #undef GEMMLOWP_STRING_CONCAT_4
49 #undef GEMMLOWP_NOP256
50 #undef GEMMLOWP_NOP64
51 #undef GEMMLOWP_NOP16
52 #undef GEMMLOWP_NOP4
53 #undef GEMMLOWP_NOP
54
55 #else // not GEMMLOWP_ALLOW_INLINE_ASM
56
57 // It is nontrivial to implement a good busy-waiting without
58 // using asm; NOP instructions have the least side effects
59 // and the lowest power usage; and since the whole busy-waiting
60 // story is an optimization, it's not very interesting anyway
61 // in places where we're slow anyway due to not being able to
62 // use our inline asm kernels.
63
64 const int kMaxBusyWaitNOPs = 0;
65 inline int Do256NOPs() { return 0; }
66
67 #endif // not GEMMLOWP_ALLOW_INLINE_ASM
68
WriteBarrier()69 inline void WriteBarrier() {
70 #ifdef GEMMLOWP_ARM_32
71 MemoryBarrier();
72 #elif defined(GEMMLOWP_ARM_64)
73 asm volatile("dmb ishst" ::: "memory");
74 #elif defined(GEMMLOWP_X86)
75 asm volatile("sfence" ::: "memory");
76 #elif defined(__mips__)
77 MemoryBarrier();
78 #else
79 #error "Unsupported architecture for WriteBarrier."
80 #endif
81 }
82
ReadBarrier()83 inline void ReadBarrier() {
84 #ifdef GEMMLOWP_ARM_32
85 MemoryBarrier();
86 #elif defined(GEMMLOWP_ARM_64)
87 asm volatile("dmb ishld" ::: "memory");
88 #elif defined(GEMMLOWP_X86)
89 asm volatile("lfence" ::: "memory");
90 #elif defined(__mips__)
91 MemoryBarrier();
92 #else
93 #error "Unsupported architecture for ReadBarrier."
94 #endif
95 }
96
97 // Waits until *var != initial_value.
98 //
99 // Returns the new value of *var. The guarantee here is that
100 // the return value is different from initial_value, and that that
101 // new value has been taken by *var at some point during the
102 // execution of this function. There is no guarantee that this is
103 // still the value of *var when this function returns, since *var is
104 // not assumed to be guarded by any lock.
105 //
106 // First does some busy-waiting for a fixed number of no-op cycles,
107 // then falls back to passive waiting for the given condvar, guarded
108 // by the given mutex.
109 //
110 // The idea of doing some initial busy-waiting is to help get
111 // better and more consistent multithreading benefits for small GEMM sizes.
112 // Busy-waiting help ensuring that if we need to wake up soon after having
113 // started waiting, then we can wake up quickly (as opposed to, say,
114 // having to wait to be scheduled again by the OS). On the other hand,
115 // we must still eventually revert to passive waiting for longer waits
116 // (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
117 // so as to avoid permanently spinning.
118 //
119 template <typename T>
WaitForVariableChange(volatile T * var,T initial_value,pthread_cond_t * cond,pthread_mutex_t * mutex)120 T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
121 pthread_mutex_t* mutex) {
122 int nops = 0;
123 // First, trivial case where the variable already changed value.
124 T new_value = *var;
125 if (new_value != initial_value) {
126 return new_value;
127 }
128 // Then try busy-waiting.
129 while (nops < kMaxBusyWaitNOPs) {
130 nops += Do256NOPs();
131 new_value = *var;
132 if (new_value != initial_value) {
133 return new_value;
134 }
135 }
136 // Finally, do real passive waiting.
137 pthread_mutex_lock(mutex);
138 new_value = *var;
139 if (new_value == initial_value) {
140 pthread_cond_wait(cond, mutex);
141 new_value = *var;
142 assert(new_value != initial_value);
143 }
144 pthread_mutex_unlock(mutex);
145 return new_value;
146 }
147
148 // A BlockingCounter lets one thread to wait for N events to occur.
149 // This is how the master thread waits for all the worker threads
150 // to have finished working.
151 class BlockingCounter {
152 public:
BlockingCounter()153 BlockingCounter()
154 : cond_(PTHREAD_COND_INITIALIZER),
155 mutex_(PTHREAD_MUTEX_INITIALIZER),
156 count_(0),
157 initial_count_(0) {}
158
159 // Sets/resets the counter; initial_count is the number of
160 // decrementing events that the Wait() call will be waiting for.
Reset(std::size_t initial_count)161 void Reset(std::size_t initial_count) {
162 pthread_mutex_lock(&mutex_);
163 assert(count_ == 0);
164 initial_count_ = initial_count;
165 count_ = initial_count_;
166 pthread_mutex_unlock(&mutex_);
167 }
168
169 // Decrements the counter; if the counter hits zero, signals
170 // the thread that was waiting for that, and returns true.
171 // Otherwise (if the decremented count is still nonzero),
172 // returns false.
DecrementCount()173 bool DecrementCount() {
174 pthread_mutex_lock(&mutex_);
175 assert(count_ > 0);
176 count_--;
177 if (count_ == 0) {
178 pthread_cond_signal(&cond_);
179 }
180 bool retval = count_ == 0;
181 pthread_mutex_unlock(&mutex_);
182 return retval;
183 }
184
185 // Waits for the N other threads (N having been set by Reset())
186 // to hit the BlockingCounter.
Wait()187 void Wait() {
188 ScopedProfilingLabel label("BlockingCounter::Wait");
189 while (count_) {
190 MemoryBarrier();
191 const std::size_t count_value = count_;
192 if (count_value) {
193 WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
194 }
195 }
196 }
197
198 private:
199 pthread_cond_t cond_;
200 pthread_mutex_t mutex_;
201 std::size_t count_;
202 std::size_t initial_count_;
203 };
204
205 // A workload for a worker.
206 struct Task {
TaskTask207 Task() : local_allocator(nullptr) {}
~TaskTask208 virtual ~Task() {}
209 virtual void Run() const = 0;
210 Allocator* local_allocator;
211 };
212
213 // A worker thread.
214 class Worker {
215 public:
216 enum class State {
217 ThreadStartup, // The initial state before the thread main loop runs.
218 Ready, // Is not working, has not yet received new work to do.
219 HasWork, // Has work to do.
220 ExitAsSoonAsPossible // Should exit at earliest convenience.
221 };
222
Worker(BlockingCounter * counter_to_decrement_when_ready)223 explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
224 : task_(nullptr),
225 state_cond_(PTHREAD_COND_INITIALIZER),
226 state_mutex_(PTHREAD_MUTEX_INITIALIZER),
227 state_(State::ThreadStartup),
228 counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
229 pthread_create(&thread_, nullptr, ThreadFunc, this);
230 }
231
~Worker()232 ~Worker() {
233 ChangeState(State::ExitAsSoonAsPossible);
234 pthread_join(thread_, nullptr);
235 }
236
237 // Changes State; may be called from either the worker thread
238 // or the master thread; however, not all state transitions are legal,
239 // which is guarded by assertions.
ChangeState(State new_state)240 void ChangeState(State new_state) {
241 ScopedProfilingLabel label("Worker::ChangeState");
242 pthread_mutex_lock(&state_mutex_);
243 assert(new_state != state_);
244 switch (state_) {
245 case State::ThreadStartup:
246 assert(new_state == State::Ready);
247 break;
248 case State::Ready:
249 assert(new_state == State::HasWork ||
250 new_state == State::ExitAsSoonAsPossible);
251 break;
252 case State::HasWork:
253 assert(new_state == State::Ready ||
254 new_state == State::ExitAsSoonAsPossible);
255 break;
256 default:
257 abort();
258 }
259 state_ = new_state;
260 pthread_cond_signal(&state_cond_);
261 if (state_ == State::Ready) {
262 counter_to_decrement_when_ready_->DecrementCount();
263 }
264 pthread_mutex_unlock(&state_mutex_);
265 }
266
267 // Thread entry point.
ThreadFunc()268 void ThreadFunc() {
269 ScopedProfilingLabel label("Worker::ThreadFunc");
270 RegisterCurrentThreadForProfiling();
271
272 ChangeState(State::Ready);
273
274 // Thread main loop
275 while (true) {
276 // Get a state to act on
277 // In the 'Ready' state, we have nothing to do but to wait until
278 // we switch to another state.
279 State state_to_act_upon = WaitForVariableChange(
280 &state_, State::Ready, &state_cond_, &state_mutex_);
281
282 // We now have a state to act on, so act.
283 switch (state_to_act_upon) {
284 case State::HasWork:
285 // Got work to do! So do it, and then revert to 'Ready' state.
286 ReadBarrier();
287 assert(task_);
288 task_->Run();
289 delete task_;
290 task_ = nullptr;
291 ChangeState(State::Ready);
292 break;
293 case State::ExitAsSoonAsPossible:
294 return;
295 default:
296 abort();
297 }
298 }
299 }
300
ThreadFunc(void * arg)301 static void* ThreadFunc(void* arg) {
302 static_cast<Worker*>(arg)->ThreadFunc();
303 return nullptr;
304 }
305
306 // Called by the master thead to give this worker work to do.
307 // It is only legal to call this if the worker
StartWork(Task * task)308 void StartWork(Task* task) {
309 assert(!task_);
310 task->local_allocator = &local_allocator_;
311 task_ = task;
312 WriteBarrier();
313 assert(state_ == State::Ready);
314 ChangeState(State::HasWork);
315 }
316
317 private:
318 // The underlying thread.
319 pthread_t thread_;
320
321 // The task to be worked on.
322 const Task* task_;
323
324 // The condition variable and mutex guarding state changes.
325 pthread_cond_t state_cond_;
326 pthread_mutex_t state_mutex_;
327
328 // The state enum tells if we're currently working, waiting for work, etc.
329 State state_;
330
331 // Each thread had a local allocator so they can allocate temporary
332 // buffers without blocking each other.
333 Allocator local_allocator_;
334
335 // pointer to the master's thread BlockingCounter object, to notify the
336 // master thread of when this worker switches to the 'Ready' state.
337 BlockingCounter* const counter_to_decrement_when_ready_;
338 };
339
340 // A very simple pool of workers, that only allows the very
341 // specific parallelization pattern that we use here:
342 // a fixed number of workers can be given work, and one then
343 // waits for all of them to finish.
344 class WorkersPool {
345 public:
WorkersPool()346 WorkersPool() {}
347
~WorkersPool()348 ~WorkersPool() {
349 for (auto w : workers_) {
350 delete w;
351 }
352 }
353
counter_to_decrement_when_ready()354 BlockingCounter& counter_to_decrement_when_ready() {
355 return counter_to_decrement_when_ready_;
356 }
357
358 // Give work to a specific worker.
StartWorker(int index,Task * task_)359 void StartWorker(int index, Task* task_) {
360 assert(static_cast<std::size_t>(index) < workers_.size());
361 workers_[index]->StartWork(task_);
362 }
363
364 // Ensures that the pool has at least the given count of workers.
365 // If any new worker has to be created, this function waits for it to
366 // be ready.
CreateWorkers(std::size_t workers_count)367 void CreateWorkers(std::size_t workers_count) {
368 if (workers_.size() >= workers_count) {
369 return;
370 }
371 counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
372 while (workers_.size() < workers_count) {
373 workers_.push_back(new Worker(&counter_to_decrement_when_ready_));
374 }
375 counter_to_decrement_when_ready_.Wait();
376 }
377
378 private:
379 // copy construction disallowed
380 WorkersPool(const WorkersPool&) = delete;
381
382 // The workers in this pool. They are owned by the pool:
383 // the pool creates workers and destroys them in its destructor.
384 std::vector<Worker*> workers_;
385
386 // The BlockingCounter used to wait for the workers.
387 BlockingCounter counter_to_decrement_when_ready_;
388 };
389
390 // The task we use to implement a multi-threaded Gemm: a block of the
391 // RHS has been packed by the master thread; each worker thread
392 // then has to pack a block of the LHS and accumulate the Gemm of these
393 // packed LHS and RHS blocks.
394 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
395 typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
396 MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
397 typename OutputPipelineType>
398 struct GemmWithPackedRhsTask : Task {
399 typedef PackedSideBlock<typename KernelFormat::Lhs> PackedLhs;
400 typedef PackedSideBlock<typename KernelFormat::Rhs> PackedRhs;
GemmWithPackedRhsTaskGemmWithPackedRhsTask401 GemmWithPackedRhsTask(const KernelBase& _kernel,
402 const MatrixMap<const InputScalar, LhsOrder>& _lhs,
403 const PackedRhs& _packed_rhs,
404 MatrixMap<OutputScalar, ResultOrder>* _result,
405 const LhsOffset& _lhs_offset,
406 const RhsOffset& _rhs_offset,
407 const OutputPipelineType& _output_pipeline)
408 : kernel(_kernel),
409 lhs(_lhs),
410 packed_rhs(_packed_rhs),
411 result(*_result),
412 lhs_offset(_lhs_offset),
413 rhs_offset(_rhs_offset),
414 output_pipeline(_output_pipeline) {}
415
RunGemmWithPackedRhsTask416 void Run() const override {
417 ScopedProfilingLabel label("GemmWithPackedRhsTask");
418
419 const int rows = result.rows();
420 const int cols = result.cols();
421 const int depth = lhs.cols();
422
423 BlockParams block_params;
424 block_params.Init<KernelFormat>(rows, cols, depth, 1);
425
426 PackedLhs packed_lhs(Side::Lhs, local_allocator, block_params);
427
428 PackedResult packed_result(local_allocator, block_params);
429
430 local_allocator->Commit();
431
432 for (int c = 0; c < cols; c += block_params.l2_cols) {
433 int cs = std::min(block_params.l2_cols, cols - c);
434
435 for (int r = 0; r < rows; r += block_params.l2_rows) {
436 int rs = std::min(block_params.l2_rows, rows - r);
437
438 PackLhs<BitDepthParams>(&packed_lhs, lhs.block(r, 0, rs, depth));
439
440 Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs);
441
442 auto result_block = result.block(r, c, rs, cs);
443 UnpackResult<BitDepthParams>(&result_block, packed_result, depth,
444 packed_lhs.sums_of_each_slice(),
445 packed_rhs.sums_of_each_slice(),
446 lhs_offset, rhs_offset, output_pipeline);
447 }
448 }
449
450 local_allocator->Decommit();
451 }
452
453 const KernelBase& kernel;
454 const MatrixMap<const InputScalar, LhsOrder> lhs;
455 const PackedRhs packed_rhs;
456 MatrixMap<OutputScalar, ResultOrder> result;
457 const LhsOffset& lhs_offset;
458 const RhsOffset& rhs_offset;
459 const OutputPipelineType& output_pipeline;
460 };
461
462 class MultiThreadGemmContext : public SingleThreadGemmContext {
463 public:
MultiThreadGemmContext()464 MultiThreadGemmContext() : max_num_threads_(0) {}
465
set_max_num_threads(int n)466 void set_max_num_threads(int n) { max_num_threads_ = n; }
467
max_num_threads()468 int max_num_threads() const { return max_num_threads_; }
469
workers_pool()470 WorkersPool* workers_pool() { return &workers_pool_; }
471
main_thread_task_allocator()472 Allocator* main_thread_task_allocator() {
473 return &main_thread_task_allocator_;
474 }
475
476 protected:
477 // The workers pool used by MultiThreadGemm. Making
478 // this part of the context allows it to be persistent,
479 // avoiding recreating threads on every Gemm.
480 WorkersPool workers_pool_;
481
482 // The maximum number of worker threads to use (in addition
483 // to the master thread).
484 // The default value 0 means the default behavior of
485 // detecting the number of hardware threads. Nonzero values mean
486 // skipping and overriding hardware detection.
487 int max_num_threads_;
488
489 // For N-threaded operations, we will use only N-1 worker threads
490 // while the last task will be run directly on the main thread.
491 // It will then use this main_thread_task_allocator_; having a
492 // dedicated allocator for that (separate from the base allocator_)
493 // allows to use the same code for all tasks regardless of which
494 // thread they run on.
495 Allocator main_thread_task_allocator_;
496 };
497
498 // Determines how many threads should be used for a given Gemm
499 // operation.
500 template <int KernelRows>
HowManyThreads(MultiThreadGemmContext * context,int rows,int cols,int depth)501 inline int HowManyThreads(MultiThreadGemmContext* context, int rows, int cols,
502 int depth) {
503 // First check if the user set an explicit maximum number of threads.
504 int max_count = context->max_num_threads();
505 if (!max_count) {
506 // No user-set maximum number of threads, so we need to
507 // do some hardware detection.
508 // This is expensive to query so we do it only once.
509 // Too bad for dynamicness. Also, we dont use the c++11 standard getter
510 // because Google's coding style currently bans #include <thread_>.
511 static const int hardware_threads_count =
512 static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
513
514 max_count = hardware_threads_count;
515 }
516
517 // Basic calculation: take into account max pool size, and
518 // how many rows we have to feed our kernel.
519 // The motivation for an absolute minimum number of rows per thread,
520 // potentially higher than KernelRows, is that very thin thread workload
521 // currently defeat assumptions of the AddMod generator, resulting
522 // in substantial bias in TestWithRealData on 24 threads.
523 // Ideally, the AddMod generator should be aware of global (r,c) coordinates
524 // so as to be independent of the number of threads.
525 static const int AbsoluteMinRowsPerThread = 16;
526 static const int MinRowsPerThread = KernelRows > AbsoluteMinRowsPerThread
527 ? KernelRows
528 : AbsoluteMinRowsPerThread;
529 int thread_count = std::min(max_count, CeilQuotient(rows, MinRowsPerThread));
530
531 // At this point for small products we already have thread_count==1 so
532 // we can avoid doing more work; otherwise, we still want to check
533 // that the cubic size (rows*cols*depth) is big enough to keep
534 // workers_ busy.
535 if (thread_count > 1) {
536 // Empirically determined value.
537 static const std::uint64_t min_cubic_size_per_thread = 64 * 1024;
538
539 // We can only multiply two out of three sizes without risking overflow
540 const std::uint64_t cubic_size =
541 std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
542
543 thread_count =
544 std::min(thread_count, int(cubic_size / min_cubic_size_per_thread));
545
546 if (thread_count < 1) {
547 thread_count = 1;
548 }
549 }
550
551 assert(thread_count > 0 && thread_count <= max_count);
552 return thread_count;
553 }
554
555 // The main multi-threaded Gemm function.
556 // To understand it, first read the code of SingleThreadedGemm().
557 // The parallelization scheme used here is to have this master function
558 // pack a block of RHS and then start worker threads to pack a block of LHS
559 // each, and accumulate the corresponding products.
560 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
561 typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
562 MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
563 typename OutputPipelineType>
MultiThreadGemm(MultiThreadGemmContext * context,const KernelBase & kernel,const MatrixMap<const InputScalar,LhsOrder> & lhs,const MatrixMap<const InputScalar,RhsOrder> & rhs,MatrixMap<OutputScalar,ResultOrder> * result,const LhsOffset & lhs_offset,const RhsOffset & rhs_offset,const OutputPipelineType & output_pipeline)564 void MultiThreadGemm(MultiThreadGemmContext* context, const KernelBase& kernel,
565 const MatrixMap<const InputScalar, LhsOrder>& lhs,
566 const MatrixMap<const InputScalar, RhsOrder>& rhs,
567 MatrixMap<OutputScalar, ResultOrder>* result,
568 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
569 const OutputPipelineType& output_pipeline) {
570 ScopedProfilingLabel label("gemmlowp::MultiThreadGemm");
571
572 assert(lhs.cols() == rhs.rows());
573
574 int rows = result->rows();
575 int cols = result->cols();
576 int depth = lhs.cols();
577
578 assert(rows > 0);
579 assert(cols > 0);
580 assert(depth > 0);
581
582 const int thread_count =
583 HowManyThreads<KernelFormat::kRows>(context, rows, cols, depth);
584 if (thread_count == 1) {
585 return SingleThreadGemm<KernelFormat, InputScalar, OutputScalar,
586 BitDepthParams>(context, kernel, lhs, rhs, result,
587 lhs_offset, rhs_offset,
588 output_pipeline);
589 }
590 assert(thread_count > 1);
591
592 // We choose to use a worker thread for all but one
593 // of the thread workloads. The remaining thread workload will be
594 // executed immediately on the current thread.
595 // In this way, the total number of threads (1 master, N-1 workers)
596 // equals the value returned by HowManyThread. This simple
597 // 1:1 mapping of threads to physical cores, is very important
598 // to getting good multithreaded performance especially for
599 // not-very-large GEMMs, and especially on Android.
600 const int workers_count = thread_count - 1;
601
602 Allocator* allocator = context->allocator();
603 WorkersPool* workers_pool = context->workers_pool();
604
605 workers_pool->CreateWorkers(workers_count);
606
607 BlockParams block_params;
608 block_params.Init<KernelFormat>(rows, cols, depth, workers_count);
609
610 PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(
611 Side::Rhs, allocator, block_params);
612 allocator->Commit();
613
614 // We loop over large blocks of the RHS.
615 for (int c = 0; c < cols; c += block_params.l2_cols) {
616 int cs = std::min(block_params.l2_cols, cols - c);
617
618 // Pack a large block of the RHS.
619 PackRhs<BitDepthParams>(&packed_rhs, rhs.block(0, c, depth, cs));
620
621 // Give work to each worker.
622 int next_start_row = 0;
623 workers_pool->counter_to_decrement_when_ready().Reset(workers_count);
624 for (int thread = 0; thread < thread_count; thread++) {
625 int start_row = next_start_row;
626 next_start_row = std::min(rows, RoundUp<KernelFormat::kRows>(
627 rows * (thread + 1) / thread_count));
628
629 int block_rows = next_start_row - start_row;
630 auto lhs_block = lhs.block(start_row, 0, block_rows, depth);
631 auto result_block = result->block(start_row, c, block_rows, cs);
632 typedef GemmWithPackedRhsTask<KernelFormat, InputScalar, OutputScalar,
633 BitDepthParams, LhsOrder, RhsOrder,
634 ResultOrder, LhsOffset, RhsOffset,
635 OutputPipelineType>
636 TaskType;
637 auto task = new TaskType(kernel, lhs_block, packed_rhs, &result_block,
638 lhs_offset, rhs_offset, output_pipeline);
639 if (thread < workers_count) {
640 workers_pool->StartWorker(thread, task);
641 } else {
642 // Execute the remaining workload immediately on the current thread.
643 task->local_allocator = context->main_thread_task_allocator();
644 task->Run();
645 delete task;
646 }
647 }
648 // Wait for the workers.
649 workers_pool->counter_to_decrement_when_ready().Wait();
650 }
651
652 allocator->Decommit();
653 }
654
655 } // namespace gemmlowp
656
657 #endif // GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
658