1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/shape_util.h" 17 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 20 #include "tensorflow/core/framework/tensor_shape.h" 21 22 namespace tensorflow { 23 namespace { 24 25 class ReverseSequenceOp : public XlaOpKernel { 26 public: ReverseSequenceOp(OpKernelConstruction * context)27 explicit ReverseSequenceOp(OpKernelConstruction* context) 28 : XlaOpKernel(context) { 29 OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); 30 OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); 31 } 32 Compile(XlaOpKernelContext * context)33 void Compile(XlaOpKernelContext* context) override { 34 const TensorShape input_shape = context->InputShape(0); 35 const TensorShape seq_lens_shape = context->InputShape(1); 36 37 OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape), 38 errors::InvalidArgument("seq_lens input must be 1-dim, not ", 39 seq_lens_shape.dims())); 40 OP_REQUIRES(context, batch_dim_ != seq_dim_, 41 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_)); 42 OP_REQUIRES( 43 context, seq_dim_ < input_shape.dims(), 44 errors::InvalidArgument("seq_dim must be < input.dims()", "( ", 45 seq_dim_, " vs. ", input_shape.dims(), ")")); 46 OP_REQUIRES( 47 context, batch_dim_ < input_shape.dims(), 48 errors::InvalidArgument("batch_dim must be < input.dims()", "( ", 49 batch_dim_, " vs. ", input_shape.dims(), ")")); 50 OP_REQUIRES( 51 context, 52 seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_), 53 errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_, 54 "), ", "(", seq_lens_shape.num_elements(), 55 " vs. ", input_shape.dim_size(batch_dim_))); 56 57 xla::ComputationBuilder* builder = context->builder(); 58 const auto input = context->Input(0); 59 const auto seq_lens = context->Input(1); 60 61 const int64 batch_size = input_shape.dim_size(batch_dim_); 62 63 const DataType input_type = context->input_type(0); 64 const DataType seq_lens_type = context->input_type(1); 65 const int64 max_seq_len = input_shape.dim_size(seq_dim_); 66 67 xla::Shape input_xla_shape; 68 OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape, 69 &input_xla_shape)); 70 xla::Shape seq_lens_xla_shape; 71 OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape, 72 &seq_lens_xla_shape)); 73 74 const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({ 75 xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}), 76 seq_lens_xla_shape, 77 input_xla_shape, 78 }); 79 80 // For each entry in the batch, reverse the sequence. 81 // TODO(b/65689298): generalize the Map() operator to non-scalar cases and 82 // use it here, instead of a While loop. 83 84 // Condition: lambda (i, _, _): i < batch_size 85 auto condition_builder = 86 builder->CreateSubBuilder("reverse_sequence_condition"); 87 { 88 auto param = condition_builder->Parameter(0, tuple_shape, "param"); 89 auto i = condition_builder->GetTupleElement(param, 0); 90 condition_builder->Lt( 91 i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, 92 batch_size)); 93 } 94 auto condition = condition_builder->Build(); 95 OP_REQUIRES_OK(context, condition.status()); 96 97 auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); 98 { 99 auto param = body_builder->Parameter(0, tuple_shape, "param"); 100 auto i = body_builder->GetTupleElement(param, 0); 101 auto seq_lens = body_builder->GetTupleElement(param, 1); 102 auto output = body_builder->GetTupleElement(param, 2); 103 104 // seq_len is the sequence length of the current batch element (rank 1) 105 auto seq_len = body_builder->DynamicSlice( 106 seq_lens, body_builder->Reshape(i, {1}), {1}); 107 108 // Indices is the offset of the batch element in the input. 109 auto indices = body_builder->Broadcast( 110 XlaHelpers::Zero(body_builder.get(), seq_lens_type), 111 {input_shape.dims()}); 112 indices = body_builder->DynamicUpdateSlice( 113 indices, body_builder->Reshape(i, {1}), 114 body_builder->Reshape( 115 XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, 116 batch_dim_), 117 {1})); 118 119 // slice_indices is the offset of the start of the reversed sequence in 120 // the input. 121 auto slice_indices = body_builder->DynamicUpdateSlice( 122 indices, 123 body_builder->Sub(XlaHelpers::IntegerLiteral( 124 body_builder.get(), seq_lens_type, max_seq_len), 125 seq_len), 126 body_builder->Reshape( 127 XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, 128 seq_dim_), 129 {1})); 130 131 // Slice out the reversed sequence. The slice will overflow the end of the 132 // sequence, and the contents of the overflow are implementation-defined. 133 // However, we will mask off these elements and replace them with elements 134 // from the original input so their values do not matter. 135 TensorShape slice_shape = input_shape; 136 slice_shape.set_dim(batch_dim_, 1); 137 auto slice = body_builder->DynamicSlice(output, slice_indices, 138 slice_shape.dim_sizes()); 139 140 // Shift the reversed sequence to the left. 141 output = body_builder->DynamicUpdateSlice(output, slice, indices); 142 143 body_builder->Tuple( 144 {body_builder->Add( 145 i, XlaHelpers::One(body_builder.get(), seq_lens_type)), 146 seq_lens, output}); 147 } 148 auto body = body_builder->Build(); 149 OP_REQUIRES_OK(context, body.status()); 150 151 auto loop_output = builder->While( 152 condition.ValueOrDie(), body.ValueOrDie(), 153 builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, 154 builder->Rev(input, {seq_dim_})})); 155 auto output = builder->GetTupleElement(loop_output, 2); 156 157 // Mask out elements after the sequence length. 158 xla::ComputationDataHandle iota; 159 OP_REQUIRES_OK( 160 context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); 161 std::vector<int64> dims(input_shape.dims(), 1); 162 dims[batch_dim_] = batch_size; 163 auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); 164 165 // Broadcast the mask up to the input shape. 166 mask = 167 builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false), 168 input_shape.dim_sizes())); 169 170 output = builder->Select(mask, output, input); 171 context->SetOutput(0, output); 172 } 173 174 private: 175 int32 batch_dim_; 176 int32 seq_dim_; 177 }; 178 179 REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); 180 181 } // namespace 182 } // namespace tensorflow 183