1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_ 18 19 #include <map> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/lite/delegates/gpu/common/access_type.h" 25 #include "tensorflow/lite/delegates/gpu/common/status.h" 26 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h" 27 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h" 28 #include "tensorflow/lite/delegates/gpu/common/types.h" 29 #include "tensorflow/lite/delegates/gpu/common/util.h" 30 31 namespace tflite { 32 namespace gpu { 33 namespace cl { 34 class CLArguments; 35 } 36 37 namespace metal { 38 class MetalArguments; 39 } 40 41 class ArgumentsBinder { 42 public: 43 virtual absl::Status SetInt(const std::string& name, int value) = 0; 44 virtual absl::Status SetFloat(const std::string& name, float value) = 0; 45 virtual absl::Status SetHalf(const std::string& name, half value) = 0; 46 virtual ~ArgumentsBinder() = default; 47 }; 48 49 class Arguments { 50 public: 51 Arguments() = default; 52 ~Arguments() = default; 53 54 // Move only 55 Arguments(Arguments&& args) = default; 56 Arguments& operator=(Arguments&& args) = default; 57 Arguments(const Arguments&) = delete; 58 Arguments& operator=(const Arguments&) = delete; 59 60 void AddFloat(const std::string& name, float value = 0.0f); 61 void AddHalf(const std::string& name, half value = half(0.0f)); 62 void AddInt(const std::string& name, int value = 0); 63 void AddObjectRef(const std::string& name, AccessType access_type, 64 GPUObjectDescriptorPtr&& descriptor_ptr); 65 void AddObject(const std::string& name, 66 GPUObjectDescriptorPtr&& descriptor_ptr); 67 68 void RenameArgs(const std::string& postfix, std::string* code) const; 69 absl::Status Merge(Arguments&& args, const std::string& postfix, 70 const std::vector<std::string>& exception_names = {}); 71 72 absl::Status GetDescriptor(const std::string& name, 73 GPUObjectDescriptor** descriptor) const; 74 75 int GetReadTexturesCount(const GpuInfo& gpu_info) const; 76 int GetWriteTexturesCount(const GpuInfo& gpu_info) const; 77 78 void ReleaseCPURepresentation(); 79 80 void GetActiveArguments(const std::string& args_prefix, 81 const std::string& code); 82 83 void SetStateValueForAllObjects(const std::string& key, 84 const std::string& value); 85 86 struct IntValue { 87 int value; 88 89 // many uniforms generated automatically and not used 90 // to reduce amount of data transferred we adding this optimization 91 bool active = false; 92 }; 93 struct FloatValue { 94 float value; 95 96 // many uniforms generated automatically and not used 97 // to reduce amount of data transferred we adding this optimization 98 bool active = false; 99 }; 100 struct HalfValue { 101 half value; 102 103 // many uniforms generated automatically and not used 104 // to reduce amount of data transferred we adding this optimization 105 bool active = false; 106 }; 107 GetIntValues()108 const std::map<std::string, IntValue>& GetIntValues() const { 109 return int_values_; 110 } GetFloatValues()111 const std::map<std::string, FloatValue>& GetFloatValues() const { 112 return float_values_; 113 } GetHalfValues()114 const std::map<std::string, HalfValue>& GetHalfValues() const { 115 return half_values_; 116 } 117 GetObjectRefs()118 const std::map<std::string, GPUObjectDescriptorPtr>& GetObjectRefs() const { 119 return object_refs_; 120 } GetObjects()121 const std::map<std::string, GPUObjectDescriptorPtr>& GetObjects() const { 122 return objects_; 123 } MoveObjectRefs(std::map<std::string,GPUObjectDescriptorPtr> * result)124 void MoveObjectRefs(std::map<std::string, GPUObjectDescriptorPtr>* result) { 125 *result = std::move(object_refs_); 126 } 127 128 private: 129 friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode( 130 const Arguments& args, flatbuffers::FlatBufferBuilder* builder); 131 friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args, 132 Arguments* args); 133 134 friend class cl::CLArguments; 135 friend class metal::MetalArguments; 136 137 std::map<std::string, IntValue> int_values_; 138 std::map<std::string, FloatValue> float_values_; 139 std::map<std::string, HalfValue> half_values_; 140 141 std::map<std::string, GPUObjectDescriptorPtr> object_refs_; 142 std::map<std::string, GPUObjectDescriptorPtr> objects_; 143 }; 144 145 } // namespace gpu 146 } // namespace tflite 147 148 #endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_ 149