1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA Stack operators.
17
18 #include <limits>
19 #include <vector>
20
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/core/framework/bounds_check.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/partial_tensor_shape.h"
30 #include "tensorflow/core/framework/register_types.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace tensorflow {
38 namespace {
39
GetStackShape(xla::XlaBuilder * builder,XlaResource * resource,TensorShape * stack_shape)40 Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource,
41 TensorShape* stack_shape) {
42 auto shape_or_status = builder->GetShape(resource->value());
43 if (!shape_or_status.ok()) {
44 return shape_or_status.status();
45 }
46 xla::Shape shape = shape_or_status.ValueOrDie();
47 TF_RET_CHECK(shape.IsTuple());
48 return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
49 stack_shape);
50 }
51
52 // Since the element shape is not provided to the Stack operator,
53 // we lazily initialize the Stack at the time of the first write.
54 //
55 // If a Stack `resource` has not been initialized, constructs storage for the
56 // Stack with elements of `elem_shape`. For both initialized and
57 // uninitialized Stacks, checks that the tensor has a type compatible with
58 // 'dtype' and shape compatible with 'elem_shape'.
59 //
60 // TODO(phawkins): consider changing the API of the stack operators to
61 // allow an optional element shape at stack construction time.
MaybeInitializeStack(xla::XlaBuilder * builder,XlaResource * resource,DataType dtype,const TensorShape & elem_shape)62 Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource,
63 DataType dtype, const TensorShape& elem_shape) {
64 if (resource->type() != dtype) {
65 return errors::InvalidArgument(
66 "Stack dtype is ", DataTypeString(resource->type()),
67 " but op has dtype ", DataTypeString(dtype), ".");
68 }
69
70 TensorShape stack_shape;
71 stack_shape.AddDim(resource->max_array_size());
72 stack_shape.AppendShape(elem_shape);
73
74 if (!resource->initialized()) {
75 // Stack has not been initialized.
76 TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
77 TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
78 } else {
79 // Checks the expected shape matches the actual shape.
80 TensorShape actual_shape;
81 TF_RETURN_IF_ERROR(GetStackShape(builder, resource, &actual_shape));
82 if (stack_shape != actual_shape) {
83 return errors::InvalidArgument(
84 "Mismatched Stack shapes: ", stack_shape.DebugString(), " vs ",
85 actual_shape.DebugString());
86 }
87 }
88 return Status::OK();
89 }
90
91 class StackOp : public XlaOpKernel {
92 public:
StackOp(OpKernelConstruction * ctx)93 explicit StackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
94 OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
95 OP_REQUIRES_OK(ctx, ctx->GetAttr("stack_name", &stack_name_));
96 }
97
Compile(XlaOpKernelContext * ctx)98 void Compile(XlaOpKernelContext* ctx) override {
99 int64 max_size;
100 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &max_size));
101 OP_REQUIRES(
102 ctx, max_size >= 0,
103 errors::InvalidArgument(
104 "XLA compilation requires a fixed stack size upper bound. If "
105 "you are using tf.while_loop, set the maximum_iterations parameter "
106 "to fix this issue."));
107
108 // We defer initializing the Stack resource until we see the first push.
109 // Otherwise we do not know the shape of the stack elements.
110 XlaResource* resource =
111 ctx->xla_context()->AddResource(XlaResource::CreateStack(
112 /*name=*/absl::StrCat("Stack: ", stack_name_), dtype_, max_size));
113 ctx->SetResourceOutput(0, resource);
114 }
115
116 private:
117 DataType dtype_;
118 string stack_name_;
119
120 TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
121 };
122
123 REGISTER_XLA_OP(
124 Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(),
125 StackOp);
126
127 class StackPushOp : public XlaOpKernel {
128 public:
StackPushOp(OpKernelConstruction * ctx)129 explicit StackPushOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
130 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
131 }
132
Compile(XlaOpKernelContext * ctx)133 void Compile(XlaOpKernelContext* ctx) override {
134 xla::XlaBuilder* b = ctx->builder();
135 TensorShape elem_shape = ctx->InputShape(1);
136
137 XlaResource* resource;
138 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
139
140 // Initializes the Stack, if the element shape was not already known.
141 OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
142
143 xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0);
144 xla::XlaOp index = xla::GetTupleElement(resource->value(), 1);
145 xla::XlaOp value = ctx->Input(1);
146
147 // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
148 std::vector<xla::XlaOp> start_indices(elem_shape.dims() + 1,
149 xla::ConstantR0<int32>(b, 0));
150 start_indices[0] = index;
151
152 TensorShape slice_shape = elem_shape;
153 slice_shape.InsertDim(0, 1LL);
154 auto update = xla::Reshape(value, slice_shape.dim_sizes());
155
156 // TODO(phawkins): We don't check the index is in bounds --- there is no
157 // error mechanism in XLA.
158 OP_REQUIRES_OK(ctx,
159 resource->SetValue(xla::Tuple(
160 b, {xla::DynamicUpdateSlice(ta, update, start_indices),
161 xla::Add(index, xla::ConstantR0<int32>(b, 1))})));
162
163 ctx->SetOutput(0, value);
164 }
165
166 private:
167 DataType dtype_;
168
169 TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
170 };
171
172 REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp);
173
174 class StackPopOp : public XlaOpKernel {
175 public:
StackPopOp(OpKernelConstruction * ctx)176 explicit StackPopOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
177 OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
178 }
179
Compile(XlaOpKernelContext * ctx)180 void Compile(XlaOpKernelContext* ctx) override {
181 xla::XlaBuilder* b = ctx->builder();
182
183 XlaResource* resource;
184 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
185
186 // There is a somewhat subtle issue here: here "uninitialized" means we have
187 // not yet seen a pop in the order that we compile operators, not the order
188 // that we run them. However, in practice the two orders should be the same
189 // for the sole user of the stack operators (loop gradients).
190 OP_REQUIRES(ctx, resource->initialized(),
191 errors::InvalidArgument("Stack pop on uninitialized stack"));
192
193 TensorShape stack_shape;
194 OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
195
196 xla::XlaOp state = resource->value();
197 xla::XlaOp ta = xla::GetTupleElement(state, 0);
198 xla::XlaOp index = xla::GetTupleElement(state, 1);
199
200 index = Sub(index, xla::ConstantR0<int32>(b, 1));
201 OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index})));
202
203 // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
204 std::vector<xla::XlaOp> start_indices(stack_shape.dims(),
205 xla::ConstantR0<int32>(b, 0));
206 start_indices[0] = index;
207
208 auto slice_shape = stack_shape.dim_sizes();
209 slice_shape[0] = 1LL;
210
211 // TODO(phawkins): We don't check the index is in bounds --- there is no
212 // error mechanism in XLA.
213 xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
214
215 // Remove the leading '1' dimension.
216 std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
217 ctx->SetOutput(0, xla::Reshape(read, value_shape));
218 }
219
220 private:
221 DataType dtype_;
222
223 TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
224 };
225
226 REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp);
227
228 class StackCloseOp : public XlaOpKernel {
229 public:
StackCloseOp(OpKernelConstruction * ctx)230 explicit StackCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
231
Compile(XlaOpKernelContext * ctx)232 void Compile(XlaOpKernelContext* ctx) override {
233 // Do nothing.
234 }
235
236 private:
237 TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
238 };
239
240 REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp);
241
242 } // anonymous namespace
243 } // namespace tensorflow
244