1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/shape_util.h" 17 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 20 #include "tensorflow/compiler/xla/client/lib/constants.h" 21 #include "tensorflow/compiler/xla/client/xla_builder.h" 22 #include "tensorflow/compiler/xla/xla_data.pb.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 25 namespace tensorflow { 26 namespace { 27 28 class ReverseSequenceOp : public XlaOpKernel { 29 public: ReverseSequenceOp(OpKernelConstruction * context)30 explicit ReverseSequenceOp(OpKernelConstruction* context) 31 : XlaOpKernel(context) { 32 OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); 33 OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); 34 } 35 Compile(XlaOpKernelContext * context)36 void Compile(XlaOpKernelContext* context) override { 37 const TensorShape input_shape = context->InputShape(0); 38 const TensorShape seq_lens_shape = context->InputShape(1); 39 40 OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape), 41 errors::InvalidArgument("seq_lens input must be 1-dim, not ", 42 seq_lens_shape.dims())); 43 OP_REQUIRES(context, batch_dim_ != seq_dim_, 44 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_)); 45 OP_REQUIRES( 46 context, seq_dim_ < input_shape.dims(), 47 errors::InvalidArgument("seq_dim must be < input.dims()", "( ", 48 seq_dim_, " vs. ", input_shape.dims(), ")")); 49 OP_REQUIRES( 50 context, batch_dim_ < input_shape.dims(), 51 errors::InvalidArgument("batch_dim must be < input.dims()", "( ", 52 batch_dim_, " vs. ", input_shape.dims(), ")")); 53 OP_REQUIRES( 54 context, 55 seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_), 56 errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_, 57 "), ", "(", seq_lens_shape.num_elements(), 58 " vs. ", input_shape.dim_size(batch_dim_))); 59 60 xla::XlaBuilder* builder = context->builder(); 61 const auto input = context->Input(0); 62 const auto seq_lens = context->Input(1); 63 64 const int64 batch_size = input_shape.dim_size(batch_dim_); 65 if (batch_size == 0) { 66 context->SetOutput(0, input); 67 return; 68 } 69 70 const xla::PrimitiveType seq_lens_type = context->input_xla_type(1); 71 const int64 max_seq_len = input_shape.dim_size(seq_dim_); 72 73 // Create [batch, sequence, 2] tensor that contains the indices where the 74 // real data belongs 75 xla::XlaOp back = xla::Sub(seq_lens, xla::ScalarLike(seq_lens, 1)); 76 xla::XlaOp batch_idx = xla::Iota( 77 builder, 78 xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), 79 /*iota_dimension=*/0); 80 xla::XlaOp forward_idx = xla::Iota( 81 builder, 82 xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}), 83 /*iota_dimension=*/1); 84 xla::XlaOp reverse_idx = xla::Sub(back, forward_idx, {0}); 85 reverse_idx = xla::Select(xla::Lt(reverse_idx, xla::ZerosLike(reverse_idx)), 86 forward_idx, reverse_idx); 87 if (batch_dim_ > seq_dim_) { 88 // The output of the XLA gather op keeps indices dimensions in the same 89 // order as they appear in the input. If the batch_dim_ needs to be after 90 // the seq_dim_ in the output, it also needs to be that way in the input 91 // so we transpose. 92 batch_idx = xla::Transpose(batch_idx, {1, 0, 2}); 93 forward_idx = xla::Transpose(forward_idx, {1, 0, 2}); 94 reverse_idx = xla::Transpose(reverse_idx, {1, 0, 2}); 95 } 96 xla::XlaOp start_indices = 97 xla::ConcatInDim(builder, {batch_idx, reverse_idx}, 98 /*dimension=*/2); 99 100 xla::GatherDimensionNumbers dnums; 101 dnums.set_index_vector_dim(2); 102 // The first and second element in the third dimension of reverse_idx are 103 // the batch_dim_ offset and the seq_dim_ offset respectively. 104 dnums.add_start_index_map(batch_dim_); 105 dnums.add_start_index_map(seq_dim_); 106 107 // batch_dim_ and seq_dim_ are collapsed and the other dimensions are kept 108 // in the gather. 109 for (int i = 0; i < input_shape.dims(); ++i) { 110 if (i != batch_dim_ && i != seq_dim_) { 111 dnums.add_offset_dims(i); 112 } else { 113 dnums.add_collapsed_slice_dims(i); 114 } 115 } 116 117 auto slice_sizes = input_shape.dim_sizes(); 118 slice_sizes[batch_dim_] = 1; 119 slice_sizes[seq_dim_] = 1; 120 121 context->SetOutput(0, 122 xla::Gather(input, start_indices, dnums, slice_sizes)); 123 } 124 125 private: 126 int32 batch_dim_; 127 int32 seq_dim_; 128 }; 129 130 REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); 131 132 } // namespace 133 } // namespace tensorflow 134