• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
17 
18 #include <cstdint>
19 #include <string>
20 
21 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
22 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
23 #include "tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h"
24 #include "tensorflow/lite/delegates/gpu/cl/util.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include <farmhash.h>
27 
28 namespace tflite {
29 namespace gpu {
30 namespace cl {
31 
ProgramDescriptor(const std::string & code_text,const std::string & options,bool use_fingerprints)32 ProgramCache::ProgramDescriptor::ProgramDescriptor(const std::string& code_text,
33                                                    const std::string& options,
34                                                    bool use_fingerprints)
35     : code(code_text),
36       compiler_options(options),
37       use_fingerprint(use_fingerprints) {
38   const uint64_t code_fingerprint = ::util::Fingerprint64(code);
39   const uint64_t options_fingerprint =
40       ::util::Fingerprint64(compiler_options);
41   fingerprint = code_fingerprint + options_fingerprint;
42 }
43 
ProgramDescriptor(uint64_t fingerprints)44 ProgramCache::ProgramDescriptor::ProgramDescriptor(uint64_t fingerprints)
45     : fingerprint(fingerprints), use_fingerprint(true) {}
46 
ProgramCache(ProgramCache && program_cache)47 ProgramCache::ProgramCache(ProgramCache&& program_cache)
48     : use_fingerprints_(program_cache.use_fingerprints_),
49       programs_(std::move(program_cache.programs_)) {}
50 
operator =(ProgramCache && program_cache)51 ProgramCache& ProgramCache::operator=(ProgramCache&& program_cache) {
52   if (this != &program_cache) {
53     use_fingerprints_ = program_cache.use_fingerprints_;
54     programs_ = std::move(program_cache.programs_);
55   }
56   return *this;
57 }
58 
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const std::vector<CompilerOptions> & compiler_options,const CLContext & context,const CLDevice & device,CLKernel * result)59 absl::Status ProgramCache::GetOrCreateCLKernel(
60     const std::string& code, const std::string& function_name,
61     const std::vector<CompilerOptions>& compiler_options,
62     const CLContext& context, const CLDevice& device, CLKernel* result) {
63   const std::string options =
64       CompilerOptionsToString(device.GetInfo(), compiler_options);
65   ProgramDescriptor desc{code, options, use_fingerprints_};
66   auto it = programs_.find(desc);
67   if (it != programs_.end()) {
68     return result->CreateFromProgram(it->second, function_name);
69   }
70 
71   CLProgram program;
72   RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program));
73   RETURN_IF_ERROR(result->CreateFromProgram(program, function_name));
74   programs_.insert(std::make_pair(std::move(desc), std::move(program)));
75   return absl::OkStatus();
76 }
77 
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const CLContext & context,const CLDevice & device,CLKernel * result)78 absl::Status ProgramCache::GetOrCreateCLKernel(const std::string& code,
79                                                const std::string& function_name,
80                                                const CLContext& context,
81                                                const CLDevice& device,
82                                                CLKernel* result) {
83   return GetOrCreateCLKernel(code, function_name, {}, context, device, result);
84 }
85 
AddSerializedCache(const CLContext & context,const CLDevice & device,absl::Span<const uint8_t> serialized_cache)86 absl::Status ProgramCache::AddSerializedCache(
87     const CLContext& context, const CLDevice& device,
88     absl::Span<const uint8_t> serialized_cache) {
89   flatbuffers::Verifier verifier(serialized_cache.data(),
90                                  serialized_cache.size());
91   if (!data::VerifyCompiledCacheBuffer(verifier)) {
92     return absl::InvalidArgumentError("Serialized model is corrupted.");
93   }
94 
95   auto model = data::GetCompiledCache(serialized_cache.data());
96   std::string platform_version(model->driver_version()->c_str(),
97                                model->driver_version()->size());
98 
99   if (device.GetPlatformVersion() != platform_version) {
100     return absl::InvalidArgumentError(
101         "OpenCL driver changed, cache invalid, should be regenerated");
102   }
103 
104   use_fingerprints_ = true;
105 
106   for (auto serialized_program : *model->programs()) {
107     ProgramDescriptor desc(serialized_program->fingerprint());
108     CLProgram program;
109     RETURN_IF_ERROR(CreateCLProgramFromBinary(
110         context, device,
111         absl::MakeSpan(serialized_program->binary()->data(),
112                        serialized_program->binary()->size()),
113         &program));
114     auto it = programs_.find(desc);
115     if (it == programs_.end()) {
116       programs_.insert(std::make_pair(std::move(desc), std::move(program)));
117     }
118   }
119   return absl::OkStatus();
120 }
121 
GetSerializedCache(const CLDevice & device,std::vector<uint8_t> * serialized_cache) const122 absl::Status ProgramCache::GetSerializedCache(
123     const CLDevice& device, std::vector<uint8_t>* serialized_cache) const {
124   ::flatbuffers::FlatBufferBuilder builder;
125   std::vector<flatbuffers::Offset<data::Program>> serialized_programs;
126   for (auto& program : programs_) {
127     std::vector<uint8_t> binary;
128     RETURN_IF_ERROR(program.second.GetBinary(&binary));
129     auto binary_offset = builder.CreateVector(binary);
130     data::ProgramBuilder program_builder(builder);
131     program_builder.add_fingerprint(program.first.fingerprint);
132     program_builder.add_binary(binary_offset);
133     serialized_programs.push_back(program_builder.Finish());
134   }
135   auto driver_version = builder.CreateString(device.GetPlatformVersion());
136   auto programs_s = builder.CreateVector(serialized_programs);
137   data::CompiledCacheBuilder cache_builder(builder);
138   cache_builder.add_driver_version(driver_version);
139   cache_builder.add_programs(programs_s);
140   data::FinishCompiledCacheBuffer(builder, cache_builder.Finish());
141   size_t next_element = serialized_cache->size();
142   serialized_cache->resize(serialized_cache->size() + builder.GetSize());
143   std::memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(),
144               builder.GetSize());
145   return absl::OkStatus();
146 }
147 
148 }  // namespace cl
149 }  // namespace gpu
150 }  // namespace tflite
151