1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/batch_util.h"
17
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/errors.h"
21
22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
23
24 namespace tensorflow {
25 namespace batch_util {
26
27 namespace {
28
ValidateInput(const Tensor & parent,const Tensor & element,int64 index)29 Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) {
30 DCHECK_NE(parent.dim_size(0), 0);
31 DCHECK_GE(index, 0);
32 if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
33 TensorShape chip_shape = parent.shape();
34 chip_shape.RemoveDim(0);
35 return errors::Internal(
36 "ValidateInput Cannot perform copy: number of elements does not match. "
37 " Shapes are: [element]: ",
38 element.shape().DebugString(),
39 ", [parent slice]: ", chip_shape.DebugString());
40 }
41 return Status::OK();
42 }
43
44 template <typename T>
HandleElementToSlice(Tensor element,Tensor * parent,int64 index,bool)45 Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index,
46 bool /* can_move */) {
47 parent->flat_outer_dims<T>().chip(index, 0) = element.flat<T>();
48 return Status::OK();
49 }
50
51 template <>
HandleElementToSlice(Tensor element,Tensor * parent,int64 index,bool can_move)52 Status HandleElementToSlice<string>(Tensor element, Tensor* parent, int64 index,
53 bool can_move) {
54 auto parent_as_matrix = parent->flat_outer_dims<string>();
55 auto element_flat = element.flat<string>();
56 if (can_move) {
57 for (int64 i = 0; i < element.NumElements(); ++i) {
58 parent_as_matrix(index, i) = std::move(element_flat(i));
59 }
60 } else {
61 parent_as_matrix.chip(index, 0) = element_flat;
62 }
63 return Status::OK();
64 }
65
66 template <>
HandleElementToSlice(Tensor element,Tensor * parent,int64 index,bool can_move)67 Status HandleElementToSlice<Variant>(Tensor element, Tensor* parent,
68 int64 index, bool can_move) {
69 auto parent_as_matrix = parent->flat_outer_dims<Variant>();
70 auto element_flat = element.flat<Variant>();
71 if (can_move) {
72 for (int64 i = 0; i < element.NumElements(); ++i) {
73 parent_as_matrix(index, i) = std::move(element_flat(i));
74 }
75 } else {
76 parent_as_matrix.chip(index, 0) = element_flat;
77 }
78 return Status::OK();
79 }
80
81 // TODO(b/78245576): Consider removing this overload.
82 template <typename T>
HandleSliceToElement(const Tensor & parent,Tensor * element,int64 index)83 void HandleSliceToElement(const Tensor& parent, Tensor* element, int64 index) {
84 element->flat<T>() = parent.flat_outer_dims<T>().chip(index, 0);
85 }
86
87 template <typename T>
HandleSliceToElement(Tensor * parent,Tensor * element,int64 index,bool can_move)88 void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index,
89 bool can_move) {
90 element->flat<T>() = parent->flat_outer_dims<T>().chip(index, 0);
91 }
92
93 template <>
HandleSliceToElement(Tensor * parent,Tensor * element,int64 index,bool can_move)94 void HandleSliceToElement<string>(Tensor* parent, Tensor* element, int64 index,
95 bool can_move) {
96 auto parent_as_matrix = parent->flat_outer_dims<string>();
97 auto element_flat = element->flat<string>();
98 if (can_move) {
99 for (int64 i = 0; i < element->NumElements(); ++i) {
100 element_flat(i) = std::move(parent_as_matrix(index, i));
101 }
102 } else {
103 element_flat = parent_as_matrix.chip(index, 0);
104 }
105 }
106
107 template <>
HandleSliceToElement(Tensor * parent,Tensor * element,int64 index,bool can_move)108 void HandleSliceToElement<Variant>(Tensor* parent, Tensor* element, int64 index,
109 bool can_move) {
110 auto parent_as_matrix = parent->flat_outer_dims<Variant>();
111 auto element_flat = element->flat<Variant>();
112 if (can_move) {
113 for (int64 i = 0; i < element->NumElements(); ++i) {
114 element_flat(i) = std::move(parent_as_matrix(index, i));
115 }
116 } else {
117 element_flat = parent_as_matrix.chip(index, 0);
118 }
119 }
120
121 } // namespace
122
123 // Copies element into the index^th slice of parent (in the 0th dimension).
CopyElementToSlice(Tensor element,Tensor * parent,int64 index)124 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
125 TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
126
127 bool can_move = element.RefCountIsOne();
128 #define HANDLE_TYPE(T) \
129 case DataTypeToEnum<T>::value: { \
130 return HandleElementToSlice<T>(std::move(element), parent, index, \
131 can_move); \
132 }
133
134 switch (element.dtype()) {
135 TF_CALL_ALL_TYPES(HANDLE_TYPE);
136 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
137 TF_CALL_uint32(HANDLE_TYPE);
138 TF_CALL_uint64(HANDLE_TYPE);
139 #undef HANDLE_TYPE
140 default:
141 return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
142 element.dtype());
143 }
144 }
145
146 // Copies the index^th slice of parent (in the 0th dimension) into element.
CopySliceToElement(const Tensor & parent,Tensor * element,int64 index)147 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
148 TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
149
150 #define HANDLE_TYPE(T) \
151 case DataTypeToEnum<T>::value: { \
152 HandleSliceToElement<T>(parent, element, index); \
153 return Status::OK(); \
154 }
155
156 switch (parent.dtype()) {
157 TF_CALL_ALL_TYPES(HANDLE_TYPE);
158 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
159 #undef HANDLE_TYPE
160 default:
161 return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
162 element->dtype());
163 }
164 }
165
166 // Copies the index^th slice of parent (in the 0th dimension) into element.
167 //
168 // NOTE(mrry): The implementation may be able to optimize the copy to a move.
169 // This is particularly important for DT_STRING tensors.
MaybeMoveSliceToElement(Tensor * parent,Tensor * element,int64 index)170 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
171 TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
172 bool can_move = parent->RefCountIsOne();
173
174 #define HANDLE_TYPE(T) \
175 case DataTypeToEnum<T>::value: { \
176 HandleSliceToElement<T>(parent, element, index, can_move); \
177 return Status::OK(); \
178 }
179
180 switch (parent->dtype()) {
181 TF_CALL_ALL_TYPES(HANDLE_TYPE);
182 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
183 #undef HANDLE_TYPE
184 default:
185 return errors::Unimplemented(
186 "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
187 }
188 }
189
190 // The following five functions are copied from padding_fifo_queue.cc.
191 // TODO(mrry): Reconcile these functions with the similar methods in the
192 // queue implementation.
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)193 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
194 DCHECK_NE(parent->dim_size(0), 0);
195 if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
196 TensorShape chip_shape = parent->shape();
197 chip_shape.RemoveDim(0);
198 return errors::Internal(
199 "HandleElementToLargerSlice Cannot copy slice: number of entries in "
200 "element is greater than number of elements in parent slice. ",
201 "Shapes are: [element]: ", element.shape().DebugString(),
202 ", [parent slice]: ", chip_shape.DebugString());
203 }
204 return Status::OK();
205 }
206
207 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)208 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
209 int index) {
210 TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
211 if (element.NumElements() == 0) {
212 return Status::OK();
213 }
214 auto element_t = element.tensor<T, NDIMS>();
215 auto parent_t = parent->tensor<T, NDIMS + 1>();
216 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
217 slice_indices[0] = index;
218 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
219 slice_size[0] = 1;
220 for (size_t i = 1; i < slice_size.size(); ++i) {
221 slice_size[i] = element_t.dimension(i - 1);
222 }
223 parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
224 return Status::OK();
225 }
226
227 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)228 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
229 int index) {
230 #define HANDLE_TYPE(T) \
231 case DataTypeToEnum<T>::value: { \
232 return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
233 }
234
235 switch (element.dtype()) {
236 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
237 #undef HANDLE_TYPE
238 default:
239 return errors::Unimplemented(
240 "HandleElementToLargerSliceWithRank Unhandled data type: ",
241 element.dtype());
242 }
243 }
244
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)245 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
246 int index) {
247 if (parent->dims() != element.dims() + 1) {
248 return errors::Internal(
249 "Mismatched ranks. Element's rank is: ", element.dims(),
250 " but element is meant to be a slice in output Tensor having rank: ",
251 parent->dims(), " (should be: ", element.dims() + 1, ")");
252 }
253
254 #define HANDLE_DIMS(NDIMS) \
255 case NDIMS: { \
256 TF_RETURN_IF_ERROR( \
257 HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
258 return Status::OK(); \
259 }
260
261 switch (element.dims()) {
262 HANDLE_DIMS(0);
263 HANDLE_DIMS(1);
264 HANDLE_DIMS(2);
265 HANDLE_DIMS(3);
266 HANDLE_DIMS(4);
267 HANDLE_DIMS(5);
268 #undef HANDLE_DIMS
269 default:
270 return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
271 element.dims());
272 }
273 }
274
SetElementZero(Tensor * element,const Tensor & padding)275 Status SetElementZero(Tensor* element, const Tensor& padding) {
276 #define HANDLE_TYPE(T) \
277 if (element->dtype() == DataTypeToEnum<T>::value) { \
278 element->flat<T>().setConstant(padding.scalar<T>()()); \
279 return Status::OK(); \
280 }
281 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
282 #undef HANDLE_TYPE
283 return errors::Unimplemented("SetElementZero Unhandled data type: ",
284 element->dtype());
285 }
286
287 } // namespace batch_util
288 } // namespace tensorflow
289