• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
17 #define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
18 
19 #include <array>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/gtl/array_slice.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 
30 // Tensor format for input/output activations used in convolution operations.
31 // The mnemonics specify the meaning of each tensor dimension sorted from
32 // largest to smallest memory stride.
33 // N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
34 // TODO(pauldonnelly): It would probably be better to switch to a registration
35 // process for tensor formats, so specialized formats could be defined more
36 // locally to where they are used.
37 enum TensorFormat {
38   // FORMAT_NHWC is the default format in TensorFlow.
39   FORMAT_NHWC = 0,
40 
41   // FORMAT_NCHW often improves performance on GPUs.
42   FORMAT_NCHW = 1,
43 
44   // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
45   // int8 convolution and fused convolution. It is laid out in the same order
46   // as NCHW, except that the size of the Channels dimension is divided by 4,
47   // and a new dimension of size 4 is appended, which packs 4 adjacent channel
48   // activations for the same pixel into an int32. Thus an NCHW format tensor
49   // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
50   // NCHW_VECT_C format.
51   // A pre-condition of this format is that C must be a multiple of 4.
52   FORMAT_NCHW_VECT_C = 2,
53 
54   // Similar to NHWC, but the size of the W dimension is divided by 4, and a
55   // new dimension of size 4 is appended, which packs 4 adjacent activations
56   // in the width dimension.
57   FORMAT_NHWC_VECT_W = 3,
58 
59   // Note: although the current code in this file assumes VECT_C and VECT_W
60   // enums imply int8x4 vectors, this should not be relied upon.
61   // In the future we may change the meaning of these enums to include vectors
62   // of other types such as int16x2, with op implementations automatically
63   // determining which format is implied based on the datatype.
64 
65   // FORMAT_HWNC is for TPUs.
66   FORMAT_HWNC = 4,
67 
68   // FORMAT_HWCN is for TPUs.
69   FORMAT_HWCN = 5,
70 };
71 
72 // Tensor format for convolutional filters.
73 // The mnemonics specify the meaning of each tensor dimension sorted
74 // from largest to smallest memory stride.
75 // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
76 // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
77 enum FilterTensorFormat {
78   // FORMAT_HWIO is the default filter format in TensorFlow.
79   // Ops that do not have a 'filter_format' attribute will assume this format.
80   FORMAT_HWIO = 0,
81 
82   // FORMAT_OIHW often improves performance on GPUs.
83   FORMAT_OIHW = 1,
84 
85   // FORMAT_OHWI used by cuDNN for NHWC convolutions.
86   FORMAT_OHWI = 2,
87 
88   // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
89   // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
90   // data format. It is laid out in the same order as OIHW, except that the size
91   // of the Input Channels dimension is divided by 4, and a new dimension of
92   // size 4 is appended, which packs 4 adjacent input channel weights into an
93   // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
94   // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
95   // A pre-condition of this format is that I must be a multiple of 4.
96   FORMAT_OIHW_VECT_I = 3,
97 };
98 
99 // Parse tensor format from the given string.
100 // Return true if the parsing succeeds, and false if it fails.
101 bool FormatFromString(absl::string_view format_str, TensorFormat* format);
102 
103 // Parse tensor format from the given string.
104 // Return true if the parsing succeeds, and false if it fails.
105 bool FilterFormatFromString(absl::string_view format_str,
106                             FilterTensorFormat* format);
107 
108 // Convert a tensor format into string.
109 std::string ToString(TensorFormat format);
110 
111 // Convert a filter tensor format into string.
112 std::string ToString(FilterTensorFormat format);
113 
114 // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
115 // format 'format'.
GetTensorSpatialDims(int num_dims,TensorFormat format)116 inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
117   switch (format) {
118     case FORMAT_NHWC:
119     case FORMAT_NCHW:
120     case FORMAT_HWNC:
121     case FORMAT_HWCN:
122       return num_dims - 2;  // Exclude N,C.
123     case FORMAT_NCHW_VECT_C:
124     case FORMAT_NHWC_VECT_W:
125       // Note: the VECT_W is not counted as an independent spatial dim here,
126       // since it just a component of the width dimension.
127       return num_dims - 3;  // Exclude N,C,VectDim.
128     default:
129       LOG(FATAL) << "Unknown format " << format;
130       return -1;  // Avoid compiler warning about missing return value
131   }
132 }
133 
GetFilterTensorSpatialDims(int num_dims,FilterTensorFormat format)134 inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
135   if (format == FORMAT_OIHW_VECT_I) {
136     return num_dims - 3;  // Exclude O,I,InnerI.
137   } else {
138     return num_dims - 2;  // Exclude O,I.
139   }
140 }
141 
142 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
143 // tensor format 'format'. This is the inverse of GetTensorSpatialDims.
GetTensorDimsFromSpatialDims(int num_spatial_dims,TensorFormat format)144 inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
145                                         TensorFormat format) {
146   switch (format) {
147     case FORMAT_NHWC:
148     case FORMAT_NCHW:
149     case FORMAT_HWNC:
150     case FORMAT_HWCN:
151       return num_spatial_dims + 2;  // Include N,C.
152     case FORMAT_NCHW_VECT_C:
153     case FORMAT_NHWC_VECT_W:
154       return num_spatial_dims + 3;  // Include N,C,VectDim.
155     default:
156       LOG(FATAL) << "Unknown format " << format;
157       return -1;  // Avoid compiler warning about missing return value
158   }
159 }
160 
161 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
162 // filter tensor format 'format'.
GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,FilterTensorFormat format)163 inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
164                                               FilterTensorFormat format) {
165   if (format == FORMAT_OIHW_VECT_I) {
166     return num_spatial_dims + 3;  // Include O,I,InnerI.
167   } else {
168     return num_spatial_dims + 2;  // Include O,I.
169   }
170 }
171 
172 // Returns the index of the batch dimension.
GetTensorBatchDimIndex(int num_dims,TensorFormat format)173 inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
174   switch (format) {
175     case FORMAT_NHWC:
176     case FORMAT_NCHW:
177     case FORMAT_NCHW_VECT_C:
178     case FORMAT_NHWC_VECT_W:
179       return 0;
180     case FORMAT_HWNC:
181       return num_dims - 2;
182     case FORMAT_HWCN:
183       return num_dims - 1;
184     default:
185       LOG(FATAL) << "Unknown format " << format;
186       return -1;  // Avoid compiler warning about missing return value
187   }
188 }
189 
190 // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
191 // the index of the outer feature dimension (i.e. dimension 1, whose size would
192 // be num_features / 4 in this case).
GetTensorFeatureDimIndex(int num_dims,TensorFormat format)193 inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
194   switch (format) {
195     case FORMAT_NHWC:
196     case FORMAT_HWNC:
197       return num_dims - 1;
198     case FORMAT_NHWC_VECT_W:
199     case FORMAT_HWCN:
200       return num_dims - 2;
201     case FORMAT_NCHW:
202     case FORMAT_NCHW_VECT_C:
203       return 1;
204     default:
205       LOG(FATAL) << "Unknown format " << format;
206       return -1;  // Avoid compiler warning about missing return value
207   }
208 }
209 
210 // Returns the index of the inner feature dimension.
GetTensorInnerFeatureDimIndex(int num_dims,TensorFormat format)211 inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
212   DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
213   return num_dims - 1;
214 }
215 
216 // Returns the index of the inner width dimension.
GetTensorInnerWidthDimIndex(int num_dims,TensorFormat format)217 inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) {
218   DCHECK_EQ(format, FORMAT_NHWC_VECT_W);
219   return num_dims - 1;
220 }
221 
222 // Returns the dimension index of the specified 'spatial_dim' within an
223 // activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns
224 // the index of the outer width dimension (i.e. dimension 2, whose size would
225 // be width / 4 in this case).
GetTensorSpatialDimIndex(int num_dims,TensorFormat format,int spatial_dim)226 inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
227                                     int spatial_dim) {
228   CHECK(spatial_dim >= 0 &&
229         spatial_dim < GetTensorSpatialDims(num_dims, format))
230       << spatial_dim << " " << num_dims << " " << ToString(format);
231   switch (format) {
232     case FORMAT_NHWC:
233     case FORMAT_NHWC_VECT_W:
234       return spatial_dim + 1;
235     case FORMAT_NCHW:
236     case FORMAT_NCHW_VECT_C:
237       return spatial_dim + 2;
238     case FORMAT_HWNC:
239     case FORMAT_HWCN:
240       return spatial_dim;
241     default:
242       LOG(FATAL) << "Unknown format " << format;
243       return -1;  // Avoid compiler warning about missing return value
244   }
245 }
246 
GetFilterTensorSpatialDimIndex(int num_dims,FilterTensorFormat format,int dim)247 inline int GetFilterTensorSpatialDimIndex(int num_dims,
248                                           FilterTensorFormat format, int dim) {
249   CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
250       << dim << " " << num_dims << " " << ToString(format);
251   switch (format) {
252     case FORMAT_HWIO:
253       return dim;
254     case FORMAT_OIHW:
255     case FORMAT_OIHW_VECT_I:
256       return dim + 2;
257     default:
258       LOG(FATAL) << "Unknown format " << format;
259       return -1;  // Avoid compiler warning about missing return value
260   }
261 }
262 
263 // Returns the index of the inner input channels dimension.
GetFilterTensorInnerInputChannelsDimIndex(int num_dims,FilterTensorFormat format)264 inline int GetFilterTensorInnerInputChannelsDimIndex(
265     int num_dims, FilterTensorFormat format) {
266   DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
267   return num_dims - 1;
268 }
269 
270 // Returns the index of the input channels dimension.
271 // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
272 // outer input channel (i.e. 1), which holds num_input_channels / 4.
GetFilterTensorInputChannelsDimIndex(int num_dims,FilterTensorFormat format)273 inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
274                                                 FilterTensorFormat format) {
275   switch (format) {
276     case FORMAT_HWIO:
277       return num_dims - 2;
278     case FORMAT_OIHW:
279     case FORMAT_OIHW_VECT_I:
280       return 1;
281     default:
282       LOG(FATAL) << "Unknown format " << format;
283       return -1;  // Avoid compiler warning about missing return value
284   }
285 }
286 
287 // Returns the index of the output channels dimension.
GetFilterTensorOutputChannelsDimIndex(int num_dims,FilterTensorFormat format)288 inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
289                                                  FilterTensorFormat format) {
290   switch (format) {
291     case FORMAT_HWIO:
292       return num_dims - 1;
293     case FORMAT_OIHW:
294     case FORMAT_OIHW_VECT_I:
295       return 0;
296     default:
297       LOG(FATAL) << "Unknown format " << format;
298       return -1;  // Avoid compiler warning about missing return value
299   }
300 }
301 
302 // TODO(pauldonnelly): Replace these tensor dimension index functions with
303 // constant structs to improve performance and reduce code size in Compute()
304 // functions.
305 
306 // Return the dimension index for the specified 'dimension' of the specified
307 // data 'tensor_format'.  'dimension' is a char that can be 'N' (batch size),
308 // 'C' (channels), 'H' (height), 'W' (width),  or a numbered spatial dimension:
309 // '0',  .. (NUM_SPATIAL_DIMS-1)..
310 // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
311 // the outer channel dimension (i.e. 1).
312 template <int NUM_SPATIAL_DIMS>
GetTensorDimIndex(TensorFormat format,char dimension)313 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
314   if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) {
315     // clang-format off
316     switch (dimension) {
317       case 'N': return 0;
318       case '0': return 1;
319       case '1': return 2;
320       case '2': return 3;
321       case 'H': return NUM_SPATIAL_DIMS - 1;
322       case 'W': return NUM_SPATIAL_DIMS;
323       case 'C': return NUM_SPATIAL_DIMS + 1;
324       default:
325         LOG(FATAL) << "Invalid dimension: " << dimension;
326         return -1;  // Avoid compiler warning about missing return value
327     }
328   } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
329     switch (dimension) {
330       case 'N': return 0;
331       case 'C': return 1;
332       case '0': return 2;
333       case '1': return 3;
334       case '2': return 4;
335       case 'H': return NUM_SPATIAL_DIMS;
336       case 'W': return NUM_SPATIAL_DIMS + 1;
337       default:
338         LOG(FATAL) << "Invalid dimension: " << dimension;
339         return -1;  // Avoid compiler warning about missing return value
340     }
341   } else if (format == FORMAT_HWNC) {
342     switch (dimension) {
343       case '0': return 0;
344       case '1': return 1;
345       case '2': return 2;
346       case 'H': return NUM_SPATIAL_DIMS - 2;
347       case 'W': return NUM_SPATIAL_DIMS - 1;
348       case 'N': return NUM_SPATIAL_DIMS;
349       case 'C': return NUM_SPATIAL_DIMS + 1;
350       default:
351         LOG(FATAL) << "Invalid dimension: " << dimension;
352         return -1;  // Avoid compiler warning about missing return value
353     }
354   } else if (format == FORMAT_HWCN) {
355     switch (dimension) {
356       case '0': return 0;
357       case '1': return 1;
358       case '2': return 2;
359       case 'H': return NUM_SPATIAL_DIMS - 2;
360       case 'W': return NUM_SPATIAL_DIMS - 1;
361       case 'C': return NUM_SPATIAL_DIMS;
362       case 'N': return NUM_SPATIAL_DIMS + 1;
363       default:
364         LOG(FATAL) << "Invalid dimension: " << dimension;
365         return -1;  // Avoid compiler warning about missing return value
366     }
367   } else {
368     LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
369     return -1;  // Avoid compiler warning about missing return value
370   }
371   // clang-format on
372 }
373 
374 // Return the dimension index for the specified 'dimension' of the specified
375 // 'filter_tensor_format'.  'dimension' is a char that can be 'O' (num output
376 // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
377 // numbered spatial dimension: '0',  .. (NUM_SPATIAL_DIMS-1).
378 // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
379 // outer input channels dimension (i.e. 1).
380 template <int NUM_SPATIAL_DIMS>
GetFilterDimIndex(FilterTensorFormat filter_tensor_format,char dimension)381 inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
382                              char dimension) {
383   // clang-format off
384   if (filter_tensor_format == FORMAT_HWIO) {
385     switch (dimension) {
386       case '0': return 0;
387       case '1': return 1;
388       case '2': return 2;
389       case 'H': return NUM_SPATIAL_DIMS - 2;
390       case 'W': return NUM_SPATIAL_DIMS - 1;
391       case 'I': return NUM_SPATIAL_DIMS;
392       case 'O': return NUM_SPATIAL_DIMS + 1;
393       default:
394         LOG(FATAL) << "Invalid dimension: " << dimension;
395         return -1;  // Avoid compiler warning about missing return value
396     }
397   } else if (filter_tensor_format == FORMAT_OIHW ||
398              filter_tensor_format == FORMAT_OIHW_VECT_I) {
399     switch (dimension) {
400       case 'O': return 0;
401       case 'I': return 1;
402       case '0': return 2;
403       case '1': return 3;
404       case '2': return 4;
405       case 'H': return NUM_SPATIAL_DIMS;
406       case 'W': return NUM_SPATIAL_DIMS + 1;
407       default:
408         LOG(FATAL) << "Invalid dimension: " << dimension;
409         return -1;  // Avoid compiler warning about missing return value
410     }
411   } else {
412     LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
413     return -1;  // Avoid compiler warning about missing return value
414   }
415   // clang-format on
416 }
417 
GetTensorDimIndex(TensorFormat format,char dimension)418 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
419   return GetTensorDimIndex<2>(format, dimension);
420 }
421 
GetTensorDimIndex(TensorFormat format,char dimension,int num_total_dims)422 inline int32 GetTensorDimIndex(TensorFormat format, char dimension,
423                                int num_total_dims) {
424   int32_t index = (GetTensorSpatialDims(num_total_dims, format) == 3)
425                       ? GetTensorDimIndex<3>(format, dimension)
426                       : GetTensorDimIndex<2>(format, dimension);
427   CHECK(index >= 0 && index < num_total_dims)  // Crash OK.
428       << "Invalid index from the dimension: " << index << ", " << format << ", "
429       << dimension;
430   return index;
431 }
432 
433 // Return the element from 'dimension_attributes' that corresponds to the
434 // specified 'dimension' according to 'tensor_format'.
435 template <typename T>
GetTensorDim(gtl::ArraySlice<T> dimension_attributes,TensorFormat tensor_format,char dimension)436 T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
437                TensorFormat tensor_format, char dimension) {
438   int index =
439       GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size());
440   return dimension_attributes[index];
441 }
442 
443 // Return the element from 'dimension_attribute' that corresponds to the
444 // specified 'dimension' according to 'filter_tensor_format'.
445 template <typename T>
GetFilterDim(gtl::ArraySlice<T> dimension_attribute,FilterTensorFormat filter_tensor_format,char dimension)446 T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
447                FilterTensorFormat filter_tensor_format, char dimension) {
448   int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
449                                           filter_tensor_format) == 3)
450                   ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
451                   : GetFilterDimIndex<2>(filter_tensor_format, dimension);
452   using size_type = typename gtl::ArraySlice<T>::size_type;
453   CHECK(index >= 0 &&
454         static_cast<size_type>(index) < dimension_attribute.size())
455       << "Invalid index from the dimension: " << index << ", "
456       << filter_tensor_format << ", " << dimension;
457   return dimension_attribute[index];
458 }
459 
460 template <typename T>
GetTensorDim(const std::vector<T> & attributes,TensorFormat format,char dimension)461 T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
462                char dimension) {
463   return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
464 }
465 
466 // Return the size of the specified 'dimension' within 'tensor_shape'
467 // according to 'tensor_format'.
GetTensorDim(const TensorShape & tensor_shape,TensorFormat tensor_format,char dimension)468 inline int64_t GetTensorDim(const TensorShape& tensor_shape,
469                             TensorFormat tensor_format, char dimension) {
470   return GetTensorDim(gtl::ArraySlice<int64_t>(tensor_shape.dim_sizes()),
471                       tensor_format, dimension);
472 }
473 
474 // Return the size of the specified 'dimension' within 'tensor_shape'
475 // according to 'tensor_filter_format'.
GetFilterDim(const TensorShape & tensor_shape,FilterTensorFormat tensor_filter_format,char dimension)476 inline int64_t GetFilterDim(const TensorShape& tensor_shape,
477                             FilterTensorFormat tensor_filter_format,
478                             char dimension) {
479   return GetFilterDim(gtl::ArraySlice<int64_t>(tensor_shape.dim_sizes()),
480                       tensor_filter_format, dimension);
481 }
482 
483 // Return the size of the specified 'dimension' of 'tensor' according to
484 // 'tensor_format'.
GetTensorDim(const Tensor & tensor,TensorFormat tensor_format,char dimension)485 inline int64_t GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
486                             char dimension) {
487   return GetTensorDim(tensor.shape(), tensor_format, dimension);
488 }
489 
490 // Return the size of the specified 'dimension' of 'tensor' according to
491 // 'filter_tensor_format'.
GetFilterDim(const Tensor & tensor,FilterTensorFormat filter_tensor_format,char dimension)492 inline int64_t GetFilterDim(const Tensor& tensor,
493                             FilterTensorFormat filter_tensor_format,
494                             char dimension) {
495   return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
496 }
497 
GetExplicitPaddingForDim(const std::vector<int64_t> & explicit_paddings,TensorFormat tensor_format,char dimension,int64_t * padding_before,int64_t * padding_after)498 inline void GetExplicitPaddingForDim(
499     const std::vector<int64_t>& explicit_paddings, TensorFormat tensor_format,
500     char dimension, int64_t* padding_before, int64_t* padding_after) {
501   int index =
502       GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2);
503   *padding_before = explicit_paddings[2 * index];
504   *padding_after = explicit_paddings[2 * index + 1];
505 }
506 
507 // Return the string that specifies the data format for convnet operations.
508 std::string GetConvnetDataFormatAttrString();
509 std::string GetConvnet3dDataFormatAttrString();
510 
511 // Return the string that specifies the filter format for convnet operations.
512 std::string GetConvnetFilterFormatAttrString();
513 std::string GetConvnet3dFilterFormatAttrString();
514 std::string GetConvnetDataFormat2D3DAttrString();
515 
516 // Returns a tensor shape for the specified format and dimension sizes.
517 // Works for both 2D and 3D operations. The output shapes are as follows:
518 // FORMAT_NHWC:        (N, spatial, C); rank = spatial.size() + 2
519 // FORMAT_NCHW:        (N, C, spatial); rank = spatial.size() + 2
520 // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3
521 // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3
ShapeFromFormat(TensorFormat format,int64_t N,gtl::ArraySlice<int64_t> spatial,int64_t C)522 inline TensorShape ShapeFromFormat(TensorFormat format, int64_t N,
523                                    gtl::ArraySlice<int64_t> spatial,
524                                    int64_t C) {
525   const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
526   gtl::InlinedVector<int64_t, 6> dim_sizes(dims);
527   dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
528   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
529     auto dim_size = spatial[dim];
530     if (format == FORMAT_NHWC_VECT_W &&
531         static_cast<size_t>(dim) == spatial.size() - 1) {
532       CHECK_EQ(0, dim_size % 4)
533           << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W="
534           << dim_size;
535       dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4;
536       dim_size /= 4;
537     }
538     dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size;
539   }
540 
541   int feature_index = GetTensorFeatureDimIndex(dims, format);
542   if (format == FORMAT_NCHW_VECT_C) {
543     CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
544                        << C;
545     C /= 4;
546     dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
547   }
548   dim_sizes[feature_index] = C;
549   return TensorShape(dim_sizes);
550 }
551 
552 // Return a tensor shape of the specified 'format', and dimensions.
553 // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
554 // the output TensorShape has spatial.size() + 3 dimensions, otherwise
555 // it has spatial.size() + 2 dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,gtl::ArraySlice<int64_t> spatial,int64_t I,int64_t O)556 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
557                                                gtl::ArraySlice<int64_t> spatial,
558                                                int64_t I, int64_t O) {
559   const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
560   gtl::InlinedVector<int64_t, 6> dim_sizes(dims);
561   dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
562   for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
563     dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
564   }
565 
566   if (format == FORMAT_OIHW_VECT_I) {
567     CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
568                        << I;
569     I /= 4;
570     dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
571   }
572   dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
573   return TensorShape(dim_sizes);
574 }
575 
576 // Return a tensor shape of the specified 'format', and dimensions.
ShapeFromFormat(TensorFormat format,int64_t N,int64_t H,int64_t W,int64_t C)577 inline TensorShape ShapeFromFormat(TensorFormat format, int64_t N, int64_t H,
578                                    int64_t W, int64_t C) {
579   return ShapeFromFormat(format, N, {H, W}, C);
580 }
581 
582 // Return a filter tensor shape of the specified 'format', and dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,int64_t H,int64_t W,int64_t I,int64_t O)583 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
584                                                int64_t H, int64_t W, int64_t I,
585                                                int64_t O) {
586   return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
587 }
588 
589 // Returns a copy of the specified tensor 'src_shape' converted from
590 // 'src_format' to 'dst_format'.
ShapeFromFormat(TensorFormat dst_format,const TensorShape & src_shape,TensorFormat src_format)591 inline TensorShape ShapeFromFormat(TensorFormat dst_format,
592                                    const TensorShape& src_shape,
593                                    TensorFormat src_format) {
594   if (src_format == dst_format) {
595     return src_shape;
596   }
597 
598   const int64_t batch = GetTensorDim(src_shape, src_format, 'N');
599   const int64_t channels = GetTensorDim(src_shape, src_format, 'C') *
600                            (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
601   const int num_src_spatial_dims =
602       GetTensorSpatialDims(src_shape.dims(), src_format);
603   std::vector<int64_t> spatial_dims(num_src_spatial_dims);
604   for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) {
605     spatial_dims[spatial_dim] = gtl::ArraySlice<int64_t>(
606         src_shape.dim_sizes())[GetTensorSpatialDimIndex(
607         src_shape.dims(), src_format, spatial_dim)];
608   }
609   if (src_format == FORMAT_NHWC_VECT_W) {
610     spatial_dims[num_src_spatial_dims - 1] *= 4;
611   }
612   return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels);
613 }
614 
615 // Returns a copy of the specified filter tensor 'src_shape' converted from
616 // 'src_filter_format' to 'dst_filter_format'.
ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,const TensorShape & src_shape,FilterTensorFormat src_filter_format)617 inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
618                                          const TensorShape& src_shape,
619                                          FilterTensorFormat src_filter_format) {
620   if (src_filter_format == dst_filter_format) {
621     return src_shape;
622   }
623 
624   const int64_t output_channels =
625       GetFilterDim(src_shape, src_filter_format, 'O');
626   const int64_t input_channels =
627       GetFilterDim(src_shape, src_filter_format, 'I') *
628       (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
629 
630   if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
631     return ShapeFromFilterTensorFormat(
632         dst_filter_format,
633         {{GetFilterDim(src_shape, src_filter_format, '0'),
634           GetFilterDim(src_shape, src_filter_format, '1'),
635           GetFilterDim(src_shape, src_filter_format, '2')}},
636         input_channels, output_channels);
637   }
638 
639   return ShapeFromFilterTensorFormat(
640       dst_filter_format,
641       {{GetFilterDim(src_shape, src_filter_format, 'H'),
642         GetFilterDim(src_shape, src_filter_format, 'W')}},
643       input_channels, output_channels);
644 }
645 
646 }  // namespace tensorflow
647 
648 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
649