1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the Licensgoe is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_ 17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_ 18 19 #import <Metal/Metal.h> 20 21 #include "tensorflow/lite/delegates/gpu/common/status.h" 22 #include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h" 23 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" 24 #include "tensorflow/lite/delegates/gpu/common/util.h" 25 #include "tensorflow/lite/delegates/gpu/metal/gpu_object.h" 26 27 namespace tflite { 28 namespace gpu { 29 namespace metal { 30 31 class MetalSpatialTensor : public GPUObject, public GpuSpatialTensor { 32 public: MetalSpatialTensor()33 MetalSpatialTensor() 34 : memory_(nullptr), 35 texture_mem_(nullptr), 36 memory_owner_(true), 37 texture_mem_owner_(true) {} 38 MetalSpatialTensor(id<MTLBuffer> buffer, id<MTLTexture> texture, 39 bool memory_owner, bool texture_mem_owner, 40 const BHWC& shape, const TensorDescriptor& descriptor); 41 MetalSpatialTensor(id<MTLBuffer> buffer, id<MTLTexture> texture, 42 bool memory_owner, bool texture_mem_owner, 43 const BHWDC& shape, const TensorDescriptor& descriptor); 44 45 // Move only 46 MetalSpatialTensor(MetalSpatialTensor&& tensor); 47 MetalSpatialTensor& operator=(MetalSpatialTensor&& tensor); 48 MetalSpatialTensor(const MetalSpatialTensor&) = delete; 49 MetalSpatialTensor& operator=(const MetalSpatialTensor&) = delete; 50 ~MetalSpatialTensor()51 ~MetalSpatialTensor() override { Release(); } 52 53 absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr, 54 GPUResourcesWithValue* resources) const override; 55 Width()56 int Width() const override { return shape_.w; } Height()57 int Height() const override { return shape_.h; } Depth()58 int Depth() const override { return shape_.d; } Channels()59 int Channels() const override { return shape_.c; } Slices()60 int Slices() const override { return DivideRoundUp(shape_.c, 4); } Batch()61 int Batch() const override { return shape_.b; } 62 GetDescriptor()63 TensorDescriptor GetDescriptor() const { return descriptor_; } GetDataType()64 DataType GetDataType() const { return descriptor_.data_type; } GetStorageType()65 TensorStorageType GetStorageType() const { return descriptor_.storage_type; } 66 67 // for profiling and memory statistics 68 uint64_t GetMemorySizeInBytes() const; 69 70 absl::Status WriteData(id<MTLDevice> device, const TensorFloat32& src); 71 absl::Status WriteData( 72 id<MTLDevice> device, 73 const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& src); 74 absl::Status WriteData( 75 id<MTLDevice> device, 76 const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& src); 77 absl::Status WriteData(id<MTLDevice> device, const Tensor5DFloat32& src); 78 absl::Status ReadData(id<MTLDevice> device, TensorFloat32* dst) const; 79 absl::Status ReadData(id<MTLDevice> device, Tensor5DFloat32* dst) const; 80 81 absl::Status CreateFromDescriptor(const TensorDescriptor& desc, 82 id<MTLDevice> device); 83 84 absl::Status SetBufferHandle(id<MTLBuffer> buffer); 85 id<MTLBuffer> GetBufferHandle() const; 86 87 private: 88 absl::Status IsValid(const BHWC& shape) const; 89 absl::Status IsValid(const BHWDC& shape) const; 90 91 absl::Status WriteDataBHWDC(id<MTLDevice> device, const float* in); 92 absl::Status ReadDataBHWDC(id<MTLDevice> device, float* out) const; 93 94 int GetAlignedChannels() const; 95 int3 GetFullTensorRegion() const; 96 void Release(); 97 98 id<MTLBuffer> memory_; 99 id<MTLTexture> texture_mem_; 100 bool memory_owner_; 101 bool texture_mem_owner_; 102 BHWDC shape_; 103 TensorDescriptor descriptor_; 104 }; 105 106 absl::Status CreateTensor(id<MTLDevice> device, const BHWC& shape, 107 const TensorDescriptor& descriptor, 108 MetalSpatialTensor* result); 109 110 absl::Status CreateTensor(id<MTLDevice> device, const BHWDC& shape, 111 const TensorDescriptor& descriptor, 112 MetalSpatialTensor* result); 113 114 absl::Status CreateSharedBufferTensor(id<MTLBuffer> buffer, const BHWC& shape, 115 const TensorDescriptor& descriptor, 116 MetalSpatialTensor* result); 117 118 absl::Status CreateSharedBufferTensor(id<MTLBuffer> buffer, const BHWDC& shape, 119 const TensorDescriptor& descriptor, 120 MetalSpatialTensor* result); 121 122 } // namespace metal 123 } // namespace gpu 124 } // namespace tflite 125 126 #endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_ 127