• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
17 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/lib/core/status.h"
23 
24 namespace tensorflow {
25 
26 // Conditional builder.
27 // Convenience builder to make it easy to construct a conditional. E.g.,
28 //   Node* pred = ...;
29 //   CondBuilder cb("cond", g);
30 //   auto switch_var = cb.AddInput("var", DT_RESOURCE);
31 //   g->AddEdge(pred, 0, cb.pred(), 0);
32 // Will create the nodes of a conditional that takes as input a resource
33 // variable ("var") as input and that switches on pred.
34 //
35 // This currently only handles the case needed by distributed_tpu_rewrite_pass
36 // and is not completely general.
37 class CondBuilder {
38  public:
39   enum Branch { kElseBranch = 0, kThenBranch = 1 };
40 
41   CondBuilder(string name, string device, const NodeDebugInfo& debug,
42               Graph* graph);
43 
44   // Returns node corresponding to the predicate input.
45   Node* pred();
46 
47   // Returns node corresponding to switch_f branch of predicate switch.
48   Node* switch_f();
49 
50   // Returns node corresponding to switch_t branch of predicate switch.
51   Node* switch_t();
52 
53   // Returns node corresponding to control successor.
54   Node* control_successor();
55 
56   // Returns the Switch node to feed a value of the given type into the
57   // conditional.
58   Status AddInput(const string& input_name, const DataType& type,
59                   const string& device, const NodeDebugInfo& debug,
60                   Node** input);
61 
62  private:
63   Node* control_successor_;
64   Node* switch_f_;
65   Node* switch_t_;
66   Node* pred_;
67   Graph* const graph_;
68   const string name_;
69   const string device_;
70 };
71 
72 }  // namespace tensorflow
73 
74 #endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
75