• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_TAPE_H_
16 #define TENSORFLOW_C_EAGER_TAPE_H_
17 
18 // Language-agnostic gradient tape. Does not perform backpropagation, just
19 // maintains the data structures required to do so.
20 
21 #include <vector>
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/gtl/array_slice.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/gtl/flatset.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace eager {
31 
32 // Represents an entry in the tape.
33 template <typename BackwardFunction, typename TapeTensor>
34 struct OpTapeEntry {
35   string op_type;
36   std::vector<TapeTensor> output_tensor_info;
37   std::vector<int64> input_tensor_id;
38 
39   // TODO(apassos) consider narrowing down this interface.
40   BackwardFunction* backward_function;
41 
42   // Should be called before deleting the backward function. TODO(apassos) use
43   // unique_ptrs to ensure this happens.
44   std::function<void(BackwardFunction*)> backward_function_deleter;
45 };
46 
47 // Map from tensor_id to internally-defined operation-id of the operation which
48 // produced this tensor. A value of -1 means that the tensor was directly
49 // watched and not the result of any operation in the tape.
50 using TensorTape = gtl::FlatMap<int64, int64>;
51 
52 // Map from operation-id to tape entry.
53 template <typename BackwardFunction, typename TapeTensor>
54 using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
55 
56 // Operations the tape needs to perform on tensors to do backpropagation. Named
57 // "vspace" because a subset of these are related to a vector space, such as
58 // adding gradients, getting zeroes, etc. Currently cannot be implemented
59 // without using tensorflow python code, hence left unspecified here.
60 //
61 // Gradient is the type returned by gradient functions. In Python TF it's either
62 // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
63 // to allow their size to be computed and they need to be passable to a backward
64 // function and deleted (as the backprop code creates lots of gradients the user
65 // is not interested in).
66 //
67 // BackwardFunction needs to be a closure which stores intermediate activations
68 // from the forward computation and calls a vector-jacobian product function
69 // (also known as adjoint function) to compute, given downstream gradients,
70 // upstream gradients.
71 //
72 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
73 // specialization, which is blocked by quite a few things needing to loop back
74 // into python now.
75 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
76 class VSpace {
77  public:
~VSpace()78   virtual ~VSpace() {}
79 
80   // Returns the number of elements in the gradient tensor.
81   virtual int64 NumElements(Gradient* tensor) const = 0;
82 
83   // Consumes references to the tensors in the gradient_tensors list and returns
84   // a tensor with the result.
85   virtual Gradient* AggregateGradients(
86       gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
87 
88   // Returns a tensor of the right shape and dtype filled with zeros.
89   virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
90 
91   // Returns a Tensor which is filled with ones and like the input.
92   virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
93 
94   // Calls the passed-in backward function.
95   virtual Status CallBackwardFunction(
96       BackwardFunction* backward_function,
97       gtl::ArraySlice<Gradient*> output_gradients,
98       std::vector<Gradient*>* result) const = 0;
99 
100   // Marks the following gradient as a result so it's not consumed by backward
101   // functions.
102   virtual void MarkAsResult(Gradient* gradient) const = 0;
103 
104   // Deletes the input tensor.
105   virtual void DeleteGradient(Gradient* gradient) const = 0;
106 };
107 
108 // Traces the execution of operations, doing eager garbage collection, and
109 // exporting a full trace so other code can do backpropagation. Not thread-safe.
110 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
111 class GradientTape {
112  public:
113   // If `persistent` is true, GradientTape will not eagerly delete backward
114   // functions (and hence the tensors they keep alive). Instead, everything
115   // is deleted in ~GradientTape. Persistent GradientTapes are useful when
116   // users want to compute multiple gradients over the same tape.
GradientTape(bool persistent)117   GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape()118   ~GradientTape() {
119     for (const auto& pair : op_tape_) {
120       pair.second.backward_function_deleter(pair.second.backward_function);
121     }
122   }
123 
124   bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
125                     gtl::ArraySlice<tensorflow::DataType> dtypes);
126 
127   void Watch(int64 tensor_id);
128 
129   void RecordOperation(
130       const string& op_type, std::vector<TapeTensor>& output_tensors,
131       gtl::ArraySlice<int64> input_tensor_id,
132       gtl::ArraySlice<tensorflow::DataType> input_dtypes,
133       const std::function<BackwardFunction*()>& backward_function_getter,
134       const std::function<void(BackwardFunction*)>& backward_function_deleter);
135 
136   void DeleteTrace(int64 tensor_id);
137 
138   // Consumes the internal state of the tape (so cannot be called more than
139   // once) and produces the gradient of the target tensors with respect to the
140   // source tensors. The output gradients are used if not empty and not
141   // null. The result is populated with one tensor per target element.
142   Status ComputeGradient(
143       const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
144       const gtl::ArraySlice<int64> target_tensor_ids,
145       const gtl::ArraySlice<int64> source_tensor_ids,
146       const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
147       gtl::ArraySlice<Gradient*> output_gradients,
148       std::vector<Gradient*>* result);
149 
IsPersistent()150   bool IsPersistent() const { return persistent_; }
151 
152  private:
153   TensorTape tensor_tape_;
154   OpTape<BackwardFunction, TapeTensor> op_tape_;
155   int64 next_op_id_{0};
156 
157   // Map from tensor id to number of remaining usages (i.e. how many entries in
158   // the tape refer to it); to aid in tape garbage collection.
159   gtl::FlatMap<int64, int64> tensor_usage_;
160 
161   // If false, all activations are deleted in the first call to ComputeGradient.
162   // Else, only when this is destructed.
163   bool persistent_;
164 };
165 
166 // Template instantiations here
167 
IsDtypeTrainable(DataType dtype)168 inline bool IsDtypeTrainable(DataType dtype) {
169   switch (dtype) {
170     case DT_HALF:
171     case DT_BFLOAT16:
172     case DT_FLOAT:
173     case DT_DOUBLE:
174     case DT_COMPLEX64:
175     case DT_COMPLEX128:
176     case DT_RESOURCE:
177     case DT_VARIANT:
178       return true;
179     default:
180       return false;
181   }
182 }
183 
184 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)185 bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
186     gtl::ArraySlice<int64> tensor_ids,
187     gtl::ArraySlice<tensorflow::DataType> dtypes) {
188   CHECK_EQ(tensor_ids.size(), dtypes.size());
189   for (int i = 0; i < tensor_ids.size(); ++i) {
190     if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
191       if (IsDtypeTrainable(dtypes[i])) {
192         return true;
193       }
194     }
195   }
196   return false;
197 }
198 
199 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id)200 void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
201     int64 tensor_id) {
202   tensor_tape_.emplace(tensor_id, -1);
203 }
204 
205 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
RecordOperation(const string & op_type,std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)206 void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
207     const string& op_type, std::vector<TapeTensor>& output_tensors,
208     gtl::ArraySlice<int64> input_tensor_id,
209     gtl::ArraySlice<tensorflow::DataType> input_dtypes,
210     const std::function<BackwardFunction*()>& backward_function_getter,
211     const std::function<void(BackwardFunction*)>& backward_function_deleter) {
212   if (!ShouldRecord(input_tensor_id, input_dtypes)) {
213     return;
214   }
215   std::vector<int64> ids;
216   ids.reserve(input_tensor_id.size());
217   for (int64 i : input_tensor_id) {
218     tensor_usage_[i]++;
219     ids.push_back(i);
220   }
221   const int64 op_id = next_op_id_++;
222   std::vector<TapeTensor> tensors;
223   tensors.reserve(output_tensors.size());
224   for (const TapeTensor& o : output_tensors) {
225     // Note: the tensor can have already been watched and hence be in the tape,
226     // so we cannot check that we're inserting it here.
227     tensor_tape_[o.GetID()] = op_id;
228     tensor_usage_[o.GetID()] = 1;
229     tensors.push_back(o);
230   }
231   op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
232       op_type, std::move(tensors), std::move(ids), backward_function_getter(),
233       backward_function_deleter};
234 }
235 
236 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteTrace(int64 tensor_id)237 void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
238     int64 tensor_id) {
239   auto it = tensor_usage_.find(tensor_id);
240   if (it == tensor_usage_.end()) {
241     return;
242   }
243   it->second--;
244   if (it->second != 0) {
245     return;
246   }
247   tensor_usage_.erase(it);
248   auto tensor_op_it = tensor_tape_.find(tensor_id);
249   if (tensor_op_it == tensor_tape_.end()) {
250     return;
251   }
252   const int64 op_id = tensor_op_it->second;
253   if (op_id == -1) {
254     // Do not delete watched tensors.
255     return;
256   }
257   tensor_tape_.erase(tensor_op_it);
258   auto op_it = op_tape_.find(op_id);
259   CHECK(op_it != op_tape_.end());
260   for (const auto& output : op_it->second.output_tensor_info) {
261     if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
262       // Found a usage for an output, so cannot delete the op.
263       return;
264     }
265   }
266   for (int64 id : op_it->second.input_tensor_id) {
267     DeleteTrace(id);
268   }
269   op_it->second.backward_function_deleter(op_it->second.backward_function);
270   op_tape_.erase(op_it);
271 }
272 
273 // Terminology:
274 //
275 //  - op: a possibly composite operation, which has an entry in the tape
276 //  - target: dy in dx/dy
277 //  - source: dx in dx/dy
278 //  - tensor: one of the many inputs or outputs of an operation
279 //
280 // Below here we do the gradient algorithm. It works as follows:
281 //
282 // First we filter the tape to just the subset of operations we want to
283 // differentiate. In the process of doing so we count how many times each Tensor
284 // is used as an input to an op (so we know when we're done computing gradients
285 // for that Tensor). We also count, for each tape entry, how many of its output
286 // Tensors need gradients to be computed (Tensors which are not used do not need
287 // any gradients to be computed).
288 //
289 // Finally, we start a backprop stack with a set of tape entries for which we
290 // have all gradients available. This set usually is a subset of the set of
291 // targets (not all since targets which have outputs in the tape will not have
292 // gradients available initially).
293 //
294 // Then we repeatedly pop an entry from the stack, run its backprop, and update
295 // the gradients of its inputs. Once we have computed all gradients for a single
296 // input we can mark this input as done, and this can trigger adding an entry to
297 // the stack if all outputs of that entry are now done.
298 //
299 // When the stack is empty we have gradients for all tensors we're interested
300 // in.
301 
302 namespace {
303 
304 template <typename BackwardFunction, typename TapeTensor>
305 struct BackpropInitialState {
306   OpTape<BackwardFunction, TapeTensor> op_tape;
307 
308   // Map from tensor ID to how many references still exist for this tensor in
309   // the tape.
310   gtl::FlatMap<int64, int64> tensor_usage_counts;
311 
312   // Maps from op ID to how many output tensors of this op still need to have
313   // their gradients computed.
314   gtl::FlatMap<int64, int64> op_missing_tensor;
315 };
316 
317 // If `persistent_tape` is true, op_tape is not changed and none of the
318 // backwards functions are deleted.
319 // If `persistent_tape` is false, op_tape is cleared and backwards functions
320 // not needed for gradient computation are deleted. Backwards functions that
321 // are needed, are copied and returned in BackpropInitialState.
322 template <typename BackwardFunction, typename TapeTensor>
PrepareBackprop(gtl::ArraySlice<int64> target,const TensorTape & tensor_tape,OpTape<BackwardFunction,TapeTensor> * op_tape,const gtl::FlatSet<int64> & sources_set,bool persistent_tape)323 BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
324     gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
325     OpTape<BackwardFunction, TapeTensor>* op_tape,
326     const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
327   std::vector<int64> tensor_stack;
328   tensor_stack.reserve(target.size());
329   for (auto t : target) {
330     tensor_stack.push_back(t);
331   }
332   BackpropInitialState<BackwardFunction, TapeTensor> result;
333   while (!tensor_stack.empty()) {
334     int64 tensor_id = tensor_stack.back();
335     tensor_stack.pop_back();
336     auto op_id_it = tensor_tape.find(tensor_id);
337     if (op_id_it == tensor_tape.end()) {
338       continue;
339     }
340     int64 op_id = op_id_it->second;
341     auto op_it = op_tape->find(op_id);
342     auto result_op_it = result.op_tape.find(op_id);
343     if (op_id == -1 || op_it == op_tape->end() ||
344         result_op_it != result.op_tape.end()) {
345       continue;
346     }
347     CHECK(result.op_tape.emplace(op_id, op_it->second).second);
348     for (auto it : op_it->second.input_tensor_id) {
349       auto count_it = result.tensor_usage_counts.find(it);
350       if (count_it != result.tensor_usage_counts.end()) {
351         count_it->second++;
352       } else {
353         result.tensor_usage_counts[it] = 1;
354         if (tensor_tape.find(it) != tensor_tape.end()) {
355           tensor_stack.push_back(it);
356         }
357       }
358     }
359     if (!persistent_tape) {
360       op_tape->erase(op_it);
361     }
362   }
363   for (auto& pair : result.tensor_usage_counts) {
364     auto it = tensor_tape.find(pair.first);
365     if (it != tensor_tape.end() && it->second != -1) {
366       result.op_missing_tensor[it->second] += 1;
367     }
368   }
369   if (!persistent_tape) {
370     // Call destructors for all unneeded gradient functions and
371     // clear the op_tape. We can clear the tape because ownership of
372     // backward functions that will be used for gradient computation
373     // has been transferred to `result`.
374     for (const auto& op_pair : *op_tape) {
375       op_pair.second.backward_function_deleter(
376           op_pair.second.backward_function);
377     }
378     op_tape->clear();
379   }
380   return result;
381 }
382 
383 template <typename BackwardFunction, typename TapeTensor>
InitialStack(const OpTape<BackwardFunction,TapeTensor> & op_tape,const gtl::FlatMap<int64,int64> & op_missing_tensor)384 std::vector<int64> InitialStack(
385     const OpTape<BackwardFunction, TapeTensor>& op_tape,
386     const gtl::FlatMap<int64, int64>& op_missing_tensor) {
387   std::vector<int64> result;
388   for (auto& op_entry : op_tape) {
389     if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
390       result.push_back(op_entry.first);
391     }
392   }
393   return result;
394 }
395 
396 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
InitialGradients(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,gtl::ArraySlice<int64> target_tensor_ids,gtl::FlatMap<int64,TapeTensor> sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,const TensorTape & tensor_tape,const OpTape<BackwardFunction,TapeTensor> & op_tape,gtl::FlatMap<int64,std::vector<Gradient * >> * result)397 Status InitialGradients(
398     const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
399     gtl::ArraySlice<int64> target_tensor_ids,
400     gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
401     gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
402     const OpTape<BackwardFunction, TapeTensor>& op_tape,
403     gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
404   for (int i = 0; i < target_tensor_ids.size(); ++i) {
405     const int64 id = target_tensor_ids[i];
406     if (output_gradients.empty() || output_gradients[i] == nullptr) {
407       auto tensor_it = tensor_tape.find(id);
408       if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
409         auto op_it = op_tape.find(tensor_it->second);
410         if (op_it == op_tape.end()) {
411           return errors::Internal(
412               "Internal state of the gradient tape is invalid: "
413               "failed to find operation producing a tensor");
414         }
415         bool found = false;
416         for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
417           if (op_it->second.output_tensor_info[j].GetID() == id) {
418             found = true;
419             (*result)[id].push_back(
420                 vspace.Ones(op_it->second.output_tensor_info[j]));
421             break;
422           }
423         }
424         if (!found) {
425           return errors::Internal(
426               "Internal state of the gradient tape is invalid: "
427               "none of operations outputs match expected tensor");
428         }
429       } else {
430         // This target tensor was not generated by any operation recorded on
431         // the tape, so no gradient needs to be computed from it unless this
432         // target is also a source.
433         auto source_tensor = sources_that_are_targets.find(id);
434         if (source_tensor != sources_that_are_targets.end()) {
435           (*result)[id].push_back(vspace.Ones(source_tensor->second));
436         }
437       }
438     } else {
439       (*result)[id].push_back(output_gradients[i]);
440     }
441   }
442   return Status::OK();
443 }
444 
445 // TODO(agarwal): use an automatic mechanism for handling None arguments to
446 // gradient functions.
447 //
448 // Some gradient functions can accept None arguments for gradients. The
449 // following maps the operation name to the indices at which the corresponding
450 // gradient function can accept None values. e.g. FusedBatchNorm outputs 5
451 // values and hence receives 5 gradient values during backprop. However the
452 // gradient function uses only the first of those values and ignores the rest.
453 // The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
454 // corresponding to index 0 is used, and the gradient values at indices 1-4 are
455 // ignored (and hence can be None). The backprop algorithm can then leverage
456 // this by not constructing zeros to pass for those indices.
FunctionsAcceptingNoneForIndicesMap()457 gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
458   static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
459       {"SoftmaxCrossEntropyWithLogits", {1}},
460       {"SparseSoftmaxCrossEntropyWithLogits", {1}},
461       {"FusedBatchNorm", {1, 2, 3, 4}},
462   });
463   return m;
464 }
465 
466 }  // namespace
467 
468 // If over kMinAggregateCount gradients are accumulated and the total
469 // memory consumption is over kMinAggregateBytes, do an early aggregation
470 // so as to release the gradient tensor to save memory.
471 constexpr int kMinAggregateCount = 4;
472 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
473 
474 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ComputeGradient(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,const gtl::ArraySlice<int64> target_tensor_ids,const gtl::ArraySlice<int64> source_tensor_ids,const gtl::FlatMap<int64,TapeTensor> sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,std::vector<Gradient * > * result)475 Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
476     const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
477     const gtl::ArraySlice<int64> target_tensor_ids,
478     const gtl::ArraySlice<int64> source_tensor_ids,
479     const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
480     gtl::ArraySlice<Gradient*> output_gradients,
481     std::vector<Gradient*>* result) {
482   gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
483                                   source_tensor_ids.end());
484   BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
485       target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
486   std::vector<int64> op_stack =
487       InitialStack(state.op_tape, state.op_missing_tensor);
488   gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
489   Status s = InitialGradients(vspace, target_tensor_ids,
490                               sources_that_are_targets, output_gradients,
491                               tensor_tape_, state.op_tape, &gradients);
492   auto cleanup = [this, &state]() {
493     if (!persistent_) {
494       // Release all backprop functions
495       for (const auto& pair : state.op_tape) {
496         pair.second.backward_function_deleter(pair.second.backward_function);
497       }
498     }
499   };
500   if (!s.ok()) {
501     cleanup();
502     return s;
503   }
504   gtl::FlatMap<int64, int64> gradients_size;
505   // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
506   // time, for better CPU backprop performance.
507   VLOG(1) << "Initial stack:";
508   if (VLOG_IS_ON(1)) {
509     for (auto t : op_stack) {
510       VLOG(1) << "  " << t;
511     }
512   }
513   while (!op_stack.empty()) {
514     const int64 op = op_stack.back();
515     VLOG(1) << "Popped " << op;
516     op_stack.pop_back();
517     auto op_it = state.op_tape.find(op);
518     if (op_it == state.op_tape.end()) {
519       // It is possible for ops to end up on the stack if they are unrelated to
520       // the target; we should just skip them.
521       continue;
522     }
523     auto trace = std::move(op_it->second);
524     state.op_tape.erase(op_it);
525     std::vector<Gradient*> out_gradients;
526     out_gradients.reserve(trace.output_tensor_info.size());
527     bool any_gradient_nonzero = false;
528     for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
529       const int64 id = trace.output_tensor_info[i].GetID();
530       auto grad_it = gradients.find(id);
531       if (grad_it == gradients.end()) {
532         auto func_name_it =
533             FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
534         if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
535             func_name_it->second.find(i) != func_name_it->second.end()) {
536           out_gradients.push_back(nullptr);
537         } else {
538           out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
539         }
540       } else {
541         any_gradient_nonzero = true;
542         Gradient* new_gradients = nullptr;
543         if (grad_it->second.size() == 1) {
544           new_gradients = grad_it->second.at(0);
545         } else {
546           new_gradients = vspace.AggregateGradients(grad_it->second);
547         }
548         if (sources_set.find(grad_it->first) == sources_set.end()) {
549           gradients.erase(grad_it);
550         } else {
551           grad_it->second.clear();
552           grad_it->second.push_back(new_gradients);
553           vspace.MarkAsResult(new_gradients);
554         }
555         out_gradients.push_back(new_gradients);
556       }
557     }
558     std::vector<Gradient*> in_gradients;
559     if (any_gradient_nonzero) {
560       Status s = vspace.CallBackwardFunction(trace.backward_function,
561                                              out_gradients, &in_gradients);
562       if (!persistent_) {
563         trace.backward_function_deleter(trace.backward_function);
564       }
565       if (!s.ok()) {
566         cleanup();
567         return s;
568       }
569     } else {
570       in_gradients.resize(trace.input_tensor_id.size());
571       if (!persistent_) {
572         trace.backward_function_deleter(trace.backward_function);
573       }
574       for (Gradient* grad : out_gradients) {
575         if (grad != nullptr) {
576           vspace.DeleteGradient(grad);
577         }
578       }
579     }
580     VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
581             << trace.input_tensor_id.size() << " sources";
582     for (int i = 0; i < in_gradients.size(); ++i) {
583       const int64 id = trace.input_tensor_id[i];
584       if (in_gradients[i] != nullptr) {
585         auto& unaggregated_grads = gradients[id];
586         unaggregated_grads.push_back(in_gradients[i]);
587         if (unaggregated_grads.size() > kMinAggregateCount) {
588           auto size_it = gradients_size.find(id);
589           int64 size;
590           if (size_it == gradients_size.end()) {
591             size = vspace.NumElements(unaggregated_grads[0]);
592             gradients_size.emplace(id, size);
593           } else {
594             size = size_it->second;
595           }
596           if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
597             Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
598             unaggregated_grads.clear();
599             unaggregated_grads.push_back(grad);
600           }
601         }
602       }
603       auto usage_count_it = state.tensor_usage_counts.find(id);
604       if (usage_count_it == state.tensor_usage_counts.end()) {
605         VLOG(1) << "Tensor " << id << " not used";
606         continue;
607       }
608       usage_count_it->second--;
609       if (usage_count_it->second > 0) {
610         VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
611         continue;
612       }
613       auto tape_it = tensor_tape_.find(id);
614       if (tape_it == tensor_tape_.end()) {
615         VLOG(1) << "Tensor " << id
616                 << " has no associated op. Deleting gradient";
617         auto grad_it = gradients.find(id);
618         if (grad_it != gradients.end()) {
619           for (auto g : grad_it->second) {
620             vspace.DeleteGradient(g);
621           }
622           gradients.erase(grad_it);
623         }
624         continue;
625       }
626       const int64 op_id = tape_it->second;
627       if (op_id == -1) {
628         VLOG(1) << "Tensor " << id << " is source";
629         continue;
630       }
631       auto missing_it = state.op_missing_tensor.find(op_id);
632       if (missing_it != state.op_missing_tensor.end()) {
633         missing_it->second--;
634         VLOG(1) << "Op " << op_id << " missing " << missing_it->second
635                 << " output gradients";
636         if (missing_it->second == 0) {
637           op_stack.push_back(op_id);
638         }
639       }
640     }
641   }
642   if (!state.op_tape.empty()) {
643     return tensorflow::errors::Internal("Invalid tape state.");
644   }
645   result->reserve(source_tensor_ids.size());
646   gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size());
647   for (auto is : source_tensor_ids) {
648     auto grad_it = gradients.find(is);
649     if (grad_it == gradients.end()) {
650       result->push_back(nullptr);
651     } else {
652       if (grad_it->second.size() > 1) {
653         Gradient* grad = vspace.AggregateGradients(grad_it->second);
654         grad_it->second.clear();
655         grad_it->second.push_back(grad);
656       }
657       result->push_back(grad_it->second[0]);
658       used_gradient_ids.insert(is);
659     }
660   }
661   VLOG(1) << "Final gradients size: "
662           << gradients.size() - used_gradient_ids.size();
663   for (auto grad_pair : gradients) {
664     if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
665       for (const auto& g : grad_pair.second) {
666         vspace.DeleteGradient(g);
667       }
668     }
669   }
670   return Status::OK();
671 }
672 
673 }  // namespace eager
674 }  // namespace tensorflow
675 
676 #endif  // TENSORFLOW_C_EAGER_TAPE_H_
677