• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_
16 #define GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_
17 
18 #include "multi_thread_common.h"
19 #include "single_thread_transform.h"
20 
21 namespace gemmlowp {
22 namespace meta {
23 namespace internal {
24 
25 const int kTransformTaskOverhead = 128000;
26 const int kMinTransformTaskSize = 32000;
27 
28 template <typename MultiThreadingContext, typename Params>
PrepareTransform1DTasks(MultiThreadingContext * context,const Params & params,int kernel_size,std::vector<Params> * task_params)29 inline bool PrepareTransform1DTasks(MultiThreadingContext* context,
30                                     const Params& params, int kernel_size,
31                                     std::vector<Params>* task_params) {
32   typedef Transform1DUtil<typename Params::InType, typename Params::OutType,
33                           typename Params::Kernel>
34       Util;
35 
36   const int max_threads = ResolveMaxThreads(context->max_num_threads());
37   const int task_size = Util::EstimateComputeCost(params.kernel);
38   const int max_tasks_by_size =
39       (task_size - kTransformTaskOverhead) / kMinTransformTaskSize;
40 
41   const int real_tasks = std::max(1, std::min(max_threads, max_tasks_by_size));
42 
43   if (real_tasks == 1) {
44     return false;
45   }
46 
47   const int chunk = params.kernel.count / real_tasks;
48   for (int i = 0; i < real_tasks - 1; ++i) {
49     task_params->push_back(params);
50     Params& task = task_params->back();
51     task.kernel.count = chunk;
52     task.input = Util::OffsetInput(params.kernel, params.input, i * chunk);
53     task.output = Util::OffsetOutput(params.kernel, params.output, i * chunk);
54   }
55   task_params->push_back(params);
56   Params& task = task_params->back();
57   const int sum_chunk = (real_tasks - 1) * chunk;
58   task.kernel.count = params.kernel.count - sum_chunk;
59   task.input = Util::OffsetInput(params.kernel, params.input, sum_chunk);
60   task.output = Util::OffsetOutput(params.kernel, params.output, sum_chunk);
61   return true;
62 }
63 
64 template <typename Params, int kernel_size>
65 struct Transform1DTaskRunner : gemmlowp::Task {
Transform1DTaskRunnerTransform1DTaskRunner66   Transform1DTaskRunner(const Params& params) : params(params) {}
67 
RunTransform1DTaskRunner68   void Run() override { Transform1D<Params, kernel_size>(params); }
69 
70   Params params;
71 };
72 
73 }  // namespace internal
74 
75 template <typename MultiThreadingContext, typename Params, int kernel_size>
MultiThreadTransform1D(MultiThreadingContext * context,const Params & params)76 inline void MultiThreadTransform1D(MultiThreadingContext* context,
77                                    const Params& params) {
78   typedef internal::Transform1DTaskRunner<Params, kernel_size> TaskRunnerType;
79 
80   std::vector<Params> task_params;
81   if (!internal::PrepareTransform1DTasks<MultiThreadingContext, Params>(
82           context, params, kernel_size, &task_params)) {
83     Transform1D<Params, kernel_size>(params);
84     return;
85   }
86 
87   auto workers_pool = context->workers_pool();
88   std::vector<Task*> tasks;
89   for (auto& task_param : task_params) {
90     tasks.push_back(new TaskRunnerType(task_param));
91   }
92   workers_pool->Execute(tasks);
93 }
94 
95 }  // namespace meta
96 }  // namespace gemmlowp
97 
98 #endif  // GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_
99