• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_TAPE_H_
16 #define TENSORFLOW_C_EAGER_TAPE_H_
17 
18 // Language-agnostic gradient tape. Does not perform backpropagation, just
19 // maintains the data structures required to do so.
20 
21 #include <stack>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 namespace eager {
36 
37 // Represents an entry in the tape.
38 template <typename BackwardFunction, typename TapeTensor>
39 struct OpTapeEntry {
40   string op_type;
41   std::vector<TapeTensor> output_tensor_info;
42   std::vector<int64> input_tensor_id;
43 
44   // TODO(apassos) consider narrowing down this interface.
45   BackwardFunction* backward_function;
46 
47   // Should be called before deleting the backward function. TODO(apassos) use
48   // unique_ptrs to ensure this happens.
49   std::function<void(BackwardFunction*)> backward_function_deleter;
50 };
51 
52 // Map from tensor_id to internally-defined operation-id of the operation which
53 // produced this tensor. A value of -1 means that the tensor was directly
54 // watched and not the result of any operation in the tape.
55 using TensorTape = std::unordered_map<int64, int64>;
56 
57 // Map from operation-id to tape entry.
58 template <typename BackwardFunction, typename TapeTensor>
59 using OpTape =
60     std::unordered_map<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
61 
62 // Operations the tape needs to perform on tensors to do backpropagation. Named
63 // "vspace" because a subset of these are related to a vector space, such as
64 // adding gradients, getting zeroes, etc. Currently cannot be implemented
65 // without using tensorflow python code, hence left unspecified here.
66 //
67 // Gradient is the type returned by gradient functions. In Python TF it's either
68 // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
69 // to allow their size to be computed and they need to be passable to a backward
70 // function and deleted (as the backprop code creates lots of gradients the user
71 // is not interested in).
72 //
73 // BackwardFunction needs to be a closure which stores intermediate activations
74 // from the forward computation and calls a vector-jacobian product function
75 // (also known as adjoint function) to compute, given downstream gradients,
76 // upstream gradients.
77 //
78 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
79 // specialization, which is blocked by quite a few things needing to loop back
80 // into python now.
81 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
82 class VSpace {
83  public:
~VSpace()84   virtual ~VSpace() {}
85 
86   // Returns the number of elements in the gradient tensor.
87   virtual int64 NumElements(Gradient* tensor) const = 0;
88 
89   // Consumes references to the tensors in the gradient_tensors list and returns
90   // a tensor with the result.
91   virtual Gradient* AggregateGradients(
92       gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
93 
94   // Calls the passed-in backward function.
95   virtual Status CallBackwardFunction(
96       BackwardFunction* backward_function,
97       const std::vector<int64>& unneeded_gradients,
98       gtl::ArraySlice<Gradient*> output_gradients,
99       std::vector<Gradient*>* result) const = 0;
100 
101   // Looks up the ID of a Gradient.
102   virtual int64 TensorId(Gradient* tensor) const = 0;
103 
104   // Converts a Gradient to a TapeTensor.
105   virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0;
106 
107   // Marks the following gradient as a result so it's not consumed by backward
108   // functions.
109   virtual void MarkAsResult(Gradient* gradient) const = 0;
110 
111   // Deletes the input tensor.
112   virtual void DeleteGradient(Gradient* gradient) const = 0;
113 };
114 
115 // Traces the execution of operations, doing eager garbage collection, and
116 // exporting a full trace so other code can do backpropagation. Not thread-safe.
117 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
118 class GradientTape {
119  public:
120   // If `persistent` is true, GradientTape will not eagerly delete backward
121   // functions (and hence the tensors they keep alive). Instead, everything
122   // is deleted in ~GradientTape. Persistent GradientTapes are useful when
123   // users want to compute multiple gradients over the same tape.
GradientTape(bool persistent)124   GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape()125   ~GradientTape() {
126     for (const auto& pair : op_tape_) {
127       pair.second.backward_function_deleter(pair.second.backward_function);
128     }
129   }
130 
131   bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
132                     gtl::ArraySlice<tensorflow::DataType> dtypes);
133 
134   void Watch(int64 tensor_id);
135 
136   void RecordOperation(
137       const string& op_type, const std::vector<TapeTensor>& output_tensors,
138       gtl::ArraySlice<int64> input_tensor_id,
139       gtl::ArraySlice<tensorflow::DataType> input_dtypes,
140       const std::function<BackwardFunction*()>& backward_function_getter,
141       const std::function<void(BackwardFunction*)>& backward_function_deleter);
142 
143   void DeleteTrace(int64 tensor_id);
144 
145   // Consumes the internal state of the tape (so cannot be called more than
146   // once) and produces the gradient of the target tensors with respect to the
147   // source tensors. The output gradients are used if not empty and not
148   // null. The result is populated with one tensor per target element.
149   Status ComputeGradient(
150       const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
151       const gtl::ArraySlice<int64> target_tensor_ids,
152       const gtl::ArraySlice<int64> source_tensor_ids,
153       const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
154       gtl::ArraySlice<Gradient*> output_gradients,
155       std::vector<Gradient*>* result);
156 
IsPersistent()157   bool IsPersistent() const { return persistent_; }
158 
159  private:
160   TensorTape tensor_tape_;
161   OpTape<BackwardFunction, TapeTensor> op_tape_;
162   int64 next_op_id_{0};
163 
164   // Map from tensor id to number of remaining usages (i.e. how many entries in
165   // the tape refer to it); to aid in tape garbage collection.
166   std::unordered_map<int64, int64> tensor_usage_;
167 
168   // If false, all activations are deleted in the first call to ComputeGradient.
169   // Else, only when this is destructed.
170   bool persistent_;
171 };
172 
173 // Describes a callback for special-cased and more efficient jvp computation.
174 //
175 // Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on
176 // that.
177 template <typename Gradient>
178 class ForwardFunction
179     : public std::function<Status(const std::vector<Gradient*>&,
180                                   std::vector<Gradient*>*)> {
181  public:
182   template <typename lambda_type>
ForwardFunction(lambda_type lambda)183   explicit ForwardFunction(lambda_type lambda)
184       : std::function<Status(const std::vector<Gradient*>&,
185                              std::vector<Gradient*>*)>(lambda) {}
186 };
187 
188 // Computes Jacobian-vector products using forward-mode automatic
189 // differentiation.
190 //
191 // While GradientTape's RecordOperation is trivial, ForwardAccumulator's
192 // Accumulate runs the gradient computation immediately.
193 //
194 // Keeps references to Tensors watched via Watch and computed in Accumulate
195 // corresponding to output_tensors, and releases these references in its
196 // destructor. However, waiting until the destructor runs loses the memory
197 // efficiency of forward-mode autodiff. Instead, language bindings should call
198 // DeleteGradient as soon as a Tensor which was `Watch`ed or was an output
199 // Tensor passed to Accumulate goes out of scope.
200 //
201 // Not thread-safe.
202 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
203 class ForwardAccumulator {
204  public:
205   // Does not take ownership of `vspace`, which must outlive the
206   // ForwardAccumulator.
ForwardAccumulator(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace)207   explicit ForwardAccumulator(
208       const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
209       : vspace_(vspace) {
210     call_state_.emplace(nullptr, false);
211   }
212 
~ForwardAccumulator()213   virtual ~ForwardAccumulator() {
214     for (auto accumulated : accumulated_gradients_) {
215       vspace_.DeleteGradient(accumulated.second);
216     }
217   }
218 
219   // Tell the forward accumulator to watch tensor_id, with a Tensor tangent
220   // vector `tangent` of matching shape and dtype. Tangents are the "vector" in
221   // "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling
222   // FetchJVP for it would return `tangent`.
223   void Watch(int64 tensor_id, Gradient* tangent);
224 
225   // Removes the gradient associated with tensor_id. Should be called when the
226   // Tensor associated with `tensor_id` is deleted.
227   void DeleteGradient(int64 tensor_id);
228 
229   // Runs forward autodiff. Should be called whenever a new operation is
230   // available and the accumulator is active.
231   //
232   // Like GradientTape::RecordOperation, this method takes the operation type
233   // `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`,
234   // `input_tensor_id`, and `input_dtypes`; the latter two are somewhat
235   // redundant but taken as arguments to avoid repeatedly fetching these values
236   // between calls to ShouldRecord and Accumulator), and its outputs
237   // (`output_tensors`).
238   //
239   // If provided, a non-null `forward_function` will be used instead of the
240   // backward function (`backward_function_getter` /
241   // `backward_function_deleter`) to compute jvps for this operation. If
242   // `forward_function` is null, a GradientTape is used on the backward function
243   // to compute the jvp, which will waste computation when executing eagerly.
244   //
245   // Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
246   // immediately. It stores the results, which feed into Accumulate for future
247   // operations and may be fetched by calling FetchJVP. ForwardAccumulator
248   // maintains a reference to these JVPs: if an `output_tensors` Tensor is
249   // deleted, `DeleteGradient` should be called as soon as possible to free the
250   // (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor
251   // will release remaining references.
252   //
253   // This method is not thread-safe (and in general ForwardAccumulator is not
254   // thread-safe).
255   Status Accumulate(
256       const string& op_type, const std::vector<TapeTensor>& input_tensors,
257       const std::vector<TapeTensor>& output_tensors,
258       gtl::ArraySlice<int64> input_tensor_id,
259       gtl::ArraySlice<tensorflow::DataType> input_dtypes,
260       const ForwardFunction<Gradient>* forward_function,
261       const std::function<BackwardFunction*()>& backward_function_getter,
262       const std::function<void(BackwardFunction*)>& backward_function_deleter);
263 
264   // Returns true if `Accumulate` is active somewhere above on the stack and
265   // there isn't an intervening PushState. This is useful for ordering
266   // ForwardAccumulators, where more deeply nested accumulators should not see
267   // computations from less deeply nested accumulators.
BusyAccumulating()268   bool BusyAccumulating() const { return call_state_.top().accumulating; }
269 
270   // Fetches the current Jacobian-vector product associated with `tensor_id`, or
271   // a nullptr if none is available.
272   //
273   // Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its
274   // return value. The caller should increment the reference count before
275   // deleting the ForwardAccumulator or calling DeleteGradient if keeping a
276   // persistent reference to a non-null result.
277   Gradient* FetchJVP(int64 tensor_id);
278 
279   // Indicates whether the forward accumulator should run on an operation with
280   // the specified inputs and dtypes.
281   bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
282                     gtl::ArraySlice<tensorflow::DataType> dtypes);
283 
284   // Temporarily push or pop transient state for this accumulator.
285   //
286   // Allows an accumulator which is currently processing an operation to
287   // temporarily reset its state. Without pushing and popping, accumulators
288   // ignore operations executed as a direct result of their own jvp
289   // computations.
PushState()290   void PushState() { call_state_.emplace(nullptr, false); }
PopState()291   void PopState() { call_state_.pop(); }
292 
293  private:
294   // Helper for Accumulate: uses a GradientTape to compute forward gradients
295   // from a backward gradient function. Fills `out_grads` corresponding to
296   // `output_tensors`. `out_grads` must not be null.
297   //
298   // Executes the backward function in order to trace its gradient, which will
299   // waste computation if executing eagerly (when graph building the unneeded
300   // computation is pruned). Temporarily sets `backward_tape` so that
301   // Accumulate will forward op executions to the tape while the backward
302   // function is running; this effectively adds the backward tape to the active
303   // set (but does not require complicated callbacks to the language bindings).
304   Status ForwardpropFromTape(
305       const std::vector<TapeTensor>& output_tensors,
306       const std::function<BackwardFunction*()>& backward_function_getter,
307       const std::function<void(BackwardFunction*)>& backward_function_deleter,
308       const std::vector<Gradient*>& in_grads,
309       std::vector<Gradient*>* out_grads);
310 
311   // Maps from tensor IDs to corresponding JVPs.
312   std::unordered_map<int64, Gradient*> accumulated_gradients_;
313   // Not owned; provides operations on Tensors which are currently only
314   // available in language bindings (e.g. Python).
315   const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
316 
317   struct AccumulatorCallState {
AccumulatorCallStateAccumulatorCallState318     AccumulatorCallState(
319         GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
320         bool accumulating)
321         : backward_tape(backward_tape), accumulating(accumulating) {}
322     // Set temporarily while in the Accumulate method; if backward_tape is not
323     // nullptr then we forward op executions to it so Accumulate can compute a
324     // backward pass on its backward function.
325     //
326     // Not owned by the ForwardAccumulator. The method which sets
327     // `backward_tape` keeps ownership.
328     GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
329     // While the Accumulate method is running (accumulating is True), any op
330     // executions not forwarded to backward_tape should be ignored.
331     bool accumulating;
332   };
333   // A deque-backed stack, whose element references are not invalidated by
334   // pushes and pops at the back.
335   std::stack<AccumulatorCallState> call_state_;
336 };
337 
338 // Template instantiations here
339 
IsDtypeTrainable(DataType dtype)340 inline bool IsDtypeTrainable(DataType dtype) {
341   switch (dtype) {
342     case DT_HALF:
343     case DT_BFLOAT16:
344     case DT_FLOAT:
345     case DT_DOUBLE:
346     case DT_COMPLEX64:
347     case DT_COMPLEX128:
348     case DT_RESOURCE:
349     case DT_VARIANT:
350       return true;
351     default:
352       return false;
353   }
354 }
355 
356 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)357 bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
358     gtl::ArraySlice<int64> tensor_ids,
359     gtl::ArraySlice<tensorflow::DataType> dtypes) {
360   CHECK_EQ(tensor_ids.size(), dtypes.size());
361   for (int i = 0; i < tensor_ids.size(); ++i) {
362     if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
363       if (IsDtypeTrainable(dtypes[i])) {
364         return true;
365       }
366     }
367   }
368   return false;
369 }
370 
371 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id)372 void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
373     int64 tensor_id) {
374   tensor_tape_.emplace(tensor_id, -1);
375 }
376 
377 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
RecordOperation(const string & op_type,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)378 void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
379     const string& op_type, const std::vector<TapeTensor>& output_tensors,
380     gtl::ArraySlice<int64> input_tensor_id,
381     gtl::ArraySlice<tensorflow::DataType> input_dtypes,
382     const std::function<BackwardFunction*()>& backward_function_getter,
383     const std::function<void(BackwardFunction*)>& backward_function_deleter) {
384   if (!ShouldRecord(input_tensor_id, input_dtypes)) {
385     return;
386   }
387   std::vector<int64> ids;
388   ids.reserve(input_tensor_id.size());
389   for (int64 i : input_tensor_id) {
390     tensor_usage_[i]++;
391     ids.push_back(i);
392   }
393   const int64 op_id = next_op_id_++;
394   std::vector<TapeTensor> tensors;
395   tensors.reserve(output_tensors.size());
396   for (const TapeTensor& o : output_tensors) {
397     // Note: the tensor can have already been watched and hence be in the tape,
398     // so we cannot check that we're inserting it here.
399     tensor_tape_[o.GetID()] = op_id;
400     tensor_usage_[o.GetID()] = 1;
401     tensors.push_back(o);
402   }
403   op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
404       op_type, std::move(tensors), std::move(ids), backward_function_getter(),
405       backward_function_deleter};
406 }
407 
408 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteTrace(int64 tensor_id)409 void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
410     int64 tensor_id) {
411   auto it = tensor_usage_.find(tensor_id);
412   if (it == tensor_usage_.end()) {
413     return;
414   }
415   it->second--;
416   if (it->second != 0) {
417     return;
418   }
419   tensor_usage_.erase(it);
420   auto tensor_op_it = tensor_tape_.find(tensor_id);
421   if (tensor_op_it == tensor_tape_.end()) {
422     return;
423   }
424   const int64 op_id = tensor_op_it->second;
425   if (op_id == -1) {
426     // Do not delete watched tensors.
427     return;
428   }
429   tensor_tape_.erase(tensor_op_it);
430   auto op_it = op_tape_.find(op_id);
431   CHECK(op_it != op_tape_.end());
432   for (const auto& output : op_it->second.output_tensor_info) {
433     if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
434       // Found a usage for an output, so cannot delete the op.
435       return;
436     }
437   }
438   for (int64 id : op_it->second.input_tensor_id) {
439     DeleteTrace(id);
440   }
441   op_it->second.backward_function_deleter(op_it->second.backward_function);
442   op_tape_.erase(op_it);
443 }
444 
445 // Terminology:
446 //
447 //  - op: a possibly composite operation, which has an entry in the tape
448 //  - target: dy in dx/dy
449 //  - source: dx in dx/dy
450 //  - tensor: one of the many inputs or outputs of an operation
451 //
452 // Below here we do the gradient algorithm. It works as follows:
453 //
454 // First we filter the tape to just the subset of operations we want to
455 // differentiate. In the process of doing so we count how many times each Tensor
456 // is used as an input to an op (so we know when we're done computing gradients
457 // for that Tensor). We also count, for each tape entry, how many of its output
458 // Tensors need gradients to be computed (Tensors which are not used do not need
459 // any gradients to be computed).
460 //
461 // Finally, we start a backprop stack with a set of tape entries for which we
462 // have all gradients available. This set usually is a subset of the set of
463 // targets (not all since targets which have outputs in the tape will not have
464 // gradients available initially).
465 //
466 // Then we repeatedly pop an entry from the stack, run its backprop, and update
467 // the gradients of its inputs. Once we have computed all gradients for a single
468 // input we can mark this input as done, and this can trigger adding an entry to
469 // the stack if all outputs of that entry are now done.
470 //
471 // When the stack is empty we have gradients for all tensors we're interested
472 // in.
473 
474 namespace {
475 
476 template <typename BackwardFunction, typename TapeTensor>
477 struct BackpropInitialState {
478   OpTape<BackwardFunction, TapeTensor> op_tape;
479 
480   // Map from tensor ID to how many references still exist for this tensor in
481   // the tape.
482   std::unordered_map<int64, int64> tensor_usage_counts;
483 
484   // Maps from op ID to how many output tensors of this op still need to have
485   // their gradients computed.
486   std::unordered_map<int64, int64> op_missing_tensor;
487 };
488 
489 // If `persistent_tape` is true, op_tape is not changed and none of the
490 // backwards functions are deleted.
491 // If `persistent_tape` is false, op_tape is cleared and backwards functions
492 // not needed for gradient computation are deleted. Backwards functions that
493 // are needed, are copied and returned in BackpropInitialState.
494 template <typename BackwardFunction, typename TapeTensor>
PrepareBackprop(gtl::ArraySlice<int64> target,const TensorTape & tensor_tape,OpTape<BackwardFunction,TapeTensor> * op_tape,const std::unordered_set<int64> & sources_set,bool persistent_tape)495 BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
496     gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
497     OpTape<BackwardFunction, TapeTensor>* op_tape,
498     const std::unordered_set<int64>& sources_set, bool persistent_tape) {
499   std::vector<int64> tensor_stack;
500   tensor_stack.reserve(target.size());
501   for (auto t : target) {
502     tensor_stack.push_back(t);
503   }
504   BackpropInitialState<BackwardFunction, TapeTensor> result;
505   while (!tensor_stack.empty()) {
506     int64 tensor_id = tensor_stack.back();
507     tensor_stack.pop_back();
508     auto op_id_it = tensor_tape.find(tensor_id);
509     if (op_id_it == tensor_tape.end()) {
510       continue;
511     }
512     int64 op_id = op_id_it->second;
513     auto op_it = op_tape->find(op_id);
514     auto result_op_it = result.op_tape.find(op_id);
515     if (op_id == -1 || op_it == op_tape->end() ||
516         result_op_it != result.op_tape.end()) {
517       continue;
518     }
519     CHECK(result.op_tape.emplace(op_id, op_it->second).second);
520     for (auto it : op_it->second.input_tensor_id) {
521       auto count_it = result.tensor_usage_counts.find(it);
522       if (count_it != result.tensor_usage_counts.end()) {
523         count_it->second++;
524       } else {
525         result.tensor_usage_counts[it] = 1;
526         if (tensor_tape.find(it) != tensor_tape.end()) {
527           tensor_stack.push_back(it);
528         }
529       }
530     }
531     if (!persistent_tape) {
532       op_tape->erase(op_it);
533     }
534   }
535   for (auto& pair : result.tensor_usage_counts) {
536     auto it = tensor_tape.find(pair.first);
537     if (it != tensor_tape.end() && it->second != -1) {
538       result.op_missing_tensor[it->second] += 1;
539     }
540   }
541   if (!persistent_tape) {
542     // Call destructors for all unneeded gradient functions and
543     // clear the op_tape. We can clear the tape because ownership of
544     // backward functions that will be used for gradient computation
545     // has been transferred to `result`.
546     for (const auto& op_pair : *op_tape) {
547       op_pair.second.backward_function_deleter(
548           op_pair.second.backward_function);
549     }
550     op_tape->clear();
551   }
552   return result;
553 }
554 
555 template <typename BackwardFunction, typename TapeTensor>
InitialStack(const OpTape<BackwardFunction,TapeTensor> & op_tape,const std::unordered_map<int64,int64> & op_missing_tensor)556 std::vector<int64> InitialStack(
557     const OpTape<BackwardFunction, TapeTensor>& op_tape,
558     const std::unordered_map<int64, int64>& op_missing_tensor) {
559   std::vector<int64> result;
560   for (auto& op_entry : op_tape) {
561     if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
562       result.push_back(op_entry.first);
563     }
564   }
565   return result;
566 }
567 
568 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
InitialGradients(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,gtl::ArraySlice<int64> target_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,const TensorTape & tensor_tape,const OpTape<BackwardFunction,TapeTensor> & op_tape,std::unordered_map<int64,std::vector<Gradient * >> * result)569 Status InitialGradients(
570     const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
571     gtl::ArraySlice<int64> target_tensor_ids,
572     const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
573     gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
574     const OpTape<BackwardFunction, TapeTensor>& op_tape,
575     std::unordered_map<int64, std::vector<Gradient*>>* result) {
576   for (int i = 0; i < target_tensor_ids.size(); ++i) {
577     const int64 id = target_tensor_ids[i];
578     if (output_gradients.empty() || output_gradients[i] == nullptr) {
579       auto tensor_it = tensor_tape.find(id);
580       if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
581         auto op_it = op_tape.find(tensor_it->second);
582         if (op_it == op_tape.end()) {
583           return errors::Internal(
584               "Internal state of the gradient tape is invalid: "
585               "failed to find operation producing a tensor");
586         }
587         bool found = false;
588         for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
589           if (op_it->second.output_tensor_info[j].GetID() == id) {
590             found = true;
591             (*result)[id].push_back(
592                 op_it->second.output_tensor_info[j].OnesLike());
593             break;
594           }
595         }
596         if (!found) {
597           return errors::Internal(
598               "Internal state of the gradient tape is invalid: "
599               "none of operations outputs match expected tensor");
600         }
601       } else {
602         // This target tensor was not generated by any operation recorded on
603         // the tape, so no gradient needs to be computed from it unless this
604         // target is also a source.
605         auto source_tensor = sources_that_are_targets.find(id);
606         if (source_tensor != sources_that_are_targets.end()) {
607           (*result)[id].push_back(source_tensor->second.OnesLike());
608         }
609       }
610     } else {
611       (*result)[id].push_back(output_gradients[i]);
612     }
613   }
614   return Status::OK();
615 }
616 
617 // TODO(agarwal): use an automatic mechanism for handling None arguments to
618 // gradient functions.
619 //
620 // Some gradient functions can accept None arguments for gradients. The
621 // following maps the operation name to the indices at which the corresponding
622 // gradient function can accept None values. e.g. FusedBatchNorm outputs 5
623 // values and hence receives 5 gradient values during backprop. However the
624 // gradient function uses only the first of those values and ignores the rest.
625 // The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
626 // corresponding to index 0 is used, and the gradient values at indices 1-4 are
627 // ignored (and hence can be None). The backprop algorithm can then leverage
628 // this by not constructing zeros to pass for those indices.
629 std::unordered_map<string, std::unordered_set<int>>*
FunctionsAcceptingNoneForIndicesMap()630 FunctionsAcceptingNoneForIndicesMap() {
631   static auto* const m =
632       new std::unordered_map<string, std::unordered_set<int>>({
633           {"SoftmaxCrossEntropyWithLogits", {1}},
634           {"SparseSoftmaxCrossEntropyWithLogits", {1}},
635           {"FusedBatchNorm", {1, 2, 3, 4}},
636       });
637   return m;
638 }
639 
640 }  // namespace
641 
642 // If over kMinAggregateCount gradients are accumulated and the total
643 // memory consumption is over kMinAggregateBytes, do an early aggregation
644 // so as to release the gradient tensor to save memory.
645 constexpr int kMinAggregateCount = 4;
646 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
647 
648 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ComputeGradient(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,const gtl::ArraySlice<int64> target_tensor_ids,const gtl::ArraySlice<int64> source_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,std::vector<Gradient * > * result)649 Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
650     const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
651     const gtl::ArraySlice<int64> target_tensor_ids,
652     const gtl::ArraySlice<int64> source_tensor_ids,
653     const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
654     gtl::ArraySlice<Gradient*> output_gradients,
655     std::vector<Gradient*>* result) {
656   std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
657                                         source_tensor_ids.end());
658   BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
659       target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
660   std::vector<int64> op_stack =
661       InitialStack(state.op_tape, state.op_missing_tensor);
662   std::unordered_map<int64, std::vector<Gradient*>> gradients;
663   Status s = InitialGradients(vspace, target_tensor_ids,
664                               sources_that_are_targets, output_gradients,
665                               tensor_tape_, state.op_tape, &gradients);
666   auto cleanup = gtl::MakeCleanup([this, &state]() {
667     if (!persistent_) {
668       // Release all backprop functions
669       for (const auto& pair : state.op_tape) {
670         pair.second.backward_function_deleter(pair.second.backward_function);
671       }
672     }
673   });
674   if (!s.ok()) {
675     return s;
676   }
677 
678   std::unordered_map<int64, int64> gradients_size;
679   // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
680   // time, for better CPU backprop performance.
681   VLOG(1) << "Initial stack:";
682   if (VLOG_IS_ON(1)) {
683     for (auto t : op_stack) {
684       VLOG(1) << "  " << t;
685     }
686   }
687   while (!op_stack.empty()) {
688     const int64 op = op_stack.back();
689     VLOG(1) << "Popped " << op;
690     op_stack.pop_back();
691     auto op_it = state.op_tape.find(op);
692     if (op_it == state.op_tape.end()) {
693       // It is possible for ops to end up on the stack if they are unrelated to
694       // the target; we should just skip them.
695       continue;
696     }
697     auto trace = std::move(op_it->second);
698     state.op_tape.erase(op_it);
699     std::vector<Gradient*> out_gradients;
700     out_gradients.reserve(trace.output_tensor_info.size());
701     std::vector<int64> unneeded_gradients;
702     for (int i = 0; i < trace.input_tensor_id.size(); i++) {
703       const auto& in_tensor_id = trace.input_tensor_id[i];
704       if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() &&
705           sources_set.find(in_tensor_id) == sources_set.end()) {
706         unneeded_gradients.push_back(i);
707       }
708     }
709 
710     bool any_gradient_nonzero = false;
711     std::vector<int> zero_indices;
712     for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
713       const int64 id = trace.output_tensor_info[i].GetID();
714       auto grad_it = gradients.find(id);
715       if (grad_it == gradients.end()) {
716         auto func_name_it =
717             FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
718         if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
719             func_name_it->second.find(i) != func_name_it->second.end()) {
720           out_gradients.push_back(nullptr);
721         } else {
722           out_gradients.push_back(nullptr);
723           zero_indices.push_back(i);
724         }
725       } else {
726         any_gradient_nonzero = true;
727         Gradient* new_gradients = nullptr;
728         if (grad_it->second.size() == 1) {
729           new_gradients = grad_it->second.at(0);
730         } else {
731           new_gradients = vspace.AggregateGradients(grad_it->second);
732         }
733         if (sources_set.find(grad_it->first) == sources_set.end()) {
734           gradients.erase(grad_it);
735         } else {
736           grad_it->second.clear();
737           grad_it->second.push_back(new_gradients);
738           vspace.MarkAsResult(new_gradients);
739         }
740         out_gradients.push_back(new_gradients);
741       }
742     }
743     std::vector<Gradient*> in_gradients;
744     if (any_gradient_nonzero) {
745       for (const auto i : zero_indices) {
746         out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
747       }
748       Status s;
749       s = vspace.CallBackwardFunction(trace.backward_function,
750                                       unneeded_gradients, out_gradients,
751                                       &in_gradients);
752       if (in_gradients.size() != trace.input_tensor_id.size()) {
753         return tensorflow::errors::Internal(
754             "Recorded operation '", trace.op_type,
755             "' returned too few gradients. Expected ",
756             trace.input_tensor_id.size(), " but received ",
757             in_gradients.size());
758       }
759       if (!persistent_) {
760         trace.backward_function_deleter(trace.backward_function);
761       }
762       if (!s.ok()) {
763         return s;
764       }
765     } else {
766       in_gradients.resize(trace.input_tensor_id.size());
767       if (!persistent_) {
768         trace.backward_function_deleter(trace.backward_function);
769       }
770       for (Gradient* grad : out_gradients) {
771         if (grad != nullptr) {
772           vspace.DeleteGradient(grad);
773         }
774       }
775     }
776     VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
777             << trace.input_tensor_id.size() << " sources";
778     for (int i = 0; i < in_gradients.size(); ++i) {
779       const int64 id = trace.input_tensor_id[i];
780       if (in_gradients[i] != nullptr) {
781         auto& unaggregated_grads = gradients[id];
782         unaggregated_grads.push_back(in_gradients[i]);
783         if (unaggregated_grads.size() > kMinAggregateCount) {
784           auto size_it = gradients_size.find(id);
785           int64 size;
786           if (size_it == gradients_size.end()) {
787             size = vspace.NumElements(unaggregated_grads[0]);
788             gradients_size.emplace(id, size);
789           } else {
790             size = size_it->second;
791           }
792           if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
793             Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
794             unaggregated_grads.clear();
795             unaggregated_grads.push_back(grad);
796           }
797         }
798       }
799       auto usage_count_it = state.tensor_usage_counts.find(id);
800       if (usage_count_it == state.tensor_usage_counts.end()) {
801         VLOG(1) << "Tensor " << id << " not used";
802         continue;
803       }
804       usage_count_it->second--;
805       if (usage_count_it->second > 0) {
806         VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
807         continue;
808       }
809       auto tape_it = tensor_tape_.find(id);
810       if (tape_it == tensor_tape_.end()) {
811         VLOG(1) << "Tensor " << id
812                 << " has no associated op. Deleting gradient";
813         auto grad_it = gradients.find(id);
814         if (grad_it != gradients.end()) {
815           for (auto g : grad_it->second) {
816             vspace.DeleteGradient(g);
817           }
818           gradients.erase(grad_it);
819         }
820         continue;
821       }
822       const int64 op_id = tape_it->second;
823       if (op_id == -1) {
824         VLOG(1) << "Tensor " << id << " is source";
825         continue;
826       }
827       auto missing_it = state.op_missing_tensor.find(op_id);
828       if (missing_it != state.op_missing_tensor.end()) {
829         missing_it->second--;
830         VLOG(1) << "Op " << op_id << " missing " << missing_it->second
831                 << " output gradients";
832         if (missing_it->second == 0) {
833           op_stack.insert(op_stack.begin(), op_id);
834         }
835       }
836     }
837   }
838   if (!state.op_tape.empty()) {
839     return tensorflow::errors::Internal("Invalid tape state.");
840   }
841   result->reserve(source_tensor_ids.size());
842   std::unordered_set<int64> used_gradient_ids(source_tensor_ids.size());
843   for (auto is : source_tensor_ids) {
844     auto grad_it = gradients.find(is);
845     if (grad_it == gradients.end()) {
846       result->push_back(nullptr);
847     } else {
848       if (grad_it->second.size() > 1) {
849         Gradient* grad = vspace.AggregateGradients(grad_it->second);
850         grad_it->second.clear();
851         grad_it->second.push_back(grad);
852       }
853       result->push_back(grad_it->second[0]);
854       used_gradient_ids.insert(is);
855     }
856   }
857   VLOG(1) << "Final gradients size: "
858           << gradients.size() - used_gradient_ids.size();
859   for (auto grad_pair : gradients) {
860     if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
861       for (const auto& g : grad_pair.second) {
862         vspace.DeleteGradient(g);
863       }
864     }
865   }
866   return Status::OK();
867 }
868 
869 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)870 bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
871     gtl::ArraySlice<int64> tensor_ids,
872     gtl::ArraySlice<tensorflow::DataType> dtypes) {
873   if (call_state_.top().backward_tape != nullptr) {
874     // If we're forwarding Accumulate calls to backward_tape's RecordOperation,
875     // we should also delegate ShouldRecord.
876     return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
877   }
878   if (call_state_.top().accumulating) {
879     return false;
880   }
881   for (int i = 0; i < tensor_ids.size(); ++i) {
882     if (accumulated_gradients_.find(tensor_ids[i]) !=
883         accumulated_gradients_.end()) {
884       if (IsDtypeTrainable(dtypes[i])) {
885         return true;
886       }
887     }
888   }
889   return false;
890 }
891 
892 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
893 Status
ForwardpropFromTape(const std::vector<TapeTensor> & output_tensors,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter,const std::vector<Gradient * > & in_grads,std::vector<Gradient * > * out_grads)894 ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
895     const std::vector<TapeTensor>& output_tensors,
896     const std::function<BackwardFunction*()>& backward_function_getter,
897     const std::function<void(BackwardFunction*)>& backward_function_deleter,
898     const std::vector<Gradient*>& in_grads, std::vector<Gradient*>* out_grads) {
899   /* This function is approximately equivalent to this Python code:
900 
901   forwardprop_aids = tf.ones_like(output_tensors)
902   with tf.GradientTape() as g:
903     g.watch(forwardprop_aids)
904     grad = backward_function(forwardprop_aids)
905   forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads)
906   accumulated_gradients_[ID(output_tensors)] = forward_grads
907   */
908   std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
909       new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
910   AccumulatorCallState& call_state = call_state_.top();
911   call_state.backward_tape = tape.get();
912   auto pop_backward_tape =
913       gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
914   std::vector<Gradient*> forwardprop_aids;
915   std::vector<int64> sources;
916   std::unordered_set<int64> sources_set;
917   sources.reserve(output_tensors.size());
918   for (const TapeTensor& output_tensor : output_tensors) {
919     // Ownership of `aid` transferred to CallBackwardFunction below.
920     Gradient* aid;
921     if (output_tensor.GetDType() == tensorflow::DT_VARIANT) {
922       // Note: Needs to be zeros rather than ones since there's currently no
923       // ones_like for variants.
924       aid = output_tensor.ZerosLike();
925     } else {
926       // TODO(allenl): Figure out why using zeros_like everywhere causes issues
927       // for some gradient functions and if there's another way to work around
928       // it (e.g. conds instead of ifs). The value shouldn't really matter.
929       aid = output_tensor.OnesLike();
930     }
931     if (TF_PREDICT_FALSE(aid == nullptr)) {
932       return tensorflow::errors::Internal(
933           "Failed to create ones tensor for tensor ", output_tensor.GetID(),
934           " with dtype ", output_tensor.GetDType());
935     }
936     forwardprop_aids.push_back(aid);
937     int64 aid_id = vspace_.TensorId(aid);
938     sources.push_back(aid_id);
939     sources_set.insert(aid_id);
940     tape->Watch(aid_id);
941   }
942   std::vector<Gradient*> grad;
943   auto delete_grad = gtl::MakeCleanup([&grad, this] {
944     for (Gradient* tensor : grad) {
945       this->vspace_.DeleteGradient(tensor);
946     }
947   });
948   {
949     std::vector<int64> unneeded_gradients;
950     std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
951         backward_function(backward_function_getter(),
952                           backward_function_deleter);
953     TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
954         backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
955   }
956 
957   // Stop the tape from recording
958   pop_backward_tape.release()();
959 
960   if (grad.size() != in_grads.size()) {
961     return tensorflow::errors::Internal("Wrong number of gradients returned.");
962   }
963 
964   std::vector<int64> targets;
965   std::vector<Gradient*> used_in_grads;
966   // We may end up with slightly fewer elements than we reserve, but grad.size()
967   // should be a reasonably tight upper bound.
968   targets.reserve(grad.size());
969   used_in_grads.reserve(grad.size());
970   std::unordered_map<int64, TapeTensor> sources_that_are_targets;
971   for (int grad_index = 0; grad_index < grad.size(); ++grad_index) {
972     Gradient* grad_tensor = grad[grad_index];
973     if (grad_tensor != nullptr) {
974       int64 tensor_id = vspace_.TensorId(grad_tensor);
975       targets.push_back(tensor_id);
976       if (sources_set.find(tensor_id) != sources_set.end()) {
977         sources_that_are_targets.emplace(
978             tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
979       }
980       Gradient* in_grad = in_grads[grad_index];
981       if (in_grad != nullptr) {
982         // ComputeGradient steals a reference
983         vspace_.MarkAsResult(in_grad);
984       }
985       used_in_grads.push_back(in_grad);
986     }
987   }
988 
989   return tape->ComputeGradient(vspace_, targets, sources,
990                                sources_that_are_targets, used_in_grads,
991                                out_grads);
992 }
993 
994 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Accumulate(const string & op_type,const std::vector<TapeTensor> & input_tensors,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const ForwardFunction<Gradient> * forward_function,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)995 Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
996     const string& op_type, const std::vector<TapeTensor>& input_tensors,
997     const std::vector<TapeTensor>& output_tensors,
998     gtl::ArraySlice<int64> input_tensor_id,
999     gtl::ArraySlice<tensorflow::DataType> input_dtypes,
1000     const ForwardFunction<Gradient>* forward_function,
1001     const std::function<BackwardFunction*()>& backward_function_getter,
1002     const std::function<void(BackwardFunction*)>& backward_function_deleter) {
1003   if (call_state_.top().backward_tape != nullptr) {
1004     // If backward_tape is not null, then this call to Accumulate is the result
1005     // of a still-active call to Accumulate which is running operations. We
1006     // forward these operations to backward_tape so the outer Accumulate call
1007     // can do its work.
1008     //
1009     // Rather than re-entering and delegating Accumulate like this, we could
1010     // instead allow ForwardAccumulator some control over the current tape set
1011     // (so it can deactivate itself and activate its GradientTape). Currently
1012     // that is managed by the language binding and would require relatively
1013     // messy callbacks.
1014     call_state_.top().backward_tape->RecordOperation(
1015         op_type, output_tensors, input_tensor_id, input_dtypes,
1016         backward_function_getter, backward_function_deleter);
1017     return Status::OK();
1018   }
1019   if (!ShouldRecord(input_tensor_id, input_dtypes)) {
1020     return Status::OK();
1021   }
1022 
1023   // We may need to allocate zero inputs for trainable dtypes we don't have JVPs
1024   // for. Make sure they get cleaned up.
1025   std::vector<Gradient*> new_zeros;
1026   auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
1027     for (Gradient* tensor : new_zeros) {
1028       this->vspace_.DeleteGradient(tensor);
1029     }
1030   });
1031   std::vector<Gradient*> in_grads;
1032   in_grads.reserve(input_tensors.size());
1033   for (int target_index = 0; target_index < input_tensors.size();
1034        ++target_index) {
1035     const auto current_grad =
1036         accumulated_gradients_.find(input_tensors[target_index].GetID());
1037     if (current_grad == accumulated_gradients_.end()) {
1038       if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
1039         // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
1040         // GradientTape which uses ones.
1041         Gradient* zero = input_tensors[target_index].ZerosLike();
1042         new_zeros.push_back(zero);
1043         in_grads.push_back(zero);
1044       } else {
1045         in_grads.push_back(nullptr);
1046       }
1047     } else {
1048       in_grads.push_back(current_grad->second);
1049     }
1050   }
1051 
1052   // Avoid infinite recursion. Whichever forward function we run, it'll end up
1053   // executing ops, and we don't want to watch those with this accumulator.
1054   call_state_.emplace(nullptr, true);
1055   auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
1056 
1057   std::vector<Gradient*> forward_grads;
1058   if (forward_function == nullptr) {
1059     // We have no special-cased forward gradient. Fall back to running the
1060     // backward function under a gradient tape.
1061     TF_RETURN_IF_ERROR(ForwardpropFromTape(
1062         output_tensors, backward_function_getter, backward_function_deleter,
1063         in_grads, &forward_grads));
1064   } else {
1065     TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
1066   }
1067   for (int i = 0; i < forward_grads.size(); ++i) {
1068     if (forward_grads[i] != nullptr) {
1069       int64 tensor_id = output_tensors[i].GetID();
1070       auto existing = accumulated_gradients_.find(tensor_id);
1071       if (existing != accumulated_gradients_.end()) {
1072         // This is a somewhat odd case to be in, since it means we have two
1073         // operations which supposedly both created the same Tensor. It comes up
1074         // in recompute_grad, where the gradients have the same value. However,
1075         // only the original gradient is connected to everything else, so we
1076         // should still use that.
1077         vspace_.DeleteGradient(forward_grads[i]);
1078       } else {
1079         accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
1080       }
1081     }
1082   }
1083   return Status::OK();
1084 }
1085 
1086 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id,Gradient * tangent)1087 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
1088     int64 tensor_id, Gradient* tangent) {
1089   typename std::unordered_map<int64, Gradient*>::iterator existing =
1090       accumulated_gradients_.find(tensor_id);
1091   vspace_.MarkAsResult(tangent);
1092   if (existing == accumulated_gradients_.end()) {
1093     accumulated_gradients_.emplace(tensor_id, tangent);
1094   } else {
1095     std::array<Gradient*, 2> to_aggregate;
1096     to_aggregate[0] = tangent;
1097     to_aggregate[1] = existing->second;
1098     // AggregateGradients steals a reference to each of its arguments. We
1099     // MarkAsResult on `tangent` above so we don't steal a reference to it.
1100     existing->second = vspace_.AggregateGradients(to_aggregate);
1101   }
1102 }
1103 
1104 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteGradient(int64 tensor_id)1105 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
1106     int64 tensor_id) {
1107   auto existing = accumulated_gradients_.find(tensor_id);
1108   if (existing != accumulated_gradients_.end()) {
1109     vspace_.DeleteGradient(existing->second);
1110     accumulated_gradients_.erase(existing);
1111   }
1112 }
1113 
1114 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
FetchJVP(int64 tensor_id)1115 Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP(
1116     int64 tensor_id) {
1117   auto lookup = accumulated_gradients_.find(tensor_id);
1118   if (lookup == accumulated_gradients_.end()) {
1119     return nullptr;
1120   } else {
1121     return lookup->second;
1122   }
1123 }
1124 
1125 }  // namespace eager
1126 }  // namespace tensorflow
1127 
1128 #endif  // TENSORFLOW_C_EAGER_TAPE_H_
1129