1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the Licensgoe is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
18
19 #include <string>
20 #include <vector>
21
22 #import <Metal/Metal.h>
23
24 #include "absl/types/span.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
27 #include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
28
29 namespace tflite {
30 namespace gpu {
31 namespace metal {
32
33 class Buffer : public GPUObject {
34 public:
Buffer()35 Buffer() {} // just for using Buffer as a class members
36 Buffer(id<MTLBuffer> buffer, size_t size_in_bytes);
37
38 // Move only
39 Buffer(Buffer&& buffer);
40 Buffer& operator=(Buffer&& buffer);
41 Buffer(const Buffer&) = delete;
42 Buffer& operator=(const Buffer&) = delete;
43
44 ~Buffer();
45
46 // for profiling and memory statistics
GetMemorySizeInBytes()47 uint64_t GetMemorySizeInBytes() const { return size_; }
48
GetMemoryPtr()49 id<MTLBuffer> GetMemoryPtr() const { return buffer_; }
50
51 // Writes data to a buffer. Data should point to a region that
52 // has exact size in bytes as size_in_bytes(constructor parameter).
53 template <typename T>
54 absl::Status WriteData(const absl::Span<T> data);
55
56 // Reads data from Buffer into CPU memory.
57 template <typename T>
58 absl::Status ReadData(std::vector<T>* result) const;
59
60 absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
61 GPUResourcesWithValue* resources) const override;
62
63 absl::Status CreateFromBufferDescriptor(const BufferDescriptor& desc, id<MTLDevice> device);
64
65 private:
66 void Release();
67
68 id<MTLBuffer> buffer_ = nullptr;
69 size_t size_;
70 };
71
72 absl::Status CreateBuffer(size_t size_in_bytes, const void* data, id<MTLDevice> device,
73 Buffer* result);
74
75 template <typename T>
WriteData(const absl::Span<T> data)76 absl::Status Buffer::WriteData(const absl::Span<T> data) {
77 if (size_ != sizeof(T) * data.size()) {
78 return absl::InvalidArgumentError(
79 "absl::Span<T> data size is different from buffer allocated size.");
80 }
81 std::memcpy([buffer_ contents], data.data(), size_);
82 return absl::OkStatus();
83 }
84
85 template <typename T>
ReadData(std::vector<T> * result)86 absl::Status Buffer::ReadData(std::vector<T>* result) const {
87 if (size_ % sizeof(T) != 0) {
88 return absl::UnknownError("Wrong element size(typename T is not correct?");
89 }
90
91 const int elements_count = size_ / sizeof(T);
92 result->resize(elements_count);
93 std::memcpy(result->data(), [buffer_ contents], size_);
94
95 return absl::OkStatus();
96 }
97
98 } // namespace metal
99 } // namespace gpu
100 } // namespace tflite
101
102 #endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_BUFFER_H_
103