1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ 17 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/graph/graph.h" 22 #include "tensorflow/core/lib/core/status.h" 23 24 namespace tensorflow { 25 26 // Conditional builder. 27 // Convenience builder to make it easy to construct a conditional. E.g., 28 // Node* pred = ...; 29 // CondBuilder cb("cond", g); 30 // auto switch_var = cb.AddInput("var", DT_RESOURCE); 31 // g->AddEdge(pred, 0, cb.pred(), 0); 32 // Will create the nodes of a conditional that takes as input a resource 33 // variable ("var") as input and that switches on pred. 34 // 35 // This currently only handles the case needed by distributed_tpu_rewrite_pass 36 // and is not completely general. 37 class CondBuilder { 38 public: 39 enum Branch { kElseBranch = 0, kThenBranch = 1 }; 40 41 CondBuilder(string name, string device, const NodeDebugInfo& debug, 42 Graph* graph); 43 44 // Returns node corresponding to the predicate input. 45 Node* pred(); 46 47 // Returns node corresponding to switch_f branch of predicate switch. 48 Node* switch_f(); 49 50 // Returns node corresponding to switch_t branch of predicate switch. 51 Node* switch_t(); 52 53 // Returns node corresponding to control successor. 54 Node* control_successor(); 55 56 // Returns the Switch node to feed a value of the given type into the 57 // conditional. 58 Status AddInput(const string& input_name, const DataType& type, 59 const string& device, const NodeDebugInfo& debug, 60 Node** input); 61 62 private: 63 Node* control_successor_; 64 Node* switch_f_; 65 Node* switch_t_; 66 Node* pred_; 67 Graph* const graph_; 68 const string name_; 69 const string device_; 70 }; 71 72 } // namespace tensorflow 73 74 #endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ 75