• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
17 #define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
18 
19 #include <array>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/gtl/array_slice.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 
30 // Tensor format for input/output activations used in convolution operations.
31 // The mnemonics specify the meaning of each tensor dimension sorted from
32 // largest to smallest memory stride.
33 // N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
34 // TODO(pauldonnelly): It would probably be better to switch to a registration
35 // process for tensor formats, so specialized formats could be defined more
36 // locally to where they are used.
37 enum TensorFormat {
38   // FORMAT_NHWC is the default format in TensorFlow.
39   FORMAT_NHWC = 0,
40 
41   // FORMAT_NCHW often improves performance on GPUs.
42   FORMAT_NCHW = 1,
43 
44   // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
45   // int8 convolution and fused convolution. It is laid out in the same order
46   // as NCHW, except that the size of the Channels dimension is divided by 4,
47   // and a new dimension of size 4 is appended, which packs 4 adjacent channel
48   // activations for the same pixel into an int32. Thus an NCHW format tensor
49   // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
50   // NCHW_VECT_C format.
51   // A pre-condition of this format is that C must be a multiple of 4.
52   FORMAT_NCHW_VECT_C = 2,
53 
54   // Similar to NHWC, but the size of the W dimension is divided by 4, and a
55   // new dimension of size 4 is appended, which packs 4 adjacent activations
56   // in the width dimension.
57   FORMAT_NHWC_VECT_W = 3,
58 
59   // Note: although the current code in this file assumes VECT_C and VECT_W
60   // enums imply int8x4 vectors, this should not be relied upon.
61   // In the future we may change the meaning of these enums to include vectors
62   // of other types such as int16x2, with op implementations automatically
63   // determining which format is implied based on the datatype.
64 
65   // FORMAT_HWNC is for TPUs.
66   FORMAT_HWNC = 4,
67 
68   // FORMAT_HWCN is for TPUs.
69   FORMAT_HWCN = 5,
70 };
71 
72 // Tensor format for convolutional filters.
73 // The mnemonics specify the meaning of each tensor dimension sorted
74 // from largest to smallest memory stride.
75 // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
76 // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
77 enum FilterTensorFormat {
78   // FORMAT_HWIO is the default filter format in TensorFlow.
79   // Ops that do not have a 'filter_format' attribute will assume this format.
80   FORMAT_HWIO = 0,
81 
82   // FORMAT_OIHW often improves performance on GPUs.
83   FORMAT_OIHW = 1,
84 
85   // FORMAT_OHWI used by cuDNN for NHWC convolutions.
86   FORMAT_OHWI = 2,
87 
88   // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
89   // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
90   // data format. It is laid out in the same order as OIHW, except that the size
91   // of the Input Channels dimension is divided by 4, and a new dimension of
92   // size 4 is appended, which packs 4 adjacent input channel weights into an
93   // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
94   // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
95   // A pre-condition of this format is that I must be a multiple of 4.
96   FORMAT_OIHW_VECT_I = 3,
97 };
98 
99 // Parse tensor format from the given string.
100 // Return true if the parsing succeeds, and false if it fails.
101 bool FormatFromString(absl::string_view format_str, TensorFormat* format);
102 
103 // Parse tensor format from the given string.
104 // Return true if the parsing succeeds, and false if it fails.
105 bool FilterFormatFromString(absl::string_view format_str,
106                             FilterTensorFormat* format);
107 
108 // Convert a tensor format into string.
109 std::string ToString(TensorFormat format);
110 
111 // Convert a filter tensor format into string.
112 std::string ToString(FilterTensorFormat format);
113 
114 // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
115 // format 'format'.
GetTensorSpatialDims(int num_dims,TensorFormat format)116 inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
117   switch (format) {
118     case FORMAT_NHWC:
119     case FORMAT_NCHW:
120     case FORMAT_HWNC:
121     case FORMAT_HWCN:
122       return num_dims - 2;  // Exclude N,C.
123     case FORMAT_NCHW_VECT_C:
124     case FORMAT_NHWC_VECT_W:
125       // Note: the VECT_W is not counted as an independent spatial dim here,
126       // since it just a component of the width dimension.
127       return num_dims - 3;  // Exclude N,C,VectDim.
128     default:
129       LOG(FATAL) << "Unknown format " << format;
130       return -1;  // Avoid compiler warning about missing return value
131   }
132 }
133 
GetFilterTensorSpatialDims(int num_dims,FilterTensorFormat format)134 inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
135   if (format == FORMAT_OIHW_VECT_I) {
136     return num_dims - 3;  // Exclude O,I,InnerI.
137   } else {
138     return num_dims - 2;  // Exclude O,I.
139   }
140 }
141 
142 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
143 // tensor format 'format'. This is the inverse of GetTensorSpatialDims.
GetTensorDimsFromSpatialDims(int num_spatial_dims,TensorFormat format)144 inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
145                                         TensorFormat format) {
146   switch (format) {
147     case FORMAT_NHWC:
148     case FORMAT_NCHW:
149     case FORMAT_HWNC:
150     case FORMAT_HWCN:
151       return num_spatial_dims + 2;  // Include N,C.
152     case FORMAT_NCHW_VECT_C:
153     case FORMAT_NHWC_VECT_W:
154       return num_spatial_dims + 3;  // Include N,C,VectDim.
155     default:
156       LOG(FATAL) << "Unknown format " << format;
157       return -1;  // Avoid compiler warning about missing return value
158   }
159 }
160 
161 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
162 // filter tensor format 'format'.
GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,FilterTensorFormat format)163 inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
164                                               FilterTensorFormat format) {
165   if (format == FORMAT_OIHW_VECT_I) {
166     return num_spatial_dims + 3;  // Include O,I,InnerI.
167   } else {
168     return num_spatial_dims + 2;  // Include O,I.
169   }
170 }
171 
172 // Returns the index of the batch dimension.
GetTensorBatchDimIndex(int num_dims,TensorFormat format)173 inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
174   switch (format) {
175     case FORMAT_NHWC:
176     case FORMAT_NCHW:
177     case FORMAT_NCHW_VECT_C:
178     case FORMAT_NHWC_VECT_W:
179       return 0;
180     case FORMAT_HWNC:
181       return num_dims - 2;
182     case FORMAT_HWCN:
183       return num_dims - 1;
184     default:
185       LOG(FATAL) << "Unknown format " << format;
186       return -1;  // Avoid compiler warning about missing return value
187   }
188 }
189 
190 // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
191 // the index of the outer feature dimension (i.e. dimension 1, whose size would
192 // be num_features / 4 in this case).
GetTensorFeatureDimIndex(int num_dims,TensorFormat format)193 inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
194   switch (format) {
195     case FORMAT_NHWC:
196     case FORMAT_HWNC:
197       return num_dims - 1;
198     case FORMAT_NHWC_VECT_W:
199     case FORMAT_HWCN:
200       return num_dims - 2;
201     case FORMAT_NCHW:
202     case FORMAT_NCHW_VECT_C:
203       return 1;
204     default:
205       LOG(FATAL) << "Unknown format " << format;
206       return -1;  // Avoid compiler warning about missing return value
207   }
208 }
209 
210 // Returns the index of the inner feature dimension.
GetTensorInnerFeatureDimIndex(int num_dims,TensorFormat format)211 inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
212   DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
213   return num_dims - 1;
214 }
215 
216 // Returns the index of the inner width dimension.
GetTensorInnerWidthDimIndex(int num_dims,TensorFormat format)217 inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) {
218   DCHECK_EQ(format, FORMAT_NHWC_VECT_W);
219   return num_dims - 1;
220 }
221 
222 // Returns the dimension index of the specified 'spatial_dim' within an
223 // activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns
224 // the index of the outer width dimension (i.e. dimension 2, whose size would
225 // be width / 4 in this case).
GetTensorSpatialDimIndex(int num_dims,TensorFormat format,int spatial_dim)226 inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
227                                     int spatial_dim) {
228   CHECK(spatial_dim >= 0 &&
229         spatial_dim < GetTensorSpatialDims(num_dims, format))
230       << spatial_dim << " " << num_dims << " " << ToString(format);
231   switch (format) {
232     case FORMAT_NHWC:
233     case FORMAT_NHWC_VECT_W:
234       return spatial_dim + 1;
235     case FORMAT_NCHW:
236     case FORMAT_NCHW_VECT_C:
237       return spatial_dim + 2;
238     case FORMAT_HWNC:
239     case FORMAT_HWCN:
240       return spatial_dim;
241     default:
242       LOG(FATAL) << "Unknown format " << format;
243       return -1;  // Avoid compiler warning about missing return value
244   }
245 }
246 
GetFilterTensorSpatialDimIndex(int num_dims,FilterTensorFormat format,int dim)247 inline int GetFilterTensorSpatialDimIndex(int num_dims,
248                                           FilterTensorFormat format, int dim) {
249   CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
250       << dim << " " << num_dims << " " << ToString(format);
251   switch (format) {
252     case FORMAT_HWIO:
253       return dim;
254     case FORMAT_OIHW:
255     case FORMAT_OIHW_VECT_I:
256       return dim + 2;
257     default:
258       LOG(FATAL) << "Unknown format " << format;
259       return -1;  // Avoid compiler warning about missing return value
260   }
261 }
262 
263 // Returns the index of the inner input channels dimension.
GetFilterTensorInnerInputChannelsDimIndex(int num_dims,FilterTensorFormat format)264 inline int GetFilterTensorInnerInputChannelsDimIndex(
265     int num_dims, FilterTensorFormat format) {
266   DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
267   return num_dims - 1;
268 }
269 
270 // Returns the index of the input channels dimension.
271 // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
272 // outer input channel (i.e. 1), which holds num_input_channels / 4.
GetFilterTensorInputChannelsDimIndex(int num_dims,FilterTensorFormat format)273 inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
274                                                 FilterTensorFormat format) {
275   switch (format) {
276     case FORMAT_HWIO:
277       return num_dims - 2;
278     case FORMAT_OIHW:
279     case FORMAT_OIHW_VECT_I:
280       return 1;
281     default:
282       LOG(FATAL) << "Unknown format " << format;
283       return -1;  // Avoid compiler warning about missing return value
284   }
285 }
286 
287 // Returns the index of the output channels dimension.
GetFilterTensorOutputChannelsDimIndex(int num_dims,FilterTensorFormat format)288 inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
289                                                  FilterTensorFormat format) {
290   switch (format) {
291     case FORMAT_HWIO:
292       return num_dims - 1;
293     case FORMAT_OIHW:
294     case FORMAT_OIHW_VECT_I:
295       return 0;
296     default:
297       LOG(FATAL) << "Unknown format " << format;
298       return -1;  // Avoid compiler warning about missing return value
299   }
300 }
301 
302 // TODO(pauldonnelly): Replace these tensor dimension index functions with
303 // constant structs to improve performance and reduce code size in Compute()
304 // functions.
305 
306 // Return the dimension index for the specified 'dimension' of the specified
307 // data 'tensor_format'.  'dimension' is a char that can be 'N' (batch size),
308 // 'C' (channels), 'H' (height), 'W' (width),  or a numbered spatial dimension:
309 // '0',  .. (NUM_SPATIAL_DIMS-1)..
310 // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
311 // the outer channel dimension (i.e. 1).
312 template <int NUM_SPATIAL_DIMS>
GetTensorDimIndex(TensorFormat format,char dimension)313 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
314   if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) {
315     // clang-format off
316     switch (dimension) {
317       case 'N': return 0;
318       case '0': return 1;
319       case '1': return 2;
320       case '2': return 3;
321       case 'H': return NUM_SPATIAL_DIMS - 1;
322       case 'W': return NUM_SPATIAL_DIMS;
323       case 'C': return NUM_SPATIAL_DIMS + 1;
324       default:
325         LOG(FATAL) << "Invalid dimension: " << dimension;
326         return -1;  // Avoid compiler warning about missing return value
327     }
328   } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
329     switch (dimension) {
330       case 'N': return 0;
331       case 'C': return 1;
332       case '0': return 2;
333       case '1': return 3;
334       case '2': return 4;
335       case 'H': return NUM_SPATIAL_DIMS;
336       case 'W': return NUM_SPATIAL_DIMS + 1;
337       default:
338         LOG(FATAL) << "Invalid dimension: " << dimension;
339         return -1;  // Avoid compiler warning about missing return value
340     }
341   } else if (format == FORMAT_HWNC) {
342     switch (dimension) {
343       case '0': return 0;
344       case '1': return 1;
345       case '2': return 2;
346       case 'H': return NUM_SPATIAL_DIMS - 2;
347       case 'W': return NUM_SPATIAL_DIMS - 1;
348       case 'N': return NUM_SPATIAL_DIMS;
349       case 'C': return NUM_SPATIAL_DIMS + 1;
350       default:
351         LOG(FATAL) << "Invalid dimension: " << dimension;
352         return -1;  // Avoid compiler warning about missing return value
353     }
354   } else if (format == FORMAT_HWCN) {
355     switch (dimension) {
356       case '0': return 0;
357       case '1': return 1;
358       case '2': return 2;
359       case 'H': return NUM_SPATIAL_DIMS - 2;
360       case 'W': return NUM_SPATIAL_DIMS - 1;
361       case 'C': return NUM_SPATIAL_DIMS;
362       case 'N': return NUM_SPATIAL_DIMS + 1;
363       default:
364         LOG(FATAL) << "Invalid dimension: " << dimension;
365         return -1;  // Avoid compiler warning about missing return value
366     }
367   } else {
368     LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
369     return -1;  // Avoid compiler warning about missing return value
370   }
371   // clang-format on
372 }
373 
374 // Return the dimension index for the specified 'dimension' of the specified
375 // 'filter_tensor_format'.  'dimension' is a char that can be 'O' (num output
376 // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
377 // numbered spatial dimension: '0',  .. (NUM_SPATIAL_DIMS-1).
378 // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
379 // outer input channels dimension (i.e. 1).
380 template <int NUM_SPATIAL_DIMS>
GetFilterDimIndex(FilterTensorFormat filter_tensor_format,char dimension)381 inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
382                              char dimension) {
383   // clang-format off
384   if (filter_tensor_format == FORMAT_HWIO) {
385     switch (dimension) {
386       case '0': return 0;
387       case '1': return 1;
388       case '2': return 2;
389       case 'H': return NUM_SPATIAL_DIMS - 2;
390       case 'W': return NUM_SPATIAL_DIMS - 1;
391       case 'I': return NUM_SPATIAL_DIMS;
392       case 'O': return NUM_SPATIAL_DIMS + 1;
393       default:
394         LOG(FATAL) << "Invalid dimension: " << dimension;
395         return -1;  // Avoid compiler warning about missing return value
396     }
397   } else if (filter_tensor_format == FORMAT_OIHW ||
398              filter_tensor_format == FORMAT_OIHW_VECT_I) {
399     switch (dimension) {
400       case 'O': return 0;
401       case 'I': return 1;
402       case '0': return 2;
403       case '1': return 3;
404       case '2': return 4;
405       case 'H': return NUM_SPATIAL_DIMS;
406       case 'W': return NUM_SPATIAL_DIMS + 1;
407       default:
408         LOG(FATAL) << "Invalid dimension: " << dimension;
409         return -1;  // Avoid compiler warning about missing return value
410     }
411   } else {
412     LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
413     return -1;  // Avoid compiler warning about missing return value
414   }
415   // clang-format on
416 }
417 
GetTensorDimIndex(TensorFormat format,char dimension)418 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
419   return GetTensorDimIndex<2>(format, dimension);
420 }
421 
GetTensorDimIndex(TensorFormat format,char dimension,int num_total_dims)422 inline int32 GetTensorDimIndex(TensorFormat format, char dimension,
423                                int num_total_dims) {
424   int32 index = (GetTensorSpatialDims(num_total_dims, format) == 3)
425                     ? GetTensorDimIndex<3>(format, dimension)
426                     : GetTensorDimIndex<2>(format, dimension);
427   CHECK(index >= 0 && index < num_total_dims)  // Crash OK.
428       << "Invalid index from the dimension: " << index << ", " << format << ", "
429       << dimension;
430   return index;
431 }
432 
433 // Return the element from 'dimension_attributes' that corresponds to the
434 // specified 'dimension' according to 'tensor_format'.
435 template <typename T>
GetTensorDim(gtl::ArraySlice<T> dimension_attributes,TensorFormat tensor_format,char dimension)436 T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
437                TensorFormat tensor_format, char dimension) {
438   int index =
439       GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size());
440   return dimension_attributes[index];
441 }
442 
443 // Return the element from 'dimension_attribute' that corresponds to the
444 // specified 'dimension' according to 'filter_tensor_format'.
445 template <typename T>
GetFilterDim(gtl::ArraySlice<T> dimension_attribute,FilterTensorFormat filter_tensor_format,char dimension)446 T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
447                FilterTensorFormat filter_tensor_format, char dimension) {
448   int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
449                                           filter_tensor_format) == 3)
450                   ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
451                   : GetFilterDimIndex<2>(filter_tensor_format, dimension);
452   using size_type = typename gtl::ArraySlice<T>::size_type;
453   CHECK(index >= 0 &&
454         static_cast<size_type>(index) < dimension_attribute.size())
455       << "Invalid index from the dimension: " << index << ", "
456       << filter_tensor_format << ", " << dimension;
457   return dimension_attribute[index];
458 }
459 
460 template <typename T>
GetTensorDim(const std::vector<T> & attributes,TensorFormat format,char dimension)461 T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
462                char dimension) {
463   return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
464 }
465 
466 // Return the size of the specified 'dimension' within 'tensor_shape'
467 // according to 'tensor_format'.
GetTensorDim(const TensorShape & tensor_shape,TensorFormat tensor_format,char dimension)468 inline int64 GetTensorDim(const TensorShape& tensor_shape,
469                           TensorFormat tensor_format, char dimension) {
470   return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
471                       tensor_format, dimension);
472 }
473 
474 // Return the size of the specified 'dimension' within 'tensor_shape'
475 // according to 'tensor_filter_format'.
GetFilterDim(const TensorShape & tensor_shape,FilterTensorFormat tensor_filter_format,char dimension)476 inline int64 GetFilterDim(const TensorShape& tensor_shape,
477                           FilterTensorFormat tensor_filter_format,
478                           char dimension) {
479   return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
480                       tensor_filter_format, dimension);
481 }
482 
483 // Return the size of the specified 'dimension' of 'tensor' according to
484 // 'tensor_format'.
GetTensorDim(const Tensor & tensor,TensorFormat tensor_format,char dimension)485 inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
486                           char dimension) {
487   return GetTensorDim(tensor.shape(), tensor_format, dimension);
488 }
489 
490 // Return the size of the specified 'dimension' of 'tensor' according to
491 // 'filter_tensor_format'.
GetFilterDim(const Tensor & tensor,FilterTensorFormat filter_tensor_format,char dimension)492 inline int64 GetFilterDim(const Tensor& tensor,
493                           FilterTensorFormat filter_tensor_format,
494                           char dimension) {
495   return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
496 }
497 
GetExplicitPaddingForDim(const std::vector<int64> & explicit_paddings,TensorFormat tensor_format,char dimension,int64 * padding_before,int64 * padding_after)498 inline void GetExplicitPaddingForDim(
499     const std::vector<int64>& explicit_paddings, TensorFormat tensor_format,
500     char dimension, int64* padding_before, int64* padding_after) {
501   int index =
502       GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2);
503   *padding_before = explicit_paddings[2 * index];
504   *padding_after = explicit_paddings[2 * index + 1];
505 }
506 
507 // Return the string that specifies the data format for convnet operations.
508 std::string GetConvnetDataFormatAttrString();
509 std::string GetConvnet3dDataFormatAttrString();
510 
511 // Return the string that specifies the filter format for convnet operations.
512 std::string GetConvnetFilterFormatAttrString();
513 std::string GetConvnet3dFilterFormatAttrString();
514 std::string GetConvnetDataFormat2D3DAttrString();
515 
516 // Returns a tensor shape for the specified format and dimension sizes.
517 // Works for both 2D and 3D operations. The output shapes are as follows:
518 // FORMAT_NHWC:        (N, spatial, C); rank = spatial.size() + 2
519 // FORMAT_NCHW:        (N, C, spatial); rank = spatial.size() + 2
520 // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3
521 // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3
ShapeFromFormat(TensorFormat format,int64 N,gtl::ArraySlice<int64> spatial,int64 C)522 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
523                                    gtl::ArraySlice<int64> spatial, int64 C) {
524   const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
525   gtl::InlinedVector<int64, 6> dim_sizes(dims);
526   dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
527   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
528     auto dim_size = spatial[dim];
529     if (format == FORMAT_NHWC_VECT_W &&
530         static_cast<size_t>(dim) == spatial.size() - 1) {
531       CHECK_EQ(0, dim_size % 4)
532           << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W="
533           << dim_size;
534       dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4;
535       dim_size /= 4;
536     }
537     dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size;
538   }
539 
540   int feature_index = GetTensorFeatureDimIndex(dims, format);
541   if (format == FORMAT_NCHW_VECT_C) {
542     CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
543                        << C;
544     C /= 4;
545     dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
546   }
547   dim_sizes[feature_index] = C;
548   return TensorShape(dim_sizes);
549 }
550 
551 // Return a tensor shape of the specified 'format', and dimensions.
552 // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
553 // the output TensorShape has spatial.size() + 3 dimensions, otherwise
554 // it has spatial.size() + 2 dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,gtl::ArraySlice<int64> spatial,int64 I,int64 O)555 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
556                                                gtl::ArraySlice<int64> spatial,
557                                                int64 I, int64 O) {
558   const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
559   gtl::InlinedVector<int64, 6> dim_sizes(dims);
560   dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
561   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
562     dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
563   }
564 
565   if (format == FORMAT_OIHW_VECT_I) {
566     CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
567                        << I;
568     I /= 4;
569     dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
570   }
571   dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
572   return TensorShape(dim_sizes);
573 }
574 
575 // Return a tensor shape of the specified 'format', and dimensions.
ShapeFromFormat(TensorFormat format,int64 N,int64 H,int64 W,int64 C)576 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
577                                    int64 W, int64 C) {
578   return ShapeFromFormat(format, N, {H, W}, C);
579 }
580 
581 // Return a filter tensor shape of the specified 'format', and dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,int64 H,int64 W,int64 I,int64 O)582 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
583                                                int64 H, int64 W, int64 I,
584                                                int64 O) {
585   return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
586 }
587 
588 // Returns a copy of the specified tensor 'src_shape' converted from
589 // 'src_format' to 'dst_format'.
ShapeFromFormat(TensorFormat dst_format,const TensorShape & src_shape,TensorFormat src_format)590 inline TensorShape ShapeFromFormat(TensorFormat dst_format,
591                                    const TensorShape& src_shape,
592                                    TensorFormat src_format) {
593   if (src_format == dst_format) {
594     return src_shape;
595   }
596 
597   const int64 batch = GetTensorDim(src_shape, src_format, 'N');
598   const int64 channels = GetTensorDim(src_shape, src_format, 'C') *
599                          (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
600   const int num_src_spatial_dims =
601       GetTensorSpatialDims(src_shape.dims(), src_format);
602   std::vector<int64> spatial_dims(num_src_spatial_dims);
603   for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) {
604     spatial_dims[spatial_dim] =
605         gtl::ArraySlice<int64>(src_shape.dim_sizes())[GetTensorSpatialDimIndex(
606             src_shape.dims(), src_format, spatial_dim)];
607   }
608   if (src_format == FORMAT_NHWC_VECT_W) {
609     spatial_dims[num_src_spatial_dims - 1] *= 4;
610   }
611   return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels);
612 }
613 
614 // Returns a copy of the specified filter tensor 'src_shape' converted from
615 // 'src_filter_format' to 'dst_filter_format'.
ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,const TensorShape & src_shape,FilterTensorFormat src_filter_format)616 inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
617                                          const TensorShape& src_shape,
618                                          FilterTensorFormat src_filter_format) {
619   if (src_filter_format == dst_filter_format) {
620     return src_shape;
621   }
622 
623   const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O');
624   const int64 input_channels =
625       GetFilterDim(src_shape, src_filter_format, 'I') *
626       (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
627 
628   if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
629     return ShapeFromFilterTensorFormat(
630         dst_filter_format,
631         {{GetFilterDim(src_shape, src_filter_format, '0'),
632           GetFilterDim(src_shape, src_filter_format, '1'),
633           GetFilterDim(src_shape, src_filter_format, '2')}},
634         input_channels, output_channels);
635   }
636 
637   return ShapeFromFilterTensorFormat(
638       dst_filter_format,
639       {{GetFilterDim(src_shape, src_filter_format, 'H'),
640         GetFilterDim(src_shape, src_filter_format, 'W')}},
641       input_channels, output_channels);
642 }
643 
644 }  // namespace tensorflow
645 
646 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
647