• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/lower_if_op.h"
17 
18 #include "tensorflow/core/common_runtime/inline_function_utils.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/graph/graph.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 using NodeOut = NodeBuilder::NodeOut;
27 
28 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
29     LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
30 
31 // Convenience builder to make it easy to construct a conditional with a single
32 // function call in the then and else branch. This first converts the if node
33 // into switches (for inputs) and merges (for outputs) around a function call
34 // per branch.
35 class CondBuilder {
36  public:
37   enum Branch { kElseBranch = 0, kThenBranch = 1 };
38 
39   // Create a CondBuilder to create the lowered form of `if_op` with then and
40   // else functions `then_fn` and `else_fn` respectively in the `graph`. The
41   // functions should be available in `flib`.
42   CondBuilder(Node* if_op, const NameAttrList& then_fn,
43               const NameAttrList& else_fn, bool keep_node_fetchable,
44               Graph* graph);
45 
46   // Constructs the basic conditional control flow using switch and merge nodes.
47   Status CreatePivotNodes();
48 
49   // Adds the inputs from the if node to the merge nodes of the lowered if.
50   Status AddInputs();
51 
52   // Adds the outputs from the if node to the merge nodes of the lowered if.
53   // Note: no inputs can be added once outputs are added as the then and else
54   // nodes are finalized while adding outputs.
55   Status AddOutputs();
56 
57   // Builds an identity node with the same outputs as If.
58   Status BuildLoweredIfOutput();
59 
60  private:
61   // Returns unique name containing the name of the If op being rewritten
62   // (name_), infix and a suffix to ensure it is unique within the graph.
63   string NewName(const string& infix);
64 
65   // Adds input to both the then and else nodes from src:src_output.
66   Status AddInput(Node* src, int src_output);
67 
68   // Finalizes the node described by `node_builder`. If `coloc_attr_` is not
69   // nullptr, adds the colocation attr to the node before finalizing it.
70   Status SetColocationAndFinalize(NodeBuilder node_builder, Graph* graph,
71                                   Node** created_node);
72 
73   // The merged outputs of the then and else nodes.
74   std::vector<NodeOut> outputs_;
75 
76   // The node that dominates all execution of the then and else body nodes.
77   Node* control_predecessor_;
78   // The original If op.
79   Node* if_op_;
80   // The colocation attr on the original If op. If it exists, control flow nodes
81   // created in the lowering (except the data Switch nodes) will inherit this
82   // attribute.
83   const AttrValue* coloc_attr_;
84   // The node with the same name as the original If op:
85   //   (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
86   //       and if the original If op had non-zero data outputs.
87   //   (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
88   Node* lowered_if_output_;
89   // The predicate of the conditional.
90   OutputTensor pred_;
91   // Node corresponding to pivot_f branch of predicate switch which is
92   // the pivot node that dominates all nodes in the false/else branch.
93   Node* pivot_f_;
94   // Node corresponding to pivot_t branch of predicate switch which is
95   // the pivot node that dominates all nodes in the true/then branch.
96   Node* pivot_t_;
97   Node* then_call_node_;
98   Node* else_call_node_;
99   // Merge node that has inputs from [pivot_t, pivot_f] and control edges from
100   // [^then_call_node_, ^else_call_node_]. This node will guarantee that even
101   // when then/else branch functions do not have outputs, they still will be
102   // executed for the side effects.
103   Node* branch_executed_node_;
104   Graph* graph_;
105   string name_;
106   bool keep_node_fetchable_;
107 
108   NodeDebugInfo debug_info_;
109   NodeBuilder then_call_builder_;
110   NodeBuilder else_call_builder_;
111 };
112 
CondBuilder(Node * if_op,const NameAttrList & then_fn,const NameAttrList & else_fn,bool keep_node_fetchable,Graph * graph)113 CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn,
114                          const NameAttrList& else_fn, bool keep_node_fetchable,
115                          Graph* graph)
116     : if_op_(if_op),
117       coloc_attr_(if_op_->attrs().Find(kColocationAttrName)),
118       graph_(graph),
119       name_(if_op->name()),
120       keep_node_fetchable_(keep_node_fetchable),
121       debug_info_(*if_op_),
122       then_call_builder_(NewName("then"), then_fn.name(), graph->op_registry(),
123                          &debug_info_),
124       else_call_builder_(NewName("else"), else_fn.name(), graph->op_registry(),
125                          &debug_info_) {
126   TF_CHECK_OK(if_op_->input_tensor(0, &pred_));
127   then_call_builder_.Device(if_op_->requested_device());
128   then_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
129   for (const auto& i : then_fn.attr()) {
130     then_call_builder_.Attr(i.first, i.second);
131   }
132   else_call_builder_.Device(if_op_->requested_device());
133   else_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
134   for (const auto& i : else_fn.attr()) {
135     else_call_builder_.Attr(i.first, i.second);
136   }
137 }
138 
SetColocationAndFinalize(NodeBuilder node_builder,Graph * graph,Node ** created_node)139 Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder,
140                                              Graph* graph,
141                                              Node** created_node) {
142   if (coloc_attr_ != nullptr) {
143     node_builder = node_builder.Attr(kColocationAttrName, *coloc_attr_);
144   }
145   return node_builder.Finalize(graph, created_node);
146 }
147 
CreatePivotNodes()148 Status CondBuilder::CreatePivotNodes() {
149   // Construct the basic cond body (consisting of feeding in the predicate to
150   // create pivot nodes).
151   Node* switch_pred;
152   TF_RETURN_IF_ERROR(
153       SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
154                                            graph_->op_registry(), &debug_info_)
155                                    .Input(NodeOut(pred_))
156                                    .Input(NodeOut(pred_))
157                                    .Device(if_op_->requested_device()),
158                                graph_, &switch_pred));
159   control_predecessor_ = switch_pred;
160   TF_RETURN_IF_ERROR(
161       SetColocationAndFinalize(NodeBuilder(NewName("pivot_f"), "Identity",
162                                            graph_->op_registry(), &debug_info_)
163                                    .Input(switch_pred, kElseBranch)
164                                    .Device(if_op_->requested_device()),
165                                graph_, &pivot_f_));
166   TF_RETURN_IF_ERROR(
167       SetColocationAndFinalize(NodeBuilder(NewName("pivot_t"), "Identity",
168                                            graph_->op_registry(), &debug_info_)
169                                    .Input(switch_pred, kThenBranch)
170                                    .Device(if_op_->requested_device()),
171                                graph_, &pivot_t_));
172   return OkStatus();
173 }
174 
NewName(const string & infix)175 string CondBuilder::NewName(const string& infix) {
176   return graph_->NewName(strings::StrCat(name_, "/", infix));
177 }
178 
AddInput(Node * src,int src_output)179 Status CondBuilder::AddInput(Node* src, int src_output) {
180   Node* input;
181   NodeDebugInfo debug_info(*src);
182   // Colocate the Switch node with the `src` node.
183   //
184   // This is to avoid unnecessary Host<->Device copies between src and the
185   // Switch node.
186   //
187   // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize`,
188   // and instead ignore the existing colocation stack. This is aligned with the
189   // legacy impl in control_flow_ops.py. The legacy impl colocates this Switch
190   // with the input tensor which resets the device stack and forces the Switch
191   // to have the same device as the input node (if set) and sets the colocation
192   // _class attr. It also ignores the existing colocation stack in the context
193   // by using colocate_with(ignore_existing=True).
194   TF_RETURN_IF_ERROR(
195       NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry(),
196                   &debug_info)
197           .Input(src, src_output)
198           .Input(pred_)
199           .Device(src->requested_device())
200           .Attr(kColocationAttrName,
201                 {absl::StrCat(kColocationGroupPrefix, src->name())})
202           .Finalize(graph_, &input));
203   then_call_builder_.Input(input, kThenBranch);
204   else_call_builder_.Input(input, kElseBranch);
205   return OkStatus();
206 }
207 
AddInputs()208 Status CondBuilder::AddInputs() {
209   // Add input data edges.
210   std::vector<const Edge*> edges;
211   TF_RETURN_IF_ERROR(if_op_->input_edges(&edges));
212   // Start at index 1 as the first input is the predicate.
213   for (int i = 1; i < edges.size(); ++i) {
214     const Edge* e = edges[i];
215     TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
216   }
217   // Add input control edges.
218   for (const Edge* e : if_op_->in_edges()) {
219     if (e->IsControlEdge()) {
220       graph_->AddControlEdge(e->src(), control_predecessor_);
221     }
222   }
223   return OkStatus();
224 }
225 
AddOutputs()226 Status CondBuilder::AddOutputs() {
227   // Construct the then and else nodes.
228   // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize`
229   // because the colocation for branch nodes is applied in python.
230   TF_RETURN_IF_ERROR(then_call_builder_.Finalize(graph_, &then_call_node_));
231   graph_->AddControlEdge(pivot_t_, then_call_node_);
232   TF_RETURN_IF_ERROR(else_call_builder_.Finalize(graph_, &else_call_node_));
233   graph_->AddControlEdge(pivot_f_, else_call_node_);
234 
235   // Add Merge node for each data output of the If node.
236   std::vector<Node*> merges(then_call_node_->num_outputs());
237   outputs_.resize(merges.size());
238   for (int i = 0; i < then_call_node_->num_outputs(); ++i) {
239     TF_RETURN_IF_ERROR(SetColocationAndFinalize(
240         NodeBuilder(NewName("output"), "Merge", graph_->op_registry(),
241                     &debug_info_)
242             .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
243             .Device(if_op_->requested_device()),
244         graph_, &merges[i]));
245     outputs_[i] = NodeOut(merges[i], 0);
246   }
247 
248   // Add a Merge node that will be used as a control dependency source for the
249   // lowered output node. This Merge node will guarantee that lowered else/then
250   // function calls will be executed even if they do not have data outputs.
251   //
252   // Furthermore it will guarantee that all function side effects will be
253   // executed, if the function will be inlined into the graph. Having data
254   // outputs is not enough, because they might become unused after inlining.
255   //
256   // We will use this node to rewrite outgoing control edges from lowered 'If'
257   // node. All data edges will read tensors directly from Merge nodes.
258   TF_RETURN_IF_ERROR(SetColocationAndFinalize(
259       NodeBuilder(NewName("branch_executed"), "Merge", graph_->op_registry(),
260                   &debug_info_)
261           .Input({pivot_t_, pivot_f_})
262           .ControlInputs({then_call_node_, else_call_node_})
263           .Device(if_op_->requested_device()),
264       graph_, &branch_executed_node_));
265 
266   TF_RETURN_IF_ERROR(BuildLoweredIfOutput());
267 
268   // Add outputs.
269   for (const Edge* e : if_op_->out_edges()) {
270     if (e->IsControlEdge()) {
271       graph_->AddControlEdge(branch_executed_node_, e->dst());
272     } else {
273       // Feed the outputs directly from the merge nodes so that downstream ops
274       // can start before all the outputs have been computed.
275       graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
276     }
277   }
278 
279   return OkStatus();
280 }
281 
BuildLoweredIfOutput()282 Status CondBuilder::BuildLoweredIfOutput() {
283   // If outputs are empty, it means that we might have only output control
284   // edges (already connected to the `branch_executed_node`). Furthermore it's
285   // illegal to have an IdentityN with empty inputs.
286   //
287   // We still must keep lowered If node as a valid source of control edges,
288   // because it might be a part of function control output set.
289   NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
290                             ? NodeBuilder(name_, "IdentityN").Input(outputs_)
291                             : NodeBuilder(name_, "NoOp");
292 
293   return builder.Device(if_op_->requested_device())
294       .ControlInput(branch_executed_node_)
295       .Finalize(graph_, &lowered_if_output_);
296 }
297 
298 }  // namespace
299 
RewriteIfNode(Node * n,Graph * g,bool keep_node_fetchable)300 Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) {
301   VLOG(2) << "Lower If node (keep_node_fetchable=" << keep_node_fetchable
302           << "): " << SummarizeNode(*n);
303 
304   const AttrValue* then_attr = n->attrs().Find("then_branch");
305   if (then_attr == nullptr) {
306     return errors::InvalidArgument("Then branch function missing");
307   }
308   const AttrValue* else_attr = n->attrs().Find("else_branch");
309   if (else_attr == nullptr) {
310     return errors::InvalidArgument("Else branch function missing");
311   }
312 
313   CondBuilder cb(n, then_attr->func(), else_attr->func(), keep_node_fetchable,
314                  g);
315   TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
316   TF_RETURN_IF_ERROR(cb.AddInputs());
317   TF_RETURN_IF_ERROR(cb.AddOutputs());
318   g->RemoveNode(n);
319 
320   return OkStatus();
321 }
322 
323 }  // namespace tensorflow
324