1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
17 #define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
18
19 #include <array>
20 #include <vector>
21
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/gtl/inlined_vector.h"
24 #include "tensorflow/core/platform/types.h"
25
26 namespace tensorflow {
27
28 // Tensor format for input/output activations used in convolution operations.
29 // The mnemonics specify the meaning of each tensor dimension sorted from
30 // largest to smallest memory stride.
31 // N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
32 // TODO(pauldonnelly): It would probably be better to switch to a registration
33 // process for tensor formats, so specialized formats could be defined more
34 // locally to where they are used.
35 enum TensorFormat {
36 // FORMAT_NHWC is the default format in TensorFlow.
37 FORMAT_NHWC = 0,
38
39 // FORMAT_NCHW often improves performance on GPUs.
40 FORMAT_NCHW = 1,
41
42 // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
43 // int8 convolution and fused convolution. It is laid out in the same order
44 // as NCHW, except that the size of the Channels dimension is divided by 4,
45 // and a new dimension of size 4 is appended, which packs 4 adjacent channel
46 // activations for the same pixel into an int32. Thus an NCHW format tensor
47 // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
48 // NCHW_VECT_C format.
49 // A pre-condition of this format is that C must be a multiple of 4.
50 FORMAT_NCHW_VECT_C = 2,
51
52 // Similar to NHWC, but the size of the W dimension is divided by 4, and a
53 // new dimension of size 4 is appended, which packs 4 adjacent activations
54 // in the width dimension.
55 FORMAT_NHWC_VECT_W = 3,
56
57 // Note: although the current code in this file assumes VECT_C and VECT_W
58 // enums imply int8x4 vectors, this should not be relied upon.
59 // In the future we may change the meaning of these enums to include vectors
60 // of other types such as int16x2, with op implementations automatically
61 // determining which format is implied based on the datatype.
62
63 // FORMAT_HWNC is for TPUs.
64 FORMAT_HWNC = 4,
65
66 // FORMAT_HWCN is for TPUs.
67 FORMAT_HWCN = 5,
68 };
69
70 // Tensor format for convolutional filters.
71 // The mnemonics specify the meaning of each tensor dimension sorted
72 // from largest to smallest memory stride.
73 // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
74 // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
75 enum FilterTensorFormat {
76 // FORMAT_HWIO is the default filter format in TensorFlow.
77 // Ops that do not have a 'filter_format' attribute will assume this format.
78 FORMAT_HWIO = 0,
79
80 // FORMAT_OIHW often improves performance on GPUs.
81 FORMAT_OIHW = 1,
82
83 // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
84 // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
85 // data format. It is laid out in the same order as OIHW, except that the size
86 // of the Input Channels dimension is divided by 4, and a new dimension of
87 // size 4 is appended, which packs 4 adjacent input channel weights into an
88 // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
89 // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
90 // A pre-condition of this format is that I must be a multiple of 4.
91 FORMAT_OIHW_VECT_I = 2,
92 };
93
94 // Parse tensor format from the given string.
95 // Return true if the parsing succeeds, and false if it fails.
96 bool FormatFromString(const string& format_str, TensorFormat* format);
97
98 // Parse tensor format from the given string.
99 // Return true if the parsing succeeds, and false if it fails.
100 bool FilterFormatFromString(const string& format_str,
101 FilterTensorFormat* format);
102
103 // Convert a tensor format into string.
104 string ToString(TensorFormat format);
105
106 // Convert a filter tensor format into string.
107 string ToString(FilterTensorFormat format);
108
109 // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
110 // format 'format'.
GetTensorSpatialDims(int num_dims,TensorFormat format)111 inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
112 switch (format) {
113 case FORMAT_NHWC:
114 case FORMAT_NCHW:
115 case FORMAT_HWNC:
116 case FORMAT_HWCN:
117 return num_dims - 2; // Exclude N,C.
118 case FORMAT_NCHW_VECT_C:
119 case FORMAT_NHWC_VECT_W:
120 // Note: the VECT_W is not counted as an independent spatial dim here,
121 // since it just a component of the width dimension.
122 return num_dims - 3; // Exclude N,C,VectDim.
123 }
124 }
125
GetFilterTensorSpatialDims(int num_dims,FilterTensorFormat format)126 inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
127 if (format == FORMAT_OIHW_VECT_I) {
128 return num_dims - 3; // Exclude O,I,InnerI.
129 } else {
130 return num_dims - 2; // Exclude O,I.
131 }
132 }
133
134 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
135 // tensor format 'format'. This is the inverse of GetTensorSpatialDims.
GetTensorDimsFromSpatialDims(int num_spatial_dims,TensorFormat format)136 inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
137 TensorFormat format) {
138 switch (format) {
139 case FORMAT_NHWC:
140 case FORMAT_NCHW:
141 case FORMAT_HWNC:
142 case FORMAT_HWCN:
143 return num_spatial_dims + 2; // Include N,C.
144 case FORMAT_NCHW_VECT_C:
145 case FORMAT_NHWC_VECT_W:
146 return num_spatial_dims + 3; // Include N,C,VectDim.
147 }
148 }
149
150 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
151 // filter tensor format 'format'.
GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,FilterTensorFormat format)152 inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
153 FilterTensorFormat format) {
154 if (format == FORMAT_OIHW_VECT_I) {
155 return num_spatial_dims + 3; // Include O,I,InnerI.
156 } else {
157 return num_spatial_dims + 2; // Include O,I.
158 }
159 }
160
161 // Returns the index of the batch dimension.
GetTensorBatchDimIndex(int num_dims,TensorFormat format)162 inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
163 switch (format) {
164 case FORMAT_NHWC:
165 case FORMAT_NCHW:
166 case FORMAT_NCHW_VECT_C:
167 case FORMAT_NHWC_VECT_W:
168 return 0;
169 case FORMAT_HWNC:
170 return num_dims - 2;
171 case FORMAT_HWCN:
172 return num_dims - 1;
173 default:
174 LOG(FATAL) << "Unknown format " << format;
175 return -1; // Avoid compiler warning about missing return value
176 }
177 }
178
179 // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
180 // the index of the outer feature dimension (i.e. dimension 1, whose size would
181 // be num_features / 4 in this case).
GetTensorFeatureDimIndex(int num_dims,TensorFormat format)182 inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
183 switch (format) {
184 case FORMAT_NHWC:
185 case FORMAT_HWNC:
186 return num_dims - 1;
187 case FORMAT_NHWC_VECT_W:
188 case FORMAT_HWCN:
189 return num_dims - 2;
190 case FORMAT_NCHW:
191 case FORMAT_NCHW_VECT_C:
192 return 1;
193 default:
194 LOG(FATAL) << "Unknown format " << format;
195 return -1; // Avoid compiler warning about missing return value
196 }
197 }
198
199 // Returns the index of the inner feature dimension.
GetTensorInnerFeatureDimIndex(int num_dims,TensorFormat format)200 inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
201 DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
202 return num_dims - 1;
203 }
204
205 // Returns the index of the inner width dimension.
GetTensorInnerWidthDimIndex(int num_dims,TensorFormat format)206 inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) {
207 DCHECK_EQ(format, FORMAT_NHWC_VECT_W);
208 return num_dims - 1;
209 }
210
211 // Returns the dimension index of the specified 'spatial_dim' within an
212 // activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns
213 // the index of the outer width dimension (i.e. dimension 2, whose size would
214 // be width / 4 in this case).
GetTensorSpatialDimIndex(int num_dims,TensorFormat format,int spatial_dim)215 inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
216 int spatial_dim) {
217 CHECK(spatial_dim >= 0 &&
218 spatial_dim < GetTensorSpatialDims(num_dims, format))
219 << spatial_dim << " " << num_dims << " " << ToString(format);
220 switch (format) {
221 case FORMAT_NHWC:
222 case FORMAT_NHWC_VECT_W:
223 return spatial_dim + 1;
224 case FORMAT_NCHW:
225 case FORMAT_NCHW_VECT_C:
226 return spatial_dim + 2;
227 case FORMAT_HWNC:
228 case FORMAT_HWCN:
229 return spatial_dim;
230 default:
231 LOG(FATAL) << "Unknown format " << format;
232 return -1; // Avoid compiler warning about missing return value
233 }
234 }
235
GetFilterTensorSpatialDimIndex(int num_dims,FilterTensorFormat format,int dim)236 inline int GetFilterTensorSpatialDimIndex(int num_dims,
237 FilterTensorFormat format, int dim) {
238 CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
239 << dim << " " << num_dims << " " << ToString(format);
240 switch (format) {
241 case FORMAT_HWIO:
242 return dim;
243 case FORMAT_OIHW:
244 case FORMAT_OIHW_VECT_I:
245 return dim + 2;
246 default:
247 LOG(FATAL) << "Unknown format " << format;
248 return -1; // Avoid compiler warning about missing return value
249 }
250 }
251
252 // Returns the index of the inner input channels dimension.
GetFilterTensorInnerInputChannelsDimIndex(int num_dims,FilterTensorFormat format)253 inline int GetFilterTensorInnerInputChannelsDimIndex(
254 int num_dims, FilterTensorFormat format) {
255 DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
256 return num_dims - 1;
257 }
258
259 // Returns the index of the input channels dimension.
260 // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
261 // outer input channel (i.e. 1), which holds num_input_channels / 4.
GetFilterTensorInputChannelsDimIndex(int num_dims,FilterTensorFormat format)262 inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
263 FilterTensorFormat format) {
264 switch (format) {
265 case FORMAT_HWIO:
266 return num_dims - 2;
267 case FORMAT_OIHW:
268 case FORMAT_OIHW_VECT_I:
269 return 1;
270 default:
271 LOG(FATAL) << "Unknown format " << format;
272 return -1; // Avoid compiler warning about missing return value
273 }
274 }
275
276 // Returns the index of the output channels dimension.
GetFilterTensorOutputChannelsDimIndex(int num_dims,FilterTensorFormat format)277 inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
278 FilterTensorFormat format) {
279 switch (format) {
280 case FORMAT_HWIO:
281 return num_dims - 1;
282 case FORMAT_OIHW:
283 case FORMAT_OIHW_VECT_I:
284 return 0;
285 default:
286 LOG(FATAL) << "Unknown format " << format;
287 return -1; // Avoid compiler warning about missing return value
288 }
289 }
290
291 // TODO(pauldonnelly): Replace these tensor dimension index functions with
292 // constant structs to improve performance and reduce code size in Compute()
293 // functions.
294
295 // Return the dimension index for the specified 'dimension' of the specified
296 // data 'tensor_format'. 'dimension' is a char that can be 'N' (batch size),
297 // 'C' (channels), 'H' (height), 'W' (width), or a numbered spatial dimension:
298 // '0', .. (NUM_SPATIAL_DIMS-1)..
299 // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
300 // the outer channel dimension (i.e. 1).
301 template <int NUM_SPATIAL_DIMS>
GetTensorDimIndex(TensorFormat format,char dimension)302 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
303 if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) {
304 // clang-format off
305 switch (dimension) {
306 case 'N': return 0;
307 case '0': return 1;
308 case '1': return 2;
309 case '2': return 3;
310 case 'H': return NUM_SPATIAL_DIMS - 1;
311 case 'W': return NUM_SPATIAL_DIMS;
312 case 'C': return NUM_SPATIAL_DIMS + 1;
313 default:
314 LOG(FATAL) << "Invalid dimension: " << dimension;
315 return -1; // Avoid compiler warning about missing return value
316 }
317 } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
318 switch (dimension) {
319 case 'N': return 0;
320 case 'C': return 1;
321 case '0': return 2;
322 case '1': return 3;
323 case '2': return 4;
324 case 'H': return NUM_SPATIAL_DIMS;
325 case 'W': return NUM_SPATIAL_DIMS + 1;
326 default:
327 LOG(FATAL) << "Invalid dimension: " << dimension;
328 return -1; // Avoid compiler warning about missing return value
329 }
330 } else if (format == FORMAT_HWNC) {
331 switch (dimension) {
332 case '0': return 0;
333 case '1': return 1;
334 case '2': return 2;
335 case 'H': return NUM_SPATIAL_DIMS - 2;
336 case 'W': return NUM_SPATIAL_DIMS - 1;
337 case 'N': return NUM_SPATIAL_DIMS;
338 case 'C': return NUM_SPATIAL_DIMS + 1;
339 default:
340 LOG(FATAL) << "Invalid dimension: " << dimension;
341 return -1; // Avoid compiler warning about missing return value
342 }
343 } else if (format == FORMAT_HWCN) {
344 switch (dimension) {
345 case '0': return 0;
346 case '1': return 1;
347 case '2': return 2;
348 case 'H': return NUM_SPATIAL_DIMS - 2;
349 case 'W': return NUM_SPATIAL_DIMS - 1;
350 case 'C': return NUM_SPATIAL_DIMS;
351 case 'N': return NUM_SPATIAL_DIMS + 1;
352 default:
353 LOG(FATAL) << "Invalid dimension: " << dimension;
354 return -1; // Avoid compiler warning about missing return value
355 }
356 } else {
357 LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
358 return -1; // Avoid compiler warning about missing return value
359 }
360 // clang-format on
361 }
362
363 // Return the dimension index for the specified 'dimension' of the specified
364 // 'filter_tensor_format'. 'dimension' is a char that can be 'O' (num output
365 // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
366 // numbered spatial dimension: '0', .. (NUM_SPATIAL_DIMS-1).
367 // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
368 // outer input channels dimension (i.e. 1).
369 template <int NUM_SPATIAL_DIMS>
GetFilterDimIndex(FilterTensorFormat filter_tensor_format,char dimension)370 inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
371 char dimension) {
372 // clang-format off
373 if (filter_tensor_format == FORMAT_HWIO) {
374 switch (dimension) {
375 case '0': return 0;
376 case '1': return 1;
377 case '2': return 2;
378 case 'H': return NUM_SPATIAL_DIMS - 2;
379 case 'W': return NUM_SPATIAL_DIMS - 1;
380 case 'I': return NUM_SPATIAL_DIMS;
381 case 'O': return NUM_SPATIAL_DIMS + 1;
382 default:
383 LOG(FATAL) << "Invalid dimension: " << dimension;
384 return -1; // Avoid compiler warning about missing return value
385 }
386 } else if (filter_tensor_format == FORMAT_OIHW ||
387 filter_tensor_format == FORMAT_OIHW_VECT_I) {
388 switch (dimension) {
389 case 'O': return 0;
390 case 'I': return 1;
391 case '0': return 2;
392 case '1': return 3;
393 case '2': return 4;
394 case 'H': return NUM_SPATIAL_DIMS;
395 case 'W': return NUM_SPATIAL_DIMS + 1;
396 default:
397 LOG(FATAL) << "Invalid dimension: " << dimension;
398 return -1; // Avoid compiler warning about missing return value
399 }
400 } else {
401 LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
402 return -1; // Avoid compiler warning about missing return value
403 }
404 // clang-format on
405 }
406
GetTensorDimIndex(TensorFormat format,char dimension)407 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
408 return GetTensorDimIndex<2>(format, dimension);
409 }
410
GetTensorDimIndex(TensorFormat format,char dimension,int num_total_dims)411 inline int32 GetTensorDimIndex(TensorFormat format, char dimension,
412 int num_total_dims) {
413 int32 index = (GetTensorSpatialDims(num_total_dims, format) == 3)
414 ? GetTensorDimIndex<3>(format, dimension)
415 : GetTensorDimIndex<2>(format, dimension);
416 CHECK(index >= 0 && index < num_total_dims) // Crash OK.
417 << "Invalid index from the dimension: " << index << ", " << format << ", "
418 << dimension;
419 return index;
420 }
421
422 // Return the element from 'dimension_attributes' that corresponds to the
423 // specified 'dimension' according to 'tensor_format'.
424 template <typename T>
GetTensorDim(gtl::ArraySlice<T> dimension_attributes,TensorFormat tensor_format,char dimension)425 T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
426 TensorFormat tensor_format, char dimension) {
427 int index =
428 GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size());
429 return dimension_attributes[index];
430 }
431
432 // Return the element from 'dimension_attribute' that corresponds to the
433 // specified 'dimension' according to 'filter_tensor_format'.
434 template <typename T>
GetFilterDim(gtl::ArraySlice<T> dimension_attribute,FilterTensorFormat filter_tensor_format,char dimension)435 T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
436 FilterTensorFormat filter_tensor_format, char dimension) {
437 int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
438 filter_tensor_format) == 3)
439 ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
440 : GetFilterDimIndex<2>(filter_tensor_format, dimension);
441 CHECK(index >= 0 && index < dimension_attribute.size())
442 << "Invalid index from the dimension: " << index << ", "
443 << filter_tensor_format << ", " << dimension;
444 return dimension_attribute[index];
445 }
446
447 template <typename T>
GetTensorDim(const std::vector<T> & attributes,TensorFormat format,char dimension)448 T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
449 char dimension) {
450 return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
451 }
452
453 // Return the size of the specified 'dimension' within 'tensor_shape'
454 // according to 'tensor_format'.
GetTensorDim(const TensorShape & tensor_shape,TensorFormat tensor_format,char dimension)455 inline int64 GetTensorDim(const TensorShape& tensor_shape,
456 TensorFormat tensor_format, char dimension) {
457 return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
458 tensor_format, dimension);
459 }
460
461 // Return the size of the specified 'dimension' within 'tensor_shape'
462 // according to 'tensor_filter_format'.
GetFilterDim(const TensorShape & tensor_shape,FilterTensorFormat tensor_filter_format,char dimension)463 inline int64 GetFilterDim(const TensorShape& tensor_shape,
464 FilterTensorFormat tensor_filter_format,
465 char dimension) {
466 return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
467 tensor_filter_format, dimension);
468 }
469
470 // Return the size of the specified 'dimension' of 'tensor' according to
471 // 'tensor_format'.
GetTensorDim(const Tensor & tensor,TensorFormat tensor_format,char dimension)472 inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
473 char dimension) {
474 return GetTensorDim(tensor.shape(), tensor_format, dimension);
475 }
476
477 // Return the size of the specified 'dimension' of 'tensor' according to
478 // 'filter_tensor_format'.
GetFilterDim(const Tensor & tensor,FilterTensorFormat filter_tensor_format,char dimension)479 inline int64 GetFilterDim(const Tensor& tensor,
480 FilterTensorFormat filter_tensor_format,
481 char dimension) {
482 return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
483 }
484
GetExplicitPaddingForDim(const std::vector<int64> & explicit_paddings,TensorFormat tensor_format,char dimension,int64 * padding_before,int64 * padding_after)485 inline void GetExplicitPaddingForDim(
486 const std::vector<int64>& explicit_paddings, TensorFormat tensor_format,
487 char dimension, int64* padding_before, int64* padding_after) {
488 int index =
489 GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2);
490 *padding_before = explicit_paddings[2 * index];
491 *padding_after = explicit_paddings[2 * index + 1];
492 }
493
494 // Return the string that specifies the data format for convnet operations.
495 string GetConvnetDataFormatAttrString();
496 string GetConvnet3dDataFormatAttrString();
497
498 // Return the string that specifies the filter format for convnet operations.
499 string GetConvnetFilterFormatAttrString();
500 string GetConvnet3dFilterFormatAttrString();
501 string GetConvnetDataFormat2D3DAttrString();
502
503 // Returns a tensor shape for the specified format and dimension sizes.
504 // Works for both 2D and 3D operations. The output shapes are as follows:
505 // FORMAT_NHWC: (N, spatial, C); rank = spatial.size() + 2
506 // FORMAT_NCHW: (N, C, spatial); rank = spatial.size() + 2
507 // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3
508 // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3
ShapeFromFormat(TensorFormat format,int64 N,gtl::ArraySlice<int64> spatial,int64 C)509 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
510 gtl::ArraySlice<int64> spatial, int64 C) {
511 const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
512 gtl::InlinedVector<int64, 6> dim_sizes(dims);
513 dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
514 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
515 auto dim_size = spatial[dim];
516 if (format == FORMAT_NHWC_VECT_W &&
517 static_cast<size_t>(dim) == spatial.size() - 1) {
518 CHECK_EQ(0, dim_size % 4)
519 << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W="
520 << dim_size;
521 dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4;
522 dim_size /= 4;
523 }
524 dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size;
525 }
526
527 int feature_index = GetTensorFeatureDimIndex(dims, format);
528 if (format == FORMAT_NCHW_VECT_C) {
529 CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
530 << C;
531 C /= 4;
532 dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
533 }
534 dim_sizes[feature_index] = C;
535 return TensorShape(dim_sizes);
536 }
537
538 // Return a tensor shape of the specified 'format', and dimensions.
539 // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
540 // the output TensorShape has spatial.size() + 3 dimensions, otherwise
541 // it has spatial.size() + 2 dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,gtl::ArraySlice<int64> spatial,int64 I,int64 O)542 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
543 gtl::ArraySlice<int64> spatial,
544 int64 I, int64 O) {
545 const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
546 gtl::InlinedVector<int64, 6> dim_sizes(dims);
547 dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
548 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
549 dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
550 }
551
552 if (format == FORMAT_OIHW_VECT_I) {
553 CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
554 << I;
555 I /= 4;
556 dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
557 }
558 dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
559 return TensorShape(dim_sizes);
560 }
561
562 // Return a tensor shape of the specified 'format', and dimensions.
ShapeFromFormat(TensorFormat format,int64 N,int64 H,int64 W,int64 C)563 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
564 int64 W, int64 C) {
565 return ShapeFromFormat(format, N, {H, W}, C);
566 }
567
568 // Return a filter tensor shape of the specified 'format', and dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,int64 H,int64 W,int64 I,int64 O)569 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
570 int64 H, int64 W, int64 I,
571 int64 O) {
572 return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
573 }
574
575 // Returns a copy of the specified tensor 'src_shape' converted from
576 // 'src_format' to 'dst_format'.
ShapeFromFormat(TensorFormat dst_format,const TensorShape & src_shape,TensorFormat src_format)577 inline TensorShape ShapeFromFormat(TensorFormat dst_format,
578 const TensorShape& src_shape,
579 TensorFormat src_format) {
580 if (src_format == dst_format) {
581 return src_shape;
582 }
583
584 const int64 batch = GetTensorDim(src_shape, src_format, 'N');
585 const int64 channels = GetTensorDim(src_shape, src_format, 'C') *
586 (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
587 const int num_src_spatial_dims =
588 GetTensorSpatialDims(src_shape.dims(), src_format);
589 std::vector<int64> spatial_dims(num_src_spatial_dims);
590 for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) {
591 spatial_dims[spatial_dim] =
592 gtl::ArraySlice<int64>(src_shape.dim_sizes())[GetTensorSpatialDimIndex(
593 src_shape.dims(), src_format, spatial_dim)];
594 }
595 if (src_format == FORMAT_NHWC_VECT_W) {
596 spatial_dims[num_src_spatial_dims - 1] *= 4;
597 }
598 return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels);
599 }
600
601 // Returns a copy of the specified filter tensor 'src_shape' converted from
602 // 'src_filter_format' to 'dst_filter_format'.
ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,const TensorShape & src_shape,FilterTensorFormat src_filter_format)603 inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
604 const TensorShape& src_shape,
605 FilterTensorFormat src_filter_format) {
606 if (src_filter_format == dst_filter_format) {
607 return src_shape;
608 }
609
610 const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O');
611 const int64 input_channels =
612 GetFilterDim(src_shape, src_filter_format, 'I') *
613 (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
614
615 if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
616 return ShapeFromFilterTensorFormat(
617 dst_filter_format,
618 {{GetFilterDim(src_shape, src_filter_format, '0'),
619 GetFilterDim(src_shape, src_filter_format, '1'),
620 GetFilterDim(src_shape, src_filter_format, '2')}},
621 input_channels, output_channels);
622 }
623
624 return ShapeFromFilterTensorFormat(
625 dst_filter_format,
626 {{GetFilterDim(src_shape, src_filter_format, 'H'),
627 GetFilterDim(src_shape, src_filter_format, 'W')}},
628 input_channels, output_channels);
629 }
630
631 } // namespace tensorflow
632
633 #endif // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
634