• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Transpose Op. This is very different to the Eigen
17 // version in third_party/tensorflow because XLA's reshape neatly
18 // handles all transposes, while Eigen needs a restricted DoTranspose
19 // helper.
20 
21 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/kernel_def_builder.h"
29 #include "tensorflow/core/framework/register_types.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 class TransposeOp : public XlaOpKernel {
35  public:
TransposeOp(OpKernelConstruction * ctx,bool conjugate=false)36   explicit TransposeOp(OpKernelConstruction* ctx, bool conjugate = false)
37       : XlaOpKernel(ctx), conjugate_(conjugate) {}
38 
Compile(XlaOpKernelContext * ctx)39   void Compile(XlaOpKernelContext* ctx) override {
40     const TensorShape input_shape = ctx->InputShape("x");
41     const TensorShape perm_tensor_shape = ctx->InputShape("perm");
42 
43     // Preliminary validation of sizes.
44     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
45                 errors::InvalidArgument("perm must be a vector, not ",
46                                         perm_tensor_shape.DebugString()));
47 
48     const int dims = input_shape.dims();
49     OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(),
50                 errors::InvalidArgument("transpose expects a vector of size ",
51                                         input_shape.dims(),
52                                         ". But input(1) is a vector of size ",
53                                         perm_tensor_shape.num_elements()));
54 
55     std::vector<int64> perm;
56     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm));
57 
58     std::vector<int64> transposed_order;
59     // Check whether permutation is a permutation of integers of [0 .. dims).
60     absl::InlinedVector<bool, 8> bits(dims);
61     bool is_identity = true;
62     for (int i = 0; i < dims; ++i) {
63       const int64 d = perm[i];
64       OP_REQUIRES(
65           ctx, 0 <= d && d < dims,
66           errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
67       bits[d] = true;
68       transposed_order.push_back(d);
69       if (d != i) {
70         is_identity = false;
71       }
72     }
73     for (int i = 0; i < dims; ++i) {
74       OP_REQUIRES(
75           ctx, bits[i],
76           errors::InvalidArgument(i, " is missing from 'perm' argument."));
77     }
78 
79     xla::XlaOp transposed;
80     // 0-D, 1-D, and identity transposes do nothing.
81     if (dims <= 1 || is_identity) {
82       transposed = ctx->Input("x");
83     } else {
84       transposed = xla::Transpose(ctx->Input("x"), transposed_order);
85     }
86 
87     // Conjugate the transposed result if this is ConjugateTransposeOp.
88     if (conjugate_) {
89       ctx->SetOutput(0, xla::Conj(transposed));
90     } else {
91       ctx->SetOutput(0, transposed);
92     }
93   }
94 
95  private:
96   const bool conjugate_;
97 };
98 
99 class ConjugateTransposeOp : public TransposeOp {
100  public:
ConjugateTransposeOp(OpKernelConstruction * ctx)101   explicit ConjugateTransposeOp(OpKernelConstruction* ctx)
102       : TransposeOp(ctx, /*conjugate=*/true) {}
103 };
104 
105 REGISTER_XLA_OP(Name("Transpose").CompileTimeConstantInput("perm"),
106                 TransposeOp);
107 
108 REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstantInput("perm"),
109                 ConjugateTransposeOp);
110 
111 // InvertPermutation frequently forms part of the gradient of Transpose.
112 //
113 // inv = InvertPermutationOp(T<int32> p) takes a permutation of
114 // integers 0, 1, ..., n - 1 and returns the inverted
115 // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
116 //
117 // REQUIRES: input is a vector of int32.
118 // REQUIRES: input is a permutation of 0, 1, ..., n-1.
119 
120 class InvertPermutationOp : public XlaOpKernel {
121  public:
InvertPermutationOp(OpKernelConstruction * ctx)122   explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
123 
Compile(XlaOpKernelContext * ctx)124   void Compile(XlaOpKernelContext* ctx) override {
125     OP_REQUIRES(ctx,
126                 FastBoundsCheck(ctx->InputShape(0).num_elements(),
127                                 std::numeric_limits<int32>::max()),
128                 errors::InvalidArgument("permutation of nonnegative int32s "
129                                         "must have <= int32 max elements"));
130 
131     auto e = ctx->InputExpression(0);
132     auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client());
133     OP_REQUIRES_OK(ctx, tensor_or_status.status());
134     // If the input is a constant, we also want the output to be a constant.
135     // Some models rely on the result of InvertPermutation being a constant.
136     // TODO(b/32495713): Remove this when we can check whether Scatter is
137     // constant. Right now, we always assume it is non-constant because we don't
138     // check the embedded computation.
139     if (tensor_or_status.ValueOrDie().has_value()) {
140       std::vector<int64> perm;
141       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm));
142 
143       int size = perm.size();
144 
145       std::vector<int32> output(size);
146       std::fill_n(output.data(), size, -1);
147       for (int i = 0; i < size; ++i) {
148         const int64 d = perm[i];
149         OP_REQUIRES(ctx, FastBoundsCheck(d, size),
150                     errors::InvalidArgument(d, " is not between 0 and ", size));
151         OP_REQUIRES(ctx, output[d] == -1,
152                     errors::InvalidArgument(d, " is duplicated in the input."));
153         output[d] = i;
154       }
155 
156       ctx->SetOutput(0, xla::ConstantR1<int32>(ctx->builder(), output));
157     } else {
158       auto indices = ctx->Input(0);
159       int size = ctx->InputShape(0).num_elements();
160       auto iota = xla::Iota(ctx->builder(), xla::S32, size);
161       auto result = XlaScatter(iota, iota, indices,
162                                /*indices_are_vectors=*/false, /*combiner=*/{},
163                                ctx->builder());
164       OP_REQUIRES_OK(ctx, result.status());
165       ctx->SetOutput(0, result.ValueOrDie());
166     }
167   }
168 };
169 
170 REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32),
171                 InvertPermutationOp);
172 
173 }  // namespace
174 }  // namespace tensorflow
175