• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/api.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <deque>
21 #include <memory>
22 #include <mutex>  // NOLINT
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "tensorflow/lite/delegates/gpu/common/model.h"
32 #include "tensorflow/lite/delegates/gpu/common/status.h"
33 #include "tensorflow/lite/delegates/gpu/common/types.h"
34 #include "tensorflow/lite/delegates/gpu/common/util.h"
35 #include "tensorflow/lite/delegates/gpu/gl/compiler.h"
36 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
37 #include "tensorflow/lite/delegates/gpu/gl/object.h"
38 #include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
39 #include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
40 #include "tensorflow/lite/delegates/gpu/gl/runtime.h"
41 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
42 
43 #ifndef TFLITE_GPU_BINARY_RELEASE
44 #include "tensorflow/lite/delegates/gpu/gl/serialization.h"
45 #endif  // TFLITE_GPU_BINARY_RELEASE
46 
47 namespace tflite {
48 namespace gpu {
49 namespace gl {
50 namespace {
51 
52 using ObjectsSizes = absl::flat_hash_map<ValueId, size_t>;
53 
54 enum class InferenceContextState {
55   NOT_STARTED,
56   IN_PROGRESS,
57 };
58 
59 class InferenceContextImpl : public InferenceContext {
60  public:
InferenceContextImpl(std::unique_ptr<Runtime> runtime)61   explicit InferenceContextImpl(std::unique_ptr<Runtime> runtime)
62       : runtime_(std::move(runtime)) {}
63 
Execute()64   absl::Status Execute() final {
65     std::lock_guard<std::mutex> lock(guard_);
66     if (state_ != InferenceContextState::NOT_STARTED) {
67       return absl::FailedPreconditionError("InferenceContext is not reset");
68     }
69     state_ = InferenceContextState::IN_PROGRESS;
70     return runtime_->Execute();
71   }
72 
Reset()73   absl::Status Reset() final {
74     std::lock_guard<std::mutex> lock(guard_);
75     // TODO(akulik): should Reset not return Status?
76     state_ = InferenceContextState::NOT_STARTED;
77     return absl::OkStatus();
78   }
79 
stats() const80   RuntimeStats stats() const final { return runtime_->stats(); }
81 
82  private:
83   std::unique_ptr<Runtime> runtime_;
84 
85   mutable std::mutex guard_;
86   InferenceContextState state_ = InferenceContextState::NOT_STARTED;
87 };
88 
89 class InferenceContextWithBatchImpl : public InferenceContext {
90  public:
InferenceContextWithBatchImpl(const ObjectsSizes & sizes,const ObjectManager * objects,std::unique_ptr<ObjectManager> refs,std::unique_ptr<Runtime> runtime)91   InferenceContextWithBatchImpl(const ObjectsSizes& sizes,
92                                 const ObjectManager* objects,
93                                 std::unique_ptr<ObjectManager> refs,
94                                 std::unique_ptr<Runtime> runtime)
95       : sizes_(sizes),
96         objects_(objects),
97         refs_(std::move(refs)),
98         runtime_(std::move(runtime)) {}
99 
Execute()100   absl::Status Execute() final {
101     std::lock_guard<std::mutex> lock(guard_);
102     if (state_ != InferenceContextState::NOT_STARTED) {
103       return absl::FailedPreconditionError("InferenceContext is not reset");
104     }
105     state_ = InferenceContextState::IN_PROGRESS;
106 
107     // Calculate expected number of batches and check that all external objects
108     // match that number.
109     int num_batches = 0;
110     for (const auto& s : sizes_) {
111       const ValueId id = s.first;
112       const size_t byte_size = s.second;
113 
114       auto buffer = objects_->FindBuffer(id);
115       if (!buffer) continue;
116 
117       if (buffer->bytes_size() % byte_size) {
118         return absl::InvalidArgumentError(absl::StrCat(
119             "Object ", id, " does not match expected byte size: ", byte_size));
120       }
121 
122       const size_t b = buffer->bytes_size() / byte_size;
123       if (num_batches == 0) {
124         num_batches = b;
125       } else if (num_batches != b) {
126         return absl::InvalidArgumentError(absl::StrCat(
127             "Object ", id, " size does not match expected batch size: ", b,
128             " vs ", num_batches));
129       }
130     }
131 
132     for (size_t b = 0; b < num_batches; ++b) {
133       // slice external objects by batch.
134       for (const auto& s : sizes_) {
135         const ValueId id = s.first;
136         const size_t byte_size = s.second;
137         auto buffer = objects_->FindBuffer(id);
138         if (buffer) {
139           auto ref = refs_->FindBuffer(id);
140           if (!ref) {
141             return absl::InvalidArgumentError(
142                 absl::StrCat("Reference to ", id, " is not found"));
143           }
144           RETURN_IF_ERROR(buffer->MakeView(b * byte_size, byte_size, ref));
145         }
146       }
147       RETURN_IF_ERROR(runtime_->Execute());
148     }
149     return absl::OkStatus();
150   }
151 
Reset()152   absl::Status Reset() final {
153     std::lock_guard<std::mutex> lock(guard_);
154     state_ = InferenceContextState::NOT_STARTED;
155     // TODO(akulik): should Reset not return Status?
156     return absl::OkStatus();
157   }
158 
stats() const159   RuntimeStats stats() const final { return runtime_->stats(); }
160 
161  private:
162   const ObjectsSizes sizes_;
163   const ObjectManager* objects_;
164 
165   // view over external objects provided by a user.
166   std::unique_ptr<ObjectManager> refs_;
167   std::unique_ptr<Runtime> runtime_;
168 
169   mutable std::mutex guard_;
170   InferenceContextState state_ = InferenceContextState::NOT_STARTED;
171 };
172 
173 struct ProgramParameters {
174   // A list of uniform parameters to be set.
175   std::vector<Variable> parameters;
176 
177   // A list of objects to bind to opengl program.
178   std::vector<Object> objects;
179 
180   uint3 workgroup_size;
181   uint3 num_workgroups;
182 
183   size_t shader_idx;
184 };
185 
GetShaderHeader(uint3 localsize)186 std::string GetShaderHeader(uint3 localsize) {
187   return absl::StrCat("#version 310 es\nlayout(local_size_x = ", localsize.x,
188                       ", local_size_y = ", localsize.y,
189                       ", local_size_z = ", localsize.z, ") in;\n");
190 }
191 
192 class CompiledModelImpl
193 #ifndef TFLITE_GPU_BINARY_RELEASE
194     : public CompiledModel,
195       public DeserializationHandler {
196 #else
197     : public CompiledModel {
198 #endif  // TFLITE_GPU_BINARY_RELEASE
199  public:
CompiledModelImpl(const GpuInfo & gpu_info)200   explicit CompiledModelImpl(const GpuInfo& gpu_info) : gpu_info_(gpu_info) {}
201 
202   // Called while compiling shaders from scratch
Add(const WorkgroupsCalculator & workgroup_calculator,ShaderCode code)203   absl::Status Add(const WorkgroupsCalculator& workgroup_calculator,
204                    ShaderCode code) {
205     // Calculate workgroup size.
206     uint3 workgroup_size = workgroup_calculator.Calculate(code);
207     uint3 num_workgroups = DivideRoundUp(code.workload, workgroup_size);
208 
209     for (const auto& object : code.objects) {
210       if (IsRef(object)) {
211         object_sizes_[GetRef(object)] = ByteSizeOf(object);
212       }
213     }
214 
215     // Store full shader and compile it if necessary.
216     size_t shader_idx;
217     RETURN_IF_ERROR(
218         AddFullShader(code.source_code, workgroup_size, &shader_idx));
219     programs_.push_back({
220         std::move(code.parameters),
221         std::move(code.objects),
222         workgroup_size,
223         num_workgroups,
224         shader_idx,
225     });
226     return absl::OkStatus();
227   }
228 
229   // Store full shader and compile it if necessary.
230   // Returns full_shader_index
AddFullShader(const std::string & partial_shader,const uint3 & workgroup_size,size_t * size)231   absl::Status AddFullShader(const std::string& partial_shader,
232                              const uint3& workgroup_size, size_t* size) {
233     std::string shader_src = GetShaderHeader(workgroup_size) + partial_shader;
234     auto it = shader_to_index_.find(shader_src);
235     if (it == shader_to_index_.end()) {
236       GlShader shader;
237       RETURN_IF_ERROR(
238           GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src, &shader));
239       shaders_.push_back(std::move(shader));
240       shader_to_index_.insert({shader_src, shader_to_index_.size()});
241       *size = shader_to_index_.size() - 1;
242     } else {
243       *size = it->second;
244     }
245     return absl::OkStatus();
246   }
247 
NewRun(const RuntimeOptions & options,const ObjectManager * objects,CommandQueue * command_queue,std::unique_ptr<InferenceContext> * inference_context) const248   absl::Status NewRun(
249       const RuntimeOptions& options, const ObjectManager* objects,
250       CommandQueue* command_queue,
251       std::unique_ptr<InferenceContext>* inference_context) const final {
252     std::unique_ptr<ObjectManager> refs;
253     if (dynamic_batch_) {
254       // Runtime is using objects from refs that will point to provided objects.
255       // At this point just create 0 batch slice references.
256       refs = std::make_unique<ObjectManager>();
257       for (const auto& s : object_sizes_) {
258         auto buffer = objects->FindBuffer(s.first);
259         if (!buffer) continue;
260         GlBuffer ref;
261         RETURN_IF_ERROR(buffer->MakeView(0, s.second, &ref));
262         RETURN_IF_ERROR(refs->RegisterBuffer(s.first, std::move(ref)));
263       }
264     }
265     auto runtime = std::make_unique<Runtime>(options, gpu_info_, command_queue,
266                                              refs ? refs.get() : objects);
267     for (auto& program : programs_) {
268       RETURN_IF_ERROR(runtime->AddProgram(shaders_[program.shader_idx],
269                                           program.parameters, program.objects,
270                                           program.num_workgroups));
271     }
272     RETURN_IF_ERROR(runtime->PrepareForExecution());
273     if (dynamic_batch_) {
274       *inference_context = std::make_unique<InferenceContextWithBatchImpl>(
275           object_sizes_, objects, std::move(refs), std::move(runtime));
276     } else {
277       *inference_context =
278           std::make_unique<InferenceContextImpl>(std::move(runtime));
279     }
280     return absl::OkStatus();
281   }
282 
283 #ifndef TFLITE_GPU_BINARY_RELEASE
284   // Called on deserialization
OnProgram(const std::vector<Variable> & parameters,const std::vector<Object> & objects,const uint3 & workgroup_size,const uint3 & num_workgroups,size_t partial_shader_index)285   absl::Status OnProgram(const std::vector<Variable>& parameters,
286                          const std::vector<Object>& objects,
287                          const uint3& workgroup_size,
288                          const uint3& num_workgroups,
289                          size_t partial_shader_index) final {
290     for (auto& object : objects) {
291       if (IsRef(object)) {
292         object_sizes_[GetRef(object)] = ByteSizeOf(object);
293       }
294     }
295 
296     size_t shader_idx;
297     RETURN_IF_ERROR(AddFullShader(partial_shaders_[partial_shader_index],
298                                   workgroup_size, &shader_idx));
299     programs_.push_back({
300         parameters,
301         objects,
302         workgroup_size,
303         num_workgroups,
304         shader_idx,
305     });
306     return absl::OkStatus();
307   }
308 
Serialize(std::vector<uint8_t> * serialized_compiled_model) const309   absl::Status Serialize(
310       std::vector<uint8_t>* serialized_compiled_model) const final {
311     SerializedCompiledModelBuilder builder;
312 
313     // sort shaders first. They need to be serialized in order.
314     std::vector<std::string> full_shaders(shaders_.size());
315     for (const auto& shader : shader_to_index_) {
316       full_shaders[shader.second] = shader.first;
317     }
318 
319     absl::flat_hash_map<std::string, size_t> partial_shader_to_index;
320     std::vector<std::string> partial_shaders;
321     for (const auto& program : programs_) {
322       // Remove a header from a shader.
323       std::string shader_without_header = full_shaders[program.shader_idx];
324       shader_without_header.erase(0, shader_without_header.find("in;") + 3);
325 
326       // Insert shader into partial shaders array.
327       auto it = partial_shader_to_index.find(shader_without_header);
328       size_t shader_idx;
329       if (it == partial_shader_to_index.end()) {
330         shader_idx = partial_shaders.size();
331         partial_shaders.push_back(shader_without_header);
332         builder.AddShader(shader_without_header);
333         partial_shader_to_index.insert({shader_without_header, shader_idx});
334       } else {
335         shader_idx = it->second;
336       }
337       builder.AddProgram(program.parameters, program.objects,
338                          program.workgroup_size, program.num_workgroups,
339                          shader_idx);
340     }
341     CompiledModelOptions options;
342     options.dynamic_batch = dynamic_batch_;
343     auto data = builder.Finalize(options);
344     serialized_compiled_model->insert(serialized_compiled_model->end(),
345                                       data.begin(), data.end());
346     return absl::OkStatus();
347   }
348 
OnShader(absl::Span<const char> shader_src)349   absl::Status OnShader(absl::Span<const char> shader_src) final {
350     std::string source(shader_src.data(), shader_src.size());
351     partial_shaders_.push_back(source);
352     return absl::OkStatus();
353   }
354 
OnOptions(const CompiledModelOptions & options)355   void OnOptions(const CompiledModelOptions& options) final {
356     dynamic_batch_ = options.dynamic_batch;
357   }
358 #endif  // TFLITE_GPU_BINARY_RELEASE
359 
stats() const360   CompilerStats stats() const final { return stats_; }
361 
set_dynamic_batch(bool dynamic_batch)362   void set_dynamic_batch(bool dynamic_batch) { dynamic_batch_ = dynamic_batch; }
363 
364  private:
365   const GpuInfo gpu_info_;
366   bool dynamic_batch_ = false;
367 
368   std::vector<std::string> partial_shaders_;
369   std::vector<GlShader> shaders_;
370 
371   // Shaders are serialized in order of their indices.
372   absl::flat_hash_map<std::string, size_t> shader_to_index_;
373   std::deque<ProgramParameters> programs_;
374   absl::flat_hash_map<ValueId, size_t> object_sizes_;
375   CompilerStats stats_;
376 };
377 }  // namespace
378 
Compile(const CompilationOptions & options,const GraphFloat32 & model,const std::unordered_set<int> & tflite_graph_io,const NodeShader & node_shader,const WorkgroupsCalculator & workgroup_calculator,std::unique_ptr<CompiledModel> * compiled_model)379 absl::Status Compile(const CompilationOptions& options,
380                      const GraphFloat32& model,
381                      const std::unordered_set<int>& tflite_graph_io,  // NOLINT
382                      const NodeShader& node_shader,
383                      const WorkgroupsCalculator& workgroup_calculator,
384                      std::unique_ptr<CompiledModel>* compiled_model) {
385   RETURN_IF_ERROR(CheckBatchSizeForAllValues(model));
386   GpuInfo gpu_info;
387   RETURN_IF_ERROR(RequestGpuInfo(&gpu_info));
388   if (!gpu_info.IsApiOpenGl31OrAbove()) {
389     return absl::InternalError(
390         "OpenGL ES 3.1 or above is required to use OpenGL inference.");
391   }
392   auto compiled_model_impl = std::make_unique<CompiledModelImpl>(gpu_info);
393   compiled_model_impl->set_dynamic_batch(options.dynamic_batch);
394   auto compiler = NewCompiler(&node_shader, &gpu_info, options);
395   RETURN_IF_ERROR(compiler->Compile(
396       model, tflite_graph_io, [&](ShaderCode code) -> absl::Status {
397         return compiled_model_impl->Add(workgroup_calculator, std::move(code));
398       }));
399   *compiled_model = std::move(compiled_model_impl);
400   return absl::OkStatus();
401 }
402 
403 #ifndef TFLITE_GPU_BINARY_RELEASE
ReadSerializedModel(const std::vector<uint8_t> & serialized_model,std::unique_ptr<CompiledModel> * compiled_model)404 absl::Status ReadSerializedModel(
405     const std::vector<uint8_t>& serialized_model,
406     std::unique_ptr<CompiledModel>* compiled_model) {
407   GpuInfo gpu_info;
408   RETURN_IF_ERROR(RequestGpuInfo(&gpu_info));
409   if (!gpu_info.IsApiOpenGl31OrAbove()) {
410     return absl::InternalError(
411         "OpenGL ES 3.1 or above is required to use OpenGL inference.");
412   }
413   auto compiled_model_impl = std::make_unique<CompiledModelImpl>(gpu_info);
414   RETURN_IF_ERROR(DeserializeCompiledModel(
415       absl::MakeConstSpan(serialized_model), compiled_model_impl.get()));
416   *compiled_model = std::move(compiled_model_impl);
417   return absl::OkStatus();
418 }
419 #endif  // TFLITE_GPU_BINARY_RELEASE
420 
421 }  // namespace gl
422 }  // namespace gpu
423 }  // namespace tflite
424