1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/lib/data_format.h"
17 #include "tensorflow/core/lib/core/errors.h"
18
19 namespace tensorflow {
20 namespace {
21
Contract(xla::XlaOp input,int64 dim)22 xla::StatusOr<xla::XlaOp> Contract(xla::XlaOp input, int64 dim) {
23 xla::XlaBuilder* builder = input.builder();
24 TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
25
26 if (input_shape.dimensions().back() != 4) {
27 return errors::InvalidArgument("Expected last dimension to be 4; got ",
28 input_shape.dimensions().back());
29 }
30
31 // Transpose the input so C is directly followed by VECT_C.
32 std::vector<int64> permutation;
33 for (int64 i = 0; i != input_shape.rank() - 1; ++i) {
34 permutation.push_back(i);
35 if (i == dim) {
36 permutation.push_back(input_shape.rank() - 1);
37 }
38 }
39
40 // Now merge the adjacent dimensions with a reshape.
41 std::vector<int64> contracted_shape(input_shape.dimensions().begin(),
42 input_shape.dimensions().end() - 1);
43 contracted_shape[dim] *= 4;
44
45 return xla::Reshape(xla::Transpose(input, permutation), contracted_shape);
46 }
47
Expand(xla::XlaOp input,int64 dim)48 xla::StatusOr<xla::XlaOp> Expand(xla::XlaOp input, int64 dim) {
49 xla::XlaBuilder* builder = input.builder();
50 TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
51
52 if (input_shape.dimensions(dim) % 4 != 0) {
53 return errors::InvalidArgument(
54 "Expected vectorized dimension to be evenly divisible by 4; got ",
55 input_shape.dimensions(dim));
56 }
57
58 // Split the `dim` into two dimensions with a reshape. The size of the new
59 // dimension is always 4.
60 std::vector<int64> expanded_shape(input_shape.dimensions());
61 expanded_shape[dim] /= 4;
62 expanded_shape.insert(expanded_shape.begin() + dim, 4);
63
64 // Move the newly created dimension to the end with a transpose.
65 std::vector<int64> permutation;
66 for (int64 i = 0; i != expanded_shape.size(); ++i) {
67 permutation.push_back(i);
68 if (i == dim) {
69 ++i;
70 }
71 }
72 permutation.push_back(dim + 1);
73
74 return xla::Transpose(xla::Reshape(input, expanded_shape), permutation);
75 }
76
77 } // namespace
78
NCHW_VECT_CToNCHW(xla::XlaOp input)79 xla::StatusOr<xla::XlaOp> NCHW_VECT_CToNCHW(xla::XlaOp input) {
80 return Contract(input, 1);
81 }
82
NCHWToNCHW_VECT_C(xla::XlaOp input)83 xla::StatusOr<xla::XlaOp> NCHWToNCHW_VECT_C(xla::XlaOp input) {
84 return Expand(input, 1);
85 }
86
87 } // namespace tensorflow
88