• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA Stack operators.
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/partial_tensor_shape.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace tensorflow {
38 namespace {
39 
GetStackShape(xla::XlaBuilder * builder,XlaResource * resource,TensorShape * stack_shape)40 Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource,
41                      TensorShape* stack_shape) {
42   auto shape_or_status = builder->GetShape(resource->value());
43   if (!shape_or_status.ok()) {
44     return shape_or_status.status();
45   }
46   xla::Shape shape = shape_or_status.ValueOrDie();
47   TF_RET_CHECK(shape.IsTuple());
48   return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
49                                stack_shape);
50 }
51 
52 // Since the element shape is not provided to the Stack operator,
53 // we lazily initialize the Stack at the time of the first write.
54 //
55 // If a Stack `resource` has not been initialized, constructs storage for the
56 // Stack with elements of `elem_shape`. For both initialized and
57 // uninitialized Stacks, checks that the tensor has a type compatible with
58 // 'dtype' and shape compatible with 'elem_shape'.
59 //
60 // TODO(phawkins): consider changing the API of the stack operators to
61 // allow an optional element shape at stack construction time.
MaybeInitializeStack(xla::XlaBuilder * builder,XlaResource * resource,DataType dtype,const TensorShape & elem_shape)62 Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource,
63                             DataType dtype, const TensorShape& elem_shape) {
64   if (resource->type() != dtype) {
65     return errors::InvalidArgument(
66         "Stack dtype is ", DataTypeString(resource->type()),
67         " but op has dtype ", DataTypeString(dtype), ".");
68   }
69 
70   TensorShape stack_shape;
71   stack_shape.AddDim(resource->max_array_size());
72   stack_shape.AppendShape(elem_shape);
73 
74   if (!resource->initialized()) {
75     // Stack has not been initialized.
76     TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
77     TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
78   } else {
79     // Checks the expected shape matches the actual shape.
80     TensorShape actual_shape;
81     TF_RETURN_IF_ERROR(GetStackShape(builder, resource, &actual_shape));
82     if (stack_shape != actual_shape) {
83       return errors::InvalidArgument(
84           "Mismatched Stack shapes: ", stack_shape.DebugString(), " vs ",
85           actual_shape.DebugString());
86     }
87   }
88   return Status::OK();
89 }
90 
91 class StackOp : public XlaOpKernel {
92  public:
StackOp(OpKernelConstruction * ctx)93   explicit StackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
94     OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
95     OP_REQUIRES_OK(ctx, ctx->GetAttr("stack_name", &stack_name_));
96   }
97 
Compile(XlaOpKernelContext * ctx)98   void Compile(XlaOpKernelContext* ctx) override {
99     int64 max_size;
100     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &max_size));
101     OP_REQUIRES(
102         ctx, max_size >= 0,
103         errors::InvalidArgument(
104             "XLA compilation requires a fixed stack size upper bound. If "
105             "you are using tf.while_loop, set the maximum_iterations parameter "
106             "to fix this issue."));
107 
108     // We defer initializing the Stack resource until we see the first push.
109     // Otherwise we do not know the shape of the stack elements.
110     XlaResource* resource =
111         ctx->xla_context()->AddResource(XlaResource::CreateStack(
112             /*name=*/absl::StrCat("Stack: ", stack_name_), dtype_, max_size));
113     ctx->SetResourceOutput(0, resource);
114   }
115 
116  private:
117   DataType dtype_;
118   string stack_name_;
119 
120   TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
121 };
122 
123 REGISTER_XLA_OP(
124     Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(),
125     StackOp);
126 
127 class StackPushOp : public XlaOpKernel {
128  public:
StackPushOp(OpKernelConstruction * ctx)129   explicit StackPushOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
130     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
131   }
132 
Compile(XlaOpKernelContext * ctx)133   void Compile(XlaOpKernelContext* ctx) override {
134     xla::XlaBuilder* b = ctx->builder();
135     TensorShape elem_shape = ctx->InputShape(1);
136 
137     XlaResource* resource;
138     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
139 
140     // Initializes the Stack, if the element shape was not already known.
141     OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
142 
143     xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0);
144     xla::XlaOp index = xla::GetTupleElement(resource->value(), 1);
145     xla::XlaOp value = ctx->Input(1);
146 
147     // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
148     std::vector<xla::XlaOp> start_indices(elem_shape.dims() + 1,
149                                           xla::ConstantR0<int32>(b, 0));
150     start_indices[0] = index;
151 
152     TensorShape slice_shape = elem_shape;
153     slice_shape.InsertDim(0, 1LL);
154     auto update = xla::Reshape(value, slice_shape.dim_sizes());
155 
156     // TODO(phawkins): We don't check the index is in bounds --- there is no
157     // error mechanism in XLA.
158     OP_REQUIRES_OK(ctx,
159                    resource->SetValue(xla::Tuple(
160                        b, {xla::DynamicUpdateSlice(ta, update, start_indices),
161                            xla::Add(index, xla::ConstantR0<int32>(b, 1))})));
162 
163     ctx->SetOutput(0, value);
164   }
165 
166  private:
167   DataType dtype_;
168 
169   TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
170 };
171 
172 REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp);
173 
174 class StackPopOp : public XlaOpKernel {
175  public:
StackPopOp(OpKernelConstruction * ctx)176   explicit StackPopOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
177     OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
178   }
179 
Compile(XlaOpKernelContext * ctx)180   void Compile(XlaOpKernelContext* ctx) override {
181     xla::XlaBuilder* b = ctx->builder();
182 
183     XlaResource* resource;
184     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
185 
186     // There is a somewhat subtle issue here: here "uninitialized" means we have
187     // not yet seen a pop in the order that we compile operators, not the order
188     // that we run them. However, in practice the two orders should be the same
189     // for the sole user of the stack operators (loop gradients).
190     OP_REQUIRES(ctx, resource->initialized(),
191                 errors::InvalidArgument("Stack pop on uninitialized stack"));
192 
193     TensorShape stack_shape;
194     OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
195 
196     xla::XlaOp state = resource->value();
197     xla::XlaOp ta = xla::GetTupleElement(state, 0);
198     xla::XlaOp index = xla::GetTupleElement(state, 1);
199 
200     index = Sub(index, xla::ConstantR0<int32>(b, 1));
201     OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index})));
202 
203     // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
204     std::vector<xla::XlaOp> start_indices(stack_shape.dims(),
205                                           xla::ConstantR0<int32>(b, 0));
206     start_indices[0] = index;
207 
208     auto slice_shape = stack_shape.dim_sizes();
209     slice_shape[0] = 1LL;
210 
211     // TODO(phawkins): We don't check the index is in bounds --- there is no
212     // error mechanism in XLA.
213     xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
214 
215     // Remove the leading '1' dimension.
216     std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
217     ctx->SetOutput(0, xla::Reshape(read, value_shape));
218   }
219 
220  private:
221   DataType dtype_;
222 
223   TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
224 };
225 
226 REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp);
227 
228 class StackCloseOp : public XlaOpKernel {
229  public:
StackCloseOp(OpKernelConstruction * ctx)230   explicit StackCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
231 
Compile(XlaOpKernelContext * ctx)232   void Compile(XlaOpKernelContext* ctx) override {
233     // Do nothing.
234   }
235 
236  private:
237   TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
238 };
239 
240 REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp);
241 
242 }  // anonymous namespace
243 }  // namespace tensorflow
244