1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 17 #define TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/cc/framework/ops.h" 23 #include "tensorflow/cc/framework/scope.h" 24 25 namespace tensorflow { 26 namespace ops { 27 28 // Function that takes cond graph inputs and returns cond graph boolean output. 29 // 'output' need not be set if an error is returned. 30 typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, 31 Output* output)> 32 CondGraphBuilderFn; 33 34 // Function that takes body graph inputs and returns body graph outputs. 35 // 'outputs' need not be populated if an error is returned. 36 typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, 37 std::vector<Output>* outputs)> 38 BodyGraphBuilderFn; 39 40 // Constructs a while loop. 41 // 42 // Arguments: 43 // * scope: used to construct the while loop. 44 // * inputs: the initial values of the loop variables. Must be non-empty. 45 // * cond: a function that builds the condition graph of the loop. Takes the 46 // current loop variables as inputs and returns a scalar boolean Output 47 // indicating whether the loop should continue. 48 // * body: a function that builds the body graph of the loop. Takes the current 49 // loop variables as inputs and returns the updated loop variables. 50 // * frame_name: the frame name to use for this while loop. This should be a 51 // unique name. This will be used as a prefix for created operations. 52 // * outputs: output param that returns final loop variable outputs in non-error 53 // case. Must be non-null and empty. 54 // * create_while_ctx: if true, a WhileContext is created and populated for this 55 // loop. See core/graph/while_context.h for more details on 56 // WhileContexts. This is set to false for loops used as part of gradient 57 // computations, since they're part of the gradient for a loop in the 58 // forward-pass. 59 // TODO(skyewm): revisit this. Should we create WhileContexts for all loops, 60 // even if we don't need them? 61 // * cond_output: if non-null, the output of the predicate is returned. This 62 // will always be a LoopCond node. 63 // 64 // Returns an error if the while loop could not be fully constructed. 65 // 66 // TODO(skyewm): clean up partially-constructed loop in error case 67 // TODO(skyewm): create public interface to this method 68 Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, 69 const CondGraphBuilderFn& cond, 70 const BodyGraphBuilderFn& body, const string& frame_name, 71 OutputList* outputs, bool create_while_ctx = true, 72 Output* cond_output = nullptr); 73 74 } // namespace ops 75 } // namespace tensorflow 76 77 #endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_ 78