1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Transpose Op. This is very different to the Eigen 17 // version in third_party/tensorflow because XLA's reshape neatly 18 // handles all transposes, while Eigen needs a restricted DoTranspose 19 // helper. 20 21 #include "tensorflow/compiler/tf2xla/lib/scatter.h" 22 #include "tensorflow/compiler/tf2xla/type_util.h" 23 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/compiler/xla/client/xla_builder.h" 27 #include "tensorflow/core/framework/bounds_check.h" 28 #include "tensorflow/core/framework/kernel_def_builder.h" 29 #include "tensorflow/core/framework/register_types.h" 30 31 namespace tensorflow { 32 namespace { 33 34 class TransposeOp : public XlaOpKernel { 35 public: TransposeOp(OpKernelConstruction * ctx,bool conjugate=false)36 explicit TransposeOp(OpKernelConstruction* ctx, bool conjugate = false) 37 : XlaOpKernel(ctx), conjugate_(conjugate) {} 38 Compile(XlaOpKernelContext * ctx)39 void Compile(XlaOpKernelContext* ctx) override { 40 const TensorShape input_shape = ctx->InputShape("x"); 41 const TensorShape perm_tensor_shape = ctx->InputShape("perm"); 42 43 // Preliminary validation of sizes. 44 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), 45 errors::InvalidArgument("perm must be a vector, not ", 46 perm_tensor_shape.DebugString())); 47 48 const int dims = input_shape.dims(); 49 OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(), 50 errors::InvalidArgument("transpose expects a vector of size ", 51 input_shape.dims(), 52 ". But input(1) is a vector of size ", 53 perm_tensor_shape.num_elements())); 54 55 std::vector<int64> perm; 56 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm)); 57 58 std::vector<int64> transposed_order; 59 // Check whether permutation is a permutation of integers of [0 .. dims). 60 absl::InlinedVector<bool, 8> bits(dims); 61 bool is_identity = true; 62 for (int i = 0; i < dims; ++i) { 63 const int64 d = perm[i]; 64 OP_REQUIRES( 65 ctx, 0 <= d && d < dims, 66 errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); 67 bits[d] = true; 68 transposed_order.push_back(d); 69 if (d != i) { 70 is_identity = false; 71 } 72 } 73 for (int i = 0; i < dims; ++i) { 74 OP_REQUIRES( 75 ctx, bits[i], 76 errors::InvalidArgument(i, " is missing from 'perm' argument.")); 77 } 78 79 xla::XlaOp transposed; 80 // 0-D, 1-D, and identity transposes do nothing. 81 if (dims <= 1 || is_identity) { 82 transposed = ctx->Input("x"); 83 } else { 84 transposed = xla::Transpose(ctx->Input("x"), transposed_order); 85 } 86 87 // Conjugate the transposed result if this is ConjugateTransposeOp. 88 if (conjugate_) { 89 ctx->SetOutput(0, xla::Conj(transposed)); 90 } else { 91 ctx->SetOutput(0, transposed); 92 } 93 } 94 95 private: 96 const bool conjugate_; 97 }; 98 99 class ConjugateTransposeOp : public TransposeOp { 100 public: ConjugateTransposeOp(OpKernelConstruction * ctx)101 explicit ConjugateTransposeOp(OpKernelConstruction* ctx) 102 : TransposeOp(ctx, /*conjugate=*/true) {} 103 }; 104 105 REGISTER_XLA_OP(Name("Transpose").CompileTimeConstantInput("perm"), 106 TransposeOp); 107 108 REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstantInput("perm"), 109 ConjugateTransposeOp); 110 111 // InvertPermutation frequently forms part of the gradient of Transpose. 112 // 113 // inv = InvertPermutationOp(T<int32> p) takes a permutation of 114 // integers 0, 1, ..., n - 1 and returns the inverted 115 // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). 116 // 117 // REQUIRES: input is a vector of int32. 118 // REQUIRES: input is a permutation of 0, 1, ..., n-1. 119 120 class InvertPermutationOp : public XlaOpKernel { 121 public: InvertPermutationOp(OpKernelConstruction * ctx)122 explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 123 Compile(XlaOpKernelContext * ctx)124 void Compile(XlaOpKernelContext* ctx) override { 125 OP_REQUIRES(ctx, 126 FastBoundsCheck(ctx->InputShape(0).num_elements(), 127 std::numeric_limits<int32>::max()), 128 errors::InvalidArgument("permutation of nonnegative int32s " 129 "must have <= int32 max elements")); 130 131 auto e = ctx->InputExpression(0); 132 auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client()); 133 OP_REQUIRES_OK(ctx, tensor_or_status.status()); 134 // If the input is a constant, we also want the output to be a constant. 135 // Some models rely on the result of InvertPermutation being a constant. 136 // TODO(b/32495713): Remove this when we can check whether Scatter is 137 // constant. Right now, we always assume it is non-constant because we don't 138 // check the embedded computation. 139 if (tensor_or_status.ValueOrDie().has_value()) { 140 std::vector<int64> perm; 141 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); 142 143 int size = perm.size(); 144 145 std::vector<int32> output(size); 146 std::fill_n(output.data(), size, -1); 147 for (int i = 0; i < size; ++i) { 148 const int64 d = perm[i]; 149 OP_REQUIRES(ctx, FastBoundsCheck(d, size), 150 errors::InvalidArgument(d, " is not between 0 and ", size)); 151 OP_REQUIRES(ctx, output[d] == -1, 152 errors::InvalidArgument(d, " is duplicated in the input.")); 153 output[d] = i; 154 } 155 156 ctx->SetOutput(0, xla::ConstantR1<int32>(ctx->builder(), output)); 157 } else { 158 auto indices = ctx->Input(0); 159 int size = ctx->InputShape(0).num_elements(); 160 auto iota = xla::Iota(ctx->builder(), xla::S32, size); 161 auto result = XlaScatter(iota, iota, indices, 162 /*indices_are_vectors=*/false, /*combiner=*/{}, 163 ctx->builder()); 164 OP_REQUIRES_OK(ctx, result.status()); 165 ctx->SetOutput(0, result.ValueOrDie()); 166 } 167 } 168 }; 169 170 REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), 171 InvertPermutationOp); 172 173 } // namespace 174 } // namespace tensorflow 175