1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_ 17 #define TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_ 18 19 #include "tensorflow/core/graph/graph.h" 20 21 namespace tensorflow { 22 23 // Information about a while loop. Every user-defined while loop has an 24 // associated WhileContext, i.e., there is a WhileContext for every execution 25 // frame. Created with the while loop and used during gradient 26 // construction. Note that the gradient graph of while loop contains while loops 27 // itself, but these do not generate separate WhileContexts. 28 // 29 // TODO(skyewm): this is currently insufficient to handle nested loops and 30 // conditionals (and possibly other requirements). This may change a lot in the 31 // future to support these features. 32 // 33 // TODO(skyewm): de/serialize in MetaGraphDef so imported while loops will be 34 // differentiable. Figure out backwards compatibility story. 35 class WhileContext { 36 public: 37 WhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes, 38 std::vector<Node*> exit_nodes, OutputTensor cond_output, 39 std::vector<OutputTensor> body_inputs, 40 std::vector<OutputTensor> body_outputs); 41 frame_name()42 const string& frame_name() const { return frame_name_; } enter_nodes()43 const std::vector<Node*>& enter_nodes() const { return enter_nodes_; } exit_nodes()44 const std::vector<Node*>& exit_nodes() const { return exit_nodes_; } cond_output()45 const OutputTensor& cond_output() const { return cond_output_; } body_inputs()46 const std::vector<OutputTensor>& body_inputs() const { return body_inputs_; } body_outputs()47 const std::vector<OutputTensor>& body_outputs() const { 48 return body_outputs_; 49 } 50 51 private: 52 // Each user-defined while loop defines a new execution frame, which is 53 // uniquely identified by its frame name. Frames are used by the executor to 54 // manage the iterations of a loop. See the FrameState comment in 55 // core/common_runtime/executor.cc for more details. 56 const string frame_name_; 57 58 // The enter nodes defining the input loop variables to the while loop. This 59 // vector defines the order of the loop variables. 60 const std::vector<Node*> enter_nodes_; 61 62 // The exit nodes defining the outputs of the while loop. These are in loop 63 // variable order. 64 const std::vector<Node*> exit_nodes_; 65 66 // The boolean output of the loop predicate. 67 const OutputTensor cond_output_; 68 69 // The inputs and outputs to the loop body. 70 const std::vector<OutputTensor> body_inputs_; 71 const std::vector<OutputTensor> body_outputs_; 72 }; 73 74 } // namespace tensorflow 75 76 #endif // TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_ 77