1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 17 #define TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 18 19 #include "tensorflow/cc/framework/ops.h" 20 #include "tensorflow/cc/framework/scope.h" 21 22 namespace tensorflow { 23 namespace ops { 24 25 // Function that takes cond graph inputs and returns cond graph boolean output. 26 // 'output' need not be set if an error is returned. 27 typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, 28 Output* output)> 29 CondGraphBuilderFn; 30 31 // Function that takes body graph inputs and returns body graph outputs. 32 // 'outputs' need not be populated if an error is returned. 33 typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, 34 std::vector<Output>* outputs)> 35 BodyGraphBuilderFn; 36 37 // Constructs a while loop. 38 // 39 // Arguments: 40 // * scope: used to construct the while loop. 41 // * inputs: the initial values of the loop variables. Must be non-empty. 42 // * cond: a function that builds the condition graph of the loop. Takes the 43 // current loop variables as inputs and returns a scalar boolean Output 44 // indicating whether the loop should continue. 45 // * body: a function that builds the body graph of the loop. Takes the current 46 // loop variables as inputs and returns the updated loop variables. 47 // * frame_name: the frame name to use for this while loop. This should be a 48 // unique name. This will be used as a prefix for created operations. 49 // * outputs: output param that returns final loop variable outputs in non-error 50 // case. Must be non-null and empty. 51 // * create_while_ctx: if true, a WhileContext is created and populated for this 52 // loop. See core/graph/while_context.h for more details on 53 // WhileContexts. This is set to false for loops used as part of gradient 54 // computations, since they're part of the gradient for a loop in the 55 // forward-pass. 56 // TODO(skyewm): revisit this. Should we create WhileContexts for all loops, 57 // even if we don't need them? 58 // * cond_output: if non-null, the output of the predicate is returned. This 59 // will always be a LoopCond node. 60 // 61 // Returns an error if the while loop could not be fully constructed. 62 // 63 // TODO(skyewm): clean up partially-constructed loop in error case 64 // TODO(skyewm): create public interface to this method 65 Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, 66 const CondGraphBuilderFn& cond, 67 const BodyGraphBuilderFn& body, const string& frame_name, 68 OutputList* outputs, bool create_while_ctx = true, 69 Output* cond_output = nullptr); 70 71 } // namespace ops 72 } // namespace tensorflow 73 74 #endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 75