1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
17
18 #include "tensorflow/compiler/xla/status_macros.h"
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
22
23 namespace tensorflow {
24
CondBuilder(string name,string device,const NodeDebugInfo & debug,Graph * graph)25 CondBuilder::CondBuilder(string name, string device, const NodeDebugInfo& debug,
26 Graph* graph)
27 : graph_(graph), name_(std::move(name)), device_(std::move(device)) {
28 auto new_name = [graph, this](string suffix) {
29 return graph->NewName(strings::StrCat(name_, "/", suffix));
30 };
31 TF_CHECK_OK(
32 IncompleteNodeDefBuilder::Identity(new_name("pred"), DT_BOOL, debug)
33 .Device(device_)
34 .Build(graph_, &pred_));
35 Node* switch_pred;
36 TF_CHECK_OK(
37 IncompleteNodeDefBuilder::Switch(new_name("switch_pred"), DT_BOOL, debug)
38 .Device(device_)
39 .Build(graph_, &switch_pred));
40 graph_->AddEdge(pred(), 0, switch_pred, 0);
41 graph_->AddEdge(pred(), 0, switch_pred, 1);
42 TF_CHECK_OK(
43 IncompleteNodeDefBuilder::Identity(new_name("switch_f"), DT_BOOL, debug)
44 .Device(device_)
45 .Build(graph_, &switch_f_));
46 TF_CHECK_OK(
47 IncompleteNodeDefBuilder::Identity(new_name("switch_t"), DT_BOOL, debug)
48 .Device(device_)
49 .Build(graph_, &switch_t_));
50 graph_->AddEdge(switch_pred, kElseBranch, switch_f_, 0);
51 graph_->AddEdge(switch_pred, kThenBranch, switch_t_, 0);
52 Node* merge_pred;
53 TF_CHECK_OK(IncompleteNodeDefBuilder::Merge(new_name("merge_pred"), DT_BOOL,
54 debug, /*n=*/2)
55 .Device(device_)
56 .Build(graph_, &merge_pred));
57 graph_->AddEdge(switch_f_, 0, merge_pred, kElseBranch);
58 graph_->AddEdge(switch_t_, 0, merge_pred, kThenBranch);
59 // Note: when additional return values are added then there should be a
60 // control dependency between those merge nodes and control_successor_ to
61 // ensure that it is control successor of conditional.
62 control_successor_ = merge_pred;
63 }
64
pred()65 Node* CondBuilder::pred() { return pred_; }
66
switch_f()67 Node* CondBuilder::switch_f() { return switch_f_; }
68
switch_t()69 Node* CondBuilder::switch_t() { return switch_t_; }
70
control_successor()71 Node* CondBuilder::control_successor() { return control_successor_; }
72
AddInput(const string & input_name,const DataType & type,const string & device,const NodeDebugInfo & debug,Node ** input)73 Status CondBuilder::AddInput(const string& input_name, const DataType& type,
74 const string& device, const NodeDebugInfo& debug,
75 Node** input) {
76 auto b = IncompleteNodeDefBuilder::Switch(
77 graph_->NewName(strings::StrCat(name_, "/", input_name)), type, debug);
78 TF_RETURN_IF_ERROR(b.Device(device).Build(graph_, input));
79 graph_->AddEdge(pred(), 0, *input, 1);
80 return Status::OK();
81 }
82
83 } // namespace tensorflow
84