• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/shape_util.h"
17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20 #include "tensorflow/compiler/xla/client/lib/constants.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 
25 namespace tensorflow {
26 namespace {
27 
28 class ReverseSequenceOp : public XlaOpKernel {
29  public:
ReverseSequenceOp(OpKernelConstruction * context)30   explicit ReverseSequenceOp(OpKernelConstruction* context)
31       : XlaOpKernel(context) {
32     OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
33     OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
34   }
35 
Compile(XlaOpKernelContext * context)36   void Compile(XlaOpKernelContext* context) override {
37     const TensorShape input_shape = context->InputShape(0);
38     const TensorShape seq_lens_shape = context->InputShape(1);
39 
40     OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape),
41                 errors::InvalidArgument("seq_lens input must be 1-dim, not ",
42                                         seq_lens_shape.dims()));
43     OP_REQUIRES(context, batch_dim_ != seq_dim_,
44                 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_));
45     OP_REQUIRES(
46         context, seq_dim_ < input_shape.dims(),
47         errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
48                                 seq_dim_, " vs. ", input_shape.dims(), ")"));
49     OP_REQUIRES(
50         context, batch_dim_ < input_shape.dims(),
51         errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
52                                 batch_dim_, " vs. ", input_shape.dims(), ")"));
53     OP_REQUIRES(
54         context,
55         seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_),
56         errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_,
57                                 "), ", "(", seq_lens_shape.num_elements(),
58                                 " vs. ", input_shape.dim_size(batch_dim_)));
59 
60     xla::XlaBuilder* builder = context->builder();
61     const auto input = context->Input(0);
62     const auto seq_lens = context->Input(1);
63 
64     const int64 batch_size = input_shape.dim_size(batch_dim_);
65     if (batch_size == 0) {
66       context->SetOutput(0, input);
67       return;
68     }
69 
70     const xla::PrimitiveType seq_lens_type = context->input_xla_type(1);
71     const int64 max_seq_len = input_shape.dim_size(seq_dim_);
72 
73     // Create [batch, sequence, 2] tensor that contains the indices where the
74     // real data belongs
75     xla::XlaOp back = xla::Sub(seq_lens, xla::ScalarLike(seq_lens, 1));
76     xla::XlaOp batch_idx = xla::Iota(
77         builder,
78         xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}),
79         /*iota_dimension=*/0);
80     xla::XlaOp forward_idx = xla::Iota(
81         builder,
82         xla::ShapeUtil::MakeShape(seq_lens_type, {batch_size, max_seq_len, 1}),
83         /*iota_dimension=*/1);
84     xla::XlaOp reverse_idx = xla::Sub(back, forward_idx, {0});
85     reverse_idx = xla::Select(xla::Lt(reverse_idx, xla::ZerosLike(reverse_idx)),
86                               forward_idx, reverse_idx);
87     if (batch_dim_ > seq_dim_) {
88       // The output of the XLA gather op keeps indices dimensions in the same
89       // order as they appear in the input. If the batch_dim_ needs to be after
90       // the seq_dim_ in the output, it also needs to be that way in the input
91       // so we transpose.
92       batch_idx = xla::Transpose(batch_idx, {1, 0, 2});
93       forward_idx = xla::Transpose(forward_idx, {1, 0, 2});
94       reverse_idx = xla::Transpose(reverse_idx, {1, 0, 2});
95     }
96     xla::XlaOp start_indices =
97         xla::ConcatInDim(builder, {batch_idx, reverse_idx},
98                          /*dimension=*/2);
99 
100     xla::GatherDimensionNumbers dnums;
101     dnums.set_index_vector_dim(2);
102     // The first and second element in the third dimension of reverse_idx are
103     // the batch_dim_ offset and the seq_dim_ offset respectively.
104     dnums.add_start_index_map(batch_dim_);
105     dnums.add_start_index_map(seq_dim_);
106 
107     // batch_dim_ and seq_dim_ are collapsed and the other dimensions are kept
108     // in the gather.
109     for (int i = 0; i < input_shape.dims(); ++i) {
110       if (i != batch_dim_ && i != seq_dim_) {
111         dnums.add_offset_dims(i);
112       } else {
113         dnums.add_collapsed_slice_dims(i);
114       }
115     }
116 
117     auto slice_sizes = input_shape.dim_sizes();
118     slice_sizes[batch_dim_] = 1;
119     slice_sizes[seq_dim_] = 1;
120 
121     context->SetOutput(0,
122                        xla::Gather(input, start_indices, dnums, slice_sizes));
123   }
124 
125  private:
126   int32 batch_dim_;
127   int32 seq_dim_;
128 };
129 
130 REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp);
131 
132 }  // namespace
133 }  // namespace tensorflow
134