1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/while_op.h"
17
18 #include "absl/strings/str_split.h"
19 #include "tensorflow/compiler/tf2xla/const_analysis.h"
20 #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
21 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36
37 namespace tensorflow {
38
39 namespace {
40
41 // Verify that input resources are grouped in the end.
VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext * ctx,const NameAttrList & body_name_attr)42 Status VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext* ctx,
43 const NameAttrList& body_name_attr) {
44 const FunctionBody* body;
45 TF_RETURN_IF_ERROR(ctx->compiler()->FindFunctionBody(body_name_attr, &body));
46 bool has_seen_resource = false;
47 for (int i = 0; i < body->arg_types.size(); i++) {
48 DataType arg_type = body->arg_types[i];
49 if (has_seen_resource) {
50 if (arg_type != DT_RESOURCE) {
51 return errors::InvalidArgument(
52 "Expect input resources are grouped in the end of while body ",
53 body_name_attr.name(), ", but the ", i, "-th argument ",
54 body->arg_nodes[i]->name(), " is not a resource.");
55 }
56 } else {
57 if (arg_type == DT_RESOURCE) {
58 has_seen_resource = true;
59 }
60 }
61 }
62 return Status::OK();
63 }
64
65 // Builds XlaCompiler argument descriptions `args` from `ctx`.
MakeXlaCompilerArgumentsFromInputs(XlaOpKernelContext * ctx,std::vector<XlaCompiler::Argument> * args,bool * has_uninitialized_vars,bool * has_tensor_arrays,bool * has_uninitialized_tensor_lists)66 Status MakeXlaCompilerArgumentsFromInputs(
67 XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args,
68 bool* has_uninitialized_vars, bool* has_tensor_arrays,
69 bool* has_uninitialized_tensor_lists) {
70 VLOG(2) << "Num inputs " << ctx->num_inputs();
71 args->resize(ctx->num_inputs());
72 *has_uninitialized_vars = false;
73 *has_tensor_arrays = false;
74 *has_uninitialized_tensor_lists = false;
75 for (int i = 0; i < ctx->num_inputs(); ++i) {
76 VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i))
77 << " shape: " << ctx->InputShape(i).DebugString();
78 XlaCompiler::Argument& arg = (*args)[i];
79 DataType type = ctx->input_type(i);
80 // When reading a resource input, use the type and shape of the resource's
81 // current value.
82 if (type == DT_RESOURCE) {
83 XlaResource* resource;
84 TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
85 XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
86 if (arg.resource_kind == XlaResource::kTensorArray) {
87 *has_tensor_arrays = true;
88 }
89 if (!arg.initialized) {
90 *has_uninitialized_vars = true;
91 }
92 VLOG(2) << " resource " << resource->name()
93 << " type: " << DataTypeString(arg.type)
94 << " shape: " << arg.ShapeHumanString()
95 << " initialized: " << arg.initialized;
96 } else {
97 arg.kind = XlaCompiler::Argument::kParameter;
98 arg.type = type;
99 TF_ASSIGN_OR_RETURN(arg.shape, ctx->builder()->GetShape(ctx->Input(i)));
100 if (IsTensorListInput(ctx, i)) {
101 // arg.initialized == false means that the element_shape of the list
102 // was not available at the time of building the list so an empty list
103 // was created instead. If so, the body function of While is run once
104 // to infer the shape of the list before actually building the While op.
105 TF_RETURN_IF_ERROR(
106 IsTensorListInitialized(ctx->Input(i), &arg.initialized));
107 if (!arg.initialized) {
108 *has_uninitialized_tensor_lists = true;
109 }
110 }
111 }
112 }
113 return Status::OK();
114 }
115
116 // Populates loop invariant indices to true in `loop_invariants`.
GetLoopInvariants(XlaOpKernelContext * ctx,const NameAttrList & body_name_attr,std::vector<bool> * const loop_invariants)117 void GetLoopInvariants(XlaOpKernelContext* ctx,
118 const NameAttrList& body_name_attr,
119 std::vector<bool>* const loop_invariants) {
120 const FunctionBody* body;
121 OP_REQUIRES_OK(ctx, ctx->compiler()->FindFunctionBody(body_name_attr, &body));
122 for (int i = 0; i < body->ret_nodes.size(); i++) {
123 const Node* arg = body->arg_nodes[i];
124 const Node* ret = body->ret_nodes[i];
125 const Node* ret_input_0;
126 OP_REQUIRES_OK(ctx, ret->input_node(0, &ret_input_0));
127 (*loop_invariants)[i] = (ret_input_0->id() == arg->id());
128 }
129 }
130
131 // Converts entries in `args` which are loop invariants and have compile time
132 // constant inputs and need to be constants in order to be compilable to
133 // constants so that they can be propagated in the loop body.
ConvertLoopInvariantsToConst(XlaOpKernelContext * ctx,const NameAttrList & body_name_attr,const NameAttrList & cond_name_attr,std::vector<XlaCompiler::Argument> * args,std::vector<bool> * compile_time_const_arg_indices,int * num_compile_time_const_args,xla::Client * client)134 Status ConvertLoopInvariantsToConst(
135 XlaOpKernelContext* ctx, const NameAttrList& body_name_attr,
136 const NameAttrList& cond_name_attr,
137 std::vector<XlaCompiler::Argument>* args,
138 std::vector<bool>* compile_time_const_arg_indices,
139 int* num_compile_time_const_args, xla::Client* client) {
140 std::vector<bool> loop_invariants(ctx->num_inputs());
141 GetLoopInvariants(ctx, body_name_attr, &loop_invariants);
142
143 std::vector<bool> body_must_be_const_nodes;
144 const FunctionBody* body;
145 std::vector<bool> cond_must_be_const_nodes;
146 const FunctionBody* cond;
147 TF_RETURN_IF_ERROR(FindMustBeConstNodes(ctx, body_name_attr,
148 &body_must_be_const_nodes, &body));
149 TF_RETURN_IF_ERROR(FindMustBeConstNodes(ctx, cond_name_attr,
150 &cond_must_be_const_nodes, &cond));
151
152 auto should_convert_to_const = [&](int arg_idx) {
153 XlaCompiler::Argument& arg = (*args)[arg_idx];
154 return arg.kind != XlaCompiler::Argument::kResource &&
155 loop_invariants[arg_idx] &&
156 (body_must_be_const_nodes[body->arg_nodes[arg_idx]->id()] ||
157 cond_must_be_const_nodes[cond->arg_nodes[arg_idx]->id()]);
158 };
159 absl::InlinedVector<int, 5> converted_constants =
160 ConvertCompileTimeConstArgumentsToConst(ctx, args,
161 /*xla_expression_offset=*/0,
162 should_convert_to_const);
163 for (int arg_idx : converted_constants) {
164 compile_time_const_arg_indices->at(arg_idx) = true;
165 (*num_compile_time_const_args)++;
166 }
167 return Status::OK();
168 }
169
VerifyBodyInputAndOutputShapeMatch(XlaOpKernelContext * ctx,const std::vector<bool> & compile_time_const_arg_indices,const XlaCompiler::CompilationResult & body,bool has_token_input_output)170 Status VerifyBodyInputAndOutputShapeMatch(
171 XlaOpKernelContext* ctx,
172 const std::vector<bool>& compile_time_const_arg_indices,
173 const XlaCompiler::CompilationResult& body, bool has_token_input_output) {
174 xla::Shape body_input_shape = body.xla_input_shapes[0];
175 xla::Shape body_output_shape;
176 body_output_shape.set_element_type(xla::TUPLE);
177 for (int i = 0; i < ctx->num_outputs(); i++) {
178 if (!compile_time_const_arg_indices[i]) {
179 *(body_output_shape.add_tuple_shapes()) =
180 body.xla_output_shape.tuple_shapes(i);
181 }
182 }
183 // If `body` has a token output, append its shape to `body_output_shape`.
184 if (has_token_input_output) {
185 *(body_output_shape.add_tuple_shapes()) =
186 body.xla_output_shape.tuple_shapes(ctx->num_inputs());
187 }
188 if (!xla::ShapeUtil::Compatible(body_input_shape, body_output_shape)) {
189 return errors::InvalidArgument(
190 "Input and output shapes of loop body do not match: ",
191 xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
192 xla::ShapeUtil::HumanString(body_output_shape));
193 }
194 return Status::OK();
195 }
196
BuildWrappedCond(XlaOpKernelContext * ctx,const XlaCompiler::CompilationResult & cond)197 xla::StatusOr<xla::XlaComputation> BuildWrappedCond(
198 XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& cond) {
199 xla::Shape cond_input_shape = cond.xla_input_shapes[0];
200 std::unique_ptr<xla::XlaBuilder> cb =
201 ctx->builder()->CreateSubBuilder("cond_wrapper");
202 auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs");
203 auto outputs = xla::Call(cb.get(), *cond.computation, {inputs});
204 xla::GetTupleElement(outputs, 0);
205 return cb->Build();
206 }
207
BuildWrappedBody(XlaOpKernelContext * ctx,const XlaCompiler::CompilationResult & body,const std::vector<bool> & compile_time_const_arg_indices,int num_compile_time_const_args,bool has_token_input_output)208 xla::StatusOr<xla::XlaComputation> BuildWrappedBody(
209 XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& body,
210 const std::vector<bool>& compile_time_const_arg_indices,
211 int num_compile_time_const_args, bool has_token_input_output) {
212 if (num_compile_time_const_args <= 0) {
213 return xla::XlaComputation(body.computation->proto());
214 }
215 xla::XlaComputation body_wrapper;
216 std::unique_ptr<xla::XlaBuilder> cb =
217 ctx->builder()->CreateSubBuilder("body_wrapper");
218 xla::Shape body_input_shape = body.xla_input_shapes[0];
219 auto inputs = xla::Parameter(cb.get(), 0, body_input_shape, "inputs");
220 // Call the original body function which has mismatched inputs and outputs
221 // and strip the compile time consts from the list of outputs. While requires
222 // the inputs and outputs of its body function to match.
223 auto outputs = xla::Call(cb.get(), *body.computation, {inputs});
224 std::vector<xla::XlaOp> non_compile_time_const_outputs;
225 for (int i = 0; i < compile_time_const_arg_indices.size(); i++) {
226 if (!compile_time_const_arg_indices[i]) {
227 non_compile_time_const_outputs.push_back(
228 xla::GetTupleElement(outputs, i));
229 }
230 }
231 // If `body` has a token output, append it to
232 // `non_compile_time_const_outputs`.
233 if (has_token_input_output) {
234 non_compile_time_const_outputs.push_back(
235 xla::GetTupleElement(outputs, ctx->num_outputs()));
236 }
237 xla::Tuple(cb.get(), non_compile_time_const_outputs);
238 return cb->Build();
239 }
240
BuildWhile(XlaOpKernelContext * ctx,const xla::XlaComputation & wrapped_cond,const xla::XlaComputation & wrapped_body,const xla::XlaOp & initial_values,const std::vector<int> & input_mapping,const std::vector<bool> & compile_time_const_arg_indices,int num_compile_time_const_args,bool has_token_input_output)241 xla::XlaOp BuildWhile(XlaOpKernelContext* ctx,
242 const xla::XlaComputation& wrapped_cond,
243 const xla::XlaComputation& wrapped_body,
244 const xla::XlaOp& initial_values,
245 const std::vector<int>& input_mapping,
246 const std::vector<bool>& compile_time_const_arg_indices,
247 int num_compile_time_const_args,
248 bool has_token_input_output) {
249 xla::XlaOp while_result =
250 xla::While(wrapped_cond, wrapped_body, initial_values);
251 std::vector<xla::XlaOp> padded_while_outputs(ctx->num_outputs());
252 int while_result_index = 0;
253 for (int i = 0; i < ctx->num_inputs(); i++) {
254 if (!compile_time_const_arg_indices[i]) {
255 padded_while_outputs[input_mapping[while_result_index]] =
256 xla::GetTupleElement(while_result, while_result_index);
257 while_result_index++;
258 } else {
259 padded_while_outputs[i] = ctx->Input(i);
260 }
261 }
262 // If `body` has a token output, append it to `padded_while_outputs`.
263 if (has_token_input_output) {
264 padded_while_outputs.push_back(xla::GetTupleElement(
265 while_result, ctx->num_inputs() - num_compile_time_const_args));
266 }
267 return xla::Tuple(ctx->builder(), padded_while_outputs);
268 }
269
270 } // anonymous namespace
271
XlaWhileOp(OpKernelConstruction * ctx)272 XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
273 const NameAttrList* name_attr;
274 OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr));
275 cond_name_attr_ = *name_attr;
276 OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
277 body_name_attr_ = *name_attr;
278 if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
279 has_token_input_output_ = false;
280 } else {
281 has_token_input_output_ = !token_input_nodes_.empty();
282 }
283 if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
284 OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
285 &propagate_compile_time_consts_));
286 }
287 }
288
Compile(XlaOpKernelContext * ctx)289 void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
290 VLOG(1) << "WhileOp::Compile";
291
292 // Input resources need to be grouped in the end of the body function
293 // according to the convention of the XLA bridge.
294 OP_REQUIRES_OK(ctx, VerifyResourceArgsGroupedAtEnd(ctx, body_name_attr_));
295
296 std::vector<XlaCompiler::Argument> arguments;
297 bool has_uninitialized_vars;
298 bool has_tensor_arrays;
299 bool has_uninitialized_tensor_lists;
300 OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
301 ctx, &arguments, &has_uninitialized_vars,
302 &has_tensor_arrays, &has_uninitialized_tensor_lists));
303
304 xla::XlaBuilder* builder = ctx->builder();
305 XlaCompiler* compiler = ctx->compiler();
306
307 // Indices of loop vars which satisfy the following conditions:
308 // 1. They are loop invariants.
309 // 2. The op inputs at these indices are compile time constants.
310 //
311 // These compile time consts do not appear as _Args in the cond/body functions
312 // and are replaced by kConstant nodes instead. As a result, the compiled
313 // body function does not have matching input and output shape. We fix this
314 // by rewriting the body computation (see body_wrapper below) to output
315 // just the non compile-time-const values and later pad up the while output
316 // with the const args.
317 std::vector<bool> compile_time_const_arg_indices(ctx->num_inputs());
318 int num_compile_time_const_args = 0;
319 if (propagate_compile_time_consts_) {
320 OP_REQUIRES_OK(ctx, ConvertLoopInvariantsToConst(
321 ctx, body_name_attr_, cond_name_attr_, &arguments,
322 &compile_time_const_arg_indices,
323 &num_compile_time_const_args, compiler->client()));
324 }
325
326 VLOG(1) << "Compiling body";
327
328 // All resource that are inputs to the loop's body must also be
329 // present as loop body outputs; the signature of the loop's input and
330 // output must match. We ensure this by asking the compiler to include the
331 // current values of all resources, even if they haven't been updated by the
332 // computation. We must also ask the compiler to keep compile-time constant
333 // outputs as part of the generated computation, for the same reason.
334 // TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
335 // operator.
336 XlaCompiler::CompileOptions body_options;
337 body_options.use_tuple_arg = true;
338 body_options.return_updated_values_for_all_resources = true;
339 body_options.is_entry_computation = false;
340 body_options.add_token_input_output = has_token_input_output_;
341 XlaCompiler::CompilationResult body;
342 OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
343 arguments, &body));
344
345 // We must use a static shape for parameters to an XLA compilation. However,
346 // we may not know the shape of a resource if it is first
347 // written inside the loop. Furthermore, we do not know ahead of time which
348 // gradient TensorArrays will be created by the TensorArrayGradV3 operator.
349 //
350 // Ideally we would change TensorFlow to provide static shape always, but
351 // but this is not easy to do. So if uninitialized resources or TensorArrays
352 // are used by the loop body, we compile the body function twice:
353 // 1) once with uninitialized resource inputs and no TensorArray gradient
354 // inputs. We then discard the computation but we assume resource shapes
355 // and the set of gradients read or written will reach a fixpoint after one
356 // iteration.
357 // Hence we can use the output shapes and TensorArray gradients of each
358 // resource as the "true" shapes.
359 // 2) again with the "correct" resource information determined by (1).
360 if (has_uninitialized_vars || has_tensor_arrays ||
361 has_uninitialized_tensor_lists) {
362 VLOG(2) << "Recompiling loop body: has_uninitialized_vars: "
363 << has_uninitialized_vars
364 << " has_tensor_arrays: " << has_tensor_arrays
365 << " has_uninitialized_tensor_lists: "
366 << has_uninitialized_tensor_lists;
367 // Initializes any uninitialized resource with zero values of the
368 // shape determined by the first compilation.
369 for (int i = 0; i < body.resource_updates.size(); ++i) {
370 const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
371 XlaResource* resource;
372 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
373
374 XlaCompiler::Argument& arg = arguments[update.input_index];
375 if (!arg.initialized) {
376 VLOG(2) << "Update shape for argument " << update.input_index << " "
377 << update.shape.DebugString();
378 arg.initialized = true;
379
380 arg.shape = update.shape;
381 OP_REQUIRES_OK(ctx,
382 resource->SetTypeAndShape(update.type, update.shape));
383
384 OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
385 }
386
387 // Add any TensorArray gradients touched by the body to the enclosing
388 // graph.
389 for (const string& grad_source : update.tensor_array_gradients_accessed) {
390 VLOG(4) << "TensorArray " << resource->name() << " accessed gradient "
391 << grad_source;
392 XlaResource* gradient;
393 OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
394 grad_source, builder, &gradient));
395 }
396
397 // Add all of the TensorArray gradients to the argument. For simplicity,
398 // we always pass all known gradients.
399 for (const auto& gradient : resource->tensor_array_gradients()) {
400 arg.tensor_array_gradients.insert(gradient.first);
401 }
402 }
403
404 // Set the shape of any uninitialized TensorLists to the shape determined by
405 // the first compilation. Note that, unlike resources, we do not initialize
406 // the input list with zeros here, that is done later.
407 xla::Shape body_output_shape = body.xla_output_shape;
408 OP_REQUIRES(ctx, body_output_shape.IsTuple(),
409 errors::FailedPrecondition(
410 "xla_output_shape of while body must be a tuple."));
411 for (int i = 0; i < arguments.size(); i++) {
412 XlaCompiler::Argument& arg = arguments[i];
413 if (arg.initialized || !IsTensorListInput(ctx, i)) {
414 continue;
415 }
416 arg.shape = body_output_shape.tuple_shapes(i);
417 arg.initialized = true;
418 }
419
420 // Recompile the body with the "correct" resource shapes.
421 VLOG(1) << "Recompiling body with corrected resource shapes";
422 body = {};
423 OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
424 arguments, &body));
425 }
426
427 VLOG(1) << "Compiling condition";
428
429 XlaCompiler::CompileOptions cond_options;
430 cond_options.use_tuple_arg = true;
431 cond_options.is_entry_computation = false;
432 cond_options.add_token_input_output = has_token_input_output_;
433 XlaCompiler::CompilationResult cond;
434 OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
435 arguments, &cond));
436
437 OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1,
438 errors::FailedPrecondition("Expected one input shape"));
439 xla::Shape body_input_shape = body.xla_input_shapes[0];
440 OP_REQUIRES(ctx, body_input_shape.IsTuple(),
441 errors::FailedPrecondition("Expected tuple shape"));
442 OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1,
443 errors::FailedPrecondition("Expected one input shape"));
444 xla::Shape cond_input_shape = cond.xla_input_shapes[0];
445 OP_REQUIRES(ctx, cond_input_shape.IsTuple(),
446 errors::FailedPrecondition("Expected tuple shape"));
447
448 VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
449 << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
450 VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape)
451 << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape);
452
453 OP_REQUIRES(ctx,
454 xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape),
455 errors::InvalidArgument(
456 "Input shapes of loop body and condition do not match: ",
457 xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
458 xla::ShapeUtil::HumanString(cond_input_shape)));
459
460 // Check that the shape of the body outputs excluding the compile time const
461 // args (which are pruned from the body outputs in body_wapper) matches the
462 // shape of the inputs.
463 OP_REQUIRES_OK(ctx, VerifyBodyInputAndOutputShapeMatch(
464 ctx, compile_time_const_arg_indices, body,
465 has_token_input_output_));
466
467 xla::Shape expected_cond_output_shape_without_side_effect =
468 xla::ShapeUtil::MakeTupleShape(
469 {xla::ShapeUtil::MakeShape(xla::PRED, {})});
470 xla::Shape expected_cond_output_shape_with_side_effect =
471 xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}),
472 xla::ShapeUtil::MakeTokenShape()});
473 OP_REQUIRES(ctx,
474 xla::ShapeUtil::Compatible(
475 cond.xla_output_shape,
476 expected_cond_output_shape_without_side_effect) ||
477 xla::ShapeUtil::Compatible(
478 cond.xla_output_shape,
479 expected_cond_output_shape_with_side_effect),
480 errors::InvalidArgument(
481 "Output shape of loop condition should be (pred[]) or "
482 "(pred[], token[]), got: ",
483 xla::ShapeUtil::HumanString(cond.xla_output_shape)));
484
485 int num_inputs = body.input_mapping.size();
486 std::vector<xla::XlaOp> inputs(num_inputs);
487 for (int i = 0; i < num_inputs; ++i) {
488 int input_num = body.input_mapping[i];
489 if (has_token_input_output_ && i == num_inputs - 1) {
490 // Set token input for this "while" op.
491 std::vector<xla::XlaOp> token_inputs;
492 for (const string& node_name : token_input_nodes_) {
493 auto token_or = compiler->GetNodeToken(node_name);
494 OP_REQUIRES_OK(ctx, token_or.status());
495 token_inputs.push_back(token_or.ValueOrDie());
496 }
497 inputs[i] = xla::AfterAll(builder, token_inputs);
498 } else if (ctx->input_type(input_num) == DT_RESOURCE) {
499 XlaResource* resource;
500 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
501 OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
502 } else if (IsTensorListInput(ctx, input_num)) {
503 xla::XlaOp input = ctx->Input(input_num);
504 auto input_shape_or = ctx->builder()->GetShape(input);
505 OP_REQUIRES_OK(ctx, input_shape_or.status());
506 xla::Shape input_shape = input_shape_or.ValueOrDie();
507 const xla::Shape& list_shape = body_input_shape.tuple_shapes(i);
508 // Shape/datatype of the input list may differ from shape/datatype of the
509 // body/cond input if the list's shape/datatype was inferred after the
510 // first compilation and the body/cond was recompiled with the updated
511 // shape/datatype of the list.
512 if (input_shape != list_shape) {
513 // Prepare dynamic dimensions for element shapes.
514 std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
515 for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
516 std::vector<xla::XlaOp> dynamic_dims;
517
518 const xla::Shape& shape = list_shape.tuple_shapes(i);
519
520 // We already have the dynamic size of leading dimension outside of
521 // the while loop without initializing the TensorList inside the while
522 // loop.
523 if (shape.is_dynamic_dimension(0)) {
524 xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0);
525 dynamic_dims.push_back(leading_dim_size);
526 } else {
527 int32 dim_size = shape.dimensions(0);
528 dynamic_dims.push_back(
529 xla::ConstantR0<int32>(ctx->builder(), dim_size));
530 }
531
532 // Set dynamic dimension size to 0 for element value. Inside the while
533 // loop, TensorlistSetItem will properly set the element shape's
534 // dynamic dimension.
535 for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) {
536 int32 dim_size = shape.dimensions(dim);
537 if (shape.is_dynamic_dimension(dim)) {
538 dim_size = 0;
539 }
540 dynamic_dims.push_back(
541 xla::ConstantR0<int32>(ctx->builder(), dim_size));
542 }
543 list_dynamic_dims.push_back(dynamic_dims);
544 }
545 OP_REQUIRES_OK(
546 ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
547 list_dynamic_dims, &inputs[i]));
548 } else {
549 inputs[i] = ctx->Input(input_num);
550 }
551 } else {
552 inputs[i] = ctx->Input(input_num);
553 }
554 }
555
556 xla::XlaOp init = xla::Tuple(builder, inputs);
557
558 VLOG(1) << "Building while loop";
559
560 // Wraps the condition in a computation that unpacks the output tuple.
561 xla::StatusOr<xla::XlaComputation> cond_result = BuildWrappedCond(ctx, cond);
562 OP_REQUIRES_OK(ctx, cond_result.status());
563 xla::XlaComputation wrapped_cond = std::move(cond_result.ValueOrDie());
564
565 // Remove compile time const args from the list of body outputs.
566 xla::StatusOr<xla::XlaComputation> body_result =
567 BuildWrappedBody(ctx, body, compile_time_const_arg_indices,
568 num_compile_time_const_args, has_token_input_output_);
569 OP_REQUIRES_OK(ctx, body_result.status());
570 xla::XlaComputation wrapped_body = std::move(body_result.ValueOrDie());
571
572 // Builds the While op and pads its output with the compile time const args.
573 xla::XlaOp while_result =
574 BuildWhile(ctx, wrapped_cond, wrapped_body, init, body.input_mapping,
575 compile_time_const_arg_indices, num_compile_time_const_args,
576 has_token_input_output_);
577
578 // Sets non-variable outputs and determine when resource variables start.
579 int resource_index = 0;
580 for (int i = 0; i < ctx->num_outputs(); ++i) {
581 if (ctx->input_type(i) != DT_RESOURCE) {
582 if (IsTensorListInput(ctx, i)) {
583 ctx->SetTensorListOutput(i, xla::GetTupleElement(while_result, i));
584 } else {
585 ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
586 }
587 ++resource_index;
588 } else {
589 break;
590 }
591 }
592 if (has_token_input_output_) {
593 // Set token output for this "while" op.
594 xla::XlaOp token_output =
595 xla::GetTupleElement(while_result, ctx->num_outputs());
596 auto shape_or = builder->GetShape(token_output);
597 OP_REQUIRES_OK(ctx, shape_or.status());
598 OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(),
599 errors::FailedPrecondition(
600 "Token output is not token type: ",
601 xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
602 OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
603 }
604
605 // Updates the values of any resource variables modified by the loop.
606 for (int i = 0; i < body.resource_updates.size(); ++i) {
607 const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
608 XlaResource* resource;
609 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
610 if (update.modified) {
611 int pos = resource_index + i;
612 OP_REQUIRES_OK(ctx,
613 resource->SetFromPack(
614 arguments[update.input_index].tensor_array_gradients,
615 xla::GetTupleElement(while_result, pos), builder));
616 }
617 VLOG(2) << "Loop-carried variable: pos: " << update.input_index
618 << " name: " << resource->name() << " modified: " << update.modified
619 << " type: " << DataTypeString(update.type)
620 << " shape: " << update.shape.DebugString();
621 // Copies the identity of the resource variable from input to output
622 // unchanged, even if the variable was not modified.
623 ctx->op_kernel_context()->set_output(
624 update.input_index,
625 ctx->op_kernel_context()->input(update.input_index));
626 }
627
628 VLOG(1) << "Done building while loop";
629 }
630
631 REGISTER_XLA_OP(Name("While").AllowResourceTypes().AllowVariantTypes(),
632 XlaWhileOp);
633 REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes().AllowVariantTypes(),
634 XlaWhileOp);
635 REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes().AllowVariantTypes(),
636 XlaWhileOp);
637
638 } // namespace tensorflow
639