• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TENSOR_DESC_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TENSOR_DESC_H_
18 
19 #include <cstddef>
20 #include <string>
21 
22 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
23 #include "tensorflow/lite/delegates/gpu/common/shape.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
25 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
26 
27 namespace tflite {
28 namespace gpu {
29 
30 enum class AddressMode {
31   kDontCare,
32   kZero,
33 };
34 
35 enum class TensorStorageType {
36   UNKNOWN,
37   BUFFER,
38   IMAGE_BUFFER,
39   TEXTURE_2D,
40   TEXTURE_3D,
41   TEXTURE_ARRAY,
42   SINGLE_TEXTURE_2D
43 };
44 
45 struct TensorDescriptor : public GPUObjectDescriptor {
46   TensorDescriptor() = default;
TensorDescriptorTensorDescriptor47   TensorDescriptor(DataType dt, TensorStorageType st, Layout l)
48       : data_type(dt), storage_type(st), layout(l) {}
49 
50   TensorDescriptor(const TensorDescriptor&) = default;
51   TensorDescriptor& operator=(const TensorDescriptor&) = default;
52   TensorDescriptor(TensorDescriptor&& desc);
53   TensorDescriptor& operator=(TensorDescriptor&& desc);
54 
55   bool operator==(const TensorDescriptor& d) const {
56     return data_type == d.data_type && storage_type == d.storage_type &&
57            layout == d.layout;
58   }
59 
60   bool operator!=(const TensorDescriptor& d) const { return !(*this == d); }
61 
62   absl::Status PerformSelector(const GpuInfo& gpu_info,
63                                const std::string& selector,
64                                const std::vector<std::string>& args,
65                                const std::vector<std::string>& template_args,
66                                std::string* result) const override;
67 
68   GPUResources GetGPUResources() const override;
69 
ReleaseTensorDescriptor70   void Release() override { data.clear(); }
71 
72   bool HasAxis(Axis axis) const;
73   void SetAddressMode(AddressMode mode);
74   int GetWidthSize(BHWDC shape) const;
75   int GetSliceStrideSize(BHWDC shape) const;
76 
77   absl::Status GetLinkingContextFromWriteSelector(
78       const std::vector<std::string>& args, std::string* value_name,
79       std::string* x_coord, std::string* y_coord, std::string* s_coord) const;
80 
81   void UploadData(const tflite::gpu::Tensor<BHWC, DataType::FLOAT32>& src);
82   void UploadData(const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src);
83   void UploadData(const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src);
84 
85   bool SupportsZeroClamp(const Axis& axis) const;
86   bool CanReadOutOfBorder(const Axis& axis) const;
87   bool IsLinear() const;
88 
89   // applicable only for types that: IsLinear -> true.
90   // In this case for address we have 1d component - addr (int)
91   // If for addr == -1 this linear storage type returns FLT4(0.0), this function
92   // returns true, otherwise false
93   bool ReturnsZeroForNegOneRead() const;
94 
95   DataType data_type = DataType::UNKNOWN;
96   TensorStorageType storage_type = TensorStorageType::UNKNOWN;
97   // This field describes logical layout, actual(physical) GPU layout can be
98   // totally different.
99   Layout layout =
100       Layout::UNKNOWN;  // Supported layouts is HWC, BHWC, HWDC, BHWDC
101 
102   // optional
103   BHWDC shape;
104   std::vector<uint8_t> data;
105 
106  private:
107   absl::Status PerformReadSelector(
108       const GpuInfo& gpu_info, const std::vector<std::string>& args,
109       const std::vector<std::string>& template_args, std::string* result) const;
110 
111   absl::Status PerformGetAddressSelector(const std::vector<std::string>& args,
112                                          std::string* result) const;
113 
114   absl::Status PerformGetPtrWithSliceOffsetSelector(
115       const std::vector<std::string>& args, std::string* result) const;
116 
117   absl::Status PerformGetWHOffsetSelector(const std::vector<std::string>& args,
118                                           std::string* result) const;
119 
120   absl::Status PerformGetHandleSelector(const std::vector<std::string>& args,
121                                         std::string* result) const;
122 
123   std::string DeclareAddress(const std::string& var_name,
124                              const std::string& address) const;
125 
126   std::string StorageTypeToAddressType() const;
127 
128   absl::Status PerformWriteSelector(const GpuInfo& gpu_info,
129                                     const std::vector<std::string>& args,
130                                     std::string* result) const;
131 
132   absl::Status PerformWriteLinearSelector(const GpuInfo& gpu_info,
133                                           const std::vector<std::string>& args,
134                                           std::string* result) const;
135 
136   absl::Status PerformWrite2DSelector(const GpuInfo& gpu_info,
137                                       const std::vector<std::string>& args,
138                                       std::string* result) const;
139 
140   std::string Read(const GpuInfo& gpu_info, DataType read_as_type,
141                    const std::vector<std::string>& coords) const;
142   std::string Write(const GpuInfo& gpu_info, const std::string& var_name,
143                     const std::vector<std::string>& coords) const;
144 
145   bool IsBatchedWidth() const;
146 
147   std::string GetWidth() const;
148 
149   AddressMode AddressModeFromState() const;
150 
151   absl::Status GetDataTypeFromTemplateArgs(const std::string& template_arg,
152                                            DataType* result) const;
153 
154   std::string GetGlobalAddressNoDeclaration(const std::string& xc,
155                                             const std::string& yc,
156                                             const std::string& zc,
157                                             const std::string& sc,
158                                             const std::string& bc) const;
159 
160   std::vector<std::string> GetPhysicalCoordsWHS(const std::string& x,
161                                                 const std::string& y,
162                                                 const std::string& s) const;
163   std::vector<std::string> GetPhysicalCoordsWHSB(const std::string& x,
164                                                  const std::string& y,
165                                                  const std::string& s,
166                                                  const std::string& b) const;
167   std::vector<std::string> GetPhysicalCoordsWHDS(const std::string& x,
168                                                  const std::string& y,
169                                                  const std::string& z,
170                                                  const std::string& s) const;
171   std::vector<std::string> GetPhysicalCoordsWHDSB(const std::string& x,
172                                                   const std::string& y,
173                                                   const std::string& z,
174                                                   const std::string& s,
175                                                   const std::string& b) const;
176   std::vector<std::string> GetPhysicalCoords(const std::string& xc,
177                                              const std::string& yc,
178                                              const std::string& zc,
179                                              const std::string& sc,
180                                              const std::string& bc) const;
181 
182   bool ParseCoordsFromArgs(const std::vector<std::string>& args, int offset,
183                            std::string* xc, std::string* yc, std::string* zc,
184                            std::string* sc, std::string* bc) const;
185 
186   void UploadData(const float* src);
187 };
188 
189 template <typename T>
190 void DataFromBHWDC(const float* src, const BHWDC& shape,
191                    const TensorDescriptor& desc, T* dst);
192 
193 template <typename T>
194 void DataToBHWDC(const T* src, const BHWDC& shape, const TensorDescriptor& desc,
195                  float* dst);
196 
197 std::string ToString(TensorStorageType type);
198 
199 }  // namespace gpu
200 }  // namespace tflite
201 
202 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASK_TENSOR_DESC_H_
203