• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
18 
19 #include <map>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/lite/delegates/gpu/common/access_type.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
28 #include "tensorflow/lite/delegates/gpu/common/types.h"
29 #include "tensorflow/lite/delegates/gpu/common/util.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace cl {
34 class CLArguments;
35 }
36 
37 namespace metal {
38 class MetalArguments;
39 }
40 
41 class ArgumentsBinder {
42  public:
43   virtual absl::Status SetInt(const std::string& name, int value) = 0;
44   virtual absl::Status SetFloat(const std::string& name, float value) = 0;
45   virtual absl::Status SetHalf(const std::string& name, half value) = 0;
46   virtual ~ArgumentsBinder() = default;
47 };
48 
49 class Arguments {
50  public:
51   Arguments() = default;
52   ~Arguments() = default;
53 
54   // Move only
55   Arguments(Arguments&& args) = default;
56   Arguments& operator=(Arguments&& args) = default;
57   Arguments(const Arguments&) = delete;
58   Arguments& operator=(const Arguments&) = delete;
59 
60   void AddFloat(const std::string& name, float value = 0.0f);
61   void AddHalf(const std::string& name, half value = half(0.0f));
62   void AddInt(const std::string& name, int value = 0);
63   void AddObjectRef(const std::string& name, AccessType access_type,
64                     GPUObjectDescriptorPtr&& descriptor_ptr);
65   void AddObject(const std::string& name,
66                  GPUObjectDescriptorPtr&& descriptor_ptr);
67 
68   void RenameArgs(const std::string& postfix, std::string* code) const;
69   absl::Status Merge(Arguments&& args, const std::string& postfix,
70                      const std::vector<std::string>& exception_names = {});
71 
72   absl::Status GetDescriptor(const std::string& name,
73                              GPUObjectDescriptor** descriptor) const;
74 
75   int GetReadTexturesCount(const GpuInfo& gpu_info) const;
76   int GetWriteTexturesCount(const GpuInfo& gpu_info) const;
77 
78   void ReleaseCPURepresentation();
79 
80   void GetActiveArguments(const std::string& args_prefix,
81                           const std::string& code);
82 
83   void SetStateValueForAllObjects(const std::string& key,
84                                   const std::string& value);
85 
86   struct IntValue {
87     int value;
88 
89     // many uniforms generated automatically and not used
90     // to reduce amount of data transferred we adding this optimization
91     bool active = false;
92   };
93   struct FloatValue {
94     float value;
95 
96     // many uniforms generated automatically and not used
97     // to reduce amount of data transferred we adding this optimization
98     bool active = false;
99   };
100   struct HalfValue {
101     half value;
102 
103     // many uniforms generated automatically and not used
104     // to reduce amount of data transferred we adding this optimization
105     bool active = false;
106   };
107 
GetIntValues()108   const std::map<std::string, IntValue>& GetIntValues() const {
109     return int_values_;
110   }
GetFloatValues()111   const std::map<std::string, FloatValue>& GetFloatValues() const {
112     return float_values_;
113   }
GetHalfValues()114   const std::map<std::string, HalfValue>& GetHalfValues() const {
115     return half_values_;
116   }
117 
GetObjectRefs()118   const std::map<std::string, GPUObjectDescriptorPtr>& GetObjectRefs() const {
119     return object_refs_;
120   }
GetObjects()121   const std::map<std::string, GPUObjectDescriptorPtr>& GetObjects() const {
122     return objects_;
123   }
MoveObjectRefs(std::map<std::string,GPUObjectDescriptorPtr> * result)124   void MoveObjectRefs(std::map<std::string, GPUObjectDescriptorPtr>* result) {
125     *result = std::move(object_refs_);
126   }
127 
128  private:
129   friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode(
130       const Arguments& args, flatbuffers::FlatBufferBuilder* builder);
131   friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
132                              Arguments* args);
133 
134   friend class cl::CLArguments;
135   friend class metal::MetalArguments;
136 
137   std::map<std::string, IntValue> int_values_;
138   std::map<std::string, FloatValue> float_values_;
139   std::map<std::string, HalfValue> half_values_;
140 
141   std::map<std::string, GPUObjectDescriptorPtr> object_refs_;
142   std::map<std::string, GPUObjectDescriptorPtr> objects_;
143 };
144 
145 }  // namespace gpu
146 }  // namespace tflite
147 
148 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
149