1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
17 #define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
18
19 #include <stddef.h>
20 #include <stdint.h>
21
22 #include <functional>
23 #include <memory>
24 #include <vector>
25
26 #include "tensorflow/c/tf_datatype.h"
27 #include "tensorflow/c/tf_tensor.h"
28 #include "tensorflow/cc/experimental/base/public/status.h"
29
30 namespace tensorflow {
31 namespace experimental {
32 namespace cc {
33
34 // Tensor represents an n-dimensional array of values.
35 class Tensor {
36 public:
37 using DeleterCallback = std::function<void(void*, size_t)>;
38
39 // Constructs a Tensor from user provided buffer.
40 //
41 // Params:
42 // dtype - The dtype of the tensor's data.
43 // shape - A shape vector, where each element corresponds to the size of
44 // the tensor's corresponding dimension.
45 // data - Pointer to a buffer of memory to construct a Tensor out of.
46 // len - The length (in bytes) of `data`
47 // deleter - A std::function to be called when the Tensor no longer needs the
48 // memory in `data`. This can be used to free `data`, or
49 // perhaps decrement a refcount associated with `data`, etc.
50 // status - Set to OK on success and an error on failure.
51 // Returns:
52 // If an error occurred, status->ok() will be false, and the returned
53 // Tensor must not be used.
54 // TODO(bmzhao): Add Runtime as an argument to this function so we can swap to
55 // a TFRT backed tensor.
56 // TODO(bmzhao): Add benchmarks on overhead for this function; we can
57 // consider using int64_t* + length rather than vector.
58 static Tensor FromBuffer(TF_DataType dtype, const std::vector<int64_t>& shape,
59 void* data, size_t len, DeleterCallback deleter,
60 Status* status);
61
62 // TODO(bmzhao): In the case we construct a tensor from non-owned memory,
63 // we should offer a way to deep copy the tensor into a new tensor, which
64 // owns the underlying memory. This could be a .deepcopy()/clone() method.
65
66 // TODO(bmzhao): In the future, we want to relax the non-copyability
67 // constraint. To do so, we can add a C API function that acts like
68 // CopyFrom:
69 // https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311
70
71 // Tensor is movable, but not copyable
72 Tensor(Tensor&&) = default;
73 Tensor& operator=(Tensor&&) = default;
74
75 // Returns the number of dimensions in the tensor. Can be -1, which represents
76 // unknown rank.
77 int dims() const;
78
79 // Returns the number of elements in dimension `d`.
80 // REQUIRES: `0 <= d < dims()`
81 int64_t dim_size(int d) const;
82
83 // Returns a pointer to the underlying data buffer.
84 void* data() const;
85
86 // Returns the data type of the tensor.
87 TF_DataType dtype() const;
88
89 // Returns the number of elements in the tensor. For a tensor with a partially
90 // defined shape, -1 means not fully defined.
91 int64_t num_elements() const;
92
93 // Returns the size of the underlying data in bytes.
94 size_t num_bytes() const;
95
96 private:
97 friend class TensorHandle;
98 friend class Runtime;
99
100 // Wraps a TF_Tensor. Takes ownership of handle.
Tensor(TF_Tensor * tensor)101 explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {}
102
103 // Tensor is not copyable
104 Tensor(const Tensor&) = delete;
105 Tensor& operator=(const Tensor&) = delete;
106
107 // Returns the underlying TF_Tensor that this object wraps.
108 // This object retains ownership of the pointer.
GetTFTensor()109 TF_Tensor* GetTFTensor() const { return tensor_.get(); }
110
111 struct DeleterStruct {
112 std::function<void(void*, size_t)> deleter;
113 };
114
DeleterFunction(void * memory,size_t len,void * deleter_struct)115 static void DeleterFunction(void* memory, size_t len, void* deleter_struct) {
116 DeleterStruct* deleter = reinterpret_cast<DeleterStruct*>(deleter_struct);
117 deleter->deleter(memory, len);
118 delete deleter;
119 }
120
121 struct TFTensorDeleter {
operatorTFTensorDeleter122 void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
123 };
124 std::unique_ptr<TF_Tensor, TFTensorDeleter> tensor_;
125 };
126
data()127 inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); }
128
dims()129 inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); }
130
dim_size(int d)131 inline int64_t Tensor::dim_size(int d) const {
132 return TF_Dim(tensor_.get(), d);
133 }
134
dtype()135 inline TF_DataType Tensor::dtype() const {
136 return TF_TensorType(tensor_.get());
137 }
138
num_elements()139 inline int64_t Tensor::num_elements() const {
140 return TF_TensorElementCount(tensor_.get());
141 }
142
num_bytes()143 inline size_t Tensor::num_bytes() const {
144 return TF_TensorByteSize(tensor_.get());
145 }
146
FromBuffer(TF_DataType dtype,const std::vector<int64_t> & shape,void * data,size_t len,DeleterCallback deleter,Status * status)147 inline Tensor Tensor::FromBuffer(TF_DataType dtype,
148 const std::vector<int64_t>& shape, void* data,
149 size_t len, DeleterCallback deleter,
150 Status* status) {
151 // Credit to apassos@ for this technique:
152 // Despite the fact that our API takes a std::function deleter, we are able
153 // to maintain ABI stability because:
154 // 1. Only a function pointer is sent across the C API (&DeleterFunction)
155 // 2. DeleterFunction is defined in the same build artifact that constructed
156 // the std::function (so there isn't confusion about std::function ABI).
157 // Note that 2. is satisfied by the fact that this is a header-only API, where
158 // the function implementations are inline.
159
160 DeleterStruct* deleter_struct = new DeleterStruct{deleter};
161 TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len,
162 &DeleterFunction, deleter_struct);
163 if (tensor == nullptr) {
164 status->SetStatus(TF_INVALID_ARGUMENT,
165 "Failed to create tensor for input buffer");
166 return Tensor(nullptr);
167 }
168 return Tensor(tensor);
169 }
170
171 } // namespace cc
172 } // namespace experimental
173 } // namespace tensorflow
174
175 #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_
176