• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_
18 
19 #include <cstdint>
20 #include <functional>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28 #include "tensorflow/lite/delegates/gpu/gl/compiled_model_generated.h"
29 #include "tensorflow/lite/delegates/gpu/gl/object.h"
30 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
31 
32 namespace tflite {
33 namespace gpu {
34 namespace gl {
35 
36 struct CompiledModelOptions {
37   // If true, a model was compiled with dynamic batch size and therefore,
38   // a user may change BATCH dimension at runtime.
39   bool dynamic_batch = false;
40 };
41 
42 // Accumulates shaders and programs and stores it in FlatBuffer format.
43 class SerializedCompiledModelBuilder {
44  public:
SerializedCompiledModelBuilder()45   SerializedCompiledModelBuilder() : builder_(32 * 1024) {}
46 
47   void AddShader(const std::string& shader_src);
48 
49   void AddProgram(const std::vector<Variable>& parameters,
50                   const std::vector<Object>& objects,
51                   const uint3& workgroup_size, const uint3& num_workgroups,
52                   size_t shader_index);
53 
54   // Returns serialized data that will stay valid until this object is
55   // destroyed.
56   absl::Span<const uint8_t> Finalize(const CompiledModelOptions& options);
57 
58  private:
59   std::vector<flatbuffers::Offset<flatbuffers::String>> shaders_;
60   std::vector<flatbuffers::Offset<data::Program>> programs_;
61   ::flatbuffers::FlatBufferBuilder builder_;
62 };
63 
64 // Handles deserialization events. it is guaranteed that shaders will be called
65 // first in the appropriate order and programs come next.
66 class DeserializationHandler {
67  public:
68   virtual ~DeserializationHandler() = default;
69 
70   virtual absl::Status OnShader(absl::Span<const char> shader_src) = 0;
71 
72   virtual absl::Status OnProgram(const std::vector<Variable>& parameters,
73                                  const std::vector<Object>& objects,
74                                  const uint3& workgroup_size,
75                                  const uint3& num_workgroups,
76                                  size_t shader_index) = 0;
77 
78   virtual void OnOptions(const CompiledModelOptions& options) = 0;
79 };
80 
81 absl::Status DeserializeCompiledModel(absl::Span<const uint8_t> serialized,
82                                       DeserializationHandler* handler);
83 
84 }  // namespace gl
85 }  // namespace gpu
86 }  // namespace tflite
87 
88 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_
89