1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/batch_util.h"
17
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/errors.h"
21
22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
23
24 namespace tensorflow {
25 namespace batch_util {
26
27 namespace {
28
ValidateInput(const Tensor & parent,const Tensor & element,int64 index)29 Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) {
30 DCHECK_NE(parent.dim_size(0), 0);
31 DCHECK_GE(index, 0);
32 if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
33 TensorShape chip_shape = parent.shape();
34 chip_shape.RemoveDim(0);
35 return errors::Internal(
36 "ValidateInput Cannot perform copy: number of elements does not match. "
37 " Shapes are: [element]: ",
38 element.shape().DebugString(),
39 ", [parent slice]: ", chip_shape.DebugString());
40 }
41 return Status::OK();
42 }
43
44 template <typename T>
HandleElementToSlice(const Tensor &,T * src,T * dest,int64 num_values)45 Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest,
46 int64 num_values) {
47 static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
48 memcpy(dest, src, num_values * sizeof(T));
49 return Status::OK();
50 }
51
52 template <>
HandleElementToSlice(const Tensor & element,tstring * src,tstring * dest,int64 num_values)53 Status HandleElementToSlice<tstring>(const Tensor& element, tstring* src,
54 tstring* dest, int64 num_values) {
55 if (element.RefCountIsOne()) {
56 for (int64 i = 0; i < num_values; ++i) {
57 *dest++ = std::move(*src++);
58 }
59 } else {
60 std::copy_n(src, num_values, dest);
61 }
62 return Status::OK();
63 }
64
65 template <>
HandleElementToSlice(const Tensor & element,Variant * src,Variant * dest,int64 num_values)66 Status HandleElementToSlice<Variant>(const Tensor& element, Variant* src,
67 Variant* dest, int64 num_values) {
68 if (element.RefCountIsOne()) {
69 for (int64 i = 0; i < num_values; ++i) {
70 *dest++ = std::move(*src++);
71 }
72 } else {
73 std::copy_n(src, num_values, dest);
74 }
75 return Status::OK();
76 }
77
78 template <>
HandleElementToSlice(const Tensor &,ResourceHandle * src,ResourceHandle * dest,int64 num_values)79 Status HandleElementToSlice<ResourceHandle>(const Tensor& /* element */,
80 ResourceHandle* src,
81 ResourceHandle* dest,
82 int64 num_values) {
83 std::copy_n(src, num_values, dest);
84 return Status::OK();
85 }
86
87 template <>
HandleElementToSlice(const Tensor &,Eigen::half * src,Eigen::half * dest,int64 num_values)88 Status HandleElementToSlice<Eigen::half>(const Tensor& /* element */,
89 Eigen::half* src, Eigen::half* dest,
90 int64 num_values) {
91 std::copy_n(src, num_values, dest);
92 return Status::OK();
93 }
94
95 template <typename T>
HandleSliceToElement(const T * src,T * dest,int64 num_values)96 void HandleSliceToElement(const T* src, T* dest, int64 num_values) {
97 static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
98 memcpy(dest, src, num_values * sizeof(T));
99 }
100
101 template <>
HandleSliceToElement(const tstring * src,tstring * dest,int64 num_values)102 void HandleSliceToElement<tstring>(const tstring* src, tstring* dest,
103 int64 num_values) {
104 std::copy_n(src, num_values, dest);
105 }
106
107 template <>
HandleSliceToElement(const Variant * src,Variant * dest,int64 num_values)108 void HandleSliceToElement<Variant>(const Variant* src, Variant* dest,
109 int64 num_values) {
110 std::copy_n(src, num_values, dest);
111 }
112
113 template <>
HandleSliceToElement(const ResourceHandle * src,ResourceHandle * dest,int64 num_values)114 void HandleSliceToElement<ResourceHandle>(const ResourceHandle* src,
115 ResourceHandle* dest,
116 int64 num_values) {
117 std::copy_n(src, num_values, dest);
118 }
119
120 template <>
HandleSliceToElement(const Eigen::half * src,Eigen::half * dest,int64 num_values)121 void HandleSliceToElement<Eigen::half>(const Eigen::half* src,
122 Eigen::half* dest, int64 num_values) {
123 std::copy_n(src, num_values, dest);
124 }
125
126 template <typename T>
HandleSliceToElement(Tensor * parent,T * src,T * dest,int64 num_values)127 void HandleSliceToElement(Tensor* parent, T* src, T* dest, int64 num_values) {
128 static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
129 memcpy(dest, src, num_values * sizeof(T));
130 }
131
132 template <>
HandleSliceToElement(Tensor * parent,tstring * src,tstring * dest,int64 num_values)133 void HandleSliceToElement<tstring>(Tensor* parent, tstring* src, tstring* dest,
134 int64 num_values) {
135 if (parent->RefCountIsOne()) {
136 for (int64 i = 0; i < num_values; ++i) {
137 dest[i] = std::move(src[i]);
138 }
139 } else {
140 std::copy_n(src, num_values, dest);
141 }
142 }
143
144 template <>
HandleSliceToElement(Tensor * parent,Variant * src,Variant * dest,int64 num_values)145 void HandleSliceToElement<Variant>(Tensor* parent, Variant* src, Variant* dest,
146 int64 num_values) {
147 if (parent->RefCountIsOne()) {
148 for (int64 i = 0; i < num_values; ++i) {
149 dest[i] = std::move(src[i]);
150 }
151 } else {
152 std::copy_n(src, num_values, dest);
153 }
154 }
155
156 template <>
HandleSliceToElement(Tensor * parent,ResourceHandle * src,ResourceHandle * dest,int64 num_values)157 void HandleSliceToElement<ResourceHandle>(Tensor* parent, ResourceHandle* src,
158 ResourceHandle* dest,
159 int64 num_values) {
160 std::copy_n(src, num_values, dest);
161 }
162
163 template <>
HandleSliceToElement(Tensor * parent,Eigen::half * src,Eigen::half * dest,int64 num_values)164 void HandleSliceToElement<Eigen::half>(Tensor* parent, Eigen::half* src,
165 Eigen::half* dest, int64 num_values) {
166 std::copy_n(src, num_values, dest);
167 }
168
169 } // namespace
170
171 // Copies element into the index^th slice of parent (in the 0th dimension).
CopyElementToSlice(Tensor element,Tensor * parent,int64 index)172 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
173 TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
174 const int64 num_values = element.NumElements();
175 #define HANDLE_TYPE(T) \
176 case DataTypeToEnum<T>::value: { \
177 T* src = element.base<T>(); \
178 T* dest = parent->base<T>() + (num_values * index); \
179 return HandleElementToSlice<T>(element, src, dest, num_values); \
180 }
181
182 switch (element.dtype()) {
183 TF_CALL_ALL_TYPES(HANDLE_TYPE);
184 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
185 #undef HANDLE_TYPE
186 default:
187 return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
188 element.dtype());
189 }
190 }
191
192 // Copies the index^th slice of parent (in the 0th dimension) into element.
CopySliceToElement(const Tensor & parent,Tensor * element,int64 index)193 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
194 TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
195 const int64 num_values = element->NumElements();
196
197 #define HANDLE_TYPE(T) \
198 case DataTypeToEnum<T>::value: { \
199 const T* src = parent.base<T>() + (num_values * index); \
200 T* dest = element->base<T>(); \
201 HandleSliceToElement<T>(src, dest, num_values); \
202 return Status::OK(); \
203 }
204
205 switch (parent.dtype()) {
206 TF_CALL_ALL_TYPES(HANDLE_TYPE);
207 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
208 #undef HANDLE_TYPE
209 default:
210 return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
211 element->dtype());
212 }
213 }
214
CopyContiguousSlices(const Tensor & src,int64 src_offset,int64 dst_offset,int64 num_slices,Tensor * dst)215 Status CopyContiguousSlices(const Tensor& src, int64 src_offset,
216 int64 dst_offset, int64 num_slices, Tensor* dst) {
217 if (src.dtype() != dst->dtype()) {
218 return errors::FailedPrecondition(
219 "CopyContiguousSlices cannot perform copy: src and dst have different "
220 "dtypes. Source dtype: ",
221 src.dtype(), " dstination dtype: ", dst->dtype(), ".");
222 }
223 if (src.dims() < 1) {
224 return errors::FailedPrecondition(
225 "CopyContiguousSlices cannot perform copy: src has to be a tensor with "
226 "rank >= 1. Source shape: ",
227 src.shape().DebugString());
228 }
229
230 if (dst->dims() < 1) {
231 return errors::FailedPrecondition(
232 "CopyContiguousSlices cannot perform copy: dst has to be a tensor "
233 "with rank >= 1. Dest shape: ",
234 dst->shape().DebugString());
235 }
236
237 const int64 src_dim0 = src.dim_size(0);
238 const int64 dst_dim0 = dst->dim_size(0);
239 int64 src_chip_size = 1;
240 int64 dst_chip_size = 1;
241 for (int i = 1; i < src.dims(); ++i) {
242 src_chip_size *= src.dim_size(i);
243 }
244 for (int i = 1; i < dst->dims(); ++i) {
245 dst_chip_size *= dst->dim_size(i);
246 }
247
248 if (src_chip_size != dst_chip_size) {
249 return errors::FailedPrecondition(
250 "CopyContiguousSlices cannot perform copy: source and dst shapes are"
251 "not compatible. Source shape: ",
252 src.shape().DebugString(), ", dst shape: ", dst->shape().DebugString());
253 }
254
255 if (src_chip_size == 0 && dst_chip_size == 0) {
256 return Status::OK();
257 }
258
259 if (src_offset < 0 || src_offset + num_slices > src_dim0 || dst_offset < 0 ||
260 dst_offset + num_slices > dst_dim0) {
261 return errors::FailedPrecondition(
262 "CopyContiguousSlices cannot perform copy: index out of range. "
263 "src_offset: ",
264 src_offset, ", num_slices: ", num_slices, ", src_dim0: ", src_dim0,
265 ", dst_offset: ", dst_offset, ", dst_dim0: ", dst_dim0, ".");
266 }
267
268 #define HANDLE_TYPE(T) \
269 case DataTypeToEnum<T>::value: { \
270 const T* src_p = src.base<T>() + (src_chip_size * src_offset); \
271 T* dst_p = dst->base<T>() + (dst_chip_size * dst_offset); \
272 HandleSliceToElement<T>(src_p, dst_p, src_chip_size * num_slices); \
273 return Status::OK(); \
274 }
275
276 switch (src.dtype()) {
277 TF_CALL_ALL_TYPES(HANDLE_TYPE);
278 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
279 #undef HANDLE_TYPE
280 default:
281 return errors::Unimplemented("CopyContiguousSlices unhandled data type: ",
282 src.dtype());
283 }
284 }
285
286 // Copies the index^th slice of parent (in the 0th dimension) into element.
287 //
288 // NOTE(mrry): The implementation may be able to optimize the copy to a move.
289 // This is particularly important for DT_STRING tensors.
MaybeMoveSliceToElement(Tensor * parent,Tensor * element,int64 index)290 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
291 TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
292 const int64 num_values = element->NumElements();
293
294 #define HANDLE_TYPE(T) \
295 case DataTypeToEnum<T>::value: { \
296 T* src = parent->base<T>() + (num_values * index); \
297 T* dest = element->base<T>(); \
298 HandleSliceToElement<T>(parent, src, dest, num_values); \
299 return Status::OK(); \
300 }
301
302 switch (parent->dtype()) {
303 TF_CALL_ALL_TYPES(HANDLE_TYPE);
304 TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
305 #undef HANDLE_TYPE
306 default:
307 return errors::Unimplemented(
308 "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
309 }
310 }
311
312 // The following five functions are copied from padding_fifo_queue.cc.
313 // TODO(mrry): Reconcile these functions with the similar methods in the
314 // queue implementation.
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)315 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
316 DCHECK_NE(parent->dim_size(0), 0);
317 if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
318 TensorShape chip_shape = parent->shape();
319 chip_shape.RemoveDim(0);
320 return errors::Internal(
321 "HandleElementToLargerSlice Cannot copy slice: number of entries in "
322 "element is greater than number of elements in parent slice. ",
323 "Shapes are: [element]: ", element.shape().DebugString(),
324 ", [parent slice]: ", chip_shape.DebugString());
325 }
326 return Status::OK();
327 }
328
329 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)330 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
331 int index) {
332 TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
333 if (element.NumElements() == 0) {
334 return Status::OK();
335 }
336 auto element_t = element.tensor<T, NDIMS>();
337 auto parent_t = parent->tensor<T, NDIMS + 1>();
338 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
339 slice_indices[0] = index;
340 Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
341 slice_size[0] = 1;
342 for (size_t i = 1; i < slice_size.size(); ++i) {
343 slice_size[i] = element_t.dimension(i - 1);
344 }
345 parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
346 return Status::OK();
347 }
348
349 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)350 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
351 int index) {
352 #define HANDLE_TYPE(T) \
353 case DataTypeToEnum<T>::value: { \
354 return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
355 }
356
357 switch (element.dtype()) {
358 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
359 #undef HANDLE_TYPE
360 default:
361 return errors::Unimplemented(
362 "HandleElementToLargerSliceWithRank Unhandled data type: ",
363 element.dtype());
364 }
365 }
366
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)367 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
368 int index) {
369 if (parent->dims() != element.dims() + 1) {
370 return errors::Internal(
371 "Mismatched ranks. Element's rank is: ", element.dims(),
372 " but element is meant to be a slice in output Tensor having rank: ",
373 parent->dims(), " (should be: ", element.dims() + 1, ")");
374 }
375
376 #define HANDLE_DIMS(NDIMS) \
377 case NDIMS: { \
378 TF_RETURN_IF_ERROR( \
379 HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
380 return Status::OK(); \
381 }
382
383 switch (element.dims()) {
384 HANDLE_DIMS(0);
385 HANDLE_DIMS(1);
386 HANDLE_DIMS(2);
387 HANDLE_DIMS(3);
388 HANDLE_DIMS(4);
389 HANDLE_DIMS(5);
390 #undef HANDLE_DIMS
391 default:
392 return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
393 element.dims());
394 }
395 }
396
SetElementZero(Tensor * element,const Tensor & padding)397 Status SetElementZero(Tensor* element, const Tensor& padding) {
398 #define HANDLE_TYPE(T) \
399 if (element->dtype() == DataTypeToEnum<T>::value) { \
400 element->flat<T>().setConstant(padding.scalar<T>()()); \
401 return Status::OK(); \
402 }
403 TF_CALL_DATASET_TYPES(HANDLE_TYPE);
404 #undef HANDLE_TYPE
405 return errors::Unimplemented("SetElementZero Unhandled data type: ",
406 element->dtype());
407 }
408
409 } // namespace batch_util
410 } // namespace tensorflow
411