• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
18 
19 #import <Metal/Metal.h>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
23 #include "tensorflow/lite/delegates/gpu/common/status.h"
24 #include "tensorflow/lite/delegates/gpu/common/task/texture2d_desc.h"
25 #include "tensorflow/lite/delegates/gpu/metal/common.h"
26 #include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
27 
28 namespace tflite {
29 namespace gpu {
30 namespace metal {
31 
32 // Texture2D represent formatted GPU data storage.
33 // Texture2D is moveable but not copyable.
34 class Texture2D : public GPUObject {
35  public:
Texture2D()36   Texture2D() {}  // just for using Texture2D as a class members
37   Texture2D(id<MTLTexture> texture, int width, int height, MTLPixelFormat pixel_format);
38 
39   // Move only
40   Texture2D(Texture2D&& texture);
41   Texture2D& operator=(Texture2D&& texture);
42   Texture2D(const Texture2D&) = delete;
43   Texture2D& operator=(const Texture2D&) = delete;
44 
~Texture2D()45   ~Texture2D() override { Release(); }
46 
47   // Writes data to a texture. Data should point to a region that
48   // has exact width * height * sizeof(pixel) bytes.
49   template <typename T>
50   absl::Status WriteData(id<MTLDevice> device, const absl::Span<T> data);
51 
52   // Reads data from Texture2D into CPU memory.
53   template <typename T>
54   absl::Status ReadData(id<MTLDevice> device, std::vector<T>* result) const;
55 
56   absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
57                                GPUResourcesWithValue* resources) const override;
58 
59   absl::Status CreateFromTexture2DDescriptor(const Texture2DDescriptor& desc, id<MTLDevice> device);
60 
61  private:
62   void Release();
63 
64   id<MTLTexture> texture_ = nullptr;
65   int width_;
66   int height_;
67   MTLPixelFormat pixel_format_;
68 };
69 
70 // Creates new 4-channel 2D texture with f32 elements
71 absl::Status CreateTexture2DRGBA32F(int width, int height, id<MTLDevice> device, Texture2D* result);
72 
73 // Creates new 4-channel 2D texture with f16 elements
74 absl::Status CreateTexture2DRGBA16F(int width, int height, id<MTLDevice> device, Texture2D* result);
75 
76 absl::Status CreateTexture2DRGBA(DataType type, int width, int height, id<MTLDevice> device,
77                                  Texture2D* result);
78 
79 absl::Status CreateTexture2DRGBA(DataType type, int width, int height, void* data,
80                                  id<MTLDevice> device, Texture2D* result);
81 
82 template <typename T>
WriteData(id<MTLDevice> device,const absl::Span<T> data)83 absl::Status Texture2D::WriteData(id<MTLDevice> device,
84                                   const absl::Span<T> data) {
85   const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
86   if (width_ * height_ * pixel_size != data.size() * sizeof(T)) {
87     return absl::InvalidArgumentError(
88         "absl::Span<T> data size is different from texture allocated size.");
89   }
90 
91   WriteDataToTexture2D(texture_, device, data.data());
92 
93   return absl::OkStatus();
94 }
95 
96 template <typename T>
ReadData(id<MTLDevice> device,std::vector<T> * result)97 absl::Status Texture2D::ReadData(id<MTLDevice> device,
98                                  std::vector<T>* result) const {
99   const int pixel_size = PixelFormatToSizeInBytes(pixel_format_);
100   if (pixel_size % sizeof(T) != 0) {
101     return absl::InvalidArgumentError("Pixel format is different.");
102   }
103   result->resize(width_ * height_ * (pixel_size / sizeof(T)));
104 
105   ReadDataFromTexture2D(texture_, device, result->data());
106 
107   return absl::OkStatus();
108 }
109 
110 }  // namespace metal
111 }  // namespace gpu
112 }  // namespace tflite
113 
114 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_TEXTURE2D_H_
115