• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/framework/while_gradients.h"
17 
18 #include "tensorflow/cc/framework/gradients.h"
19 #include "tensorflow/cc/framework/scope_internal.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/cc/ops/while_loop.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 using ops::BodyGraphBuilderFn;
28 using ops::BuildWhileLoop;
29 using ops::CondGraphBuilderFn;
30 
ToOutput(OutputTensor output_tensor)31 Output ToOutput(OutputTensor output_tensor) {
32   return Output(const_cast<Node*>(output_tensor.node), output_tensor.index);
33 }
34 
ToOutputVector(const std::vector<OutputTensor> & output_tensors)35 std::vector<Output> ToOutputVector(
36     const std::vector<OutputTensor>& output_tensors) {
37   size_t n = output_tensors.size();
38   std::vector<Output> result;
39   result.reserve(n);
40   for (int i = 0; i < n; ++i) result.push_back(ToOutput(output_tensors[i]));
41   return result;
42 }
43 
44 // The backprop loop counter and main backprop loop run in their own execution
45 // frame (conceptually, the main forward loop and forward loop counter run
46 // together in a frame, then the backprop loop counter and backprop loop run
47 // together in a different frame). This returns the frame name to use for the
48 // backprop while loops.
49 // TODO(skyewm): make sure this is unique among existing frame names
BackPropFrameName(const string & forward_frame_name)50 string BackPropFrameName(const string& forward_frame_name) {
51   return strings::StrCat(forward_frame_name, "_backprop");
52 }
53 
54 // Creates a loop that counts the number of iterations performed by the
55 // while loop associated with `while_ctx`. The returned output yields the
56 // iteration count.
AddForwardLoopCounter(WhileContext * while_ctx,const Scope & scope,Output * count)57 Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope,
58                              Output* count) {
59   // Create while loop:
60   //   i = 0
61   //   while forward loop predicate is true:
62   //     ++i
63 
64   Output zero = ops::Const(scope, 0, {});
65 
66   // Condition function that returns condition output from original while loop.
67   CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope,
68                                            const std::vector<Output>& inputs,
69                                            Output* output) {
70     *output = ToOutput(while_ctx->cond_output());
71     return Status::OK();
72   };
73 
74   // Body function that adds one to input.
75   BodyGraphBuilderFn body_fn = [](const Scope& scope,
76                                   const std::vector<Output>& inputs,
77                                   std::vector<Output>* outputs) {
78     DCHECK_EQ(inputs.size(), 1);
79     outputs->emplace_back(ops::Add(scope, inputs[0], 1));
80     return scope.status();
81   };
82 
83   // Note that this loop runs in the same execution frame as the forward loop.
84   std::vector<Output> outputs;
85   TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn,
86                                     while_ctx->frame_name(), &outputs,
87                                     /* create_while_ctx */ false));
88   *count = outputs[0];
89   return Status::OK();
90 }
91 
92 // Creates a loop that executes `loop_count` times. The returned output is the
93 // boolean predicate indicating if the loop is still executing. This is used to
94 // drive the gradient computation for the while loop associated with
95 // `while_ctx`.
AddBackPropLoopCounter(WhileContext * while_ctx,const Output & loop_count,const Scope & scope,Output * backprop_execution_pred)96 Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count,
97                               const Scope& scope,
98                               Output* backprop_execution_pred) {
99   // Create while loop:
100   //   n = loop_count
101   //   while n > 0:
102   //     --n
103 
104   // Condition function that returns input > 0.
105   CondGraphBuilderFn cond_fn = [](const Scope& scope,
106                                   const std::vector<Output>& inputs,
107                                   Output* output) {
108     DCHECK_EQ(inputs.size(), 1);
109     *output = ops::Greater(scope, inputs[0], 0);
110     return scope.status();
111   };
112 
113   // Body function that subtracts one from input.
114   BodyGraphBuilderFn body_fn = [](const Scope& scope,
115                                   const std::vector<Output>& inputs,
116                                   std::vector<Output>* outputs) {
117     DCHECK_EQ(inputs.size(), 1);
118     outputs->emplace_back(ops::Subtract(scope, inputs[0], 1));
119     return scope.status();
120   };
121 
122   string frame_name = BackPropFrameName(while_ctx->frame_name());
123   std::vector<Output> outputs;
124   TF_RETURN_IF_ERROR(BuildWhileLoop(
125       scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs,
126       /* create_while_ctx */ false, backprop_execution_pred));
127   return Status::OK();
128 }
129 
130 // Creates the main backprop loop that computes the gradient of the loop
131 // associated with `while_ctx`. `grad_inputs` are the partial derivatives
132 // w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is
133 // the predicate to use for the backprop loop (see AddBackPropLoopCounter()).
134 // The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are
135 // returned in `grad_outputs`.
AddWhileGradientLoop(WhileContext * while_ctx,const std::vector<Output> & grad_inputs,const Output & backprop_execution_pred,const Scope & parent_scope,std::vector<Output> * grad_outputs)136 Status AddWhileGradientLoop(WhileContext* while_ctx,
137                             const std::vector<Output>& grad_inputs,
138                             const Output& backprop_execution_pred,
139                             const Scope& parent_scope,
140                             std::vector<Output>* grad_outputs) {
141   DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size());
142   DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size());
143 
144   Scope scope = parent_scope.NewSubScope("while");
145 
146   // Create while loop:
147   //   while backprop_execution_pred:
148   //     forward loop body gradient
149 
150   // Condition function that returns 'backprop_execution_pred'.
151   CondGraphBuilderFn cond_fn = [backprop_execution_pred](
152                                    const Scope& scope,
153                                    const std::vector<Output>& inputs,
154                                    Output* output) {
155     *output = backprop_execution_pred;
156     return Status::OK();
157   };
158 
159   // Body function that builds while body gradient subgraph.
160   BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope,
161                                            const std::vector<Output>& inputs,
162                                            std::vector<Output>* outputs) {
163     std::vector<Output> body_outputs =
164         ToOutputVector(while_ctx->body_outputs());
165     std::vector<Output> body_inputs = ToOutputVector(while_ctx->body_inputs());
166     return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs,
167                                 outputs);
168   };
169 
170   string frame_name = BackPropFrameName(while_ctx->frame_name());
171   TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn,
172                                     frame_name, grad_outputs,
173                                     /* create_while_ctx */ false));
174   return Status::OK();
175 }
176 
177 }  // namespace
178 
AddWhileLoopGradient(WhileContext * while_ctx,const Scope & scope,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)179 Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope,
180                             const std::vector<Output>& grad_inputs,
181                             std::vector<Output>* grad_outputs) {
182   Output forward_loop_count;
183   TF_RETURN_IF_ERROR(AddForwardLoopCounter(
184       while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count));
185 
186   // TODO(skyewm): can we combine the backprop loop counter and main gradient
187   // loop into a single loop? The original Python code doesn't combine the
188   // loops, but I'm not sure why.
189   Output backprop_counter_cond;
190   TF_RETURN_IF_ERROR(AddBackPropLoopCounter(
191       while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"),
192       &backprop_counter_cond));
193 
194   return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond,
195                               scope, grad_outputs);
196 }
197 
198 }  // namespace tensorflow
199