• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the Licensgoe is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_
18 
19 #import <Metal/Metal.h>
20 
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
24 #include "tensorflow/lite/delegates/gpu/common/util.h"
25 #include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
26 
27 namespace tflite {
28 namespace gpu {
29 namespace metal {
30 
31 class MetalSpatialTensor : public GPUObject, public GpuSpatialTensor {
32  public:
MetalSpatialTensor()33   MetalSpatialTensor()
34       : memory_(nullptr),
35         texture_mem_(nullptr),
36         memory_owner_(true),
37         texture_mem_owner_(true) {}
38   MetalSpatialTensor(id<MTLBuffer> buffer, id<MTLTexture> texture,
39                      bool memory_owner, bool texture_mem_owner,
40                      const BHWC& shape, const TensorDescriptor& descriptor);
41   MetalSpatialTensor(id<MTLBuffer> buffer, id<MTLTexture> texture,
42                      bool memory_owner, bool texture_mem_owner,
43                      const BHWDC& shape, const TensorDescriptor& descriptor);
44 
45   // Move only
46   MetalSpatialTensor(MetalSpatialTensor&& tensor);
47   MetalSpatialTensor& operator=(MetalSpatialTensor&& tensor);
48   MetalSpatialTensor(const MetalSpatialTensor&) = delete;
49   MetalSpatialTensor& operator=(const MetalSpatialTensor&) = delete;
50 
~MetalSpatialTensor()51   ~MetalSpatialTensor() override { Release(); }
52 
53   absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
54                                GPUResourcesWithValue* resources) const override;
55 
Width()56   int Width() const override { return shape_.w; }
Height()57   int Height() const override { return shape_.h; }
Depth()58   int Depth() const override { return shape_.d; }
Channels()59   int Channels() const override { return shape_.c; }
Slices()60   int Slices() const override { return DivideRoundUp(shape_.c, 4); }
Batch()61   int Batch() const override { return shape_.b; }
62 
GetDescriptor()63   TensorDescriptor GetDescriptor() const { return descriptor_; }
GetDataType()64   DataType GetDataType() const { return descriptor_.data_type; }
GetStorageType()65   TensorStorageType GetStorageType() const { return descriptor_.storage_type; }
66 
67   // for profiling and memory statistics
68   uint64_t GetMemorySizeInBytes() const;
69 
70   absl::Status WriteData(id<MTLDevice> device, const TensorFloat32& src);
71   absl::Status WriteData(
72       id<MTLDevice> device,
73       const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src);
74   absl::Status WriteData(
75       id<MTLDevice> device,
76       const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src);
77   absl::Status WriteData(id<MTLDevice> device, const Tensor5DFloat32& src);
78   absl::Status ReadData(id<MTLDevice> device, TensorFloat32* dst) const;
79   absl::Status ReadData(id<MTLDevice> device, Tensor5DFloat32* dst) const;
80 
81   absl::Status CreateFromDescriptor(const TensorDescriptor& desc,
82                                     id<MTLDevice> device);
83 
84   absl::Status SetBufferHandle(id<MTLBuffer> buffer);
85   id<MTLBuffer> GetBufferHandle() const;
86 
87  private:
88   absl::Status IsValid(const BHWC& shape) const;
89   absl::Status IsValid(const BHWDC& shape) const;
90 
91   absl::Status WriteDataBHWDC(id<MTLDevice> device, const float* in);
92   absl::Status ReadDataBHWDC(id<MTLDevice> device, float* out) const;
93 
94   int GetAlignedChannels() const;
95   int3 GetFullTensorRegion() const;
96   void Release();
97 
98   id<MTLBuffer> memory_;
99   id<MTLTexture> texture_mem_;
100   bool memory_owner_;
101   bool texture_mem_owner_;
102   BHWDC shape_;
103   TensorDescriptor descriptor_;
104 };
105 
106 absl::Status CreateTensor(id<MTLDevice> device, const BHWC& shape,
107                           const TensorDescriptor& descriptor,
108                           MetalSpatialTensor* result);
109 
110 absl::Status CreateTensor(id<MTLDevice> device, const BHWDC& shape,
111                           const TensorDescriptor& descriptor,
112                           MetalSpatialTensor* result);
113 
114 absl::Status CreateSharedBufferTensor(id<MTLBuffer> buffer, const BHWC& shape,
115                                       const TensorDescriptor& descriptor,
116                                       MetalSpatialTensor* result);
117 
118 absl::Status CreateSharedBufferTensor(id<MTLBuffer> buffer, const BHWDC& shape,
119                                       const TensorDescriptor& descriptor,
120                                       MetalSpatialTensor* result);
121 
122 }  // namespace metal
123 }  // namespace gpu
124 }  // namespace tflite
125 
126 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_
127