• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/while_loop.h"
17 
18 #include "tensorflow/cc/framework/scope_internal.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/common_runtime/shape_refiner.h"
22 #include "tensorflow/core/graph/node_builder.h"
23 
24 namespace tensorflow {
25 namespace ops {
26 
27 namespace {
28 
29 // Utility function for converting to internal C++ datatypes.
ToOutputTensor(const Output & output)30 OutputTensor ToOutputTensor(const Output& output) {
31   return OutputTensor(output.node(), output.index());
32 }
33 
34 // Utility function for converting to internal C++ datatypes.
ToOutputTensors(const std::vector<Output> & outputs)35 std::vector<OutputTensor> ToOutputTensors(const std::vector<Output>& outputs) {
36   std::vector<OutputTensor> result(outputs.size());
37   for (int i = 0; i < outputs.size(); ++i) {
38     result[i] = ToOutputTensor(outputs[i]);
39   }
40   return result;
41 }
42 
43 // Utility function for converting to internal C++ datatypes.
ToNodes(const std::vector<Output> & outputs)44 std::vector<Node*> ToNodes(const std::vector<Output>& outputs) {
45   std::vector<Node*> result(outputs.size());
46   for (int i = 0; i < outputs.size(); ++i) {
47     result[i] = outputs[i].node();
48   }
49   return result;
50 }
51 
52 // Manually generates the name of the `loop_var_idx`-th NextIteration node of a
53 // loop being constructed with `scope`. This is used to define the backedge
54 // before the NextIteration node is created.
NextIterationName(const Scope & scope,int loop_var_idx)55 string NextIterationName(const Scope& scope, int loop_var_idx) {
56   string result;
57   const string& prefix = scope.impl()->name();
58   if (!prefix.empty()) strings::StrAppend(&result, prefix, "/");
59   strings::StrAppend(&result, "NextIteration");
60   if (loop_var_idx > 0) strings::StrAppend(&result, "_", loop_var_idx);
61   return result;
62 }
63 
64 // Creates the `loop_var_idx`-th Merge node of a loop being constructed with
65 // `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output.
CreateMerge(const Scope & scope,int loop_var_idx,const Output & enter_output,Output * merge_output)66 Status CreateMerge(const Scope& scope, int loop_var_idx,
67                    const Output& enter_output, Output* merge_output) {
68   // The merge nodes accept the while loop's back edges as an input (i.e. the
69   // not-yet-created next iteration nodes). Use the underlying NodeBuilder API
70   // directly to create the back edge.
71   NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index());
72 
73   const int next_output_index = 0;
74   DataType dtype = enter_output.node()->output_type(0);
75   NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx),
76                                   next_output_index, dtype);
77 
78   std::vector<NodeBuilder::NodeOut> input_list({enter_input, next_input});
79   const string unique_name = scope.GetUniqueNameForOp("Merge");
80   NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list);
81   scope.UpdateBuilder(&builder);
82 
83   Node* merge_node;
84   TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node));
85   TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node));
86   *merge_output = Output(merge_node, 0);
87   return Status::OK();
88 }
89 
90 // Creates the condition subgraph defined by `cond`.
CreateCond(const Scope & scope,const CondGraphBuilderFn & cond,const std::vector<Output> & inputs,Output * output)91 Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond,
92                   const std::vector<Output>& inputs, Output* output) {
93   // The control dependency is for constants in the cond graph, and other ops
94   // that do not depend on the loop variables. This ensures that these ops are
95   // in the while loop frame (since they will indirectly depend on an Enter node
96   // defining the frame) and that they are executed once per loop iteration.
97   //
98   // TODO(skyewm): the control dep will be added to all nodes in the cond graph.
99   // This is at best unnecessary, and at worst may prevent different parts of
100   // different loop iterations from executing in parallel.
101   Scope cond_scope =
102       scope.NewSubScope("cond").WithControlDependencies(inputs[0]);
103   Output raw_cond_out;
104   TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out));
105 
106   TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(),
107                                                         raw_cond_out.index()));
108   if (raw_cond_out.type() != DT_BOOL) {
109     return errors::InvalidArgument(
110         "BuildWhileLoop: 'cond' argument must return a boolean output, got ",
111         DataTypeString(raw_cond_out.type()));
112   }
113   // TODO(skyewm): check that raw_cond_out is scalar
114 
115   *output = LoopCond(scope, raw_cond_out).output;
116   return Status::OK();
117 }
118 
119 // Create the body subgraph defined by `body`. `outputs` must be non-null and
120 // empty.
CreateBody(const Scope & scope,const BodyGraphBuilderFn & body,const std::vector<Output> & inputs,std::vector<Output> * outputs)121 Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
122                   const std::vector<Output>& inputs,
123                   std::vector<Output>* outputs) {
124   DCHECK(outputs != nullptr);
125   DCHECK(outputs->empty());
126 
127   // The control dependency is analogous to that in CreateCond().
128   Scope body_scope =
129       scope.NewSubScope("body").WithControlDependencies(inputs[0]);
130   TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs));
131 
132   const size_t num_loop_vars = inputs.size();
133   if (outputs->size() != num_loop_vars) {
134     return errors::InvalidArgument(
135         "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars,
136         " output(s), got ", outputs->size());
137   }
138   for (const Output& output : *outputs) {
139     TF_RETURN_IF_ERROR(
140         scope.graph()->IsValidOutputTensor(output.node(), output.index()));
141     // TODO(skyewm): check output types/shapes
142   }
143   return Status::OK();
144 }
145 
146 }  // namespace
147 
148 // A while loop with a single loop variable looks like this:
149 //
150 // (output)
151 //     ^    +---------------+
152 //     |    | body subgraph +-------------+
153 //    Exit  +---------------+             |
154 //      ^    ^                            |
155 //      |    |                            |
156 //      Switch<--------+                  v
157 //        ^            |             NextIteration
158 //        |     +------+--------+         |
159 //        +---->| cond subgraph |         |
160 //        |     +---------------+         |
161 //       Merge<---------------------------+
162 //       ^
163 //       |
164 //    Enter
165 //      ^
166 //      |
167 //   (input)
168 //
169 // If there are multiple loop variables, each of the control flow ops is
170 // duplicated for each loop variable.
171 // TODO(skyewm): link to public version of design doc
BuildWhileLoop(const Scope & scope,const std::vector<Output> & inputs,const CondGraphBuilderFn & cond,const BodyGraphBuilderFn & body,const string & frame_name,OutputList * outputs,bool create_while_ctx,Output * cond_output)172 Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
173                       const CondGraphBuilderFn& cond,
174                       const BodyGraphBuilderFn& body, const string& frame_name,
175                       OutputList* outputs, bool create_while_ctx,
176                       Output* cond_output) {
177   DCHECK(!inputs.empty());
178   DCHECK(outputs != nullptr);
179   DCHECK(outputs->empty());
180 
181   TF_RETURN_IF_ERROR(scope.status());
182   const size_t num_loop_vars = inputs.size();
183 
184   std::vector<Output> enter_outputs(num_loop_vars);
185   for (int i = 0; i < num_loop_vars; ++i) {
186     enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name);
187   }
188   TF_RETURN_IF_ERROR(scope.status());
189 
190   std::vector<Output> merge_outputs(num_loop_vars);
191   for (int i = 0; i < num_loop_vars; ++i) {
192     TF_RETURN_IF_ERROR(
193         CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i]));
194   }
195 
196   Output cond_out;
197   TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
198   if (cond_output != nullptr) *cond_output = cond_out;
199 
200   std::vector<Output> switch_trues(num_loop_vars);
201   std::vector<Output> switch_falses(num_loop_vars);
202   for (int i = 0; i < num_loop_vars; ++i) {
203     auto switch_i = Switch(scope, merge_outputs[i], cond_out);
204     switch_trues[i] = switch_i.output_true;
205     switch_falses[i] = switch_i.output_false;
206   }
207   TF_RETURN_IF_ERROR(scope.status());
208 
209   std::vector<Output> body_outputs;
210   TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs));
211 
212   std::vector<Output> next_outputs(num_loop_vars);
213   for (int i = 0; i < num_loop_vars; ++i) {
214     next_outputs[i] = NextIteration(scope, body_outputs[i]);
215     DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i));
216   }
217   TF_RETURN_IF_ERROR(scope.status());
218 
219   // Create the backedges from the NextIteration nodes to the Merge nodes.
220   for (int i = 0; i < num_loop_vars; ++i) {
221     const int merge_backedge_output_index = 1;
222     scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(),
223                            merge_outputs[i].node(),
224                            merge_backedge_output_index);
225   }
226 
227   outputs->resize(num_loop_vars);
228   for (int i = 0; i < num_loop_vars; ++i) {
229     (*outputs)[i] = internal::Exit(scope, switch_falses[i]);
230   }
231   TF_RETURN_IF_ERROR(scope.status());
232 
233   if (create_while_ctx) {
234     WhileContext* while_ctx;
235     TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext(
236         frame_name, ToNodes(enter_outputs), ToNodes(*outputs),
237         ToOutputTensor(cond_out), ToOutputTensors(switch_trues),
238         ToOutputTensors(body_outputs), &while_ctx));
239 
240     // Set while_ctx for all exit nodes. We currently don't require knowing the
241     // while_ctx for any other nodes.
242     for (int i = 0; i < num_loop_vars; ++i) {
243       (*outputs)[i].node()->set_while_ctx(while_ctx);
244     }
245   }
246   return Status::OK();
247 }
248 
249 }  // namespace ops
250 }  // namespace tensorflow
251