• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
18 
19 #include <map>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow/lite/delegates/gpu/common/access_type.h"
24 #include "tensorflow/lite/delegates/gpu/common/status.h"
25 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28 #include "tensorflow/lite/delegates/gpu/common/util.h"
29 
30 namespace tflite {
31 namespace gpu {
32 namespace cl {
33 class CLArguments;
34 }
35 
36 namespace metal {
37 class MetalArguments;
38 }
39 
40 class ArgumentsBinder {
41  public:
42   virtual absl::Status SetInt(const std::string& name, int value) = 0;
43   virtual absl::Status SetFloat(const std::string& name, float value) = 0;
44   virtual absl::Status SetHalf(const std::string& name, half value) = 0;
45   virtual ~ArgumentsBinder() = default;
46 };
47 
48 class Arguments {
49  public:
50   Arguments() = default;
51   ~Arguments() = default;
52 
53   // Move only
54   Arguments(Arguments&& args) = default;
55   Arguments& operator=(Arguments&& args) = default;
56   Arguments(const Arguments&) = delete;
57   Arguments& operator=(const Arguments&) = delete;
58 
59   void AddFloat(const std::string& name, float value = 0.0f);
60   void AddHalf(const std::string& name, half value = half(0.0f));
61   void AddInt(const std::string& name, int value = 0);
62   void AddObjectRef(const std::string& name, AccessType access_type,
63                     GPUObjectDescriptorPtr&& descriptor_ptr);
64   void AddObject(const std::string& name,
65                  GPUObjectDescriptorPtr&& descriptor_ptr);
66 
67   void RenameArgs(const std::string& postfix, std::string* code) const;
68   absl::Status Merge(Arguments&& args, const std::string& postfix);
69 
70   void ReleaseCPURepresentation();
71 
72  private:
73   friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode(
74       const Arguments& args, flatbuffers::FlatBufferBuilder* builder);
75   friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
76                              Arguments* args);
77 
78   friend class cl::CLArguments;
79   friend class metal::MetalArguments;
80   void GetActiveArguments(const std::string& args_prefix,
81                           const std::string& code);
82 
83   struct IntValue {
84     int value;
85 
86     // many uniforms generated automatically and not used
87     // to reduce amount of data transferred we adding this optimization
88     bool active = false;
89   };
90   std::map<std::string, IntValue> int_values_;
91 
92   struct FloatValue {
93     float value;
94 
95     // many uniforms generated automatically and not used
96     // to reduce amount of data transferred we adding this optimization
97     bool active = false;
98   };
99   std::map<std::string, FloatValue> float_values_;
100 
101   struct HalfValue {
102     half value;
103 
104     // many uniforms generated automatically and not used
105     // to reduce amount of data transferred we adding this optimization
106     bool active = false;
107   };
108   std::map<std::string, HalfValue> half_values_;
109 
110   std::map<std::string, GPUObjectDescriptorPtr> object_refs_;
111   std::map<std::string, GPUObjectDescriptorPtr> objects_;
112 };
113 
114 }  // namespace gpu
115 }  // namespace tflite
116 
117 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_ARGUMENTS_H_
118