1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_TAPE_H_
16 #define TENSORFLOW_C_EAGER_TAPE_H_
17
18 // Language-agnostic gradient tape. Does not perform backpropagation, just
19 // maintains the data structures required to do so.
20
21 #include <stack>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace tensorflow {
35 namespace eager {
36
37 // Represents an entry in the tape.
38 template <typename BackwardFunction, typename TapeTensor>
39 struct OpTapeEntry {
40 string op_type;
41 std::vector<TapeTensor> output_tensor_info;
42 std::vector<int64> input_tensor_id;
43
44 // TODO(apassos) consider narrowing down this interface.
45 BackwardFunction* backward_function;
46
47 // Should be called before deleting the backward function. TODO(apassos) use
48 // unique_ptrs to ensure this happens.
49 std::function<void(BackwardFunction*)> backward_function_deleter;
50 };
51
52 // Map from tensor_id to internally-defined operation-id of the operation which
53 // produced this tensor. A value of -1 means that the tensor was directly
54 // watched and not the result of any operation in the tape.
55 using TensorTape = std::unordered_map<int64, int64>;
56
57 // Map from operation-id to tape entry.
58 template <typename BackwardFunction, typename TapeTensor>
59 using OpTape =
60 std::unordered_map<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
61
62 // Operations the tape needs to perform on tensors to do backpropagation. Named
63 // "vspace" because a subset of these are related to a vector space, such as
64 // adding gradients, getting zeroes, etc. Currently cannot be implemented
65 // without using tensorflow python code, hence left unspecified here.
66 //
67 // Gradient is the type returned by gradient functions. In Python TF it's either
68 // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
69 // to allow their size to be computed and they need to be passable to a backward
70 // function and deleted (as the backprop code creates lots of gradients the user
71 // is not interested in).
72 //
73 // BackwardFunction needs to be a closure which stores intermediate activations
74 // from the forward computation and calls a vector-jacobian product function
75 // (also known as adjoint function) to compute, given downstream gradients,
76 // upstream gradients.
77 //
78 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
79 // specialization, which is blocked by quite a few things needing to loop back
80 // into python now.
81 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
82 class VSpace {
83 public:
~VSpace()84 virtual ~VSpace() {}
85
86 // Returns the number of elements in the gradient tensor.
87 virtual int64 NumElements(Gradient* tensor) const = 0;
88
89 // Consumes references to the tensors in the gradient_tensors list and returns
90 // a tensor with the result.
91 virtual Gradient* AggregateGradients(
92 gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
93
94 // Calls the passed-in backward function.
95 virtual Status CallBackwardFunction(
96 BackwardFunction* backward_function,
97 const std::vector<int64>& unneeded_gradients,
98 gtl::ArraySlice<Gradient*> output_gradients,
99 std::vector<Gradient*>* result) const = 0;
100
101 // Looks up the ID of a Gradient.
102 virtual int64 TensorId(Gradient* tensor) const = 0;
103
104 // Converts a Gradient to a TapeTensor.
105 virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0;
106
107 // Marks the following gradient as a result so it's not consumed by backward
108 // functions.
109 virtual void MarkAsResult(Gradient* gradient) const = 0;
110
111 // Deletes the input tensor.
112 virtual void DeleteGradient(Gradient* gradient) const = 0;
113 };
114
115 // Traces the execution of operations, doing eager garbage collection, and
116 // exporting a full trace so other code can do backpropagation. Not thread-safe.
117 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
118 class GradientTape {
119 public:
120 // If `persistent` is true, GradientTape will not eagerly delete backward
121 // functions (and hence the tensors they keep alive). Instead, everything
122 // is deleted in ~GradientTape. Persistent GradientTapes are useful when
123 // users want to compute multiple gradients over the same tape.
GradientTape(bool persistent)124 GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape()125 ~GradientTape() {
126 for (const auto& pair : op_tape_) {
127 pair.second.backward_function_deleter(pair.second.backward_function);
128 }
129 }
130
131 bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
132 gtl::ArraySlice<tensorflow::DataType> dtypes);
133
134 void Watch(int64 tensor_id);
135
136 void RecordOperation(
137 const string& op_type, const std::vector<TapeTensor>& output_tensors,
138 gtl::ArraySlice<int64> input_tensor_id,
139 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
140 const std::function<BackwardFunction*()>& backward_function_getter,
141 const std::function<void(BackwardFunction*)>& backward_function_deleter);
142
143 void DeleteTrace(int64 tensor_id);
144
145 // Consumes the internal state of the tape (so cannot be called more than
146 // once) and produces the gradient of the target tensors with respect to the
147 // source tensors. The output gradients are used if not empty and not
148 // null. The result is populated with one tensor per target element.
149 Status ComputeGradient(
150 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
151 const gtl::ArraySlice<int64> target_tensor_ids,
152 const gtl::ArraySlice<int64> source_tensor_ids,
153 const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
154 gtl::ArraySlice<Gradient*> output_gradients,
155 std::vector<Gradient*>* result);
156
IsPersistent()157 bool IsPersistent() const { return persistent_; }
158
159 private:
160 TensorTape tensor_tape_;
161 OpTape<BackwardFunction, TapeTensor> op_tape_;
162 int64 next_op_id_{0};
163
164 // Map from tensor id to number of remaining usages (i.e. how many entries in
165 // the tape refer to it); to aid in tape garbage collection.
166 std::unordered_map<int64, int64> tensor_usage_;
167
168 // If false, all activations are deleted in the first call to ComputeGradient.
169 // Else, only when this is destructed.
170 bool persistent_;
171 };
172
173 // Describes a callback for special-cased and more efficient jvp computation.
174 //
175 // Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on
176 // that.
177 template <typename Gradient>
178 class ForwardFunction
179 : public std::function<Status(const std::vector<Gradient*>&,
180 std::vector<Gradient*>*)> {
181 public:
182 template <typename lambda_type>
ForwardFunction(lambda_type lambda)183 explicit ForwardFunction(lambda_type lambda)
184 : std::function<Status(const std::vector<Gradient*>&,
185 std::vector<Gradient*>*)>(lambda) {}
186 };
187
188 // Computes Jacobian-vector products using forward-mode automatic
189 // differentiation.
190 //
191 // While GradientTape's RecordOperation is trivial, ForwardAccumulator's
192 // Accumulate runs the gradient computation immediately.
193 //
194 // Keeps references to Tensors watched via Watch and computed in Accumulate
195 // corresponding to output_tensors, and releases these references in its
196 // destructor. However, waiting until the destructor runs loses the memory
197 // efficiency of forward-mode autodiff. Instead, language bindings should call
198 // DeleteGradient as soon as a Tensor which was `Watch`ed or was an output
199 // Tensor passed to Accumulate goes out of scope.
200 //
201 // Not thread-safe.
202 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
203 class ForwardAccumulator {
204 public:
205 // Does not take ownership of `vspace`, which must outlive the
206 // ForwardAccumulator.
ForwardAccumulator(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace)207 explicit ForwardAccumulator(
208 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace)
209 : vspace_(vspace) {
210 call_state_.emplace(nullptr, false);
211 }
212
~ForwardAccumulator()213 virtual ~ForwardAccumulator() {
214 for (auto accumulated : accumulated_gradients_) {
215 vspace_.DeleteGradient(accumulated.second);
216 }
217 }
218
219 // Tell the forward accumulator to watch tensor_id, with a Tensor tangent
220 // vector `tangent` of matching shape and dtype. Tangents are the "vector" in
221 // "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling
222 // FetchJVP for it would return `tangent`.
223 void Watch(int64 tensor_id, Gradient* tangent);
224
225 // Removes the gradient associated with tensor_id. Should be called when the
226 // Tensor associated with `tensor_id` is deleted.
227 void DeleteGradient(int64 tensor_id);
228
229 // Runs forward autodiff. Should be called whenever a new operation is
230 // available and the accumulator is active.
231 //
232 // Like GradientTape::RecordOperation, this method takes the operation type
233 // `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`,
234 // `input_tensor_id`, and `input_dtypes`; the latter two are somewhat
235 // redundant but taken as arguments to avoid repeatedly fetching these values
236 // between calls to ShouldRecord and Accumulator), and its outputs
237 // (`output_tensors`).
238 //
239 // If provided, a non-null `forward_function` will be used instead of the
240 // backward function (`backward_function_getter` /
241 // `backward_function_deleter`) to compute jvps for this operation. If
242 // `forward_function` is null, a GradientTape is used on the backward function
243 // to compute the jvp, which will waste computation when executing eagerly.
244 //
245 // Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
246 // immediately. It stores the results, which feed into Accumulate for future
247 // operations and may be fetched by calling FetchJVP. ForwardAccumulator
248 // maintains a reference to these JVPs: if an `output_tensors` Tensor is
249 // deleted, `DeleteGradient` should be called as soon as possible to free the
250 // (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor
251 // will release remaining references.
252 //
253 // This method is not thread-safe (and in general ForwardAccumulator is not
254 // thread-safe).
255 Status Accumulate(
256 const string& op_type, const std::vector<TapeTensor>& input_tensors,
257 const std::vector<TapeTensor>& output_tensors,
258 gtl::ArraySlice<int64> input_tensor_id,
259 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
260 const ForwardFunction<Gradient>* forward_function,
261 const std::function<BackwardFunction*()>& backward_function_getter,
262 const std::function<void(BackwardFunction*)>& backward_function_deleter);
263
264 // Returns true if `Accumulate` is active somewhere above on the stack and
265 // there isn't an intervening PushState. This is useful for ordering
266 // ForwardAccumulators, where more deeply nested accumulators should not see
267 // computations from less deeply nested accumulators.
BusyAccumulating()268 bool BusyAccumulating() const { return call_state_.top().accumulating; }
269
270 // Fetches the current Jacobian-vector product associated with `tensor_id`, or
271 // a nullptr if none is available.
272 //
273 // Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its
274 // return value. The caller should increment the reference count before
275 // deleting the ForwardAccumulator or calling DeleteGradient if keeping a
276 // persistent reference to a non-null result.
277 Gradient* FetchJVP(int64 tensor_id);
278
279 // Indicates whether the forward accumulator should run on an operation with
280 // the specified inputs and dtypes.
281 bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
282 gtl::ArraySlice<tensorflow::DataType> dtypes);
283
284 // Temporarily push or pop transient state for this accumulator.
285 //
286 // Allows an accumulator which is currently processing an operation to
287 // temporarily reset its state. Without pushing and popping, accumulators
288 // ignore operations executed as a direct result of their own jvp
289 // computations.
PushState()290 void PushState() { call_state_.emplace(nullptr, false); }
PopState()291 void PopState() { call_state_.pop(); }
292
293 private:
294 // Helper for Accumulate: uses a GradientTape to compute forward gradients
295 // from a backward gradient function. Fills `out_grads` corresponding to
296 // `output_tensors`. `out_grads` must not be null.
297 //
298 // Executes the backward function in order to trace its gradient, which will
299 // waste computation if executing eagerly (when graph building the unneeded
300 // computation is pruned). Temporarily sets `backward_tape` so that
301 // Accumulate will forward op executions to the tape while the backward
302 // function is running; this effectively adds the backward tape to the active
303 // set (but does not require complicated callbacks to the language bindings).
304 Status ForwardpropFromTape(
305 const std::vector<TapeTensor>& output_tensors,
306 const std::function<BackwardFunction*()>& backward_function_getter,
307 const std::function<void(BackwardFunction*)>& backward_function_deleter,
308 const std::vector<Gradient*>& in_grads,
309 std::vector<Gradient*>* out_grads);
310
311 // Maps from tensor IDs to corresponding JVPs.
312 std::unordered_map<int64, Gradient*> accumulated_gradients_;
313 // Not owned; provides operations on Tensors which are currently only
314 // available in language bindings (e.g. Python).
315 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
316
317 struct AccumulatorCallState {
AccumulatorCallStateAccumulatorCallState318 AccumulatorCallState(
319 GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
320 bool accumulating)
321 : backward_tape(backward_tape), accumulating(accumulating) {}
322 // Set temporarily while in the Accumulate method; if backward_tape is not
323 // nullptr then we forward op executions to it so Accumulate can compute a
324 // backward pass on its backward function.
325 //
326 // Not owned by the ForwardAccumulator. The method which sets
327 // `backward_tape` keeps ownership.
328 GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
329 // While the Accumulate method is running (accumulating is True), any op
330 // executions not forwarded to backward_tape should be ignored.
331 bool accumulating;
332 };
333 // A deque-backed stack, whose element references are not invalidated by
334 // pushes and pops at the back.
335 std::stack<AccumulatorCallState> call_state_;
336 };
337
338 // Template instantiations here
339
IsDtypeTrainable(DataType dtype)340 inline bool IsDtypeTrainable(DataType dtype) {
341 switch (dtype) {
342 case DT_HALF:
343 case DT_BFLOAT16:
344 case DT_FLOAT:
345 case DT_DOUBLE:
346 case DT_COMPLEX64:
347 case DT_COMPLEX128:
348 case DT_RESOURCE:
349 case DT_VARIANT:
350 return true;
351 default:
352 return false;
353 }
354 }
355
356 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)357 bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
358 gtl::ArraySlice<int64> tensor_ids,
359 gtl::ArraySlice<tensorflow::DataType> dtypes) {
360 CHECK_EQ(tensor_ids.size(), dtypes.size());
361 for (int i = 0; i < tensor_ids.size(); ++i) {
362 if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
363 if (IsDtypeTrainable(dtypes[i])) {
364 return true;
365 }
366 }
367 }
368 return false;
369 }
370
371 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id)372 void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
373 int64 tensor_id) {
374 tensor_tape_.emplace(tensor_id, -1);
375 }
376
377 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
RecordOperation(const string & op_type,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)378 void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
379 const string& op_type, const std::vector<TapeTensor>& output_tensors,
380 gtl::ArraySlice<int64> input_tensor_id,
381 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
382 const std::function<BackwardFunction*()>& backward_function_getter,
383 const std::function<void(BackwardFunction*)>& backward_function_deleter) {
384 if (!ShouldRecord(input_tensor_id, input_dtypes)) {
385 return;
386 }
387 std::vector<int64> ids;
388 ids.reserve(input_tensor_id.size());
389 for (int64 i : input_tensor_id) {
390 tensor_usage_[i]++;
391 ids.push_back(i);
392 }
393 const int64 op_id = next_op_id_++;
394 std::vector<TapeTensor> tensors;
395 tensors.reserve(output_tensors.size());
396 for (const TapeTensor& o : output_tensors) {
397 // Note: the tensor can have already been watched and hence be in the tape,
398 // so we cannot check that we're inserting it here.
399 tensor_tape_[o.GetID()] = op_id;
400 tensor_usage_[o.GetID()] = 1;
401 tensors.push_back(o);
402 }
403 op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
404 op_type, std::move(tensors), std::move(ids), backward_function_getter(),
405 backward_function_deleter};
406 }
407
408 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteTrace(int64 tensor_id)409 void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
410 int64 tensor_id) {
411 auto it = tensor_usage_.find(tensor_id);
412 if (it == tensor_usage_.end()) {
413 return;
414 }
415 it->second--;
416 if (it->second != 0) {
417 return;
418 }
419 tensor_usage_.erase(it);
420 auto tensor_op_it = tensor_tape_.find(tensor_id);
421 if (tensor_op_it == tensor_tape_.end()) {
422 return;
423 }
424 const int64 op_id = tensor_op_it->second;
425 if (op_id == -1) {
426 // Do not delete watched tensors.
427 return;
428 }
429 tensor_tape_.erase(tensor_op_it);
430 auto op_it = op_tape_.find(op_id);
431 CHECK(op_it != op_tape_.end());
432 for (const auto& output : op_it->second.output_tensor_info) {
433 if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
434 // Found a usage for an output, so cannot delete the op.
435 return;
436 }
437 }
438 for (int64 id : op_it->second.input_tensor_id) {
439 DeleteTrace(id);
440 }
441 op_it->second.backward_function_deleter(op_it->second.backward_function);
442 op_tape_.erase(op_it);
443 }
444
445 // Terminology:
446 //
447 // - op: a possibly composite operation, which has an entry in the tape
448 // - target: dy in dx/dy
449 // - source: dx in dx/dy
450 // - tensor: one of the many inputs or outputs of an operation
451 //
452 // Below here we do the gradient algorithm. It works as follows:
453 //
454 // First we filter the tape to just the subset of operations we want to
455 // differentiate. In the process of doing so we count how many times each Tensor
456 // is used as an input to an op (so we know when we're done computing gradients
457 // for that Tensor). We also count, for each tape entry, how many of its output
458 // Tensors need gradients to be computed (Tensors which are not used do not need
459 // any gradients to be computed).
460 //
461 // Finally, we start a backprop stack with a set of tape entries for which we
462 // have all gradients available. This set usually is a subset of the set of
463 // targets (not all since targets which have outputs in the tape will not have
464 // gradients available initially).
465 //
466 // Then we repeatedly pop an entry from the stack, run its backprop, and update
467 // the gradients of its inputs. Once we have computed all gradients for a single
468 // input we can mark this input as done, and this can trigger adding an entry to
469 // the stack if all outputs of that entry are now done.
470 //
471 // When the stack is empty we have gradients for all tensors we're interested
472 // in.
473
474 namespace {
475
476 template <typename BackwardFunction, typename TapeTensor>
477 struct BackpropInitialState {
478 OpTape<BackwardFunction, TapeTensor> op_tape;
479
480 // Map from tensor ID to how many references still exist for this tensor in
481 // the tape.
482 std::unordered_map<int64, int64> tensor_usage_counts;
483
484 // Maps from op ID to how many output tensors of this op still need to have
485 // their gradients computed.
486 std::unordered_map<int64, int64> op_missing_tensor;
487 };
488
489 // If `persistent_tape` is true, op_tape is not changed and none of the
490 // backwards functions are deleted.
491 // If `persistent_tape` is false, op_tape is cleared and backwards functions
492 // not needed for gradient computation are deleted. Backwards functions that
493 // are needed, are copied and returned in BackpropInitialState.
494 template <typename BackwardFunction, typename TapeTensor>
PrepareBackprop(gtl::ArraySlice<int64> target,const TensorTape & tensor_tape,OpTape<BackwardFunction,TapeTensor> * op_tape,const std::unordered_set<int64> & sources_set,bool persistent_tape)495 BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
496 gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
497 OpTape<BackwardFunction, TapeTensor>* op_tape,
498 const std::unordered_set<int64>& sources_set, bool persistent_tape) {
499 std::vector<int64> tensor_stack;
500 tensor_stack.reserve(target.size());
501 for (auto t : target) {
502 tensor_stack.push_back(t);
503 }
504 BackpropInitialState<BackwardFunction, TapeTensor> result;
505 while (!tensor_stack.empty()) {
506 int64 tensor_id = tensor_stack.back();
507 tensor_stack.pop_back();
508 auto op_id_it = tensor_tape.find(tensor_id);
509 if (op_id_it == tensor_tape.end()) {
510 continue;
511 }
512 int64 op_id = op_id_it->second;
513 auto op_it = op_tape->find(op_id);
514 auto result_op_it = result.op_tape.find(op_id);
515 if (op_id == -1 || op_it == op_tape->end() ||
516 result_op_it != result.op_tape.end()) {
517 continue;
518 }
519 CHECK(result.op_tape.emplace(op_id, op_it->second).second);
520 for (auto it : op_it->second.input_tensor_id) {
521 auto count_it = result.tensor_usage_counts.find(it);
522 if (count_it != result.tensor_usage_counts.end()) {
523 count_it->second++;
524 } else {
525 result.tensor_usage_counts[it] = 1;
526 if (tensor_tape.find(it) != tensor_tape.end()) {
527 tensor_stack.push_back(it);
528 }
529 }
530 }
531 if (!persistent_tape) {
532 op_tape->erase(op_it);
533 }
534 }
535 for (auto& pair : result.tensor_usage_counts) {
536 auto it = tensor_tape.find(pair.first);
537 if (it != tensor_tape.end() && it->second != -1) {
538 result.op_missing_tensor[it->second] += 1;
539 }
540 }
541 if (!persistent_tape) {
542 // Call destructors for all unneeded gradient functions and
543 // clear the op_tape. We can clear the tape because ownership of
544 // backward functions that will be used for gradient computation
545 // has been transferred to `result`.
546 for (const auto& op_pair : *op_tape) {
547 op_pair.second.backward_function_deleter(
548 op_pair.second.backward_function);
549 }
550 op_tape->clear();
551 }
552 return result;
553 }
554
555 template <typename BackwardFunction, typename TapeTensor>
InitialStack(const OpTape<BackwardFunction,TapeTensor> & op_tape,const std::unordered_map<int64,int64> & op_missing_tensor)556 std::vector<int64> InitialStack(
557 const OpTape<BackwardFunction, TapeTensor>& op_tape,
558 const std::unordered_map<int64, int64>& op_missing_tensor) {
559 std::vector<int64> result;
560 for (auto& op_entry : op_tape) {
561 if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
562 result.push_back(op_entry.first);
563 }
564 }
565 return result;
566 }
567
568 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
InitialGradients(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,gtl::ArraySlice<int64> target_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,const TensorTape & tensor_tape,const OpTape<BackwardFunction,TapeTensor> & op_tape,std::unordered_map<int64,std::vector<Gradient * >> * result)569 Status InitialGradients(
570 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
571 gtl::ArraySlice<int64> target_tensor_ids,
572 const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
573 gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
574 const OpTape<BackwardFunction, TapeTensor>& op_tape,
575 std::unordered_map<int64, std::vector<Gradient*>>* result) {
576 for (int i = 0; i < target_tensor_ids.size(); ++i) {
577 const int64 id = target_tensor_ids[i];
578 if (output_gradients.empty() || output_gradients[i] == nullptr) {
579 auto tensor_it = tensor_tape.find(id);
580 if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
581 auto op_it = op_tape.find(tensor_it->second);
582 if (op_it == op_tape.end()) {
583 return errors::Internal(
584 "Internal state of the gradient tape is invalid: "
585 "failed to find operation producing a tensor");
586 }
587 bool found = false;
588 for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
589 if (op_it->second.output_tensor_info[j].GetID() == id) {
590 found = true;
591 (*result)[id].push_back(
592 op_it->second.output_tensor_info[j].OnesLike());
593 break;
594 }
595 }
596 if (!found) {
597 return errors::Internal(
598 "Internal state of the gradient tape is invalid: "
599 "none of operations outputs match expected tensor");
600 }
601 } else {
602 // This target tensor was not generated by any operation recorded on
603 // the tape, so no gradient needs to be computed from it unless this
604 // target is also a source.
605 auto source_tensor = sources_that_are_targets.find(id);
606 if (source_tensor != sources_that_are_targets.end()) {
607 (*result)[id].push_back(source_tensor->second.OnesLike());
608 }
609 }
610 } else {
611 (*result)[id].push_back(output_gradients[i]);
612 }
613 }
614 return Status::OK();
615 }
616
617 // TODO(agarwal): use an automatic mechanism for handling None arguments to
618 // gradient functions.
619 //
620 // Some gradient functions can accept None arguments for gradients. The
621 // following maps the operation name to the indices at which the corresponding
622 // gradient function can accept None values. e.g. FusedBatchNorm outputs 5
623 // values and hence receives 5 gradient values during backprop. However the
624 // gradient function uses only the first of those values and ignores the rest.
625 // The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
626 // corresponding to index 0 is used, and the gradient values at indices 1-4 are
627 // ignored (and hence can be None). The backprop algorithm can then leverage
628 // this by not constructing zeros to pass for those indices.
629 std::unordered_map<string, std::unordered_set<int>>*
FunctionsAcceptingNoneForIndicesMap()630 FunctionsAcceptingNoneForIndicesMap() {
631 static auto* const m =
632 new std::unordered_map<string, std::unordered_set<int>>({
633 {"SoftmaxCrossEntropyWithLogits", {1}},
634 {"SparseSoftmaxCrossEntropyWithLogits", {1}},
635 {"FusedBatchNorm", {1, 2, 3, 4}},
636 });
637 return m;
638 }
639
640 } // namespace
641
642 // If over kMinAggregateCount gradients are accumulated and the total
643 // memory consumption is over kMinAggregateBytes, do an early aggregation
644 // so as to release the gradient tensor to save memory.
645 constexpr int kMinAggregateCount = 4;
646 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
647
648 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ComputeGradient(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,const gtl::ArraySlice<int64> target_tensor_ids,const gtl::ArraySlice<int64> source_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,std::vector<Gradient * > * result)649 Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
650 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
651 const gtl::ArraySlice<int64> target_tensor_ids,
652 const gtl::ArraySlice<int64> source_tensor_ids,
653 const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
654 gtl::ArraySlice<Gradient*> output_gradients,
655 std::vector<Gradient*>* result) {
656 std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
657 source_tensor_ids.end());
658 BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
659 target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
660 std::vector<int64> op_stack =
661 InitialStack(state.op_tape, state.op_missing_tensor);
662 std::unordered_map<int64, std::vector<Gradient*>> gradients;
663 Status s = InitialGradients(vspace, target_tensor_ids,
664 sources_that_are_targets, output_gradients,
665 tensor_tape_, state.op_tape, &gradients);
666 auto cleanup = gtl::MakeCleanup([this, &state]() {
667 if (!persistent_) {
668 // Release all backprop functions
669 for (const auto& pair : state.op_tape) {
670 pair.second.backward_function_deleter(pair.second.backward_function);
671 }
672 }
673 });
674 if (!s.ok()) {
675 return s;
676 }
677
678 std::unordered_map<int64, int64> gradients_size;
679 // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
680 // time, for better CPU backprop performance.
681 VLOG(1) << "Initial stack:";
682 if (VLOG_IS_ON(1)) {
683 for (auto t : op_stack) {
684 VLOG(1) << " " << t;
685 }
686 }
687 while (!op_stack.empty()) {
688 const int64 op = op_stack.back();
689 VLOG(1) << "Popped " << op;
690 op_stack.pop_back();
691 auto op_it = state.op_tape.find(op);
692 if (op_it == state.op_tape.end()) {
693 // It is possible for ops to end up on the stack if they are unrelated to
694 // the target; we should just skip them.
695 continue;
696 }
697 auto trace = std::move(op_it->second);
698 state.op_tape.erase(op_it);
699 std::vector<Gradient*> out_gradients;
700 out_gradients.reserve(trace.output_tensor_info.size());
701 std::vector<int64> unneeded_gradients;
702 for (int i = 0; i < trace.input_tensor_id.size(); i++) {
703 const auto& in_tensor_id = trace.input_tensor_id[i];
704 if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() &&
705 sources_set.find(in_tensor_id) == sources_set.end()) {
706 unneeded_gradients.push_back(i);
707 }
708 }
709
710 bool any_gradient_nonzero = false;
711 std::vector<int> zero_indices;
712 for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
713 const int64 id = trace.output_tensor_info[i].GetID();
714 auto grad_it = gradients.find(id);
715 if (grad_it == gradients.end()) {
716 auto func_name_it =
717 FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
718 if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
719 func_name_it->second.find(i) != func_name_it->second.end()) {
720 out_gradients.push_back(nullptr);
721 } else {
722 out_gradients.push_back(nullptr);
723 zero_indices.push_back(i);
724 }
725 } else {
726 any_gradient_nonzero = true;
727 Gradient* new_gradients = nullptr;
728 if (grad_it->second.size() == 1) {
729 new_gradients = grad_it->second.at(0);
730 } else {
731 new_gradients = vspace.AggregateGradients(grad_it->second);
732 }
733 if (sources_set.find(grad_it->first) == sources_set.end()) {
734 gradients.erase(grad_it);
735 } else {
736 grad_it->second.clear();
737 grad_it->second.push_back(new_gradients);
738 vspace.MarkAsResult(new_gradients);
739 }
740 out_gradients.push_back(new_gradients);
741 }
742 }
743 std::vector<Gradient*> in_gradients;
744 if (any_gradient_nonzero) {
745 for (const auto i : zero_indices) {
746 out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
747 }
748 Status s;
749 s = vspace.CallBackwardFunction(trace.backward_function,
750 unneeded_gradients, out_gradients,
751 &in_gradients);
752 if (in_gradients.size() != trace.input_tensor_id.size()) {
753 return tensorflow::errors::Internal(
754 "Recorded operation '", trace.op_type,
755 "' returned too few gradients. Expected ",
756 trace.input_tensor_id.size(), " but received ",
757 in_gradients.size());
758 }
759 if (!persistent_) {
760 trace.backward_function_deleter(trace.backward_function);
761 }
762 if (!s.ok()) {
763 return s;
764 }
765 } else {
766 in_gradients.resize(trace.input_tensor_id.size());
767 if (!persistent_) {
768 trace.backward_function_deleter(trace.backward_function);
769 }
770 for (Gradient* grad : out_gradients) {
771 if (grad != nullptr) {
772 vspace.DeleteGradient(grad);
773 }
774 }
775 }
776 VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
777 << trace.input_tensor_id.size() << " sources";
778 for (int i = 0; i < in_gradients.size(); ++i) {
779 const int64 id = trace.input_tensor_id[i];
780 if (in_gradients[i] != nullptr) {
781 auto& unaggregated_grads = gradients[id];
782 unaggregated_grads.push_back(in_gradients[i]);
783 if (unaggregated_grads.size() > kMinAggregateCount) {
784 auto size_it = gradients_size.find(id);
785 int64 size;
786 if (size_it == gradients_size.end()) {
787 size = vspace.NumElements(unaggregated_grads[0]);
788 gradients_size.emplace(id, size);
789 } else {
790 size = size_it->second;
791 }
792 if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
793 Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
794 unaggregated_grads.clear();
795 unaggregated_grads.push_back(grad);
796 }
797 }
798 }
799 auto usage_count_it = state.tensor_usage_counts.find(id);
800 if (usage_count_it == state.tensor_usage_counts.end()) {
801 VLOG(1) << "Tensor " << id << " not used";
802 continue;
803 }
804 usage_count_it->second--;
805 if (usage_count_it->second > 0) {
806 VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
807 continue;
808 }
809 auto tape_it = tensor_tape_.find(id);
810 if (tape_it == tensor_tape_.end()) {
811 VLOG(1) << "Tensor " << id
812 << " has no associated op. Deleting gradient";
813 auto grad_it = gradients.find(id);
814 if (grad_it != gradients.end()) {
815 for (auto g : grad_it->second) {
816 vspace.DeleteGradient(g);
817 }
818 gradients.erase(grad_it);
819 }
820 continue;
821 }
822 const int64 op_id = tape_it->second;
823 if (op_id == -1) {
824 VLOG(1) << "Tensor " << id << " is source";
825 continue;
826 }
827 auto missing_it = state.op_missing_tensor.find(op_id);
828 if (missing_it != state.op_missing_tensor.end()) {
829 missing_it->second--;
830 VLOG(1) << "Op " << op_id << " missing " << missing_it->second
831 << " output gradients";
832 if (missing_it->second == 0) {
833 op_stack.insert(op_stack.begin(), op_id);
834 }
835 }
836 }
837 }
838 if (!state.op_tape.empty()) {
839 return tensorflow::errors::Internal("Invalid tape state.");
840 }
841 result->reserve(source_tensor_ids.size());
842 std::unordered_set<int64> used_gradient_ids(source_tensor_ids.size());
843 for (auto is : source_tensor_ids) {
844 auto grad_it = gradients.find(is);
845 if (grad_it == gradients.end()) {
846 result->push_back(nullptr);
847 } else {
848 if (grad_it->second.size() > 1) {
849 Gradient* grad = vspace.AggregateGradients(grad_it->second);
850 grad_it->second.clear();
851 grad_it->second.push_back(grad);
852 }
853 result->push_back(grad_it->second[0]);
854 used_gradient_ids.insert(is);
855 }
856 }
857 VLOG(1) << "Final gradients size: "
858 << gradients.size() - used_gradient_ids.size();
859 for (auto grad_pair : gradients) {
860 if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
861 for (const auto& g : grad_pair.second) {
862 vspace.DeleteGradient(g);
863 }
864 }
865 }
866 return Status::OK();
867 }
868
869 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)870 bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
871 gtl::ArraySlice<int64> tensor_ids,
872 gtl::ArraySlice<tensorflow::DataType> dtypes) {
873 if (call_state_.top().backward_tape != nullptr) {
874 // If we're forwarding Accumulate calls to backward_tape's RecordOperation,
875 // we should also delegate ShouldRecord.
876 return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
877 }
878 if (call_state_.top().accumulating) {
879 return false;
880 }
881 for (int i = 0; i < tensor_ids.size(); ++i) {
882 if (accumulated_gradients_.find(tensor_ids[i]) !=
883 accumulated_gradients_.end()) {
884 if (IsDtypeTrainable(dtypes[i])) {
885 return true;
886 }
887 }
888 }
889 return false;
890 }
891
892 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
893 Status
ForwardpropFromTape(const std::vector<TapeTensor> & output_tensors,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter,const std::vector<Gradient * > & in_grads,std::vector<Gradient * > * out_grads)894 ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
895 const std::vector<TapeTensor>& output_tensors,
896 const std::function<BackwardFunction*()>& backward_function_getter,
897 const std::function<void(BackwardFunction*)>& backward_function_deleter,
898 const std::vector<Gradient*>& in_grads, std::vector<Gradient*>* out_grads) {
899 /* This function is approximately equivalent to this Python code:
900
901 forwardprop_aids = tf.ones_like(output_tensors)
902 with tf.GradientTape() as g:
903 g.watch(forwardprop_aids)
904 grad = backward_function(forwardprop_aids)
905 forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads)
906 accumulated_gradients_[ID(output_tensors)] = forward_grads
907 */
908 std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
909 new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
910 AccumulatorCallState& call_state = call_state_.top();
911 call_state.backward_tape = tape.get();
912 auto pop_backward_tape =
913 gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
914 std::vector<Gradient*> forwardprop_aids;
915 std::vector<int64> sources;
916 std::unordered_set<int64> sources_set;
917 sources.reserve(output_tensors.size());
918 for (const TapeTensor& output_tensor : output_tensors) {
919 // Ownership of `aid` transferred to CallBackwardFunction below.
920 Gradient* aid;
921 if (output_tensor.GetDType() == tensorflow::DT_VARIANT) {
922 // Note: Needs to be zeros rather than ones since there's currently no
923 // ones_like for variants.
924 aid = output_tensor.ZerosLike();
925 } else {
926 // TODO(allenl): Figure out why using zeros_like everywhere causes issues
927 // for some gradient functions and if there's another way to work around
928 // it (e.g. conds instead of ifs). The value shouldn't really matter.
929 aid = output_tensor.OnesLike();
930 }
931 if (TF_PREDICT_FALSE(aid == nullptr)) {
932 return tensorflow::errors::Internal(
933 "Failed to create ones tensor for tensor ", output_tensor.GetID(),
934 " with dtype ", output_tensor.GetDType());
935 }
936 forwardprop_aids.push_back(aid);
937 int64 aid_id = vspace_.TensorId(aid);
938 sources.push_back(aid_id);
939 sources_set.insert(aid_id);
940 tape->Watch(aid_id);
941 }
942 std::vector<Gradient*> grad;
943 auto delete_grad = gtl::MakeCleanup([&grad, this] {
944 for (Gradient* tensor : grad) {
945 this->vspace_.DeleteGradient(tensor);
946 }
947 });
948 {
949 std::vector<int64> unneeded_gradients;
950 std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
951 backward_function(backward_function_getter(),
952 backward_function_deleter);
953 TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
954 backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
955 }
956
957 // Stop the tape from recording
958 pop_backward_tape.release()();
959
960 if (grad.size() != in_grads.size()) {
961 return tensorflow::errors::Internal("Wrong number of gradients returned.");
962 }
963
964 std::vector<int64> targets;
965 std::vector<Gradient*> used_in_grads;
966 // We may end up with slightly fewer elements than we reserve, but grad.size()
967 // should be a reasonably tight upper bound.
968 targets.reserve(grad.size());
969 used_in_grads.reserve(grad.size());
970 std::unordered_map<int64, TapeTensor> sources_that_are_targets;
971 for (int grad_index = 0; grad_index < grad.size(); ++grad_index) {
972 Gradient* grad_tensor = grad[grad_index];
973 if (grad_tensor != nullptr) {
974 int64 tensor_id = vspace_.TensorId(grad_tensor);
975 targets.push_back(tensor_id);
976 if (sources_set.find(tensor_id) != sources_set.end()) {
977 sources_that_are_targets.emplace(
978 tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
979 }
980 Gradient* in_grad = in_grads[grad_index];
981 if (in_grad != nullptr) {
982 // ComputeGradient steals a reference
983 vspace_.MarkAsResult(in_grad);
984 }
985 used_in_grads.push_back(in_grad);
986 }
987 }
988
989 return tape->ComputeGradient(vspace_, targets, sources,
990 sources_that_are_targets, used_in_grads,
991 out_grads);
992 }
993
994 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Accumulate(const string & op_type,const std::vector<TapeTensor> & input_tensors,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const ForwardFunction<Gradient> * forward_function,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)995 Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
996 const string& op_type, const std::vector<TapeTensor>& input_tensors,
997 const std::vector<TapeTensor>& output_tensors,
998 gtl::ArraySlice<int64> input_tensor_id,
999 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
1000 const ForwardFunction<Gradient>* forward_function,
1001 const std::function<BackwardFunction*()>& backward_function_getter,
1002 const std::function<void(BackwardFunction*)>& backward_function_deleter) {
1003 if (call_state_.top().backward_tape != nullptr) {
1004 // If backward_tape is not null, then this call to Accumulate is the result
1005 // of a still-active call to Accumulate which is running operations. We
1006 // forward these operations to backward_tape so the outer Accumulate call
1007 // can do its work.
1008 //
1009 // Rather than re-entering and delegating Accumulate like this, we could
1010 // instead allow ForwardAccumulator some control over the current tape set
1011 // (so it can deactivate itself and activate its GradientTape). Currently
1012 // that is managed by the language binding and would require relatively
1013 // messy callbacks.
1014 call_state_.top().backward_tape->RecordOperation(
1015 op_type, output_tensors, input_tensor_id, input_dtypes,
1016 backward_function_getter, backward_function_deleter);
1017 return Status::OK();
1018 }
1019 if (!ShouldRecord(input_tensor_id, input_dtypes)) {
1020 return Status::OK();
1021 }
1022
1023 // We may need to allocate zero inputs for trainable dtypes we don't have JVPs
1024 // for. Make sure they get cleaned up.
1025 std::vector<Gradient*> new_zeros;
1026 auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
1027 for (Gradient* tensor : new_zeros) {
1028 this->vspace_.DeleteGradient(tensor);
1029 }
1030 });
1031 std::vector<Gradient*> in_grads;
1032 in_grads.reserve(input_tensors.size());
1033 for (int target_index = 0; target_index < input_tensors.size();
1034 ++target_index) {
1035 const auto current_grad =
1036 accumulated_gradients_.find(input_tensors[target_index].GetID());
1037 if (current_grad == accumulated_gradients_.end()) {
1038 if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
1039 // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
1040 // GradientTape which uses ones.
1041 Gradient* zero = input_tensors[target_index].ZerosLike();
1042 new_zeros.push_back(zero);
1043 in_grads.push_back(zero);
1044 } else {
1045 in_grads.push_back(nullptr);
1046 }
1047 } else {
1048 in_grads.push_back(current_grad->second);
1049 }
1050 }
1051
1052 // Avoid infinite recursion. Whichever forward function we run, it'll end up
1053 // executing ops, and we don't want to watch those with this accumulator.
1054 call_state_.emplace(nullptr, true);
1055 auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
1056
1057 std::vector<Gradient*> forward_grads;
1058 if (forward_function == nullptr) {
1059 // We have no special-cased forward gradient. Fall back to running the
1060 // backward function under a gradient tape.
1061 TF_RETURN_IF_ERROR(ForwardpropFromTape(
1062 output_tensors, backward_function_getter, backward_function_deleter,
1063 in_grads, &forward_grads));
1064 } else {
1065 TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
1066 }
1067 for (int i = 0; i < forward_grads.size(); ++i) {
1068 if (forward_grads[i] != nullptr) {
1069 int64 tensor_id = output_tensors[i].GetID();
1070 auto existing = accumulated_gradients_.find(tensor_id);
1071 if (existing != accumulated_gradients_.end()) {
1072 // This is a somewhat odd case to be in, since it means we have two
1073 // operations which supposedly both created the same Tensor. It comes up
1074 // in recompute_grad, where the gradients have the same value. However,
1075 // only the original gradient is connected to everything else, so we
1076 // should still use that.
1077 vspace_.DeleteGradient(forward_grads[i]);
1078 } else {
1079 accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
1080 }
1081 }
1082 }
1083 return Status::OK();
1084 }
1085
1086 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id,Gradient * tangent)1087 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
1088 int64 tensor_id, Gradient* tangent) {
1089 typename std::unordered_map<int64, Gradient*>::iterator existing =
1090 accumulated_gradients_.find(tensor_id);
1091 vspace_.MarkAsResult(tangent);
1092 if (existing == accumulated_gradients_.end()) {
1093 accumulated_gradients_.emplace(tensor_id, tangent);
1094 } else {
1095 std::array<Gradient*, 2> to_aggregate;
1096 to_aggregate[0] = tangent;
1097 to_aggregate[1] = existing->second;
1098 // AggregateGradients steals a reference to each of its arguments. We
1099 // MarkAsResult on `tangent` above so we don't steal a reference to it.
1100 existing->second = vspace_.AggregateGradients(to_aggregate);
1101 }
1102 }
1103
1104 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteGradient(int64 tensor_id)1105 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
1106 int64 tensor_id) {
1107 auto existing = accumulated_gradients_.find(tensor_id);
1108 if (existing != accumulated_gradients_.end()) {
1109 vspace_.DeleteGradient(existing->second);
1110 accumulated_gradients_.erase(existing);
1111 }
1112 }
1113
1114 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
FetchJVP(int64 tensor_id)1115 Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP(
1116 int64 tensor_id) {
1117 auto lookup = accumulated_gradients_.find(tensor_id);
1118 if (lookup == accumulated_gradients_.end()) {
1119 return nullptr;
1120 } else {
1121 return lookup->second;
1122 }
1123 }
1124
1125 } // namespace eager
1126 } // namespace tensorflow
1127
1128 #endif // TENSORFLOW_C_EAGER_TAPE_H_
1129