1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_TAPE_H_
16 #define TENSORFLOW_C_EAGER_TAPE_H_
17 
18 // Language-agnostic gradient tape. Does not perform backpropagation, just
19 // maintains the data structures required to do so.
20 
21 #include <stack>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 namespace eager {
37 
38 // Represents an entry in the tape.
39 template <typename BackwardFunction, typename TapeTensor>
40 struct OpTapeEntry {
41   string op_type;
42   std::vector<TapeTensor> output_tensor_info;
43   std::vector<int64> input_tensor_id;
44 
45   // TODO(apassos) consider narrowing down this interface.
46   BackwardFunction* backward_function;
47 
48   // Should be called before deleting the backward function. TODO(apassos) use
49   // unique_ptrs to ensure this happens.
50   std::function<void(BackwardFunction*)> backward_function_deleter;
51 };
52 
53 // Map from tensor_id to internally-defined operation-id of the operation which
54 // produced this tensor. A value of -1 means that the tensor was directly
55 // watched and not the result of any operation in the tape.
56 using TensorTape = std::unordered_map<int64, int64>;
57 
58 // Map from operation-id to tape entry.
59 template <typename BackwardFunction, typename TapeTensor>
60 using OpTape =
61     std::unordered_map<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
62 
63 // Operations the tape needs to perform on tensors to do backpropagation. Named
64 // "vspace" because a subset of these are related to a vector space, such as
65 // adding gradients, getting zeroes, etc. Currently cannot be implemented
66 // without using tensorflow python code, hence left unspecified here.
67 //
68 // Gradient is the type returned by gradient functions. In Python TF it's either
69 // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
70 // to allow their size to be computed and they need to be passable to a backward
71 // function and deleted (as the backprop code creates lots of gradients the user
72 // is not interested in).
73 //
74 // BackwardFunction needs to be a closure which stores intermediate activations
75 // from the forward computation and calls a vector-jacobian product function
76 // (also known as adjoint function) to compute, given downstream gradients,
77 // upstream gradients.
78 //
79 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
80 // specialization, which is blocked by quite a few things needing to loop back
81 // into python now.
82 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
83 class VSpace {
84  public:
~VSpace()85   virtual ~VSpace() {}
86 
87   // Returns the number of elements in the gradient tensor.
88   virtual int64 NumElements(Gradient* tensor) const = 0;
89 
90   // Consumes references to the tensors in the gradient_tensors list and returns
91   // a tensor with the result.
92   virtual Gradient* AggregateGradients(
93       gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
94 
95   // Calls the passed-in backward function.
96   //
97   // `unneeded_gradients` contains sorted list of input indices for which a
98   // gradient is not required.
99   virtual Status CallBackwardFunction(
100       const string& op_type, BackwardFunction* backward_function,
101       const std::vector<int64>& unneeded_gradients,
102       gtl::ArraySlice<Gradient*> output_gradients,
103       absl::Span<Gradient*> result) const = 0;
104 
105   // Builds a tensor filled with ones with the same shape and dtype as `t`.
106   virtual Status BuildOnesLike(const TapeTensor& t,
107                                Gradient** result) const = 0;
108 
109   // Looks up the ID of a Gradient.
110   virtual int64 TensorId(Gradient* tensor) const = 0;
111 
112   // Converts a Gradient to a TapeTensor.
113   virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0;
114 
115   // Marks the following gradient as a result so it's not consumed by backward
116   // functions.
117   virtual void MarkAsResult(Gradient* gradient) const = 0;
118 
119   // Deletes the input tensor.
120   virtual void DeleteGradient(Gradient* gradient) const = 0;
121 };
122 
123 // Traces the execution of operations, doing eager garbage collection, and
124 // exporting a full trace so other code can do backpropagation. Not thread-safe.
125 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
126 class GradientTape {
127  public:
128   // If `persistent` is true, GradientTape will not eagerly delete backward
129   // functions (and hence the tensors they keep alive). Instead, everything
130   // is deleted in ~GradientTape. Persistent GradientTapes are useful when
131   // users want to compute multiple gradients over the same tape.
GradientTape(bool persistent)132   explicit GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape()133   ~GradientTape() {
134     for (const auto& pair : op_tape_) {
135       pair.second.backward_function_deleter(pair.second.backward_function);
136     }
137   }
138 
139   // Returns whether any tensor in a list of tensors is being watched and has
140   // a trainable dtype.
141   bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
142                     gtl::ArraySlice<tensorflow::DataType> dtypes) const;
143 
144   // Adds this tensor to the list of watched tensors.
145   //
146   // This is a no-op if the tensor is already being watched either from an
147   // earlier call to `GradientTape::Watch` or being an output of an op with
148   // watched inputs.
149   void Watch(int64_t tensor_id);
150 
151   // Records an operation with inputs `input_tensor_id` and outputs
152   // `output_tensors` on the tape and marks all its outputs as watched if at
153   // least one input of the op is watched and has trainable dtype.
154   //
155   // op_type is used to decide which of the incoming gradients can be left as
156   // nullptr instead of building zeros when build_default_zeros_grads == true.
157   void RecordOperation(
158       const string& op_type, const std::vector<TapeTensor>& output_tensors,
159       gtl::ArraySlice<int64> input_tensor_id,
160       gtl::ArraySlice<tensorflow::DataType> input_dtypes,
161       const std::function<BackwardFunction*()>& backward_function_getter,
162       const std::function<void(BackwardFunction*)>& backward_function_deleter);
163 
164   void DeleteTrace(int64_t tensor_id);
165 
166   // Consumes the internal state of the tape (so cannot be called more than
167   // once) and produces the gradient of the target tensors with respect to the
168   // source tensors. The output gradients are used if not empty and not
169   // null. The result is populated with one tensor per target element.
170   // When running backward functions, builds zeros-like tensors for
171   // incoming grads which are nullptrs, unless `build_default_zeros_grads`
172   // is set to false.
173   Status ComputeGradient(
174       const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
175       const gtl::ArraySlice<int64> target_tensor_ids,
176       const gtl::ArraySlice<int64> source_tensor_ids,
177       const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
178       gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
179       bool build_default_zeros_grads = true);
180 
181   // Whether the tape is persistent. See ctor for detailed description.
IsPersistent()182   bool IsPersistent() const { return persistent_; }
183 
184  private:
185   TensorTape tensor_tape_;
186   OpTape<BackwardFunction, TapeTensor> op_tape_;
187   int64 next_op_id_{0};
188 
189   // Map from tensor id to number of remaining usages (i.e. how many entries in
190   // the tape refer to it); to aid in tape garbage collection.
191   std::unordered_map<int64, int64> tensor_usage_;
192 
193   // If false, all activations are deleted in the first call to ComputeGradient.
194   // Else, only when this is destructed.
195   bool persistent_;
196 };
197 
198 // Describes a callback for special-cased and more efficient jvp computation.
199 //
200 // Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on
201 // that.
202 template <typename Gradient>
203 class ForwardFunction
204     : public std::function<Status(const std::vector<Gradient*>&,
205                                   std::vector<Gradient*>*, bool)> {
206  public:
207   template <typename lambda_type>
ForwardFunction(lambda_type lambda)208   explicit ForwardFunction(lambda_type lambda)
209       : std::function<Status(const std::vector<Gradient*>&,
210                              std::vector<Gradient*>*, bool)>(lambda) {}
211 };
212 
213 // Computes Jacobian-vector products using forward-mode automatic
214 // differentiation.
215 //
216 // While GradientTape's RecordOperation is trivial, ForwardAccumulator's
217 // Accumulate runs the gradient computation immediately.
218 //
219 // Keeps references to Tensors watched via Watch and computed in Accumulate
220 // corresponding to output_tensors, and releases these references in its
221 // destructor. However, waiting until the destructor runs loses the memory
222 // efficiency of forward-mode autodiff. Instead, language bindings should call
223 // DeleteGradient as soon as a Tensor which was `Watch`ed or was an output
224 // Tensor passed to Accumulate goes out of scope.
225 //
226 // Not thread-safe.
227 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
228 class ForwardAccumulator {
229  public:
230   // Does not take ownership of `vspace`, which must outlive the
231   // ForwardAccumulator.
ForwardAccumulator(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,bool use_batch)232   explicit ForwardAccumulator(
233       const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
234       bool use_batch)
235       : vspace_(vspace), use_batch_(use_batch) {
236     call_state_.emplace(nullptr, false);
237   }
238 
~ForwardAccumulator()239   virtual ~ForwardAccumulator() {
240     for (auto accumulated : accumulated_gradients_) {
241       vspace_.DeleteGradient(accumulated.second);
242     }
243   }
244 
245   // Tell the forward accumulator to watch tensor_id, with a Tensor tangent
246   // vector `tangent` of matching shape and dtype. Tangents are the "vector" in
247   // "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling
248   // FetchJVP for it would return `tangent`.
249   void Watch(int64_t tensor_id, Gradient* tangent);
250 
251   // Removes the gradient associated with tensor_id. Should be called when the
252   // Tensor associated with `tensor_id` is deleted.
253   void DeleteGradient(int64_t tensor_id);
254 
255   // Runs forward autodiff. Should be called whenever a new operation is
256   // available and the accumulator is active.
257   //
258   // Like GradientTape::RecordOperation, this method takes the operation type
259   // `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`,
260   // `input_tensor_id`, and `input_dtypes`; the latter two are somewhat
261   // redundant but taken as arguments to avoid repeatedly fetching these values
262   // between calls to ShouldRecord and Accumulator), and its outputs
263   // (`output_tensors`).
264   //
265   // If provided, a non-null `forward_function` will be used instead of the
266   // backward function (`backward_function_getter` /
267   // `backward_function_deleter`) to compute jvps for this operation. If
268   // `forward_function` is null, a GradientTape is used on the backward function
269   // to compute the jvp, which will waste computation when executing eagerly.
270   //
271   // Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
272   // immediately. It stores the results, which feed into Accumulate for future
273   // operations and may be fetched by calling FetchJVP. ForwardAccumulator
274   // maintains a reference to these JVPs: if an `output_tensors` Tensor is
275   // deleted, `DeleteGradient` should be called as soon as possible to free the
276   // (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor
277   // will release remaining references.
278   //
279   // This method is not thread-safe (and in general ForwardAccumulator is not
280   // thread-safe).
281   Status Accumulate(
282       const string& op_type, const std::vector<TapeTensor>& input_tensors,
283       const std::vector<TapeTensor>& output_tensors,
284       gtl::ArraySlice<int64> input_tensor_id,
285       gtl::ArraySlice<tensorflow::DataType> input_dtypes,
286       const ForwardFunction<Gradient>* forward_function,
287       const std::function<BackwardFunction*()>& backward_function_getter,
288       const std::function<void(BackwardFunction*)>& backward_function_deleter);
289 
290   // Returns true if `Accumulate` is active somewhere above on the stack and
291   // there isn't an intervening PushState. This is useful for ordering
292   // ForwardAccumulators, where more deeply nested accumulators should not see
293   // computations from less deeply nested accumulators.
BusyAccumulating()294   bool BusyAccumulating() const { return call_state_.top().accumulating; }
295 
296   // Fetches the current Jacobian-vector product associated with `tensor_id`, or
297   // a nullptr if none is available.
298   //
299   // Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its
300   // return value. The caller should increment the reference count before
301   // deleting the ForwardAccumulator or calling DeleteGradient if keeping a
302   // persistent reference to a non-null result.
303   Gradient* FetchJVP(int64_t tensor_id);
304 
305   // Indicates whether the forward accumulator should run on an operation with
306   // the specified inputs and dtypes.
307   bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
308                     gtl::ArraySlice<tensorflow::DataType> dtypes);
309 
310   // Temporarily push or pop transient state for this accumulator.
311   //
312   // Allows an accumulator which is currently processing an operation to
313   // temporarily reset its state. Without pushing and popping, accumulators
314   // ignore operations executed as a direct result of their own jvp
315   // computations.
PushState()316   void PushState() { call_state_.emplace(nullptr, false); }
PopState()317   void PopState() { call_state_.pop(); }
318 
319  private:
320   // Helper for Accumulate: uses a GradientTape to compute forward gradients
321   // from a backward gradient function. Fills `out_grads` corresponding to
322   // `output_tensors`. `out_grads` must not be null.
323   //
324   // Executes the backward function in order to trace its gradient, which will
325   // waste computation if executing eagerly (when graph building the unneeded
326   // computation is pruned). Temporarily sets `backward_tape` so that
327   // Accumulate will forward op executions to the tape while the backward
328   // function is running; this effectively adds the backward tape to the active
329   // set (but does not require complicated callbacks to the language bindings).
330   Status ForwardpropFromTape(
331       const string& op_type, const std::vector<TapeTensor>& output_tensors,
332       const std::function<BackwardFunction*()>& backward_function_getter,
333       const std::function<void(BackwardFunction*)>& backward_function_deleter,
334       const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads);
335 
336   // Maps from tensor IDs to corresponding JVPs.
337   std::unordered_map<int64, Gradient*> accumulated_gradients_;
338   // Not owned; provides operations on Tensors which are currently only
339   // available in language bindings (e.g. Python).
340   const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
341 
342   // Decides if tangents are vectorized or not
343   bool use_batch_;
344 
345   struct AccumulatorCallState {
AccumulatorCallStateAccumulatorCallState346     AccumulatorCallState(
347         GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
348         bool accumulating)
349         : backward_tape(backward_tape), accumulating(accumulating) {}
350     // Set temporarily while in the Accumulate method; if backward_tape is not
351     // nullptr then we forward op executions to it so Accumulate can compute a
352     // backward pass on its backward function.
353     //
354     // Not owned by the ForwardAccumulator. The method which sets
355     // `backward_tape` keeps ownership.
356     GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
357     // While the Accumulate method is running (accumulating is True), any op
358     // executions not forwarded to backward_tape should be ignored.
359     bool accumulating;
360   };
361   // A deque-backed stack, whose element references are not invalidated by
362   // pushes and pops at the back.
363   std::stack<AccumulatorCallState> call_state_;
364 };
365 
366 // Template instantiations here
367 
IsDtypeTrainable(DataType dtype)368 inline bool IsDtypeTrainable(DataType dtype) {
369   switch (dtype) {
370     case DT_HALF:
371     case DT_BFLOAT16:
372     case DT_FLOAT:
373     case DT_DOUBLE:
374     case DT_COMPLEX64:
375     case DT_COMPLEX128:
376     case DT_RESOURCE:
377     case DT_VARIANT:
378       return true;
379     default:
380       return false;
381   }
382 }
383 
384 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)385 bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
386     gtl::ArraySlice<int64> tensor_ids,
387     gtl::ArraySlice<tensorflow::DataType> dtypes) const {
388   CHECK_EQ(tensor_ids.size(), dtypes.size());
389   for (int i = 0; i < tensor_ids.size(); ++i) {
390     if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
391       if (IsDtypeTrainable(dtypes[i])) {
392         return true;
393       }
394     }
395   }
396   return false;
397 }
398 
399 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64_t tensor_id)400 void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
401     int64_t tensor_id) {
402   tensor_tape_.emplace(tensor_id, -1);
403 }
404 
405 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
RecordOperation(const string & op_type,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)406 void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
407     const string& op_type, const std::vector<TapeTensor>& output_tensors,
408     gtl::ArraySlice<int64> input_tensor_id,
409     gtl::ArraySlice<tensorflow::DataType> input_dtypes,
410     const std::function<BackwardFunction*()>& backward_function_getter,
411     const std::function<void(BackwardFunction*)>& backward_function_deleter) {
412   if (!ShouldRecord(input_tensor_id, input_dtypes)) {
413     return;
414   }
415   std::vector<int64> ids;
416   ids.reserve(input_tensor_id.size());
417   for (int64_t i : input_tensor_id) {
418     tensor_usage_[i]++;
419     ids.push_back(i);
420   }
421   const int64_t op_id = next_op_id_++;
422   std::vector<TapeTensor> tensors;
423   tensors.reserve(output_tensors.size());
424   for (const TapeTensor& o : output_tensors) {
425     // Note: the tensor can have already been watched and hence be in the tape,
426     // so we cannot check that we're inserting it here.
427     tensor_tape_[o.GetID()] = op_id;
428     tensor_usage_[o.GetID()] = 1;
429     tensors.push_back(o);
430   }
431   op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
432       op_type, std::move(tensors), std::move(ids), backward_function_getter(),
433       backward_function_deleter};
434 }
435 
436 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteTrace(int64_t tensor_id)437 void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
438     int64_t tensor_id) {
439   auto it = tensor_usage_.find(tensor_id);
440   if (it == tensor_usage_.end()) {
441     return;
442   }
443   it->second--;
444   if (it->second != 0) {
445     return;
446   }
447   tensor_usage_.erase(it);
448   auto tensor_op_it = tensor_tape_.find(tensor_id);
449   if (tensor_op_it == tensor_tape_.end()) {
450     return;
451   }
452   const int64_t op_id = tensor_op_it->second;
453   if (op_id == -1) {
454     // Do not delete watched tensors.
455     return;
456   }
457   tensor_tape_.erase(tensor_op_it);
458   auto op_it = op_tape_.find(op_id);
459   CHECK(op_it != op_tape_.end());
460   for (const auto& output : op_it->second.output_tensor_info) {
461     if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
462       // Found a usage for an output, so cannot delete the op.
463       return;
464     }
465   }
466   for (int64_t id : op_it->second.input_tensor_id) {
467     DeleteTrace(id);
468   }
469   op_it->second.backward_function_deleter(op_it->second.backward_function);
470   op_tape_.erase(op_it);
471 }
472 
473 // Terminology:
474 //
475 //  - op: a possibly composite operation, which has an entry in the tape
476 //  - target: dy in dx/dy
477 //  - source: dx in dx/dy
478 //  - tensor: one of the many inputs or outputs of an operation
479 //
480 // Below here we do the gradient algorithm. It works as follows:
481 //
482 // First we filter the tape to just the subset of operations we want to
483 // differentiate. In the process of doing so we count how many times each Tensor
484 // is used as an input to an op (so we know when we're done computing gradients
485 // for that Tensor). We also count, for each tape entry, how many of its output
486 // Tensors need gradients to be computed (Tensors which are not used do not need
487 // any gradients to be computed).
488 //
489 // Finally, we start a backprop stack with a set of tape entries for which we
490 // have all gradients available. This set usually is a subset of the set of
491 // targets (not all since targets which have outputs in the tape will not have
492 // gradients available initially).
493 //
494 // Then we repeatedly pop an entry from the stack, run its backprop, and update
495 // the gradients of its inputs. Once we have computed all gradients for a single
496 // input we can mark this input as done, and this can trigger adding an entry to
497 // the stack if all outputs of that entry are now done.
498 //
499 // When the stack is empty we have gradients for all tensors we're interested
500 // in.
501 
502 namespace {
503 
504 template <typename BackwardFunction, typename TapeTensor>
505 struct BackpropInitialState {
506   OpTape<BackwardFunction, TapeTensor> op_tape;
507 
508   // Map from tensor ID to how many references still exist for this tensor in
509   // the tape.
510   std::unordered_map<int64, int64> tensor_usage_counts;
511 
512   // Maps from op ID to how many output tensors of this op still need to have
513   // their gradients computed.
514   std::unordered_map<int64, int64> op_missing_tensor;
515 };
516 
517 // If `persistent_tape` is true, op_tape is not changed and none of the
518 // backwards functions are deleted.
519 // If `persistent_tape` is false, op_tape is cleared and backwards functions
520 // not needed for gradient computation are deleted. Backwards functions that
521 // are needed, are copied and returned in BackpropInitialState.
522 template <typename BackwardFunction, typename TapeTensor>
PrepareBackprop(gtl::ArraySlice<int64> target,const TensorTape & tensor_tape,OpTape<BackwardFunction,TapeTensor> * op_tape,const std::unordered_set<int64> & sources_set,bool persistent_tape)523 BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
524     gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
525     OpTape<BackwardFunction, TapeTensor>* op_tape,
526     const std::unordered_set<int64>& sources_set, bool persistent_tape) {
527   std::vector<int64> tensor_stack;
528   tensor_stack.reserve(target.size());
529   for (auto t : target) {
530     tensor_stack.push_back(t);
531   }
532   BackpropInitialState<BackwardFunction, TapeTensor> result;
533   while (!tensor_stack.empty()) {
534     int64_t tensor_id = tensor_stack.back();
535     tensor_stack.pop_back();
536     auto op_id_it = tensor_tape.find(tensor_id);
537     if (op_id_it == tensor_tape.end()) {
538       continue;
539     }
540     int64_t op_id = op_id_it->second;
541     auto op_it = op_tape->find(op_id);
542     auto result_op_it = result.op_tape.find(op_id);
543     if (op_id == -1 || op_it == op_tape->end() ||
544         result_op_it != result.op_tape.end()) {
545       continue;
546     }
547     CHECK(result.op_tape.emplace(op_id, op_it->second).second);
548     for (auto it : op_it->second.input_tensor_id) {
549       auto count_it = result.tensor_usage_counts.find(it);
550       if (count_it != result.tensor_usage_counts.end()) {
551         count_it->second++;
552       } else {
553         result.tensor_usage_counts[it] = 1;
554         if (tensor_tape.find(it) != tensor_tape.end()) {
555           tensor_stack.push_back(it);
556         }
557       }
558     }
559     if (!persistent_tape) {
560       op_tape->erase(op_it);
561     }
562   }
563   for (auto& pair : result.tensor_usage_counts) {
564     auto it = tensor_tape.find(pair.first);
565     if (it != tensor_tape.end() && it->second != -1) {
566       result.op_missing_tensor[it->second] += 1;
567     }
568   }
569   if (!persistent_tape) {
570     // Call destructors for all unneeded gradient functions and
571     // clear the op_tape. We can clear the tape because ownership of
572     // backward functions that will be used for gradient computation
573     // has been transferred to `result`.
574     for (const auto& op_pair : *op_tape) {
575       op_pair.second.backward_function_deleter(
576           op_pair.second.backward_function);
577     }
578     op_tape->clear();
579   }
580   return result;
581 }
582 
583 template <typename BackwardFunction, typename TapeTensor>
InitialStack(const OpTape<BackwardFunction,TapeTensor> & op_tape,const std::unordered_map<int64,int64> & op_missing_tensor)584 std::vector<int64> InitialStack(
585     const OpTape<BackwardFunction, TapeTensor>& op_tape,
586     const std::unordered_map<int64, int64>& op_missing_tensor) {
587   std::vector<int64> result;
588   for (auto& op_entry : op_tape) {
589     if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
590       result.push_back(op_entry.first);
591     }
592   }
593   return result;
594 }
595 
596 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
InitialGradients(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,gtl::ArraySlice<int64> target_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,const TensorTape & tensor_tape,const OpTape<BackwardFunction,TapeTensor> & op_tape,std::unordered_map<int64,std::vector<Gradient * >> * result)597 Status InitialGradients(
598     const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
599     gtl::ArraySlice<int64> target_tensor_ids,
600     const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
601     gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
602     const OpTape<BackwardFunction, TapeTensor>& op_tape,
603     std::unordered_map<int64, std::vector<Gradient*>>* result) {
604   for (int i = 0, end = target_tensor_ids.size(); i < end; ++i) {
605     const int64_t id = target_tensor_ids[i];
606     if (output_gradients.empty() || output_gradients[i] == nullptr) {
607       auto tensor_it = tensor_tape.find(id);
608       if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
609         auto op_it = op_tape.find(tensor_it->second);
610         if (op_it == op_tape.end()) {
611           return errors::Internal(
612               "Internal state of the gradient tape is invalid: "
613               "failed to find operation producing a tensor");
614         }
615         bool found = false;
616         for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
617           if (op_it->second.output_tensor_info[j].GetID() == id) {
618             found = true;
619             Gradient* ones_like = nullptr;
620             TF_RETURN_IF_ERROR(vspace.BuildOnesLike(
621                 op_it->second.output_tensor_info[j], &ones_like));
622             (*result)[id].push_back(ones_like);
623             break;
624           }
625         }
626         if (!found) {
627           return errors::Internal(
628               "Internal state of the gradient tape is invalid: "
629               "none of operations outputs match expected tensor");
630         }
631       } else {
632         // This target tensor was not generated by any operation recorded on
633         // the tape, so no gradient needs to be computed from it unless this
634         // target is also a source.
635         auto source_tensor = sources_that_are_targets.find(id);
636         if (source_tensor != sources_that_are_targets.end()) {
637           Gradient* ones_like = nullptr;
638           TF_RETURN_IF_ERROR(
639               vspace.BuildOnesLike(source_tensor->second, &ones_like));
640           (*result)[id].push_back(ones_like);
641         }
642       }
643     } else {
644       (*result)[id].push_back(output_gradients[i]);
645     }
646   }
647   return Status::OK();
648 }
649 
650 // TODO(agarwal): use an automatic mechanism for handling None arguments to
651 // gradient functions.
652 //
653 // Some gradient functions can accept None arguments for gradients. The
654 // following maps the operation name to the indices at which the corresponding
655 // gradient function can accept None values. e.g. FusedBatchNorm outputs 5
656 // values and hence receives 5 gradient values during backprop. However the
657 // gradient function uses only the first of those values and ignores the rest.
658 // The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
659 // corresponding to index 0 is used, and the gradient values at indices 1-4 are
660 // ignored (and hence can be None). The backprop algorithm can then leverage
661 // this by not constructing zeros to pass for those indices.
662 std::unordered_map<string, std::unordered_set<int>>*
FunctionsAcceptingNoneForIndicesMap()663 FunctionsAcceptingNoneForIndicesMap() {
664   static auto* const m =
665       new std::unordered_map<string, std::unordered_set<int>>({
666           {"SoftmaxCrossEntropyWithLogits", {1}},
667           {"SparseSoftmaxCrossEntropyWithLogits", {1}},
668           {"FusedBatchNorm", {1, 2, 3, 4}},
669       });
670   return m;
671 }
672 
673 }  // namespace
674 
675 // If over kMinAggregateCount gradients are accumulated and the total
676 // memory consumption is over kMinAggregateBytes, do an early aggregation
677 // so as to release the gradient tensor to save memory.
678 constexpr int kMinAggregateCount = 4;
679 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
680 
681 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ComputeGradient(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,const gtl::ArraySlice<int64> target_tensor_ids,const gtl::ArraySlice<int64> source_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,absl::Span<Gradient * > result,bool build_default_zeros_grads)682 Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
683     const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
684     const gtl::ArraySlice<int64> target_tensor_ids,
685     const gtl::ArraySlice<int64> source_tensor_ids,
686     const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
687     gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
688     bool build_default_zeros_grads) {
689   std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
690                                         source_tensor_ids.end());
691   BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
692       target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
693   std::vector<int64> op_stack =
694       InitialStack(state.op_tape, state.op_missing_tensor);
695   std::unordered_map<int64, std::vector<Gradient*>> gradients;
696   Status s = InitialGradients(vspace, target_tensor_ids,
697                               sources_that_are_targets, output_gradients,
698                               tensor_tape_, state.op_tape, &gradients);
699   auto cleanup = gtl::MakeCleanup([this, &state]() {
700     if (!persistent_) {
701       // Release all backprop functions
702       for (const auto& pair : state.op_tape) {
703         pair.second.backward_function_deleter(pair.second.backward_function);
704       }
705     }
706   });
707   if (!s.ok()) {
708     return s;
709   }
710 
711   std::unordered_map<int64, int64> gradients_size;
712   // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
713   // time, for better CPU backprop performance.
714   VLOG(1) << "Initial stack:";
715   if (VLOG_IS_ON(1)) {
716     for (auto t : op_stack) {
717       VLOG(1) << "  " << t;
718     }
719   }
720   while (!op_stack.empty()) {
721     const int64_t op = op_stack.back();
722     VLOG(1) << "Popped " << op;
723     op_stack.pop_back();
724     auto op_it = state.op_tape.find(op);
725     if (op_it == state.op_tape.end()) {
726       // It is possible for ops to end up on the stack if they are unrelated to
727       // the target; we should just skip them.
728       continue;
729     }
730     auto trace = std::move(op_it->second);
731     state.op_tape.erase(op_it);
732     std::vector<Gradient*> out_gradients;
733     out_gradients.reserve(trace.output_tensor_info.size());
734     std::vector<int64> unneeded_gradients;
735     for (int i = 0, end = trace.input_tensor_id.size(); i < end; i++) {
736       const auto& in_tensor_id = trace.input_tensor_id[i];
737       if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() &&
738           sources_set.find(in_tensor_id) == sources_set.end()) {
739         unneeded_gradients.push_back(i);
740       }
741     }
742 
743     bool any_gradient_nonzero = false;
744     std::vector<int> zero_indices;
745     for (int i = 0, end = trace.output_tensor_info.size(); i < end; ++i) {
746       const int64_t id = trace.output_tensor_info[i].GetID();
747       auto grad_it = gradients.find(id);
748       if (grad_it == gradients.end()) {
749         out_gradients.push_back(nullptr);
750         if (build_default_zeros_grads) {
751           auto func_name_it =
752               FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
753           if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
754               func_name_it->second.find(i) == func_name_it->second.end()) {
755             zero_indices.push_back(i);
756           }
757         }
758       } else {
759         any_gradient_nonzero = true;
760         Gradient* new_gradients = nullptr;
761         if (grad_it->second.size() == 1) {
762           new_gradients = grad_it->second.at(0);
763         } else {
764           new_gradients = vspace.AggregateGradients(grad_it->second);
765         }
766         if (sources_set.find(grad_it->first) == sources_set.end()) {
767           gradients.erase(grad_it);
768         } else {
769           grad_it->second.clear();
770           grad_it->second.push_back(new_gradients);
771           vspace.MarkAsResult(new_gradients);
772         }
773         out_gradients.push_back(new_gradients);
774       }
775     }
776     VLOG(1) << "Calling gradient function for '" << trace.op_type << "'";
777     std::vector<Gradient*> in_gradients(trace.input_tensor_id.size());
778     DCHECK(build_default_zeros_grads || zero_indices.empty());
779     if (any_gradient_nonzero) {
780       for (const auto i : zero_indices) {
781         out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
782       }
783       Status s;
784       s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function,
785                                       unneeded_gradients, out_gradients,
786                                       absl::MakeSpan(in_gradients));
787       if (!persistent_) {
788         trace.backward_function_deleter(trace.backward_function);
789       }
790       if (!s.ok()) {
791         return s;
792       }
793     } else {
794       if (!persistent_) {
795         trace.backward_function_deleter(trace.backward_function);
796       }
797       for (Gradient* grad : out_gradients) {
798         if (grad != nullptr) {
799           vspace.DeleteGradient(grad);
800         }
801       }
802     }
803     for (int i = 0, end = in_gradients.size(); i < end; ++i) {
804       const int64_t id = trace.input_tensor_id[i];
805       if (in_gradients[i] != nullptr) {
806         auto& unaggregated_grads = gradients[id];
807         unaggregated_grads.push_back(in_gradients[i]);
808         if (unaggregated_grads.size() > kMinAggregateCount) {
809           auto size_it = gradients_size.find(id);
810           int64_t size;
811           if (size_it == gradients_size.end()) {
812             size = vspace.NumElements(unaggregated_grads[0]);
813             gradients_size.emplace(id, size);
814           } else {
815             size = size_it->second;
816           }
817           if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
818             Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
819             unaggregated_grads.clear();
820             unaggregated_grads.push_back(grad);
821           }
822         }
823       }
824       auto usage_count_it = state.tensor_usage_counts.find(id);
825       if (usage_count_it == state.tensor_usage_counts.end()) {
826         VLOG(1) << "Tensor " << id << " not used";
827         continue;
828       }
829       usage_count_it->second--;
830       if (usage_count_it->second > 0) {
831         VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
832         continue;
833       }
834       auto tape_it = tensor_tape_.find(id);
835       if (tape_it == tensor_tape_.end()) {
836         VLOG(1) << "Tensor " << id
837                 << " has no associated op. Deleting gradient";
838         auto grad_it = gradients.find(id);
839         if (grad_it != gradients.end()) {
840           for (auto g : grad_it->second) {
841             vspace.DeleteGradient(g);
842           }
843           gradients.erase(grad_it);
844         }
845         continue;
846       }
847       const int64_t op_id = tape_it->second;
848       if (op_id == -1) {
849         VLOG(1) << "Tensor " << id << " is source";
850         continue;
851       }
852       auto missing_it = state.op_missing_tensor.find(op_id);
853       if (missing_it != state.op_missing_tensor.end()) {
854         missing_it->second--;
855         VLOG(1) << "Op " << op_id << " missing " << missing_it->second
856                 << " output gradients";
857         if (missing_it->second == 0) {
858           op_stack.insert(op_stack.begin(), op_id);
859         }
860       }
861     }
862   }
863   if (!state.op_tape.empty()) {
864     return tensorflow::errors::Internal("Invalid tape state.");
865   }
866   if (result.size() != source_tensor_ids.size()) {
867     return errors::Internal("Expected result Span to be of size ",
868                             source_tensor_ids.size(), " found ", result.size(),
869                             " in call to Tape::ComputeGradient.");
870   }
871   std::unordered_set<int64> used_gradient_ids(source_tensor_ids.size());
872   for (int i = 0; i < source_tensor_ids.size(); i++) {
873     int64_t tensor_id = source_tensor_ids[i];
874     auto grad_it = gradients.find(tensor_id);
875     if (grad_it == gradients.end()) {
876       result[i] = nullptr;
877     } else {
878       if (grad_it->second.size() > 1) {
879         Gradient* grad = vspace.AggregateGradients(grad_it->second);
880         grad_it->second.clear();
881         grad_it->second.push_back(grad);
882       }
883       result[i] = grad_it->second[0];
884       used_gradient_ids.insert(tensor_id);
885     }
886   }
887   VLOG(1) << "Final gradients size: "
888           << gradients.size() - used_gradient_ids.size();
889   for (const auto& grad_pair : gradients) {
890     if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
891       for (const auto& g : grad_pair.second) {
892         vspace.DeleteGradient(g);
893       }
894     }
895   }
896   return Status::OK();
897 }
898 
899 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)900 bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
901     gtl::ArraySlice<int64> tensor_ids,
902     gtl::ArraySlice<tensorflow::DataType> dtypes) {
903   if (call_state_.top().backward_tape != nullptr) {
904     // If we're forwarding Accumulate calls to backward_tape's RecordOperation,
905     // we should also delegate ShouldRecord.
906     return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
907   }
908   if (call_state_.top().accumulating) {
909     return false;
910   }
911   for (int i = 0; i < tensor_ids.size(); ++i) {
912     if (accumulated_gradients_.find(tensor_ids[i]) !=
913         accumulated_gradients_.end()) {
914       if (IsDtypeTrainable(dtypes[i])) {
915         return true;
916       }
917     }
918   }
919   return false;
920 }
921 
922 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
923 Status
ForwardpropFromTape(const string & op_type,const std::vector<TapeTensor> & output_tensors,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter,const std::vector<Gradient * > & in_grads,absl::Span<Gradient * > out_grads)924 ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
925     const string& op_type, const std::vector<TapeTensor>& output_tensors,
926     const std::function<BackwardFunction*()>& backward_function_getter,
927     const std::function<void(BackwardFunction*)>& backward_function_deleter,
928     const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads) {
929   /* This function is approximately equivalent to this Python code:
930 
931   forwardprop_aids = tf.ones_like(output_tensors)
932   with tf.GradientTape() as g:
933     g.watch(forwardprop_aids)
934     grad = backward_function(forwardprop_aids)
935   forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads)
936   accumulated_gradients_[ID(output_tensors)] = forward_grads
937   */
938   std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
939       new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
940   AccumulatorCallState& call_state = call_state_.top();
941   call_state.backward_tape = tape.get();
942   auto pop_backward_tape =
943       gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
944   std::vector<Gradient*> forwardprop_aids;
945   std::vector<int64> sources;
946   std::unordered_set<int64> sources_set;
947   sources.reserve(output_tensors.size());
948   for (const TapeTensor& output_tensor : output_tensors) {
949     // Ownership of `aid` transferred to CallBackwardFunction below.
950     Gradient* aid;
951     if (output_tensor.GetDType() == tensorflow::DT_VARIANT) {
952       // Note: Needs to be zeros rather than ones since there's currently no
953       // ones_like for variants.
954       aid = output_tensor.ZerosLike();
955     } else {
956       // TODO(allenl): Figure out why using zeros_like everywhere causes issues
957       // for some gradient functions and if there's another way to work around
958       // it (e.g. conds instead of ifs). The value shouldn't really matter.
959       TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid));
960     }
961     if (TF_PREDICT_FALSE(aid == nullptr)) {
962       return tensorflow::errors::Internal(
963           "Failed to create ones tensor for tensor ", output_tensor.GetID(),
964           " with dtype ", output_tensor.GetDType());
965     }
966     forwardprop_aids.push_back(aid);
967     int64_t aid_id = vspace_.TensorId(aid);
968     sources.push_back(aid_id);
969     sources_set.insert(aid_id);
970     tape->Watch(aid_id);
971   }
972   std::vector<Gradient*> grad(in_grads.size());
973   auto delete_grad = gtl::MakeCleanup([&grad, this] {
974     for (Gradient* tensor : grad) {
975       this->vspace_.DeleteGradient(tensor);
976     }
977   });
978   {
979     std::vector<int64> unneeded_gradients;
980     std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
981         backward_function(backward_function_getter(),
982                           backward_function_deleter);
983     TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
984         op_type, backward_function.get(), unneeded_gradients, forwardprop_aids,
985         absl::MakeSpan(grad)));
986   }
987 
988   // Stop the tape from recording
989   pop_backward_tape.release()();
990 
991   std::vector<int64> targets;
992   std::vector<Gradient*> used_in_grads;
993   // We may end up with slightly fewer elements than we reserve, but grad.size()
994   // should be a reasonably tight upper bound.
995   targets.reserve(grad.size());
996   used_in_grads.reserve(grad.size());
997   std::unordered_map<int64, TapeTensor> sources_that_are_targets;
998   for (int grad_index = 0, end = grad.size(); grad_index < end; ++grad_index) {
999     Gradient* grad_tensor = grad[grad_index];
1000     if (grad_tensor != nullptr) {
1001       int64_t tensor_id = vspace_.TensorId(grad_tensor);
1002       targets.push_back(tensor_id);
1003       if (sources_set.find(tensor_id) != sources_set.end()) {
1004         sources_that_are_targets.emplace(
1005             tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
1006       }
1007       Gradient* in_grad = in_grads[grad_index];
1008       if (in_grad != nullptr) {
1009         // ComputeGradient steals a reference
1010         vspace_.MarkAsResult(in_grad);
1011       }
1012       used_in_grads.push_back(in_grad);
1013     }
1014   }
1015 
1016   return tape->ComputeGradient(vspace_, targets, sources,
1017                                sources_that_are_targets, used_in_grads,
1018                                out_grads);
1019 }
1020 
1021 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Accumulate(const string & op_type,const std::vector<TapeTensor> & input_tensors,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const ForwardFunction<Gradient> * forward_function,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)1022 Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
1023     const string& op_type, const std::vector<TapeTensor>& input_tensors,
1024     const std::vector<TapeTensor>& output_tensors,
1025     gtl::ArraySlice<int64> input_tensor_id,
1026     gtl::ArraySlice<tensorflow::DataType> input_dtypes,
1027     const ForwardFunction<Gradient>* forward_function,
1028     const std::function<BackwardFunction*()>& backward_function_getter,
1029     const std::function<void(BackwardFunction*)>& backward_function_deleter) {
1030   if (call_state_.top().backward_tape != nullptr) {
1031     // If backward_tape is not null, then this call to Accumulate is the result
1032     // of a still-active call to Accumulate which is running operations. We
1033     // forward these operations to backward_tape so the outer Accumulate call
1034     // can do its work.
1035     //
1036     // Rather than re-entering and delegating Accumulate like this, we could
1037     // instead allow ForwardAccumulator some control over the current tape set
1038     // (so it can deactivate itself and activate its GradientTape). Currently
1039     // that is managed by the language binding and would require relatively
1040     // messy callbacks.
1041     call_state_.top().backward_tape->RecordOperation(
1042         op_type, output_tensors, input_tensor_id, input_dtypes,
1043         backward_function_getter, backward_function_deleter);
1044     return Status::OK();
1045   }
1046   if (!ShouldRecord(input_tensor_id, input_dtypes)) {
1047     return Status::OK();
1048   }
1049 
1050   // We may need to allocate zero inputs for trainable dtypes we don't have JVPs
1051   // for. Make sure they get cleaned up.
1052   std::vector<Gradient*> new_zeros;
1053   auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
1054     for (Gradient* tensor : new_zeros) {
1055       this->vspace_.DeleteGradient(tensor);
1056     }
1057   });
1058   std::vector<Gradient*> in_grads;
1059   in_grads.reserve(input_tensors.size());
1060   for (int target_index = 0; target_index < input_tensors.size();
1061        ++target_index) {
1062     const auto current_grad =
1063         accumulated_gradients_.find(input_tensors[target_index].GetID());
1064     if (current_grad == accumulated_gradients_.end()) {
1065       if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
1066         // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
1067         // GradientTape which uses ones.
1068         Gradient* zero = input_tensors[target_index].ZerosLike();
1069         new_zeros.push_back(zero);
1070         in_grads.push_back(zero);
1071       } else {
1072         in_grads.push_back(nullptr);
1073       }
1074     } else {
1075       in_grads.push_back(current_grad->second);
1076     }
1077   }
1078 
1079   // Avoid infinite recursion. Whichever forward function we run, it'll end up
1080   // executing ops, and we don't want to watch those with this accumulator.
1081   call_state_.emplace(nullptr, true);
1082   auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
1083 
1084   std::vector<Gradient*> forward_grads;
1085   if (forward_function == nullptr) {
1086     // We have no special-cased forward gradient. Fall back to running the
1087     // backward function under a gradient tape.
1088     forward_grads.resize(output_tensors.size());
1089     TF_RETURN_IF_ERROR(ForwardpropFromTape(
1090         op_type, output_tensors, backward_function_getter,
1091         backward_function_deleter, in_grads, absl::MakeSpan(forward_grads)));
1092   } else {
1093     TF_RETURN_IF_ERROR(
1094         (*forward_function)(in_grads, &forward_grads, use_batch_));
1095   }
1096   for (int i = 0; i < forward_grads.size(); ++i) {
1097     if (forward_grads[i] != nullptr) {
1098       int64_t tensor_id = output_tensors[i].GetID();
1099       auto existing = accumulated_gradients_.find(tensor_id);
1100       if (existing != accumulated_gradients_.end()) {
1101         // This is a somewhat odd case to be in, since it means we have two
1102         // operations which supposedly both created the same Tensor. It comes up
1103         // in recompute_grad, where the gradients have the same value. However,
1104         // only the original gradient is connected to everything else, so we
1105         // should still use that.
1106         vspace_.DeleteGradient(forward_grads[i]);
1107       } else {
1108         accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
1109       }
1110     }
1111   }
1112   return Status::OK();
1113 }
1114 
1115 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64_t tensor_id,Gradient * tangent)1116 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
1117     int64_t tensor_id, Gradient* tangent) {
1118   typename std::unordered_map<int64, Gradient*>::iterator existing =
1119       accumulated_gradients_.find(tensor_id);
1120   vspace_.MarkAsResult(tangent);
1121   if (existing == accumulated_gradients_.end()) {
1122     accumulated_gradients_.emplace(tensor_id, tangent);
1123   } else {
1124     std::array<Gradient*, 2> to_aggregate;
1125     to_aggregate[0] = tangent;
1126     to_aggregate[1] = existing->second;
1127     // AggregateGradients steals a reference to each of its arguments. We
1128     // MarkAsResult on `tangent` above so we don't steal a reference to it.
1129     existing->second = vspace_.AggregateGradients(to_aggregate);
1130   }
1131 }
1132 
1133 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteGradient(int64_t tensor_id)1134 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
1135     int64_t tensor_id) {
1136   auto existing = accumulated_gradients_.find(tensor_id);
1137   if (existing != accumulated_gradients_.end()) {
1138     vspace_.DeleteGradient(existing->second);
1139     accumulated_gradients_.erase(existing);
1140   }
1141 }
1142 
1143 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
FetchJVP(int64_t tensor_id)1144 Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP(
1145     int64_t tensor_id) {
1146   auto lookup = accumulated_gradients_.find(tensor_id);
1147   if (lookup == accumulated_gradients_.end()) {
1148     return nullptr;
1149   } else {
1150     return lookup->second;
1151   }
1152 }
1153 
1154 }  // namespace eager
1155 }  // namespace tensorflow
1156 
1157 #endif  // TENSORFLOW_C_EAGER_TAPE_H_
1158