1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
17 #define TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
18
19 #include <array>
20 #include <vector>
21
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/gtl/array_slice.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/types.h"
27
28 namespace tensorflow {
29
30 // Tensor format for input/output activations used in convolution operations.
31 // The mnemonics specify the meaning of each tensor dimension sorted from
32 // largest to smallest memory stride.
33 // N = Batch, H = Image Height, W = Image Width, C = Number of Channels.
34 // TODO(pauldonnelly): It would probably be better to switch to a registration
35 // process for tensor formats, so specialized formats could be defined more
36 // locally to where they are used.
37 enum TensorFormat {
38 // FORMAT_NHWC is the default format in TensorFlow.
39 FORMAT_NHWC = 0,
40
41 // FORMAT_NCHW often improves performance on GPUs.
42 FORMAT_NCHW = 1,
43
44 // NCHW_VECT_C is the most performant tensor format for cudnn6's quantized
45 // int8 convolution and fused convolution. It is laid out in the same order
46 // as NCHW, except that the size of the Channels dimension is divided by 4,
47 // and a new dimension of size 4 is appended, which packs 4 adjacent channel
48 // activations for the same pixel into an int32. Thus an NCHW format tensor
49 // with dimensions [N, C, H, W] would have dimensions [N, C/4, H, W, 4] in
50 // NCHW_VECT_C format.
51 // A pre-condition of this format is that C must be a multiple of 4.
52 FORMAT_NCHW_VECT_C = 2,
53
54 // Similar to NHWC, but the size of the W dimension is divided by 4, and a
55 // new dimension of size 4 is appended, which packs 4 adjacent activations
56 // in the width dimension.
57 FORMAT_NHWC_VECT_W = 3,
58
59 // Note: although the current code in this file assumes VECT_C and VECT_W
60 // enums imply int8x4 vectors, this should not be relied upon.
61 // In the future we may change the meaning of these enums to include vectors
62 // of other types such as int16x2, with op implementations automatically
63 // determining which format is implied based on the datatype.
64
65 // FORMAT_HWNC is for TPUs.
66 FORMAT_HWNC = 4,
67
68 // FORMAT_HWCN is for TPUs.
69 FORMAT_HWCN = 5,
70 };
71
72 // Tensor format for convolutional filters.
73 // The mnemonics specify the meaning of each tensor dimension sorted
74 // from largest to smallest memory stride.
75 // H = Kernel Height, W = Kernel Width, I = Input Channels, O = Output Channels.
76 // Note: In cudnnGetFilter4dDescriptor(), 'O' is called 'K', 'I' is called 'C'.
77 enum FilterTensorFormat {
78 // FORMAT_HWIO is the default filter format in TensorFlow.
79 // Ops that do not have a 'filter_format' attribute will assume this format.
80 FORMAT_HWIO = 0,
81
82 // FORMAT_OIHW often improves performance on GPUs.
83 FORMAT_OIHW = 1,
84
85 // FORMAT_OHWI used by cuDNN for NHWC convolutions.
86 FORMAT_OHWI = 2,
87
88 // OIHW_VECT_I is the most performant tensor format for cudnn6's quantized
89 // int8 convolution and fused convolution. It is analogous to the NCHW_VECT_C
90 // data format. It is laid out in the same order as OIHW, except that the size
91 // of the Input Channels dimension is divided by 4, and a new dimension of
92 // size 4 is appended, which packs 4 adjacent input channel weights into an
93 // int32. Thus an OIHW format filter with dimensions [O, I, H, W] would have
94 // dimensions [O, I/4, H, W, 4] in OIHW_VECT_I format.
95 // A pre-condition of this format is that I must be a multiple of 4.
96 FORMAT_OIHW_VECT_I = 3,
97 };
98
99 // Parse tensor format from the given string.
100 // Return true if the parsing succeeds, and false if it fails.
101 bool FormatFromString(absl::string_view format_str, TensorFormat* format);
102
103 // Parse tensor format from the given string.
104 // Return true if the parsing succeeds, and false if it fails.
105 bool FilterFormatFromString(absl::string_view format_str,
106 FilterTensorFormat* format);
107
108 // Convert a tensor format into string.
109 std::string ToString(TensorFormat format);
110
111 // Convert a filter tensor format into string.
112 std::string ToString(FilterTensorFormat format);
113
114 // Returns the number of spatial dims of a tensor of rank 'num_dims' and tensor
115 // format 'format'.
GetTensorSpatialDims(int num_dims,TensorFormat format)116 inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
117 switch (format) {
118 case FORMAT_NHWC:
119 case FORMAT_NCHW:
120 case FORMAT_HWNC:
121 case FORMAT_HWCN:
122 return num_dims - 2; // Exclude N,C.
123 case FORMAT_NCHW_VECT_C:
124 case FORMAT_NHWC_VECT_W:
125 // Note: the VECT_W is not counted as an independent spatial dim here,
126 // since it just a component of the width dimension.
127 return num_dims - 3; // Exclude N,C,VectDim.
128 default:
129 LOG(FATAL) << "Unknown format " << format;
130 return -1; // Avoid compiler warning about missing return value
131 }
132 }
133
GetFilterTensorSpatialDims(int num_dims,FilterTensorFormat format)134 inline int GetFilterTensorSpatialDims(int num_dims, FilterTensorFormat format) {
135 if (format == FORMAT_OIHW_VECT_I) {
136 return num_dims - 3; // Exclude O,I,InnerI.
137 } else {
138 return num_dims - 2; // Exclude O,I.
139 }
140 }
141
142 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
143 // tensor format 'format'. This is the inverse of GetTensorSpatialDims.
GetTensorDimsFromSpatialDims(int num_spatial_dims,TensorFormat format)144 inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
145 TensorFormat format) {
146 switch (format) {
147 case FORMAT_NHWC:
148 case FORMAT_NCHW:
149 case FORMAT_HWNC:
150 case FORMAT_HWCN:
151 return num_spatial_dims + 2; // Include N,C.
152 case FORMAT_NCHW_VECT_C:
153 case FORMAT_NHWC_VECT_W:
154 return num_spatial_dims + 3; // Include N,C,VectDim.
155 default:
156 LOG(FATAL) << "Unknown format " << format;
157 return -1; // Avoid compiler warning about missing return value
158 }
159 }
160
161 // Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
162 // filter tensor format 'format'.
GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,FilterTensorFormat format)163 inline int GetFilterTensorDimsFromSpatialDims(int num_spatial_dims,
164 FilterTensorFormat format) {
165 if (format == FORMAT_OIHW_VECT_I) {
166 return num_spatial_dims + 3; // Include O,I,InnerI.
167 } else {
168 return num_spatial_dims + 2; // Include O,I.
169 }
170 }
171
172 // Returns the index of the batch dimension.
GetTensorBatchDimIndex(int num_dims,TensorFormat format)173 inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
174 switch (format) {
175 case FORMAT_NHWC:
176 case FORMAT_NCHW:
177 case FORMAT_NCHW_VECT_C:
178 case FORMAT_NHWC_VECT_W:
179 return 0;
180 case FORMAT_HWNC:
181 return num_dims - 2;
182 case FORMAT_HWCN:
183 return num_dims - 1;
184 default:
185 LOG(FATAL) << "Unknown format " << format;
186 return -1; // Avoid compiler warning about missing return value
187 }
188 }
189
190 // Returns the index of the feature dimension. If format is NCHW_VECT_C, returns
191 // the index of the outer feature dimension (i.e. dimension 1, whose size would
192 // be num_features / 4 in this case).
GetTensorFeatureDimIndex(int num_dims,TensorFormat format)193 inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
194 switch (format) {
195 case FORMAT_NHWC:
196 case FORMAT_HWNC:
197 return num_dims - 1;
198 case FORMAT_NHWC_VECT_W:
199 case FORMAT_HWCN:
200 return num_dims - 2;
201 case FORMAT_NCHW:
202 case FORMAT_NCHW_VECT_C:
203 return 1;
204 default:
205 LOG(FATAL) << "Unknown format " << format;
206 return -1; // Avoid compiler warning about missing return value
207 }
208 }
209
210 // Returns the index of the inner feature dimension.
GetTensorInnerFeatureDimIndex(int num_dims,TensorFormat format)211 inline int GetTensorInnerFeatureDimIndex(int num_dims, TensorFormat format) {
212 DCHECK_EQ(format, FORMAT_NCHW_VECT_C);
213 return num_dims - 1;
214 }
215
216 // Returns the index of the inner width dimension.
GetTensorInnerWidthDimIndex(int num_dims,TensorFormat format)217 inline int GetTensorInnerWidthDimIndex(int num_dims, TensorFormat format) {
218 DCHECK_EQ(format, FORMAT_NHWC_VECT_W);
219 return num_dims - 1;
220 }
221
222 // Returns the dimension index of the specified 'spatial_dim' within an
223 // activation tensor. If format is NHWC_VECT_W and spatial_dim is 1, returns
224 // the index of the outer width dimension (i.e. dimension 2, whose size would
225 // be width / 4 in this case).
GetTensorSpatialDimIndex(int num_dims,TensorFormat format,int spatial_dim)226 inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
227 int spatial_dim) {
228 CHECK(spatial_dim >= 0 &&
229 spatial_dim < GetTensorSpatialDims(num_dims, format))
230 << spatial_dim << " " << num_dims << " " << ToString(format);
231 switch (format) {
232 case FORMAT_NHWC:
233 case FORMAT_NHWC_VECT_W:
234 return spatial_dim + 1;
235 case FORMAT_NCHW:
236 case FORMAT_NCHW_VECT_C:
237 return spatial_dim + 2;
238 case FORMAT_HWNC:
239 case FORMAT_HWCN:
240 return spatial_dim;
241 default:
242 LOG(FATAL) << "Unknown format " << format;
243 return -1; // Avoid compiler warning about missing return value
244 }
245 }
246
GetFilterTensorSpatialDimIndex(int num_dims,FilterTensorFormat format,int dim)247 inline int GetFilterTensorSpatialDimIndex(int num_dims,
248 FilterTensorFormat format, int dim) {
249 CHECK(dim >= 0 && dim < GetFilterTensorSpatialDims(num_dims, format))
250 << dim << " " << num_dims << " " << ToString(format);
251 switch (format) {
252 case FORMAT_HWIO:
253 return dim;
254 case FORMAT_OIHW:
255 case FORMAT_OIHW_VECT_I:
256 return dim + 2;
257 default:
258 LOG(FATAL) << "Unknown format " << format;
259 return -1; // Avoid compiler warning about missing return value
260 }
261 }
262
263 // Returns the index of the inner input channels dimension.
GetFilterTensorInnerInputChannelsDimIndex(int num_dims,FilterTensorFormat format)264 inline int GetFilterTensorInnerInputChannelsDimIndex(
265 int num_dims, FilterTensorFormat format) {
266 DCHECK_EQ(format, FORMAT_OIHW_VECT_I);
267 return num_dims - 1;
268 }
269
270 // Returns the index of the input channels dimension.
271 // If 'format' is FORMAT_OIHW_VECT_I, returns the dimension index of the
272 // outer input channel (i.e. 1), which holds num_input_channels / 4.
GetFilterTensorInputChannelsDimIndex(int num_dims,FilterTensorFormat format)273 inline int GetFilterTensorInputChannelsDimIndex(int num_dims,
274 FilterTensorFormat format) {
275 switch (format) {
276 case FORMAT_HWIO:
277 return num_dims - 2;
278 case FORMAT_OIHW:
279 case FORMAT_OIHW_VECT_I:
280 return 1;
281 default:
282 LOG(FATAL) << "Unknown format " << format;
283 return -1; // Avoid compiler warning about missing return value
284 }
285 }
286
287 // Returns the index of the output channels dimension.
GetFilterTensorOutputChannelsDimIndex(int num_dims,FilterTensorFormat format)288 inline int GetFilterTensorOutputChannelsDimIndex(int num_dims,
289 FilterTensorFormat format) {
290 switch (format) {
291 case FORMAT_HWIO:
292 return num_dims - 1;
293 case FORMAT_OIHW:
294 case FORMAT_OIHW_VECT_I:
295 return 0;
296 default:
297 LOG(FATAL) << "Unknown format " << format;
298 return -1; // Avoid compiler warning about missing return value
299 }
300 }
301
302 // TODO(pauldonnelly): Replace these tensor dimension index functions with
303 // constant structs to improve performance and reduce code size in Compute()
304 // functions.
305
306 // Return the dimension index for the specified 'dimension' of the specified
307 // data 'tensor_format'. 'dimension' is a char that can be 'N' (batch size),
308 // 'C' (channels), 'H' (height), 'W' (width), or a numbered spatial dimension:
309 // '0', .. (NUM_SPATIAL_DIMS-1)..
310 // If 'format' is NCHW_VECT_C and 'dimension' is 'C', returns the index of
311 // the outer channel dimension (i.e. 1).
312 template <int NUM_SPATIAL_DIMS>
GetTensorDimIndex(TensorFormat format,char dimension)313 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
314 if (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) {
315 // clang-format off
316 switch (dimension) {
317 case 'N': return 0;
318 case '0': return 1;
319 case '1': return 2;
320 case '2': return 3;
321 case 'H': return NUM_SPATIAL_DIMS - 1;
322 case 'W': return NUM_SPATIAL_DIMS;
323 case 'C': return NUM_SPATIAL_DIMS + 1;
324 default:
325 LOG(FATAL) << "Invalid dimension: " << dimension;
326 return -1; // Avoid compiler warning about missing return value
327 }
328 } else if (format == FORMAT_NCHW || format == FORMAT_NCHW_VECT_C) {
329 switch (dimension) {
330 case 'N': return 0;
331 case 'C': return 1;
332 case '0': return 2;
333 case '1': return 3;
334 case '2': return 4;
335 case 'H': return NUM_SPATIAL_DIMS;
336 case 'W': return NUM_SPATIAL_DIMS + 1;
337 default:
338 LOG(FATAL) << "Invalid dimension: " << dimension;
339 return -1; // Avoid compiler warning about missing return value
340 }
341 } else if (format == FORMAT_HWNC) {
342 switch (dimension) {
343 case '0': return 0;
344 case '1': return 1;
345 case '2': return 2;
346 case 'H': return NUM_SPATIAL_DIMS - 2;
347 case 'W': return NUM_SPATIAL_DIMS - 1;
348 case 'N': return NUM_SPATIAL_DIMS;
349 case 'C': return NUM_SPATIAL_DIMS + 1;
350 default:
351 LOG(FATAL) << "Invalid dimension: " << dimension;
352 return -1; // Avoid compiler warning about missing return value
353 }
354 } else if (format == FORMAT_HWCN) {
355 switch (dimension) {
356 case '0': return 0;
357 case '1': return 1;
358 case '2': return 2;
359 case 'H': return NUM_SPATIAL_DIMS - 2;
360 case 'W': return NUM_SPATIAL_DIMS - 1;
361 case 'C': return NUM_SPATIAL_DIMS;
362 case 'N': return NUM_SPATIAL_DIMS + 1;
363 default:
364 LOG(FATAL) << "Invalid dimension: " << dimension;
365 return -1; // Avoid compiler warning about missing return value
366 }
367 } else {
368 LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
369 return -1; // Avoid compiler warning about missing return value
370 }
371 // clang-format on
372 }
373
374 // Return the dimension index for the specified 'dimension' of the specified
375 // 'filter_tensor_format'. 'dimension' is a char that can be 'O' (num output
376 // channels), 'I' (num input channels), 'H' (height), 'W' (width), or a
377 // numbered spatial dimension: '0', .. (NUM_SPATIAL_DIMS-1).
378 // If 'format' is OIHW_VECT_I and 'dimension' is 'I', returns the index of the
379 // outer input channels dimension (i.e. 1).
380 template <int NUM_SPATIAL_DIMS>
GetFilterDimIndex(FilterTensorFormat filter_tensor_format,char dimension)381 inline int GetFilterDimIndex(FilterTensorFormat filter_tensor_format,
382 char dimension) {
383 // clang-format off
384 if (filter_tensor_format == FORMAT_HWIO) {
385 switch (dimension) {
386 case '0': return 0;
387 case '1': return 1;
388 case '2': return 2;
389 case 'H': return NUM_SPATIAL_DIMS - 2;
390 case 'W': return NUM_SPATIAL_DIMS - 1;
391 case 'I': return NUM_SPATIAL_DIMS;
392 case 'O': return NUM_SPATIAL_DIMS + 1;
393 default:
394 LOG(FATAL) << "Invalid dimension: " << dimension;
395 return -1; // Avoid compiler warning about missing return value
396 }
397 } else if (filter_tensor_format == FORMAT_OIHW ||
398 filter_tensor_format == FORMAT_OIHW_VECT_I) {
399 switch (dimension) {
400 case 'O': return 0;
401 case 'I': return 1;
402 case '0': return 2;
403 case '1': return 3;
404 case '2': return 4;
405 case 'H': return NUM_SPATIAL_DIMS;
406 case 'W': return NUM_SPATIAL_DIMS + 1;
407 default:
408 LOG(FATAL) << "Invalid dimension: " << dimension;
409 return -1; // Avoid compiler warning about missing return value
410 }
411 } else {
412 LOG(FATAL) << "Invalid format: " << static_cast<int>(filter_tensor_format);
413 return -1; // Avoid compiler warning about missing return value
414 }
415 // clang-format on
416 }
417
GetTensorDimIndex(TensorFormat format,char dimension)418 inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
419 return GetTensorDimIndex<2>(format, dimension);
420 }
421
GetTensorDimIndex(TensorFormat format,char dimension,int num_total_dims)422 inline int32 GetTensorDimIndex(TensorFormat format, char dimension,
423 int num_total_dims) {
424 int32 index = (GetTensorSpatialDims(num_total_dims, format) == 3)
425 ? GetTensorDimIndex<3>(format, dimension)
426 : GetTensorDimIndex<2>(format, dimension);
427 CHECK(index >= 0 && index < num_total_dims) // Crash OK.
428 << "Invalid index from the dimension: " << index << ", " << format << ", "
429 << dimension;
430 return index;
431 }
432
433 // Return the element from 'dimension_attributes' that corresponds to the
434 // specified 'dimension' according to 'tensor_format'.
435 template <typename T>
GetTensorDim(gtl::ArraySlice<T> dimension_attributes,TensorFormat tensor_format,char dimension)436 T GetTensorDim(gtl::ArraySlice<T> dimension_attributes,
437 TensorFormat tensor_format, char dimension) {
438 int index =
439 GetTensorDimIndex(tensor_format, dimension, dimension_attributes.size());
440 return dimension_attributes[index];
441 }
442
443 // Return the element from 'dimension_attribute' that corresponds to the
444 // specified 'dimension' according to 'filter_tensor_format'.
445 template <typename T>
GetFilterDim(gtl::ArraySlice<T> dimension_attribute,FilterTensorFormat filter_tensor_format,char dimension)446 T GetFilterDim(gtl::ArraySlice<T> dimension_attribute,
447 FilterTensorFormat filter_tensor_format, char dimension) {
448 int index = (GetFilterTensorSpatialDims(dimension_attribute.size(),
449 filter_tensor_format) == 3)
450 ? GetFilterDimIndex<3>(filter_tensor_format, dimension)
451 : GetFilterDimIndex<2>(filter_tensor_format, dimension);
452 using size_type = typename gtl::ArraySlice<T>::size_type;
453 CHECK(index >= 0 &&
454 static_cast<size_type>(index) < dimension_attribute.size())
455 << "Invalid index from the dimension: " << index << ", "
456 << filter_tensor_format << ", " << dimension;
457 return dimension_attribute[index];
458 }
459
460 template <typename T>
GetTensorDim(const std::vector<T> & attributes,TensorFormat format,char dimension)461 T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
462 char dimension) {
463 return GetTensorDim(gtl::ArraySlice<T>(attributes), format, dimension);
464 }
465
466 // Return the size of the specified 'dimension' within 'tensor_shape'
467 // according to 'tensor_format'.
GetTensorDim(const TensorShape & tensor_shape,TensorFormat tensor_format,char dimension)468 inline int64 GetTensorDim(const TensorShape& tensor_shape,
469 TensorFormat tensor_format, char dimension) {
470 return GetTensorDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
471 tensor_format, dimension);
472 }
473
474 // Return the size of the specified 'dimension' within 'tensor_shape'
475 // according to 'tensor_filter_format'.
GetFilterDim(const TensorShape & tensor_shape,FilterTensorFormat tensor_filter_format,char dimension)476 inline int64 GetFilterDim(const TensorShape& tensor_shape,
477 FilterTensorFormat tensor_filter_format,
478 char dimension) {
479 return GetFilterDim(gtl::ArraySlice<int64>(tensor_shape.dim_sizes()),
480 tensor_filter_format, dimension);
481 }
482
483 // Return the size of the specified 'dimension' of 'tensor' according to
484 // 'tensor_format'.
GetTensorDim(const Tensor & tensor,TensorFormat tensor_format,char dimension)485 inline int64 GetTensorDim(const Tensor& tensor, TensorFormat tensor_format,
486 char dimension) {
487 return GetTensorDim(tensor.shape(), tensor_format, dimension);
488 }
489
490 // Return the size of the specified 'dimension' of 'tensor' according to
491 // 'filter_tensor_format'.
GetFilterDim(const Tensor & tensor,FilterTensorFormat filter_tensor_format,char dimension)492 inline int64 GetFilterDim(const Tensor& tensor,
493 FilterTensorFormat filter_tensor_format,
494 char dimension) {
495 return GetFilterDim(tensor.shape(), filter_tensor_format, dimension);
496 }
497
GetExplicitPaddingForDim(const std::vector<int64> & explicit_paddings,TensorFormat tensor_format,char dimension,int64 * padding_before,int64 * padding_after)498 inline void GetExplicitPaddingForDim(
499 const std::vector<int64>& explicit_paddings, TensorFormat tensor_format,
500 char dimension, int64* padding_before, int64* padding_after) {
501 int index =
502 GetTensorDimIndex(tensor_format, dimension, explicit_paddings.size() / 2);
503 *padding_before = explicit_paddings[2 * index];
504 *padding_after = explicit_paddings[2 * index + 1];
505 }
506
507 // Return the string that specifies the data format for convnet operations.
508 std::string GetConvnetDataFormatAttrString();
509 std::string GetConvnet3dDataFormatAttrString();
510
511 // Return the string that specifies the filter format for convnet operations.
512 std::string GetConvnetFilterFormatAttrString();
513 std::string GetConvnet3dFilterFormatAttrString();
514 std::string GetConvnetDataFormat2D3DAttrString();
515
516 // Returns a tensor shape for the specified format and dimension sizes.
517 // Works for both 2D and 3D operations. The output shapes are as follows:
518 // FORMAT_NHWC: (N, spatial, C); rank = spatial.size() + 2
519 // FORMAT_NCHW: (N, C, spatial); rank = spatial.size() + 2
520 // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3
521 // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3
ShapeFromFormat(TensorFormat format,int64 N,gtl::ArraySlice<int64> spatial,int64 C)522 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
523 gtl::ArraySlice<int64> spatial, int64 C) {
524 const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
525 gtl::InlinedVector<int64, 6> dim_sizes(dims);
526 dim_sizes[GetTensorBatchDimIndex(dims, format)] = N;
527 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
528 auto dim_size = spatial[dim];
529 if (format == FORMAT_NHWC_VECT_W &&
530 static_cast<size_t>(dim) == spatial.size() - 1) {
531 CHECK_EQ(0, dim_size % 4)
532 << "FORMAT_NHWC_VECT_W requires W to be a multiple of 4, but W="
533 << dim_size;
534 dim_sizes[GetTensorInnerWidthDimIndex(dims, format)] = 4;
535 dim_size /= 4;
536 }
537 dim_sizes[GetTensorSpatialDimIndex(dims, format, dim)] = dim_size;
538 }
539
540 int feature_index = GetTensorFeatureDimIndex(dims, format);
541 if (format == FORMAT_NCHW_VECT_C) {
542 CHECK_EQ(0, C % 4) << "NCHW_VECT_C requires C to be a multiple of 4, but C="
543 << C;
544 C /= 4;
545 dim_sizes[GetTensorInnerFeatureDimIndex(dims, format)] = 4;
546 }
547 dim_sizes[feature_index] = C;
548 return TensorShape(dim_sizes);
549 }
550
551 // Return a tensor shape of the specified 'format', and dimensions.
552 // Works for both 2D and 3D operations. If 'format' is OIHW_VECT_I,
553 // the output TensorShape has spatial.size() + 3 dimensions, otherwise
554 // it has spatial.size() + 2 dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,gtl::ArraySlice<int64> spatial,int64 I,int64 O)555 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
556 gtl::ArraySlice<int64> spatial,
557 int64 I, int64 O) {
558 const int dims = GetFilterTensorDimsFromSpatialDims(spatial.size(), format);
559 gtl::InlinedVector<int64, 6> dim_sizes(dims);
560 dim_sizes[GetFilterTensorOutputChannelsDimIndex(dims, format)] = O;
561 for (int dim = 0; static_cast<size_t>(dim) < spatial.size(); dim++) {
562 dim_sizes[GetFilterTensorSpatialDimIndex(dims, format, dim)] = spatial[dim];
563 }
564
565 if (format == FORMAT_OIHW_VECT_I) {
566 CHECK_EQ(0, I % 4) << "OIHW_VECT_I requires I to be a multiple of 4, but I="
567 << I;
568 I /= 4;
569 dim_sizes[GetFilterTensorInnerInputChannelsDimIndex(dims, format)] = 4;
570 }
571 dim_sizes[GetFilterTensorInputChannelsDimIndex(dims, format)] = I;
572 return TensorShape(dim_sizes);
573 }
574
575 // Return a tensor shape of the specified 'format', and dimensions.
ShapeFromFormat(TensorFormat format,int64 N,int64 H,int64 W,int64 C)576 inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
577 int64 W, int64 C) {
578 return ShapeFromFormat(format, N, {H, W}, C);
579 }
580
581 // Return a filter tensor shape of the specified 'format', and dimensions.
ShapeFromFilterTensorFormat(FilterTensorFormat format,int64 H,int64 W,int64 I,int64 O)582 inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format,
583 int64 H, int64 W, int64 I,
584 int64 O) {
585 return ShapeFromFilterTensorFormat(format, {H, W}, I, O);
586 }
587
588 // Returns a copy of the specified tensor 'src_shape' converted from
589 // 'src_format' to 'dst_format'.
ShapeFromFormat(TensorFormat dst_format,const TensorShape & src_shape,TensorFormat src_format)590 inline TensorShape ShapeFromFormat(TensorFormat dst_format,
591 const TensorShape& src_shape,
592 TensorFormat src_format) {
593 if (src_format == dst_format) {
594 return src_shape;
595 }
596
597 const int64 batch = GetTensorDim(src_shape, src_format, 'N');
598 const int64 channels = GetTensorDim(src_shape, src_format, 'C') *
599 (src_format == FORMAT_NCHW_VECT_C ? 4 : 1);
600 const int num_src_spatial_dims =
601 GetTensorSpatialDims(src_shape.dims(), src_format);
602 std::vector<int64> spatial_dims(num_src_spatial_dims);
603 for (int spatial_dim = 0; spatial_dim < num_src_spatial_dims; ++spatial_dim) {
604 spatial_dims[spatial_dim] =
605 gtl::ArraySlice<int64>(src_shape.dim_sizes())[GetTensorSpatialDimIndex(
606 src_shape.dims(), src_format, spatial_dim)];
607 }
608 if (src_format == FORMAT_NHWC_VECT_W) {
609 spatial_dims[num_src_spatial_dims - 1] *= 4;
610 }
611 return ShapeFromFormat(dst_format, batch, {spatial_dims}, channels);
612 }
613
614 // Returns a copy of the specified filter tensor 'src_shape' converted from
615 // 'src_filter_format' to 'dst_filter_format'.
ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,const TensorShape & src_shape,FilterTensorFormat src_filter_format)616 inline TensorShape ShapeFromFilterFormat(FilterTensorFormat dst_filter_format,
617 const TensorShape& src_shape,
618 FilterTensorFormat src_filter_format) {
619 if (src_filter_format == dst_filter_format) {
620 return src_shape;
621 }
622
623 const int64 output_channels = GetFilterDim(src_shape, src_filter_format, 'O');
624 const int64 input_channels =
625 GetFilterDim(src_shape, src_filter_format, 'I') *
626 (src_filter_format == FORMAT_OIHW_VECT_I ? 4 : 1);
627
628 if (GetFilterTensorSpatialDims(src_shape.dims(), src_filter_format) == 3) {
629 return ShapeFromFilterTensorFormat(
630 dst_filter_format,
631 {{GetFilterDim(src_shape, src_filter_format, '0'),
632 GetFilterDim(src_shape, src_filter_format, '1'),
633 GetFilterDim(src_shape, src_filter_format, '2')}},
634 input_channels, output_channels);
635 }
636
637 return ShapeFromFilterTensorFormat(
638 dst_filter_format,
639 {{GetFilterDim(src_shape, src_filter_format, 'H'),
640 GetFilterDim(src_shape, src_filter_format, 'W')}},
641 input_channels, output_channels);
642 }
643
644 } // namespace tensorflow
645
646 #endif // TENSORFLOW_CORE_UTIL_TENSOR_FORMAT_H_
647