• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
21 #include "tensorflow/lite/delegates/gpu/common/types.h"
22 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
23 
24 namespace tflite {
25 namespace gpu {
26 namespace gl {
27 namespace {
28 
29 class DefaultWorkgroupsCalculator : public WorkgroupsCalculator {
30  public:
DefaultWorkgroupsCalculator(const GpuInfo & gpu_info)31   explicit DefaultWorkgroupsCalculator(const GpuInfo& gpu_info)
32       : WorkgroupsCalculator(gpu_info) {}
CalculateInternal(const ShaderCode & shader_code) const33   uint3 CalculateInternal(const ShaderCode& shader_code) const final {
34     const auto& workload = shader_code.workload;
35     if (workload.z >= 64) {
36       return uint3(4, 4, 64);
37     }
38     if (workload.z >= 32) {
39       return uint3(8, 4, 32);
40     }
41     if (workload.z >= 16) {
42       return uint3(8, 8, 16);
43     }
44     if (workload.z >= 8) {
45       return uint3(16, 8, 8);
46     }
47     if (workload.z >= 4) {
48       return uint3(16, 16, 4);
49     }
50     if (workload.z >= 2) {
51       return uint3(32, 16, 2);
52     }
53     return uint3(32, 32, 1);
54   }
55 };
56 
57 class WorkgroupsCalculatorForMali : public WorkgroupsCalculator {
58  public:
WorkgroupsCalculatorForMali(const GpuInfo & gpu_info)59   explicit WorkgroupsCalculatorForMali(const GpuInfo& gpu_info)
60       : WorkgroupsCalculator(gpu_info) {}
CalculateInternal(const ShaderCode & shader_code) const61   uint3 CalculateInternal(const ShaderCode& shader_code) const final {
62     const auto& workload = shader_code.workload;
63     if (workload.z >= 32) {
64       return uint3(2, 2, 32);
65     }
66     if (workload.z >= 16) {
67       return uint3(4, 2, 16);
68     }
69     if (workload.z >= 8) {
70       return uint3(4, 4, 8);
71     }
72     if (workload.z >= 4) {
73       return uint3(8, 4, 4);
74     }
75     if (workload.z >= 2) {
76       return uint3(8, 8, 2);
77     }
78     return uint3(16, 8, 1);
79   }
80 };
81 
82 }  // namespace
83 
NewDefaultWorkgroupsCalculator(const GpuInfo & gpu_info)84 std::unique_ptr<WorkgroupsCalculator> NewDefaultWorkgroupsCalculator(
85     const GpuInfo& gpu_info) {
86   if (gpu_info.IsMali()) {
87     return std::make_unique<WorkgroupsCalculatorForMali>(gpu_info);
88   } else {
89     return std::make_unique<DefaultWorkgroupsCalculator>(gpu_info);
90   }
91 }
92 
93 }  // namespace gl
94 }  // namespace gpu
95 }  // namespace tflite
96