• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20 #include "tensorflow/core/framework/tensor_shape.h"
21 
22 namespace tensorflow {
23 namespace {
24 
25 class ReverseSequenceOp : public XlaOpKernel {
26  public:
ReverseSequenceOp(OpKernelConstruction * context)27   explicit ReverseSequenceOp(OpKernelConstruction* context)
28       : XlaOpKernel(context) {
29     OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
30     OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
31   }
32 
Compile(XlaOpKernelContext * context)33   void Compile(XlaOpKernelContext* context) override {
34     const TensorShape input_shape = context->InputShape(0);
35     const TensorShape seq_lens_shape = context->InputShape(1);
36 
37     OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape),
38                 errors::InvalidArgument("seq_lens input must be 1-dim, not ",
39                                         seq_lens_shape.dims()));
40     OP_REQUIRES(context, batch_dim_ != seq_dim_,
41                 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_));
42     OP_REQUIRES(
43         context, seq_dim_ < input_shape.dims(),
44         errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
45                                 seq_dim_, " vs. ", input_shape.dims(), ")"));
46     OP_REQUIRES(
47         context, batch_dim_ < input_shape.dims(),
48         errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
49                                 batch_dim_, " vs. ", input_shape.dims(), ")"));
50     OP_REQUIRES(
51         context,
52         seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_),
53         errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_,
54                                 "), ", "(", seq_lens_shape.num_elements(),
55                                 " vs. ", input_shape.dim_size(batch_dim_)));
56 
57     xla::ComputationBuilder* builder = context->builder();
58     const auto input = context->Input(0);
59     const auto seq_lens = context->Input(1);
60 
61     const int64 batch_size = input_shape.dim_size(batch_dim_);
62 
63     const DataType input_type = context->input_type(0);
64     const DataType seq_lens_type = context->input_type(1);
65     const int64 max_seq_len = input_shape.dim_size(seq_dim_);
66 
67     xla::Shape input_xla_shape;
68     OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape,
69                                                   &input_xla_shape));
70     xla::Shape seq_lens_xla_shape;
71     OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape,
72                                                   &seq_lens_xla_shape));
73 
74     const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({
75         xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}),
76         seq_lens_xla_shape,
77         input_xla_shape,
78     });
79 
80     // For each entry in the batch, reverse the sequence.
81     // TODO(b/65689298): generalize the Map() operator to non-scalar cases and
82     // use it here, instead of a While loop.
83 
84     // Condition: lambda (i, _, _): i < batch_size
85     auto condition_builder =
86         builder->CreateSubBuilder("reverse_sequence_condition");
87     {
88       auto param = condition_builder->Parameter(0, tuple_shape, "param");
89       auto i = condition_builder->GetTupleElement(param, 0);
90       condition_builder->Lt(
91           i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type,
92                                         batch_size));
93     }
94     auto condition = condition_builder->Build();
95     OP_REQUIRES_OK(context, condition.status());
96 
97     auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
98     {
99       auto param = body_builder->Parameter(0, tuple_shape, "param");
100       auto i = body_builder->GetTupleElement(param, 0);
101       auto seq_lens = body_builder->GetTupleElement(param, 1);
102       auto output = body_builder->GetTupleElement(param, 2);
103 
104       // seq_len is the sequence length of the current batch element (rank 1)
105       auto seq_len = body_builder->DynamicSlice(
106           seq_lens, body_builder->Reshape(i, {1}), {1});
107 
108       // Indices is the offset of the batch element in the input.
109       auto indices = body_builder->Broadcast(
110           XlaHelpers::Zero(body_builder.get(), seq_lens_type),
111           {input_shape.dims()});
112       indices = body_builder->DynamicUpdateSlice(
113           indices, body_builder->Reshape(i, {1}),
114           body_builder->Reshape(
115               XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
116                                          batch_dim_),
117               {1}));
118 
119       // slice_indices is the offset of the start of the reversed sequence in
120       // the input.
121       auto slice_indices = body_builder->DynamicUpdateSlice(
122           indices,
123           body_builder->Sub(XlaHelpers::IntegerLiteral(
124                                 body_builder.get(), seq_lens_type, max_seq_len),
125                             seq_len),
126           body_builder->Reshape(
127               XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
128                                          seq_dim_),
129               {1}));
130 
131       // Slice out the reversed sequence. The slice will overflow the end of the
132       // sequence, and the contents of the overflow are implementation-defined.
133       // However, we will mask off these elements and replace them with elements
134       // from the original input so their values do not matter.
135       TensorShape slice_shape = input_shape;
136       slice_shape.set_dim(batch_dim_, 1);
137       auto slice = body_builder->DynamicSlice(output, slice_indices,
138                                               slice_shape.dim_sizes());
139 
140       // Shift the reversed sequence to the left.
141       output = body_builder->DynamicUpdateSlice(output, slice, indices);
142 
143       body_builder->Tuple(
144           {body_builder->Add(
145                i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
146            seq_lens, output});
147     }
148     auto body = body_builder->Build();
149     OP_REQUIRES_OK(context, body.status());
150 
151     auto loop_output = builder->While(
152         condition.ValueOrDie(), body.ValueOrDie(),
153         builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
154                         builder->Rev(input, {seq_dim_})}));
155     auto output = builder->GetTupleElement(loop_output, 2);
156 
157     // Mask out elements after the sequence length.
158     xla::ComputationDataHandle iota;
159     OP_REQUIRES_OK(
160         context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
161     std::vector<int64> dims(input_shape.dims(), 1);
162     dims[batch_dim_] = batch_size;
163     auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_});
164 
165     // Broadcast the mask up to the input shape.
166     mask =
167         builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false),
168                                              input_shape.dim_sizes()));
169 
170     output = builder->Select(mask, output, input);
171     context->SetOutput(0, output);
172   }
173 
174  private:
175   int32 batch_dim_;
176   int32 seq_dim_;
177 };
178 
179 REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp);
180 
181 }  // namespace
182 }  // namespace tensorflow
183