1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator_from_metadata.h"
17
18 #ifndef TFLITE_GPU_BINARY_RELEASE
19
20 #include <memory>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
25 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h"
28 #include "tensorflow/lite/delegates/gpu/gl/workgroups/calculator.h"
29 #include "tensorflow/lite/delegates/gpu/gl/workgroups/default_calculator.h"
30 #include "tensorflow/lite/delegates/gpu/gl/workgroups_generated.h"
31
32 #endif // TFLITE_GPU_BINARY_RELEASE
33
34 namespace tflite {
35 namespace gpu {
36 namespace gl {
37
38 #ifndef TFLITE_GPU_BINARY_RELEASE
39 namespace {
40 class WorkgroupsCalculatorFromMetadata : public WorkgroupsCalculator {
41 public:
WorkgroupsCalculatorFromMetadata(const data::HardcodedWorkgroups & workgroups,const GpuInfo & gpu_info)42 WorkgroupsCalculatorFromMetadata(const data::HardcodedWorkgroups& workgroups,
43 const GpuInfo& gpu_info)
44 : WorkgroupsCalculator(gpu_info),
45 default_calculator_(NewDefaultWorkgroupsCalculator(gpu_info)) {
46 for (const auto* workgroup : *workgroups.workgroups()) {
47 uint3 size(workgroup->size()->x(), workgroup->size()->y(),
48 workgroup->size()->z());
49 // Class implementation relies on the fact that it uses unique graph
50 // representation where each node id appears in a single workgroup.
51 for (auto node_id : *workgroup->node_indices()) {
52 workgroups_.insert({node_id, size});
53 }
54 }
55 }
56
CalculateInternal(const ShaderCode & shader_code) const57 uint3 CalculateInternal(const ShaderCode& shader_code) const final {
58 auto it = workgroups_.find(shader_code.node_indices[0]);
59 return it != workgroups_.end()
60 ? it->second
61 : default_calculator_->Calculate(shader_code);
62 }
63
64 private:
65 absl::flat_hash_map<NodeId, uint3> workgroups_;
66 std::unique_ptr<WorkgroupsCalculator> default_calculator_;
67 };
68
FindWorkgroups(const data::CustomWorkgroups & workgroups,const GpuInfo & gpu_info)69 const data::HardcodedWorkgroups* FindWorkgroups(
70 const data::CustomWorkgroups& workgroups, const GpuInfo& gpu_info) {
71 for (auto workgroup : *workgroups.hardcoded_workgroups()) {
72 if (workgroup->gpu_info()->c_str() == gpu_info.opengl_info.renderer_name) {
73 return workgroup;
74 }
75 }
76 return nullptr;
77 }
78
79 } // namespace
80
NewWorkgroupsCalculatorFromMetadata(const uint8_t * metadata,const GpuInfo & gpu_info)81 std::unique_ptr<WorkgroupsCalculator> NewWorkgroupsCalculatorFromMetadata(
82 const uint8_t* metadata, const GpuInfo& gpu_info) {
83 if (!metadata) return nullptr;
84 const auto* flow_metadata =
85 flatbuffers::GetRoot<data::FlowMetadata>(metadata);
86 if (!flow_metadata || !flow_metadata->workgroups()) return nullptr;
87 const data::HardcodedWorkgroups* workgroups =
88 FindWorkgroups(*flow_metadata->workgroups(), gpu_info);
89 if (!workgroups) return nullptr;
90 return std::make_unique<WorkgroupsCalculatorFromMetadata>(*workgroups,
91 gpu_info);
92 }
93
94 #else // TFLITE_GPU_BINARY_RELEASE
95
96 std::unique_ptr<WorkgroupsCalculator> NewWorkgroupsCalculatorFromMetadata(
97 const uint8_t* metadata, const GpuInfo& gpu_info) {
98 return nullptr;
99 }
100
101 #endif // TFLITE_GPU_BINARY_RELEASE
102
103 } // namespace gl
104 } // namespace gpu
105 } // namespace tflite
106