1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_ 18 19 #include <cstdint> 20 #include <functional> 21 #include <string> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "flatbuffers/flatbuffers.h" // from @flatbuffers 26 #include "tensorflow/lite/delegates/gpu/common/status.h" 27 #include "tensorflow/lite/delegates/gpu/common/types.h" 28 #include "tensorflow/lite/delegates/gpu/gl/compiled_model_generated.h" 29 #include "tensorflow/lite/delegates/gpu/gl/object.h" 30 #include "tensorflow/lite/delegates/gpu/gl/variable.h" 31 32 namespace tflite { 33 namespace gpu { 34 namespace gl { 35 36 struct CompiledModelOptions { 37 // If true, a model was compiled with dynamic batch size and therefore, 38 // a user may change BATCH dimension at runtime. 39 bool dynamic_batch = false; 40 }; 41 42 // Accumulates shaders and programs and stores it in FlatBuffer format. 43 class SerializedCompiledModelBuilder { 44 public: SerializedCompiledModelBuilder()45 SerializedCompiledModelBuilder() : builder_(32 * 1024) {} 46 47 void AddShader(const std::string& shader_src); 48 49 void AddProgram(const std::vector<Variable>& parameters, 50 const std::vector<Object>& objects, 51 const uint3& workgroup_size, const uint3& num_workgroups, 52 size_t shader_index); 53 54 // Returns serialized data that will stay valid until this object is 55 // destroyed. 56 absl::Span<const uint8_t> Finalize(const CompiledModelOptions& options); 57 58 private: 59 std::vector<flatbuffers::Offset<flatbuffers::String>> shaders_; 60 std::vector<flatbuffers::Offset<data::Program>> programs_; 61 ::flatbuffers::FlatBufferBuilder builder_; 62 }; 63 64 // Handles deserialization events. it is guaranteed that shaders will be called 65 // first in the appropriate order and programs come next. 66 class DeserializationHandler { 67 public: 68 virtual ~DeserializationHandler() = default; 69 70 virtual absl::Status OnShader(absl::Span<const char> shader_src) = 0; 71 72 virtual absl::Status OnProgram(const std::vector<Variable>& parameters, 73 const std::vector<Object>& objects, 74 const uint3& workgroup_size, 75 const uint3& num_workgroups, 76 size_t shader_index) = 0; 77 78 virtual void OnOptions(const CompiledModelOptions& options) = 0; 79 }; 80 81 absl::Status DeserializeCompiledModel(absl::Span<const uint8_t> serialized, 82 DeserializationHandler* handler); 83 84 } // namespace gl 85 } // namespace gpu 86 } // namespace tflite 87 88 #endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_SERIALIZATION_H_ 89