• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/while_op.h"
17 
18 #include "absl/strings/str_split.h"
19 #include "tensorflow/compiler/tf2xla/const_analysis.h"
20 #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
21 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/core/framework/attr_value.pb.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 
37 namespace tensorflow {
38 
39 namespace {
40 
41 // Verify that input resources are grouped in the end.
VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext * ctx,const NameAttrList & body_name_attr)42 Status VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext* ctx,
43                                       const NameAttrList& body_name_attr) {
44   const FunctionBody* body;
45   TF_RETURN_IF_ERROR(ctx->compiler()->FindFunctionBody(body_name_attr, &body));
46   bool has_seen_resource = false;
47   for (int i = 0; i < body->arg_types.size(); i++) {
48     DataType arg_type = body->arg_types[i];
49     if (has_seen_resource) {
50       if (arg_type != DT_RESOURCE) {
51         return errors::InvalidArgument(
52             "Expect input resources are grouped in the end of while body ",
53             body_name_attr.name(), ", but the ", i, "-th argument ",
54             body->arg_nodes[i]->name(), " is not a resource.");
55       }
56     } else {
57       if (arg_type == DT_RESOURCE) {
58         has_seen_resource = true;
59       }
60     }
61   }
62   return Status::OK();
63 }
64 
65 // Builds XlaCompiler argument descriptions `args` from `ctx`.
MakeXlaCompilerArgumentsFromInputs(XlaOpKernelContext * ctx,std::vector<XlaCompiler::Argument> * args,bool * has_uninitialized_vars,bool * has_tensor_arrays,bool * has_uninitialized_tensor_lists)66 Status MakeXlaCompilerArgumentsFromInputs(
67     XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args,
68     bool* has_uninitialized_vars, bool* has_tensor_arrays,
69     bool* has_uninitialized_tensor_lists) {
70   VLOG(2) << "Num inputs " << ctx->num_inputs();
71   args->resize(ctx->num_inputs());
72   *has_uninitialized_vars = false;
73   *has_tensor_arrays = false;
74   *has_uninitialized_tensor_lists = false;
75   for (int i = 0; i < ctx->num_inputs(); ++i) {
76     VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i))
77             << " shape: " << ctx->InputShape(i).DebugString();
78     XlaCompiler::Argument& arg = (*args)[i];
79     DataType type = ctx->input_type(i);
80     // When reading a resource input, use the type and shape of the resource's
81     // current value.
82     if (type == DT_RESOURCE) {
83       XlaResource* resource;
84       TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
85       XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
86       if (arg.resource_kind == XlaResource::kTensorArray) {
87         *has_tensor_arrays = true;
88       }
89       if (!arg.initialized) {
90         *has_uninitialized_vars = true;
91       }
92       VLOG(2) << "    resource " << resource->name()
93               << " type: " << DataTypeString(arg.type)
94               << " shape: " << arg.ShapeHumanString()
95               << " initialized: " << arg.initialized;
96     } else {
97       arg.kind = XlaCompiler::Argument::kParameter;
98       arg.type = type;
99       TF_ASSIGN_OR_RETURN(arg.shape, ctx->builder()->GetShape(ctx->Input(i)));
100       if (IsTensorListInput(ctx, i)) {
101         // arg.initialized == false means that the element_shape of the list
102         // was not available at the time of building the list so an empty list
103         // was created instead. If so, the body function of While is run once
104         // to infer the shape of the list before actually building the While op.
105         TF_RETURN_IF_ERROR(
106             IsTensorListInitialized(ctx->Input(i), &arg.initialized));
107         if (!arg.initialized) {
108           *has_uninitialized_tensor_lists = true;
109         }
110       }
111     }
112   }
113   return Status::OK();
114 }
115 
116 // Populates loop invariant indices to true in `loop_invariants`.
GetLoopInvariants(XlaOpKernelContext * ctx,const NameAttrList & body_name_attr,std::vector<bool> * const loop_invariants)117 void GetLoopInvariants(XlaOpKernelContext* ctx,
118                        const NameAttrList& body_name_attr,
119                        std::vector<bool>* const loop_invariants) {
120   const FunctionBody* body;
121   OP_REQUIRES_OK(ctx, ctx->compiler()->FindFunctionBody(body_name_attr, &body));
122   for (int i = 0; i < body->ret_nodes.size(); i++) {
123     const Node* arg = body->arg_nodes[i];
124     const Node* ret = body->ret_nodes[i];
125     const Node* ret_input_0;
126     OP_REQUIRES_OK(ctx, ret->input_node(0, &ret_input_0));
127     (*loop_invariants)[i] = (ret_input_0->id() == arg->id());
128   }
129 }
130 
131 // Converts entries in `args` which are loop invariants and have compile time
132 // constant inputs and need to be constants in order to be compilable to
133 // constants so that they can be propagated in the loop body.
ConvertLoopInvariantsToConst(XlaOpKernelContext * ctx,const NameAttrList & body_name_attr,const NameAttrList & cond_name_attr,std::vector<XlaCompiler::Argument> * args,std::vector<bool> * compile_time_const_arg_indices,int * num_compile_time_const_args,xla::Client * client)134 Status ConvertLoopInvariantsToConst(
135     XlaOpKernelContext* ctx, const NameAttrList& body_name_attr,
136     const NameAttrList& cond_name_attr,
137     std::vector<XlaCompiler::Argument>* args,
138     std::vector<bool>* compile_time_const_arg_indices,
139     int* num_compile_time_const_args, xla::Client* client) {
140   std::vector<bool> loop_invariants(ctx->num_inputs());
141   GetLoopInvariants(ctx, body_name_attr, &loop_invariants);
142 
143   std::vector<bool> body_must_be_const_nodes;
144   const FunctionBody* body;
145   std::vector<bool> cond_must_be_const_nodes;
146   const FunctionBody* cond;
147   TF_RETURN_IF_ERROR(FindMustBeConstNodes(ctx, body_name_attr,
148                                           &body_must_be_const_nodes, &body));
149   TF_RETURN_IF_ERROR(FindMustBeConstNodes(ctx, cond_name_attr,
150                                           &cond_must_be_const_nodes, &cond));
151 
152   auto should_convert_to_const = [&](int arg_idx) {
153     XlaCompiler::Argument& arg = (*args)[arg_idx];
154     return arg.kind != XlaCompiler::Argument::kResource &&
155            loop_invariants[arg_idx] &&
156            (body_must_be_const_nodes[body->arg_nodes[arg_idx]->id()] ||
157             cond_must_be_const_nodes[cond->arg_nodes[arg_idx]->id()]);
158   };
159   absl::InlinedVector<int, 5> converted_constants =
160       ConvertCompileTimeConstArgumentsToConst(ctx, args,
161                                               /*xla_expression_offset=*/0,
162                                               should_convert_to_const);
163   for (int arg_idx : converted_constants) {
164     compile_time_const_arg_indices->at(arg_idx) = true;
165     (*num_compile_time_const_args)++;
166   }
167   return Status::OK();
168 }
169 
VerifyBodyInputAndOutputShapeMatch(XlaOpKernelContext * ctx,const std::vector<bool> & compile_time_const_arg_indices,const XlaCompiler::CompilationResult & body,bool has_token_input_output)170 Status VerifyBodyInputAndOutputShapeMatch(
171     XlaOpKernelContext* ctx,
172     const std::vector<bool>& compile_time_const_arg_indices,
173     const XlaCompiler::CompilationResult& body, bool has_token_input_output) {
174   xla::Shape body_input_shape = body.xla_input_shapes[0];
175   xla::Shape body_output_shape;
176   body_output_shape.set_element_type(xla::TUPLE);
177   for (int i = 0; i < ctx->num_outputs(); i++) {
178     if (!compile_time_const_arg_indices[i]) {
179       *(body_output_shape.add_tuple_shapes()) =
180           body.xla_output_shape.tuple_shapes(i);
181     }
182   }
183   // If `body` has a token output, append its shape to `body_output_shape`.
184   if (has_token_input_output) {
185     *(body_output_shape.add_tuple_shapes()) =
186         body.xla_output_shape.tuple_shapes(ctx->num_inputs());
187   }
188   if (!xla::ShapeUtil::Compatible(body_input_shape, body_output_shape)) {
189     return errors::InvalidArgument(
190         "Input and output shapes of loop body do not match: ",
191         xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
192         xla::ShapeUtil::HumanString(body_output_shape));
193   }
194   return Status::OK();
195 }
196 
BuildWrappedCond(XlaOpKernelContext * ctx,const XlaCompiler::CompilationResult & cond)197 xla::StatusOr<xla::XlaComputation> BuildWrappedCond(
198     XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& cond) {
199   xla::Shape cond_input_shape = cond.xla_input_shapes[0];
200   std::unique_ptr<xla::XlaBuilder> cb =
201       ctx->builder()->CreateSubBuilder("cond_wrapper");
202   auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs");
203   auto outputs = xla::Call(cb.get(), *cond.computation, {inputs});
204   xla::GetTupleElement(outputs, 0);
205   return cb->Build();
206 }
207 
BuildWrappedBody(XlaOpKernelContext * ctx,const XlaCompiler::CompilationResult & body,const std::vector<bool> & compile_time_const_arg_indices,int num_compile_time_const_args,bool has_token_input_output)208 xla::StatusOr<xla::XlaComputation> BuildWrappedBody(
209     XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& body,
210     const std::vector<bool>& compile_time_const_arg_indices,
211     int num_compile_time_const_args, bool has_token_input_output) {
212   if (num_compile_time_const_args <= 0) {
213     return xla::XlaComputation(body.computation->proto());
214   }
215   xla::XlaComputation body_wrapper;
216   std::unique_ptr<xla::XlaBuilder> cb =
217       ctx->builder()->CreateSubBuilder("body_wrapper");
218   xla::Shape body_input_shape = body.xla_input_shapes[0];
219   auto inputs = xla::Parameter(cb.get(), 0, body_input_shape, "inputs");
220   // Call the original body function which has mismatched inputs and outputs
221   // and strip the compile time consts from the list of outputs. While requires
222   // the inputs and outputs of its body function to match.
223   auto outputs = xla::Call(cb.get(), *body.computation, {inputs});
224   std::vector<xla::XlaOp> non_compile_time_const_outputs;
225   for (int i = 0; i < compile_time_const_arg_indices.size(); i++) {
226     if (!compile_time_const_arg_indices[i]) {
227       non_compile_time_const_outputs.push_back(
228           xla::GetTupleElement(outputs, i));
229     }
230   }
231   // If `body` has a token output, append it to
232   // `non_compile_time_const_outputs`.
233   if (has_token_input_output) {
234     non_compile_time_const_outputs.push_back(
235         xla::GetTupleElement(outputs, ctx->num_outputs()));
236   }
237   xla::Tuple(cb.get(), non_compile_time_const_outputs);
238   return cb->Build();
239 }
240 
BuildWhile(XlaOpKernelContext * ctx,const xla::XlaComputation & wrapped_cond,const xla::XlaComputation & wrapped_body,const xla::XlaOp & initial_values,const std::vector<int> & input_mapping,const std::vector<bool> & compile_time_const_arg_indices,int num_compile_time_const_args,bool has_token_input_output)241 xla::XlaOp BuildWhile(XlaOpKernelContext* ctx,
242                       const xla::XlaComputation& wrapped_cond,
243                       const xla::XlaComputation& wrapped_body,
244                       const xla::XlaOp& initial_values,
245                       const std::vector<int>& input_mapping,
246                       const std::vector<bool>& compile_time_const_arg_indices,
247                       int num_compile_time_const_args,
248                       bool has_token_input_output) {
249   xla::XlaOp while_result =
250       xla::While(wrapped_cond, wrapped_body, initial_values);
251   std::vector<xla::XlaOp> padded_while_outputs(ctx->num_outputs());
252   int while_result_index = 0;
253   for (int i = 0; i < ctx->num_inputs(); i++) {
254     if (!compile_time_const_arg_indices[i]) {
255       padded_while_outputs[input_mapping[while_result_index]] =
256           xla::GetTupleElement(while_result, while_result_index);
257       while_result_index++;
258     } else {
259       padded_while_outputs[i] = ctx->Input(i);
260     }
261   }
262   // If `body` has a token output, append it to `padded_while_outputs`.
263   if (has_token_input_output) {
264     padded_while_outputs.push_back(xla::GetTupleElement(
265         while_result, ctx->num_inputs() - num_compile_time_const_args));
266   }
267   return xla::Tuple(ctx->builder(), padded_while_outputs);
268 }
269 
270 }  // anonymous namespace
271 
XlaWhileOp(OpKernelConstruction * ctx)272 XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
273   const NameAttrList* name_attr;
274   OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr));
275   cond_name_attr_ = *name_attr;
276   OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
277   body_name_attr_ = *name_attr;
278   if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
279     has_token_input_output_ = false;
280   } else {
281     has_token_input_output_ = !token_input_nodes_.empty();
282   }
283   if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
284     OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
285                                      &propagate_compile_time_consts_));
286   }
287 }
288 
Compile(XlaOpKernelContext * ctx)289 void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
290   VLOG(1) << "WhileOp::Compile";
291 
292   // Input resources need to be grouped in the end of the body function
293   // according to the convention of the XLA bridge.
294   OP_REQUIRES_OK(ctx, VerifyResourceArgsGroupedAtEnd(ctx, body_name_attr_));
295 
296   std::vector<XlaCompiler::Argument> arguments;
297   bool has_uninitialized_vars;
298   bool has_tensor_arrays;
299   bool has_uninitialized_tensor_lists;
300   OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
301                           ctx, &arguments, &has_uninitialized_vars,
302                           &has_tensor_arrays, &has_uninitialized_tensor_lists));
303 
304   xla::XlaBuilder* builder = ctx->builder();
305   XlaCompiler* compiler = ctx->compiler();
306 
307   // Indices of loop vars which satisfy the following conditions:
308   // 1. They are loop invariants.
309   // 2. The op inputs at these indices are compile time constants.
310   //
311   // These compile time consts do not appear as _Args in the cond/body functions
312   // and are replaced by kConstant nodes instead. As a result, the compiled
313   // body function does not have matching input and output shape. We fix this
314   // by rewriting the body computation (see body_wrapper below) to output
315   // just the non compile-time-const values and later pad up the while output
316   // with the const args.
317   std::vector<bool> compile_time_const_arg_indices(ctx->num_inputs());
318   int num_compile_time_const_args = 0;
319   if (propagate_compile_time_consts_) {
320     OP_REQUIRES_OK(ctx, ConvertLoopInvariantsToConst(
321                             ctx, body_name_attr_, cond_name_attr_, &arguments,
322                             &compile_time_const_arg_indices,
323                             &num_compile_time_const_args, compiler->client()));
324   }
325 
326   VLOG(1) << "Compiling body";
327 
328   // All resource that are inputs to the loop's body must also be
329   // present as loop body outputs; the signature of the loop's input and
330   // output must match. We ensure this by asking the compiler to include the
331   // current values of all resources, even if they haven't been updated by the
332   // computation. We must also ask the compiler to keep compile-time constant
333   // outputs as part of the generated computation, for the same reason.
334   // TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
335   // operator.
336   XlaCompiler::CompileOptions body_options;
337   body_options.use_tuple_arg = true;
338   body_options.return_updated_values_for_all_resources = true;
339   body_options.is_entry_computation = false;
340   body_options.add_token_input_output = has_token_input_output_;
341   XlaCompiler::CompilationResult body;
342   OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
343                                                 arguments, &body));
344 
345   // We must use a static shape for parameters to an XLA compilation. However,
346   // we may not know the shape of a resource if it is first
347   // written inside the loop. Furthermore, we do not know ahead of time which
348   // gradient TensorArrays will be created by the TensorArrayGradV3 operator.
349   //
350   // Ideally we would change TensorFlow to provide static shape always, but
351   // but this is not easy to do. So if uninitialized resources or TensorArrays
352   // are used by the loop body, we compile the body function twice:
353   // 1) once with uninitialized resource inputs and no TensorArray gradient
354   //    inputs. We then discard the computation but we assume resource shapes
355   //    and the set of gradients read or written will reach a fixpoint after one
356   //    iteration.
357   //    Hence we can use the output shapes and TensorArray gradients of each
358   //    resource as the "true" shapes.
359   // 2) again with the "correct" resource information determined by (1).
360   if (has_uninitialized_vars || has_tensor_arrays ||
361       has_uninitialized_tensor_lists) {
362     VLOG(2) << "Recompiling loop body: has_uninitialized_vars: "
363             << has_uninitialized_vars
364             << " has_tensor_arrays: " << has_tensor_arrays
365             << " has_uninitialized_tensor_lists: "
366             << has_uninitialized_tensor_lists;
367     // Initializes any uninitialized resource with zero values of the
368     // shape determined by the first compilation.
369     for (int i = 0; i < body.resource_updates.size(); ++i) {
370       const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
371       XlaResource* resource;
372       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
373 
374       XlaCompiler::Argument& arg = arguments[update.input_index];
375       if (!arg.initialized) {
376         VLOG(2) << "Update shape for argument " << update.input_index << " "
377                 << update.shape.DebugString();
378         arg.initialized = true;
379 
380         arg.shape = update.shape;
381         OP_REQUIRES_OK(ctx,
382                        resource->SetTypeAndShape(update.type, update.shape));
383 
384         OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
385       }
386 
387       // Add any TensorArray gradients touched by the body to the enclosing
388       // graph.
389       for (const string& grad_source : update.tensor_array_gradients_accessed) {
390         VLOG(4) << "TensorArray " << resource->name() << " accessed gradient "
391                 << grad_source;
392         XlaResource* gradient;
393         OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
394                                 grad_source, builder, &gradient));
395       }
396 
397       // Add all of the TensorArray gradients to the argument. For simplicity,
398       // we always pass all known gradients.
399       for (const auto& gradient : resource->tensor_array_gradients()) {
400         arg.tensor_array_gradients.insert(gradient.first);
401       }
402     }
403 
404     // Set the shape of any uninitialized TensorLists to the shape determined by
405     // the first compilation. Note that, unlike resources, we do not initialize
406     // the input list with zeros here, that is done later.
407     xla::Shape body_output_shape = body.xla_output_shape;
408     OP_REQUIRES(ctx, body_output_shape.IsTuple(),
409                 errors::FailedPrecondition(
410                     "xla_output_shape of while body must be a tuple."));
411     for (int i = 0; i < arguments.size(); i++) {
412       XlaCompiler::Argument& arg = arguments[i];
413       if (arg.initialized || !IsTensorListInput(ctx, i)) {
414         continue;
415       }
416       arg.shape = body_output_shape.tuple_shapes(i);
417       arg.initialized = true;
418     }
419 
420     // Recompile the body with the "correct" resource shapes.
421     VLOG(1) << "Recompiling body with corrected resource shapes";
422     body = {};
423     OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
424                                                   arguments, &body));
425   }
426 
427   VLOG(1) << "Compiling condition";
428 
429   XlaCompiler::CompileOptions cond_options;
430   cond_options.use_tuple_arg = true;
431   cond_options.is_entry_computation = false;
432   cond_options.add_token_input_output = has_token_input_output_;
433   XlaCompiler::CompilationResult cond;
434   OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
435                                                 arguments, &cond));
436 
437   OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1,
438               errors::FailedPrecondition("Expected one input shape"));
439   xla::Shape body_input_shape = body.xla_input_shapes[0];
440   OP_REQUIRES(ctx, body_input_shape.IsTuple(),
441               errors::FailedPrecondition("Expected tuple shape"));
442   OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1,
443               errors::FailedPrecondition("Expected one input shape"));
444   xla::Shape cond_input_shape = cond.xla_input_shapes[0];
445   OP_REQUIRES(ctx, cond_input_shape.IsTuple(),
446               errors::FailedPrecondition("Expected tuple shape"));
447 
448   VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
449           << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
450   VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape)
451           << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape);
452 
453   OP_REQUIRES(ctx,
454               xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape),
455               errors::InvalidArgument(
456                   "Input shapes of loop body and condition do not match: ",
457                   xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
458                   xla::ShapeUtil::HumanString(cond_input_shape)));
459 
460   // Check that the shape of the body outputs excluding the compile time const
461   // args (which are pruned from the body outputs in body_wapper) matches the
462   // shape of the inputs.
463   OP_REQUIRES_OK(ctx, VerifyBodyInputAndOutputShapeMatch(
464                           ctx, compile_time_const_arg_indices, body,
465                           has_token_input_output_));
466 
467   xla::Shape expected_cond_output_shape_without_side_effect =
468       xla::ShapeUtil::MakeTupleShape(
469           {xla::ShapeUtil::MakeShape(xla::PRED, {})});
470   xla::Shape expected_cond_output_shape_with_side_effect =
471       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}),
472                                       xla::ShapeUtil::MakeTokenShape()});
473   OP_REQUIRES(ctx,
474               xla::ShapeUtil::Compatible(
475                   cond.xla_output_shape,
476                   expected_cond_output_shape_without_side_effect) ||
477                   xla::ShapeUtil::Compatible(
478                       cond.xla_output_shape,
479                       expected_cond_output_shape_with_side_effect),
480               errors::InvalidArgument(
481                   "Output shape of loop condition should be (pred[]) or "
482                   "(pred[], token[]), got: ",
483                   xla::ShapeUtil::HumanString(cond.xla_output_shape)));
484 
485   int num_inputs = body.input_mapping.size();
486   std::vector<xla::XlaOp> inputs(num_inputs);
487   for (int i = 0; i < num_inputs; ++i) {
488     int input_num = body.input_mapping[i];
489     if (has_token_input_output_ && i == num_inputs - 1) {
490       // Set token input for this "while" op.
491       std::vector<xla::XlaOp> token_inputs;
492       for (const string& node_name : token_input_nodes_) {
493         auto token_or = compiler->GetNodeToken(node_name);
494         OP_REQUIRES_OK(ctx, token_or.status());
495         token_inputs.push_back(token_or.ValueOrDie());
496       }
497       inputs[i] = xla::AfterAll(builder, token_inputs);
498     } else if (ctx->input_type(input_num) == DT_RESOURCE) {
499       XlaResource* resource;
500       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
501       OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
502     } else if (IsTensorListInput(ctx, input_num)) {
503       xla::XlaOp input = ctx->Input(input_num);
504       auto input_shape_or = ctx->builder()->GetShape(input);
505       OP_REQUIRES_OK(ctx, input_shape_or.status());
506       xla::Shape input_shape = input_shape_or.ValueOrDie();
507       const xla::Shape& list_shape = body_input_shape.tuple_shapes(i);
508       // Shape/datatype of the input list may differ from shape/datatype of the
509       // body/cond input if the list's shape/datatype was inferred after the
510       // first compilation and the body/cond was recompiled with the updated
511       // shape/datatype of the list.
512       if (input_shape != list_shape) {
513         // Prepare dynamic dimensions for element shapes.
514         std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
515         for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
516           std::vector<xla::XlaOp> dynamic_dims;
517 
518           const xla::Shape& shape = list_shape.tuple_shapes(i);
519 
520           // We already have the dynamic size of leading dimension outside of
521           // the while loop without initializing the TensorList inside the while
522           // loop.
523           if (shape.is_dynamic_dimension(0)) {
524             xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0);
525             dynamic_dims.push_back(leading_dim_size);
526           } else {
527             int32 dim_size = shape.dimensions(0);
528             dynamic_dims.push_back(
529                 xla::ConstantR0<int32>(ctx->builder(), dim_size));
530           }
531 
532           // Set dynamic dimension size to 0 for element value. Inside the while
533           // loop, TensorlistSetItem will properly set the element shape's
534           // dynamic dimension.
535           for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) {
536             int32 dim_size = shape.dimensions(dim);
537             if (shape.is_dynamic_dimension(dim)) {
538               dim_size = 0;
539             }
540             dynamic_dims.push_back(
541                 xla::ConstantR0<int32>(ctx->builder(), dim_size));
542           }
543           list_dynamic_dims.push_back(dynamic_dims);
544         }
545         OP_REQUIRES_OK(
546             ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
547                                                 list_dynamic_dims, &inputs[i]));
548       } else {
549         inputs[i] = ctx->Input(input_num);
550       }
551     } else {
552       inputs[i] = ctx->Input(input_num);
553     }
554   }
555 
556   xla::XlaOp init = xla::Tuple(builder, inputs);
557 
558   VLOG(1) << "Building while loop";
559 
560   // Wraps the condition in a computation that unpacks the output tuple.
561   xla::StatusOr<xla::XlaComputation> cond_result = BuildWrappedCond(ctx, cond);
562   OP_REQUIRES_OK(ctx, cond_result.status());
563   xla::XlaComputation wrapped_cond = std::move(cond_result.ValueOrDie());
564 
565   // Remove compile time const args from the list of body outputs.
566   xla::StatusOr<xla::XlaComputation> body_result =
567       BuildWrappedBody(ctx, body, compile_time_const_arg_indices,
568                        num_compile_time_const_args, has_token_input_output_);
569   OP_REQUIRES_OK(ctx, body_result.status());
570   xla::XlaComputation wrapped_body = std::move(body_result.ValueOrDie());
571 
572   // Builds the While op and pads its output with the compile time const args.
573   xla::XlaOp while_result =
574       BuildWhile(ctx, wrapped_cond, wrapped_body, init, body.input_mapping,
575                  compile_time_const_arg_indices, num_compile_time_const_args,
576                  has_token_input_output_);
577 
578   // Sets non-variable outputs and determine when resource variables start.
579   int resource_index = 0;
580   for (int i = 0; i < ctx->num_outputs(); ++i) {
581     if (ctx->input_type(i) != DT_RESOURCE) {
582       if (IsTensorListInput(ctx, i)) {
583         ctx->SetTensorListOutput(i, xla::GetTupleElement(while_result, i));
584       } else {
585         ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
586       }
587       ++resource_index;
588     } else {
589       break;
590     }
591   }
592   if (has_token_input_output_) {
593     // Set token output for this "while" op.
594     xla::XlaOp token_output =
595         xla::GetTupleElement(while_result, ctx->num_outputs());
596     auto shape_or = builder->GetShape(token_output);
597     OP_REQUIRES_OK(ctx, shape_or.status());
598     OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(),
599                 errors::FailedPrecondition(
600                     "Token output is not token type: ",
601                     xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
602     OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
603   }
604 
605   // Updates the values of any resource variables modified by the loop.
606   for (int i = 0; i < body.resource_updates.size(); ++i) {
607     const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
608     XlaResource* resource;
609     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
610     if (update.modified) {
611       int pos = resource_index + i;
612       OP_REQUIRES_OK(ctx,
613                      resource->SetFromPack(
614                          arguments[update.input_index].tensor_array_gradients,
615                          xla::GetTupleElement(while_result, pos), builder));
616     }
617     VLOG(2) << "Loop-carried variable: pos: " << update.input_index
618             << " name: " << resource->name() << " modified: " << update.modified
619             << " type: " << DataTypeString(update.type)
620             << " shape: " << update.shape.DebugString();
621     // Copies the identity of the resource variable from input to output
622     // unchanged, even if the variable was not modified.
623     ctx->op_kernel_context()->set_output(
624         update.input_index,
625         ctx->op_kernel_context()->input(update.input_index));
626   }
627 
628   VLOG(1) << "Done building while loop";
629 }
630 
631 REGISTER_XLA_OP(Name("While").AllowResourceTypes().AllowVariantTypes(),
632                 XlaWhileOp);
633 REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes().AllowVariantTypes(),
634                 XlaWhileOp);
635 REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes().AllowVariantTypes(),
636                 XlaWhileOp);
637 
638 }  // namespace tensorflow
639