• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/batch_util.h"
17 
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 
22 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
23 
24 namespace tensorflow {
25 namespace batch_util {
26 
27 namespace {
28 
ValidateInput(const Tensor & parent,const Tensor & element,int64 index)29 Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) {
30   DCHECK_NE(parent.dim_size(0), 0);
31   DCHECK_GE(index, 0);
32   if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) {
33     TensorShape chip_shape = parent.shape();
34     chip_shape.RemoveDim(0);
35     return errors::Internal(
36         "ValidateInput Cannot perform copy: number of elements does not match. "
37         " Shapes are: [element]: ",
38         element.shape().DebugString(),
39         ", [parent slice]: ", chip_shape.DebugString());
40   }
41   return Status::OK();
42 }
43 
44 template <typename T>
HandleElementToSlice(Tensor element,Tensor * parent,int64 index,bool)45 Status HandleElementToSlice(Tensor element, Tensor* parent, int64 index,
46                             bool /* can_move */) {
47   parent->flat_outer_dims<T>().chip(index, 0) = element.flat<T>();
48   return Status::OK();
49 }
50 
51 template <>
HandleElementToSlice(Tensor element,Tensor * parent,int64 index,bool can_move)52 Status HandleElementToSlice<string>(Tensor element, Tensor* parent, int64 index,
53                                     bool can_move) {
54   auto parent_as_matrix = parent->flat_outer_dims<string>();
55   auto element_flat = element.flat<string>();
56   if (can_move) {
57     for (int64 i = 0; i < element.NumElements(); ++i) {
58       parent_as_matrix(index, i) = std::move(element_flat(i));
59     }
60   } else {
61     parent_as_matrix.chip(index, 0) = element_flat;
62   }
63   return Status::OK();
64 }
65 
66 template <>
HandleElementToSlice(Tensor element,Tensor * parent,int64 index,bool can_move)67 Status HandleElementToSlice<Variant>(Tensor element, Tensor* parent,
68                                      int64 index, bool can_move) {
69   auto parent_as_matrix = parent->flat_outer_dims<Variant>();
70   auto element_flat = element.flat<Variant>();
71   if (can_move) {
72     for (int64 i = 0; i < element.NumElements(); ++i) {
73       parent_as_matrix(index, i) = std::move(element_flat(i));
74     }
75   } else {
76     parent_as_matrix.chip(index, 0) = element_flat;
77   }
78   return Status::OK();
79 }
80 
81 // TODO(b/78245576): Consider removing this overload.
82 template <typename T>
HandleSliceToElement(const Tensor & parent,Tensor * element,int64 index)83 void HandleSliceToElement(const Tensor& parent, Tensor* element, int64 index) {
84   element->flat<T>() = parent.flat_outer_dims<T>().chip(index, 0);
85 }
86 
87 template <typename T>
HandleSliceToElement(Tensor * parent,Tensor * element,int64 index,bool can_move)88 void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index,
89                           bool can_move) {
90   element->flat<T>() = parent->flat_outer_dims<T>().chip(index, 0);
91 }
92 
93 template <>
HandleSliceToElement(Tensor * parent,Tensor * element,int64 index,bool can_move)94 void HandleSliceToElement<string>(Tensor* parent, Tensor* element, int64 index,
95                                   bool can_move) {
96   auto parent_as_matrix = parent->flat_outer_dims<string>();
97   auto element_flat = element->flat<string>();
98   if (can_move) {
99     for (int64 i = 0; i < element->NumElements(); ++i) {
100       element_flat(i) = std::move(parent_as_matrix(index, i));
101     }
102   } else {
103     element_flat = parent_as_matrix.chip(index, 0);
104   }
105 }
106 
107 template <>
HandleSliceToElement(Tensor * parent,Tensor * element,int64 index,bool can_move)108 void HandleSliceToElement<Variant>(Tensor* parent, Tensor* element, int64 index,
109                                    bool can_move) {
110   auto parent_as_matrix = parent->flat_outer_dims<Variant>();
111   auto element_flat = element->flat<Variant>();
112   if (can_move) {
113     for (int64 i = 0; i < element->NumElements(); ++i) {
114       element_flat(i) = std::move(parent_as_matrix(index, i));
115     }
116   } else {
117     element_flat = parent_as_matrix.chip(index, 0);
118   }
119 }
120 
121 }  // namespace
122 
123 // Copies element into the index^th slice of parent (in the 0th dimension).
CopyElementToSlice(Tensor element,Tensor * parent,int64 index)124 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
125   TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index));
126 
127   bool can_move = element.RefCountIsOne();
128 #define HANDLE_TYPE(T)                                                \
129   case DataTypeToEnum<T>::value: {                                    \
130     return HandleElementToSlice<T>(std::move(element), parent, index, \
131                                    can_move);                         \
132   }
133 
134   switch (element.dtype()) {
135     TF_CALL_ALL_TYPES(HANDLE_TYPE);
136     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
137     TF_CALL_uint32(HANDLE_TYPE);
138     TF_CALL_uint64(HANDLE_TYPE);
139 #undef HANDLE_TYPE
140     default:
141       return errors::Unimplemented("CopyElementToSlice Unhandled data type: ",
142                                    element.dtype());
143   }
144 }
145 
146 // Copies the index^th slice of parent (in the 0th dimension) into element.
CopySliceToElement(const Tensor & parent,Tensor * element,int64 index)147 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
148   TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
149 
150 #define HANDLE_TYPE(T)                               \
151   case DataTypeToEnum<T>::value: {                   \
152     HandleSliceToElement<T>(parent, element, index); \
153     return Status::OK();                             \
154   }
155 
156   switch (parent.dtype()) {
157     TF_CALL_ALL_TYPES(HANDLE_TYPE);
158     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
159 #undef HANDLE_TYPE
160     default:
161       return errors::Unimplemented("CopySliceToElement Unhandled data type: ",
162                                    element->dtype());
163   }
164 }
165 
166 // Copies the index^th slice of parent (in the 0th dimension) into element.
167 //
168 // NOTE(mrry): The implementation may be able to optimize the copy to a move.
169 // This is particularly important for DT_STRING tensors.
MaybeMoveSliceToElement(Tensor * parent,Tensor * element,int64 index)170 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
171   TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
172   bool can_move = parent->RefCountIsOne();
173 
174 #define HANDLE_TYPE(T)                                         \
175   case DataTypeToEnum<T>::value: {                             \
176     HandleSliceToElement<T>(parent, element, index, can_move); \
177     return Status::OK();                                       \
178   }
179 
180   switch (parent->dtype()) {
181     TF_CALL_ALL_TYPES(HANDLE_TYPE);
182     TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
183 #undef HANDLE_TYPE
184     default:
185       return errors::Unimplemented(
186           "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
187   }
188 }
189 
190 // The following five functions are copied from padding_fifo_queue.cc.
191 // TODO(mrry): Reconcile these functions with the similar methods in the
192 // queue implementation.
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)193 Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) {
194   DCHECK_NE(parent->dim_size(0), 0);
195   if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
196     TensorShape chip_shape = parent->shape();
197     chip_shape.RemoveDim(0);
198     return errors::Internal(
199         "HandleElementToLargerSlice Cannot copy slice: number of entries in "
200         "element is greater than number of elements in parent slice.  ",
201         "Shapes are: [element]: ", element.shape().DebugString(),
202         ", [parent slice]: ", chip_shape.DebugString());
203   }
204   return Status::OK();
205 }
206 
207 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)208 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
209                                   int index) {
210   TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent));
211   if (element.NumElements() == 0) {
212     return Status::OK();
213   }
214   auto element_t = element.tensor<T, NDIMS>();
215   auto parent_t = parent->tensor<T, NDIMS + 1>();
216   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
217   slice_indices[0] = index;
218   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
219   slice_size[0] = 1;
220   for (size_t i = 1; i < slice_size.size(); ++i) {
221     slice_size[i] = element_t.dimension(i - 1);
222   }
223   parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
224   return Status::OK();
225 }
226 
227 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)228 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
229                                           int index) {
230 #define HANDLE_TYPE(T)                                                   \
231   case DataTypeToEnum<T>::value: {                                       \
232     return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
233   }
234 
235   switch (element.dtype()) {
236     TF_CALL_DATASET_TYPES(HANDLE_TYPE);
237 #undef HANDLE_TYPE
238     default:
239       return errors::Unimplemented(
240           "HandleElementToLargerSliceWithRank Unhandled data type: ",
241           element.dtype());
242   }
243 }
244 
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)245 Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
246                                 int index) {
247   if (parent->dims() != element.dims() + 1) {
248     return errors::Internal(
249         "Mismatched ranks.  Element's rank is: ", element.dims(),
250         " but element is meant to be a slice in output Tensor having rank: ",
251         parent->dims(), " (should be: ", element.dims() + 1, ")");
252   }
253 
254 #define HANDLE_DIMS(NDIMS)                                                  \
255   case NDIMS: {                                                             \
256     TF_RETURN_IF_ERROR(                                                     \
257         HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
258     return Status::OK();                                                    \
259   }
260 
261   switch (element.dims()) {
262     HANDLE_DIMS(0);
263     HANDLE_DIMS(1);
264     HANDLE_DIMS(2);
265     HANDLE_DIMS(3);
266     HANDLE_DIMS(4);
267     HANDLE_DIMS(5);
268 #undef HANDLE_DIMS
269     default:
270       return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
271                                    element.dims());
272   }
273 }
274 
SetElementZero(Tensor * element,const Tensor & padding)275 Status SetElementZero(Tensor* element, const Tensor& padding) {
276 #define HANDLE_TYPE(T)                                     \
277   if (element->dtype() == DataTypeToEnum<T>::value) {      \
278     element->flat<T>().setConstant(padding.scalar<T>()()); \
279     return Status::OK();                                   \
280   }
281   TF_CALL_DATASET_TYPES(HANDLE_TYPE);
282 #undef HANDLE_TYPE
283   return errors::Unimplemented("SetElementZero Unhandled data type: ",
284                                element->dtype());
285 }
286 
287 }  // namespace batch_util
288 }  // namespace tensorflow
289