1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
17
18 #include <cstdint>
19 #include <string>
20
21 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
22 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
23 #include "tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h"
24 #include "tensorflow/lite/delegates/gpu/cl/util.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include <farmhash.h>
27
28 namespace tflite {
29 namespace gpu {
30 namespace cl {
31
ProgramDescriptor(const std::string & code_text,const std::string & options,bool use_fingerprints)32 ProgramCache::ProgramDescriptor::ProgramDescriptor(const std::string& code_text,
33 const std::string& options,
34 bool use_fingerprints)
35 : code(code_text),
36 compiler_options(options),
37 use_fingerprint(use_fingerprints) {
38 const uint64_t code_fingerprint = ::util::Fingerprint64(code);
39 const uint64_t options_fingerprint =
40 ::util::Fingerprint64(compiler_options);
41 fingerprint = code_fingerprint + options_fingerprint;
42 }
43
ProgramDescriptor(uint64_t fingerprints)44 ProgramCache::ProgramDescriptor::ProgramDescriptor(uint64_t fingerprints)
45 : fingerprint(fingerprints), use_fingerprint(true) {}
46
ProgramCache(ProgramCache && program_cache)47 ProgramCache::ProgramCache(ProgramCache&& program_cache)
48 : use_fingerprints_(program_cache.use_fingerprints_),
49 programs_(std::move(program_cache.programs_)) {}
50
operator =(ProgramCache && program_cache)51 ProgramCache& ProgramCache::operator=(ProgramCache&& program_cache) {
52 if (this != &program_cache) {
53 use_fingerprints_ = program_cache.use_fingerprints_;
54 programs_ = std::move(program_cache.programs_);
55 }
56 return *this;
57 }
58
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const std::vector<CompilerOptions> & compiler_options,const CLContext & context,const CLDevice & device,CLKernel * result)59 absl::Status ProgramCache::GetOrCreateCLKernel(
60 const std::string& code, const std::string& function_name,
61 const std::vector<CompilerOptions>& compiler_options,
62 const CLContext& context, const CLDevice& device, CLKernel* result) {
63 const std::string options =
64 CompilerOptionsToString(device.GetInfo(), compiler_options);
65 ProgramDescriptor desc{code, options, use_fingerprints_};
66 auto it = programs_.find(desc);
67 if (it != programs_.end()) {
68 return result->CreateFromProgram(it->second, function_name);
69 }
70
71 CLProgram program;
72 RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program));
73 RETURN_IF_ERROR(result->CreateFromProgram(program, function_name));
74 programs_.insert(std::make_pair(std::move(desc), std::move(program)));
75 return absl::OkStatus();
76 }
77
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const CLContext & context,const CLDevice & device,CLKernel * result)78 absl::Status ProgramCache::GetOrCreateCLKernel(const std::string& code,
79 const std::string& function_name,
80 const CLContext& context,
81 const CLDevice& device,
82 CLKernel* result) {
83 return GetOrCreateCLKernel(code, function_name, {}, context, device, result);
84 }
85
AddSerializedCache(const CLContext & context,const CLDevice & device,absl::Span<const uint8_t> serialized_cache)86 absl::Status ProgramCache::AddSerializedCache(
87 const CLContext& context, const CLDevice& device,
88 absl::Span<const uint8_t> serialized_cache) {
89 flatbuffers::Verifier verifier(serialized_cache.data(),
90 serialized_cache.size());
91 if (!data::VerifyCompiledCacheBuffer(verifier)) {
92 return absl::InvalidArgumentError("Serialized model is corrupted.");
93 }
94
95 auto model = data::GetCompiledCache(serialized_cache.data());
96 std::string platform_version(model->driver_version()->c_str(),
97 model->driver_version()->size());
98
99 if (device.GetPlatformVersion() != platform_version) {
100 return absl::InvalidArgumentError(
101 "OpenCL driver changed, cache invalid, should be regenerated");
102 }
103
104 use_fingerprints_ = true;
105
106 for (auto serialized_program : *model->programs()) {
107 ProgramDescriptor desc(serialized_program->fingerprint());
108 CLProgram program;
109 RETURN_IF_ERROR(CreateCLProgramFromBinary(
110 context, device,
111 absl::MakeSpan(serialized_program->binary()->data(),
112 serialized_program->binary()->size()),
113 &program));
114 auto it = programs_.find(desc);
115 if (it == programs_.end()) {
116 programs_.insert(std::make_pair(std::move(desc), std::move(program)));
117 }
118 }
119 return absl::OkStatus();
120 }
121
GetSerializedCache(const CLDevice & device,std::vector<uint8_t> * serialized_cache) const122 absl::Status ProgramCache::GetSerializedCache(
123 const CLDevice& device, std::vector<uint8_t>* serialized_cache) const {
124 ::flatbuffers::FlatBufferBuilder builder;
125 std::vector<flatbuffers::Offset<data::Program>> serialized_programs;
126 for (auto& program : programs_) {
127 std::vector<uint8_t> binary;
128 RETURN_IF_ERROR(program.second.GetBinary(&binary));
129 auto binary_offset = builder.CreateVector(binary);
130 data::ProgramBuilder program_builder(builder);
131 program_builder.add_fingerprint(program.first.fingerprint);
132 program_builder.add_binary(binary_offset);
133 serialized_programs.push_back(program_builder.Finish());
134 }
135 auto driver_version = builder.CreateString(device.GetPlatformVersion());
136 auto programs_s = builder.CreateVector(serialized_programs);
137 data::CompiledCacheBuilder cache_builder(builder);
138 cache_builder.add_driver_version(driver_version);
139 cache_builder.add_programs(programs_s);
140 data::FinishCompiledCacheBuffer(builder, cache_builder.Finish());
141 size_t next_element = serialized_cache->size();
142 serialized_cache->resize(serialized_cache->size() + builder.GetSize());
143 std::memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(),
144 builder.GetSize());
145 return absl::OkStatus();
146 }
147
148 } // namespace cl
149 } // namespace gpu
150 } // namespace tflite
151