• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/api.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <deque>
21 #include <mutex>  // NOLINT
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_cat.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/common/types.h"
31 #include "tensorflow/lite/delegates/gpu/common/util.h"
32 #include "tensorflow/lite/delegates/gpu/gl/compiler.h"
33 #include "tensorflow/lite/delegates/gpu/gl/gl_call.h"
34 #include "tensorflow/lite/delegates/gpu/gl/object.h"
35 #include "tensorflow/lite/delegates/gpu/gl/portable_gl31.h"
36 #include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h"
37 #include "tensorflow/lite/delegates/gpu/gl/runtime.h"
38 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
39 
40 #ifndef TFLITE_GPU_BINARY_RELEASE
41 #include "tensorflow/lite/delegates/gpu/gl/serialization.h"
42 #endif  // TFLITE_GPU_BINARY_RELEASE
43 
44 namespace tflite {
45 namespace gpu {
46 namespace gl {
47 namespace {
48 
49 using ObjectsSizes = absl::flat_hash_map<ValueId, size_t>;
50 
51 enum class InferenceContextState {
52   NOT_STARTED,
53   IN_PROGRESS,
54 };
55 
56 class InferenceContextImpl : public InferenceContext {
57  public:
InferenceContextImpl(std::unique_ptr<Runtime> runtime)58   explicit InferenceContextImpl(std::unique_ptr<Runtime> runtime)
59       : runtime_(std::move(runtime)) {}
60 
Execute()61   absl::Status Execute() final {
62     std::lock_guard<std::mutex> lock(guard_);
63     if (state_ != InferenceContextState::NOT_STARTED) {
64       return absl::FailedPreconditionError("InferenceContext is not reset");
65     }
66     state_ = InferenceContextState::IN_PROGRESS;
67     return runtime_->Execute();
68   }
69 
Reset()70   absl::Status Reset() final {
71     std::lock_guard<std::mutex> lock(guard_);
72     // TODO(akulik): should Reset not return Status?
73     state_ = InferenceContextState::NOT_STARTED;
74     return absl::OkStatus();
75   }
76 
stats() const77   RuntimeStats stats() const final { return runtime_->stats(); }
78 
79  private:
80   std::unique_ptr<Runtime> runtime_;
81 
82   mutable std::mutex guard_;
83   InferenceContextState state_ = InferenceContextState::NOT_STARTED;
84 };
85 
86 class InferenceContextWithBatchImpl : public InferenceContext {
87  public:
InferenceContextWithBatchImpl(const ObjectsSizes & sizes,const ObjectManager * objects,std::unique_ptr<ObjectManager> refs,std::unique_ptr<Runtime> runtime)88   InferenceContextWithBatchImpl(const ObjectsSizes& sizes,
89                                 const ObjectManager* objects,
90                                 std::unique_ptr<ObjectManager> refs,
91                                 std::unique_ptr<Runtime> runtime)
92       : sizes_(sizes),
93         objects_(objects),
94         refs_(std::move(refs)),
95         runtime_(std::move(runtime)) {}
96 
Execute()97   absl::Status Execute() final {
98     std::lock_guard<std::mutex> lock(guard_);
99     if (state_ != InferenceContextState::NOT_STARTED) {
100       return absl::FailedPreconditionError("InferenceContext is not reset");
101     }
102     state_ = InferenceContextState::IN_PROGRESS;
103 
104     // Calculate expected number of batches and check that all external objects
105     // match that number.
106     int num_batches = 0;
107     for (const auto& s : sizes_) {
108       const ValueId id = s.first;
109       const size_t byte_size = s.second;
110 
111       auto buffer = objects_->FindBuffer(id);
112       if (!buffer) continue;
113 
114       if (buffer->bytes_size() % byte_size) {
115         return absl::InvalidArgumentError(absl::StrCat(
116             "Object ", id, " does not match expected byte size: ", byte_size));
117       }
118 
119       const size_t b = buffer->bytes_size() / byte_size;
120       if (num_batches == 0) {
121         num_batches = b;
122       } else if (num_batches != b) {
123         return absl::InvalidArgumentError(absl::StrCat(
124             "Object ", id, " size does not match expected batch size: ", b,
125             " vs ", num_batches));
126       }
127     }
128 
129     for (size_t b = 0; b < num_batches; ++b) {
130       // slice external objects by batch.
131       for (const auto& s : sizes_) {
132         const ValueId id = s.first;
133         const size_t byte_size = s.second;
134         auto buffer = objects_->FindBuffer(id);
135         if (buffer) {
136           auto ref = refs_->FindBuffer(id);
137           if (!ref) {
138             return absl::InvalidArgumentError(
139                 absl::StrCat("Reference to ", id, " is not found"));
140           }
141           RETURN_IF_ERROR(buffer->MakeView(b * byte_size, byte_size, ref));
142         }
143       }
144       RETURN_IF_ERROR(runtime_->Execute());
145     }
146     return absl::OkStatus();
147   }
148 
Reset()149   absl::Status Reset() final {
150     std::lock_guard<std::mutex> lock(guard_);
151     state_ = InferenceContextState::NOT_STARTED;
152     // TODO(akulik): should Reset not return Status?
153     return absl::OkStatus();
154   }
155 
stats() const156   RuntimeStats stats() const final { return runtime_->stats(); }
157 
158  private:
159   const ObjectsSizes sizes_;
160   const ObjectManager* objects_;
161 
162   // view over external objects provided by a user.
163   std::unique_ptr<ObjectManager> refs_;
164   std::unique_ptr<Runtime> runtime_;
165 
166   mutable std::mutex guard_;
167   InferenceContextState state_ = InferenceContextState::NOT_STARTED;
168 };
169 
170 struct ProgramParameters {
171   // A list of uniform parameters to be set.
172   std::vector<Variable> parameters;
173 
174   // A list of objects to bind to opengl program.
175   std::vector<Object> objects;
176 
177   uint3 workgroup_size;
178   uint3 num_workgroups;
179 
180   size_t shader_idx;
181 };
182 
GetShaderHeader(uint3 localsize)183 std::string GetShaderHeader(uint3 localsize) {
184   return absl::StrCat("#version 310 es\nlayout(local_size_x = ", localsize.x,
185                       ", local_size_y = ", localsize.y,
186                       ", local_size_z = ", localsize.z, ") in;\n");
187 }
188 
189 class CompiledModelImpl
190 #ifndef TFLITE_GPU_BINARY_RELEASE
191     : public CompiledModel,
192       public DeserializationHandler {
193 #else
194     : public CompiledModel {
195 #endif  // TFLITE_GPU_BINARY_RELEASE
196  public:
CompiledModelImpl(const GpuInfo & gpu_info)197   explicit CompiledModelImpl(const GpuInfo& gpu_info) : gpu_info_(gpu_info) {}
198 
199   // Called while compiling shaders from scratch
Add(const WorkgroupsCalculator & workgroup_calculator,ShaderCode code)200   absl::Status Add(const WorkgroupsCalculator& workgroup_calculator,
201                    ShaderCode code) {
202     // Calculate workgroup size.
203     uint3 workgroup_size = workgroup_calculator.Calculate(code);
204     uint3 num_workgroups = DivideRoundUp(code.workload, workgroup_size);
205 
206     for (const auto& object : code.objects) {
207       if (IsRef(object)) {
208         object_sizes_[GetRef(object)] = ByteSizeOf(object);
209       }
210     }
211 
212     // Store full shader and compile it if necessary.
213     size_t shader_idx;
214     RETURN_IF_ERROR(
215         AddFullShader(code.source_code, workgroup_size, &shader_idx));
216     programs_.push_back({
217         std::move(code.parameters),
218         std::move(code.objects),
219         workgroup_size,
220         num_workgroups,
221         shader_idx,
222     });
223     return absl::OkStatus();
224   }
225 
226   // Store full shader and compile it if necessary.
227   // Returns full_shader_index
AddFullShader(const std::string & partial_shader,const uint3 & workgroup_size,size_t * size)228   absl::Status AddFullShader(const std::string& partial_shader,
229                              const uint3& workgroup_size, size_t* size) {
230     std::string shader_src = GetShaderHeader(workgroup_size) + partial_shader;
231     auto it = shader_to_index_.find(shader_src);
232     if (it == shader_to_index_.end()) {
233       GlShader shader;
234       RETURN_IF_ERROR(
235           GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src, &shader));
236       shaders_.push_back(std::move(shader));
237       shader_to_index_.insert({shader_src, shader_to_index_.size()});
238       *size = shader_to_index_.size() - 1;
239     } else {
240       *size = it->second;
241     }
242     return absl::OkStatus();
243   }
244 
NewRun(const RuntimeOptions & options,const ObjectManager * objects,CommandQueue * command_queue,std::unique_ptr<InferenceContext> * inference_context) const245   absl::Status NewRun(
246       const RuntimeOptions& options, const ObjectManager* objects,
247       CommandQueue* command_queue,
248       std::unique_ptr<InferenceContext>* inference_context) const final {
249     std::unique_ptr<ObjectManager> refs;
250     if (dynamic_batch_) {
251       // Runtime is using objects from refs that will point to provided objects.
252       // At this point just create 0 batch slice references.
253       refs = absl::make_unique<ObjectManager>();
254       for (const auto& s : object_sizes_) {
255         auto buffer = objects->FindBuffer(s.first);
256         if (!buffer) continue;
257         GlBuffer ref;
258         RETURN_IF_ERROR(buffer->MakeView(0, s.second, &ref));
259         RETURN_IF_ERROR(refs->RegisterBuffer(s.first, std::move(ref)));
260       }
261     }
262     auto runtime = absl::make_unique<Runtime>(options, gpu_info_, command_queue,
263                                               refs ? refs.get() : objects);
264     for (auto& program : programs_) {
265       RETURN_IF_ERROR(runtime->AddProgram(shaders_[program.shader_idx],
266                                           program.parameters, program.objects,
267                                           program.num_workgroups));
268     }
269     RETURN_IF_ERROR(runtime->PrepareForExecution());
270     if (dynamic_batch_) {
271       *inference_context = absl::make_unique<InferenceContextWithBatchImpl>(
272           object_sizes_, objects, std::move(refs), std::move(runtime));
273     } else {
274       *inference_context =
275           absl::make_unique<InferenceContextImpl>(std::move(runtime));
276     }
277     return absl::OkStatus();
278   }
279 
280 #ifndef TFLITE_GPU_BINARY_RELEASE
281   // Called on deserialization
OnProgram(const std::vector<Variable> & parameters,const std::vector<Object> & objects,const uint3 & workgroup_size,const uint3 & num_workgroups,size_t partial_shader_index)282   absl::Status OnProgram(const std::vector<Variable>& parameters,
283                          const std::vector<Object>& objects,
284                          const uint3& workgroup_size,
285                          const uint3& num_workgroups,
286                          size_t partial_shader_index) final {
287     for (auto& object : objects) {
288       if (IsRef(object)) {
289         object_sizes_[GetRef(object)] = ByteSizeOf(object);
290       }
291     }
292 
293     size_t shader_idx;
294     RETURN_IF_ERROR(AddFullShader(partial_shaders_[partial_shader_index],
295                                   workgroup_size, &shader_idx));
296     programs_.push_back({
297         parameters,
298         objects,
299         workgroup_size,
300         num_workgroups,
301         shader_idx,
302     });
303     return absl::OkStatus();
304   }
305 
Serialize(std::vector<uint8_t> * serialized_compiled_model) const306   absl::Status Serialize(
307       std::vector<uint8_t>* serialized_compiled_model) const final {
308     SerializedCompiledModelBuilder builder;
309 
310     // sort shaders first. They need to be serialized in order.
311     std::vector<std::string> full_shaders(shaders_.size());
312     for (const auto& shader : shader_to_index_) {
313       full_shaders[shader.second] = shader.first;
314     }
315 
316     absl::flat_hash_map<std::string, size_t> partial_shader_to_index;
317     std::vector<std::string> partial_shaders;
318     for (const auto& program : programs_) {
319       // Remove a header from a shader.
320       std::string shader_without_header = full_shaders[program.shader_idx];
321       shader_without_header.erase(0, shader_without_header.find("in;") + 3);
322 
323       // Insert shader into partial shaders array.
324       auto it = partial_shader_to_index.find(shader_without_header);
325       size_t shader_idx;
326       if (it == partial_shader_to_index.end()) {
327         shader_idx = partial_shaders.size();
328         partial_shaders.push_back(shader_without_header);
329         builder.AddShader(shader_without_header);
330         partial_shader_to_index.insert({shader_without_header, shader_idx});
331       } else {
332         shader_idx = it->second;
333       }
334       builder.AddProgram(program.parameters, program.objects,
335                          program.workgroup_size, program.num_workgroups,
336                          shader_idx);
337     }
338     CompiledModelOptions options;
339     options.dynamic_batch = dynamic_batch_;
340     auto data = builder.Finalize(options);
341     serialized_compiled_model->insert(serialized_compiled_model->end(),
342                                       data.begin(), data.end());
343     return absl::OkStatus();
344   }
345 
OnShader(absl::Span<const char> shader_src)346   absl::Status OnShader(absl::Span<const char> shader_src) final {
347     std::string source(shader_src.data(), shader_src.size());
348     partial_shaders_.push_back(source);
349     return absl::OkStatus();
350   }
351 
OnOptions(const CompiledModelOptions & options)352   void OnOptions(const CompiledModelOptions& options) final {
353     dynamic_batch_ = options.dynamic_batch;
354   }
355 #endif  // TFLITE_GPU_BINARY_RELEASE
356 
stats() const357   CompilerStats stats() const final { return stats_; }
358 
set_dynamic_batch(bool dynamic_batch)359   void set_dynamic_batch(bool dynamic_batch) { dynamic_batch_ = dynamic_batch; }
360 
361  private:
362   const GpuInfo gpu_info_;
363   bool dynamic_batch_ = false;
364 
365   std::vector<std::string> partial_shaders_;
366   std::vector<GlShader> shaders_;
367 
368   // Shaders are serialized in order of their indices.
369   absl::flat_hash_map<std::string, size_t> shader_to_index_;
370   std::deque<ProgramParameters> programs_;
371   absl::flat_hash_map<ValueId, size_t> object_sizes_;
372   CompilerStats stats_;
373 };
374 }  // namespace
375 
Compile(const CompilationOptions & options,const GraphFloat32 & model,const std::unordered_set<int> & tflite_graph_io,const NodeShader & node_shader,const WorkgroupsCalculator & workgroup_calculator,std::unique_ptr<CompiledModel> * compiled_model)376 absl::Status Compile(const CompilationOptions& options,
377                      const GraphFloat32& model,
378                      const std::unordered_set<int>& tflite_graph_io,  // NOLINT
379                      const NodeShader& node_shader,
380                      const WorkgroupsCalculator& workgroup_calculator,
381                      std::unique_ptr<CompiledModel>* compiled_model) {
382   if (!IsBatchMatchesForAllValues(model)) {
383     return absl::InvalidArgumentError(
384         "Only identical batch dimension is supported");
385   }
386   GpuInfo gpu_info;
387   RETURN_IF_ERROR(RequestGpuInfo(&gpu_info));
388   if (!gpu_info.IsApiOpenGl31OrAbove()) {
389     return absl::InternalError(
390         "OpenGL ES 3.1 or above is required to use OpenGL inference.");
391   }
392   auto compiled_model_impl = absl::make_unique<CompiledModelImpl>(gpu_info);
393   compiled_model_impl->set_dynamic_batch(options.dynamic_batch);
394   auto compiler = NewCompiler(&node_shader, &gpu_info, options);
395   RETURN_IF_ERROR(compiler->Compile(
396       model, tflite_graph_io, [&](ShaderCode code) -> absl::Status {
397         return compiled_model_impl->Add(workgroup_calculator, std::move(code));
398       }));
399   *compiled_model = std::move(compiled_model_impl);
400   return absl::OkStatus();
401 }
402 
403 #ifndef TFLITE_GPU_BINARY_RELEASE
ReadSerializedModel(const std::vector<uint8_t> & serialized_model,std::unique_ptr<CompiledModel> * compiled_model)404 absl::Status ReadSerializedModel(
405     const std::vector<uint8_t>& serialized_model,
406     std::unique_ptr<CompiledModel>* compiled_model) {
407   GpuInfo gpu_info;
408   RETURN_IF_ERROR(RequestGpuInfo(&gpu_info));
409   if (!gpu_info.IsApiOpenGl31OrAbove()) {
410     return absl::InternalError(
411         "OpenGL ES 3.1 or above is required to use OpenGL inference.");
412   }
413   auto compiled_model_impl = absl::make_unique<CompiledModelImpl>(gpu_info);
414   RETURN_IF_ERROR(DeserializeCompiledModel(
415       absl::MakeConstSpan(serialized_model), compiled_model_impl.get()));
416   *compiled_model = std::move(compiled_model_impl);
417   return absl::OkStatus();
418 }
419 #endif  // TFLITE_GPU_BINARY_RELEASE
420 
421 }  // namespace gl
422 }  // namespace gpu
423 }  // namespace tflite
424