• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/batch_util.h"
17 
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 
22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
23 
24 namespace tensorflow {
25 namespace batch_util {
26 
27 namespace {
28 
ValidateInput(const Tensor & parent,const Tensor & element,int64 index)29 Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) {
30   DCHECK_NE(parent.dim_size(0), 0);
31   DCHECK_GE(index, 0);
32   if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
33     TensorShape chip_shape = parent.shape();
34     chip_shape.RemoveDim(0);
35     return errors::Internal(
36         "ValidateInput Cannot perform copy: number of elements does not match. "
37         " Shapes are: [element]: ",
38         element.shape().DebugString(),
39         ", [parent slice]: ", chip_shape.DebugString());
40   }
41   return Status::OK();
42 }
43 
44 template <typename T>
HandleElementToSlice(const Tensor &,T * src,T * dest,int64 num_values)45 Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest,
46                             int64 num_values) {
47   static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
48   memcpy(dest, src, num_values * sizeof(T));
49   return Status::OK();
50 }
51 
52 template <>
HandleElementToSlice(const Tensor & element,tstring * src,tstring * dest,int64 num_values)53 Status HandleElementToSlice<tstring>(const Tensor& element, tstring* src,
54                                      tstring* dest, int64 num_values) {
55   if (element.RefCountIsOne()) {
56     for (int64 i = 0; i < num_values; ++i) {
57       *dest++ = std::move(*src++);
58     }
59   } else {
60     std::copy_n(src, num_values, dest);
61   }
62   return Status::OK();
63 }
64 
65 template <>
HandleElementToSlice(const Tensor & element,Variant * src,Variant * dest,int64 num_values)66 Status HandleElementToSlice<Variant>(const Tensor& element, Variant* src,
67                                      Variant* dest, int64 num_values) {
68   if (element.RefCountIsOne()) {
69     for (int64 i = 0; i < num_values; ++i) {
70       *dest++ = std::move(*src++);
71     }
72   } else {
73     std::copy_n(src, num_values, dest);
74   }
75   return Status::OK();
76 }
77 
78 template <>
HandleElementToSlice(const Tensor &,ResourceHandle * src,ResourceHandle * dest,int64 num_values)79 Status HandleElementToSlice<ResourceHandle>(const Tensor& /* element */,
80                                             ResourceHandle* src,
81                                             ResourceHandle* dest,
82                                             int64 num_values) {
83   std::copy_n(src, num_values, dest);
84   return Status::OK();
85 }
86 
87 template <>
HandleElementToSlice(const Tensor &,Eigen::half * src,Eigen::half * dest,int64 num_values)88 Status HandleElementToSlice<Eigen::half>(const Tensor& /* element */,
89                                          Eigen::half* src, Eigen::half* dest,
90                                          int64 num_values) {
91   std::copy_n(src, num_values, dest);
92   return Status::OK();
93 }
94 
95 template <typename T>
HandleSliceToElement(const T * src,T * dest,int64 num_values)96 void HandleSliceToElement(const T* src, T* dest, int64 num_values) {
97   static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
98   memcpy(dest, src, num_values * sizeof(T));
99 }
100 
101 template <>
HandleSliceToElement(const tstring * src,tstring * dest,int64 num_values)102 void HandleSliceToElement<tstring>(const tstring* src, tstring* dest,
103                                    int64 num_values) {
104   std::copy_n(src, num_values, dest);
105 }
106 
107 template <>
HandleSliceToElement(const Variant * src,Variant * dest,int64 num_values)108 void HandleSliceToElement<Variant>(const Variant* src, Variant* dest,
109                                    int64 num_values) {
110   std::copy_n(src, num_values, dest);
111 }
112 
113 template <>
HandleSliceToElement(const ResourceHandle * src,ResourceHandle * dest,int64 num_values)114 void HandleSliceToElement<ResourceHandle>(const ResourceHandle* src,
115                                           ResourceHandle* dest,
116                                           int64 num_values) {
117   std::copy_n(src, num_values, dest);
118 }
119 
120 template <>
HandleSliceToElement(const Eigen::half * src,Eigen::half * dest,int64 num_values)121 void HandleSliceToElement<Eigen::half>(const Eigen::half* src,
122                                        Eigen::half* dest, int64 num_values) {
123   std::copy_n(src, num_values, dest);
124 }
125 
126 template <typename T>
HandleSliceToElement(Tensor * parent,T * src,T * dest,int64 num_values)127 void HandleSliceToElement(Tensor* parent, T* src, T* dest, int64 num_values) {
128   static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
129   memcpy(dest, src, num_values * sizeof(T));
130 }
131 
132 template <>
HandleSliceToElement(Tensor * parent,tstring * src,tstring * dest,int64 num_values)133 void HandleSliceToElement<tstring>(Tensor* parent, tstring* src, tstring* dest,
134                                    int64 num_values) {
135   if (parent->RefCountIsOne()) {
136     for (int64 i = 0; i < num_values; ++i) {
137       dest[i] = std::move(src[i]);
138     }
139   } else {
140     std::copy_n(src, num_values, dest);
141   }
142 }
143 
144 template <>
HandleSliceToElement(Tensor * parent,Variant * src,Variant * dest,int64 num_values)145 void HandleSliceToElement<Variant>(Tensor* parent, Variant* src, Variant* dest,
146                                    int64 num_values) {
147   if (parent->RefCountIsOne()) {
148     for (int64 i = 0; i < num_values; ++i) {
149       dest[i] = std::move(src[i]);
150     }
151   } else {
152     std::copy_n(src, num_values, dest);
153   }
154 }
155 
156 template <>
HandleSliceToElement(Tensor * parent,ResourceHandle * src,ResourceHandle * dest,int64 num_values)157 void HandleSliceToElement<ResourceHandle>(Tensor* parent, ResourceHandle* src,
158                                           ResourceHandle* dest,
159                                           int64 num_values) {
160   std::copy_n(src, num_values, dest);
161 }
162 
163 template <>
HandleSliceToElement(Tensor * parent,Eigen::half * src,Eigen::half * dest,int64 num_values)164 void HandleSliceToElement<Eigen::half>(Tensor* parent, Eigen::half* src,
165                                        Eigen::half* dest, int64 num_values) {
166   std::copy_n(src, num_values, dest);
167 }
168 
169 }  // namespace
170 
171 // Copies element into the index^th slice of parent (in the 0th dimension).
CopyElementToSlice(Tensor element,Tensor * parent,int64 index)172 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
173   TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
174   const int64 num_values = element.NumElements();
175 #define HANDLE_TYPE(T)                                              \
176   case DataTypeToEnum<T>::value: {                                  \
177     T* src = element.base<T>();                                     \
178     T* dest = parent->base<T>() + (num_values * index);             \
179     return HandleElementToSlice<T>(element, src, dest, num_values); \
180   }
181 
182   switch (element.dtype()) {
183     TF_CALL_ALL_TYPES(HANDLE_TYPE);
184     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
185 #undef HANDLE_TYPE
186     default:
187       return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
188                                    element.dtype());
189   }
190 }
191 
192 // Copies the index^th slice of parent (in the 0th dimension) into element.
CopySliceToElement(const Tensor & parent,Tensor * element,int64 index)193 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
194   TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
195   const int64 num_values = element->NumElements();
196 
197 #define HANDLE_TYPE(T)                                      \
198   case DataTypeToEnum<T>::value: {                          \
199     const T* src = parent.base<T>() + (num_values * index); \
200     T* dest = element->base<T>();                           \
201     HandleSliceToElement<T>(src, dest, num_values);         \
202     return Status::OK();                                    \
203   }
204 
205   switch (parent.dtype()) {
206     TF_CALL_ALL_TYPES(HANDLE_TYPE);
207     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
208 #undef HANDLE_TYPE
209     default:
210       return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
211                                    element->dtype());
212   }
213 }
214 
CopyContiguousSlices(const Tensor & src,int64 src_offset,int64 dst_offset,int64 num_slices,Tensor * dst)215 Status CopyContiguousSlices(const Tensor& src, int64 src_offset,
216                             int64 dst_offset, int64 num_slices, Tensor* dst) {
217   if (src.dtype() != dst->dtype()) {
218     return errors::FailedPrecondition(
219         "CopyContiguousSlices cannot perform copy: src and dst have different "
220         "dtypes. Source dtype: ",
221         src.dtype(), " dstination dtype: ", dst->dtype(), ".");
222   }
223   if (src.dims() < 1) {
224     return errors::FailedPrecondition(
225         "CopyContiguousSlices cannot perform copy: src has to be a tensor with "
226         "rank >= 1. Source shape: ",
227         src.shape().DebugString());
228   }
229 
230   if (dst->dims() < 1) {
231     return errors::FailedPrecondition(
232         "CopyContiguousSlices cannot perform copy: dst has to be a tensor "
233         "with rank >= 1. Dest shape: ",
234         dst->shape().DebugString());
235   }
236 
237   const int64 src_dim0 = src.dim_size(0);
238   const int64 dst_dim0 = dst->dim_size(0);
239   int64 src_chip_size = 1;
240   int64 dst_chip_size = 1;
241   for (int i = 1; i < src.dims(); ++i) {
242     src_chip_size *= src.dim_size(i);
243   }
244   for (int i = 1; i < dst->dims(); ++i) {
245     dst_chip_size *= dst->dim_size(i);
246   }
247 
248   if (src_chip_size != dst_chip_size) {
249     return errors::FailedPrecondition(
250         "CopyContiguousSlices cannot perform copy: source and dst shapes are"
251         "not compatible. Source shape: ",
252         src.shape().DebugString(), ", dst shape: ", dst->shape().DebugString());
253   }
254 
255   if (src_chip_size == 0 && dst_chip_size == 0) {
256     return Status::OK();
257   }
258 
259   if (src_offset < 0 || src_offset + num_slices > src_dim0 || dst_offset < 0 ||
260       dst_offset + num_slices > dst_dim0) {
261     return errors::FailedPrecondition(
262         "CopyContiguousSlices cannot perform copy: index out of range. "
263         "src_offset: ",
264         src_offset, ", num_slices: ", num_slices, ", src_dim0: ", src_dim0,
265         ", dst_offset: ", dst_offset, ", dst_dim0: ", dst_dim0, ".");
266   }
267 
268 #define HANDLE_TYPE(T)                                                 \
269   case DataTypeToEnum<T>::value: {                                     \
270     const T* src_p = src.base<T>() + (src_chip_size * src_offset);     \
271     T* dst_p = dst->base<T>() + (dst_chip_size * dst_offset);          \
272     HandleSliceToElement<T>(src_p, dst_p, src_chip_size * num_slices); \
273     return Status::OK();                                               \
274   }
275 
276   switch (src.dtype()) {
277     TF_CALL_ALL_TYPES(HANDLE_TYPE);
278     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
279 #undef HANDLE_TYPE
280     default:
281       return errors::Unimplemented("CopyContiguousSlices unhandled data type: ",
282                                    src.dtype());
283   }
284 }
285 
286 // Copies the index^th slice of parent (in the 0th dimension) into element.
287 //
288 // NOTE(mrry): The implementation may be able to optimize the copy to a move.
289 // This is particularly important for DT_STRING tensors.
MaybeMoveSliceToElement(Tensor * parent,Tensor * element,int64 index)290 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
291   TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
292   const int64 num_values = element->NumElements();
293 
294 #define HANDLE_TYPE(T)                                      \
295   case DataTypeToEnum<T>::value: {                          \
296     T* src = parent->base<T>() + (num_values * index);      \
297     T* dest = element->base<T>();                           \
298     HandleSliceToElement<T>(parent, src, dest, num_values); \
299     return Status::OK();                                    \
300   }
301 
302   switch (parent->dtype()) {
303     TF_CALL_ALL_TYPES(HANDLE_TYPE);
304     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
305 #undef HANDLE_TYPE
306     default:
307       return errors::Unimplemented(
308           "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
309   }
310 }
311 
312 // The following five functions are copied from padding_fifo_queue.cc.
313 // TODO(mrry): Reconcile these functions with the similar methods in the
314 // queue implementation.
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)315 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
316   DCHECK_NE(parent->dim_size(0), 0);
317   if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
318     TensorShape chip_shape = parent->shape();
319     chip_shape.RemoveDim(0);
320     return errors::Internal(
321         "HandleElementToLargerSlice Cannot copy slice: number of entries in "
322         "element is greater than number of elements in parent slice.  ",
323         "Shapes are: [element]: ", element.shape().DebugString(),
324         ", [parent slice]: ", chip_shape.DebugString());
325   }
326   return Status::OK();
327 }
328 
329 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)330 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
331                                   int index) {
332   TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
333   if (element.NumElements() == 0) {
334     return Status::OK();
335   }
336   auto element_t = element.tensor<T, NDIMS>();
337   auto parent_t = parent->tensor<T, NDIMS + 1>();
338   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
339   slice_indices[0] = index;
340   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
341   slice_size[0] = 1;
342   for (size_t i = 1; i < slice_size.size(); ++i) {
343     slice_size[i] = element_t.dimension(i - 1);
344   }
345   parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
346   return Status::OK();
347 }
348 
349 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)350 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
351                                           int index) {
352 #define HANDLE_TYPE(T)                                                   \
353   case DataTypeToEnum<T>::value: {                                       \
354     return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
355   }
356 
357   switch (element.dtype()) {
358     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
359 #undef HANDLE_TYPE
360     default:
361       return errors::Unimplemented(
362           "HandleElementToLargerSliceWithRank Unhandled data type: ",
363           element.dtype());
364   }
365 }
366 
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)367 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
368                                 int index) {
369   if (parent->dims() != element.dims() + 1) {
370     return errors::Internal(
371         "Mismatched ranks.  Element's rank is: ", element.dims(),
372         " but element is meant to be a slice in output Tensor having rank: ",
373         parent->dims(), " (should be: ", element.dims() + 1, ")");
374   }
375 
376 #define HANDLE_DIMS(NDIMS)                                                  \
377   case NDIMS: {                                                             \
378     TF_RETURN_IF_ERROR(                                                     \
379         HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
380     return Status::OK();                                                    \
381   }
382 
383   switch (element.dims()) {
384     HANDLE_DIMS(0);
385     HANDLE_DIMS(1);
386     HANDLE_DIMS(2);
387     HANDLE_DIMS(3);
388     HANDLE_DIMS(4);
389     HANDLE_DIMS(5);
390 #undef HANDLE_DIMS
391     default:
392       return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
393                                    element.dims());
394   }
395 }
396 
SetElementZero(Tensor * element,const Tensor & padding)397 Status SetElementZero(Tensor* element, const Tensor& padding) {
398 #define HANDLE_TYPE(T)                                     \
399   if (element->dtype() == DataTypeToEnum<T>::value) {      \
400     element->flat<T>().setConstant(padding.scalar<T>()()); \
401     return Status::OK();                                   \
402   }
403   TF_CALL_DATASET_TYPES(HANDLE_TYPE);
404 #undef HANDLE_TYPE
405   return errors::Unimplemented("SetElementZero Unhandled data type: ",
406                                element->dtype());
407 }
408 
409 }  // namespace batch_util
410 }  // namespace tensorflow
411