1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_TAPE_H_
16 #define TENSORFLOW_C_EAGER_TAPE_H_
17
18 // Language-agnostic gradient tape. Does not perform backpropagation, just
19 // maintains the data structures required to do so.
20
21 #include <stack>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36 namespace eager {
37
38 // Represents an entry in the tape.
39 template <typename BackwardFunction, typename TapeTensor>
40 struct OpTapeEntry {
41 string op_type;
42 std::vector<TapeTensor> output_tensor_info;
43 std::vector<int64> input_tensor_id;
44
45 // TODO(apassos) consider narrowing down this interface.
46 BackwardFunction* backward_function;
47
48 // Should be called before deleting the backward function. TODO(apassos) use
49 // unique_ptrs to ensure this happens.
50 std::function<void(BackwardFunction*)> backward_function_deleter;
51 };
52
53 // Map from tensor_id to internally-defined operation-id of the operation which
54 // produced this tensor. A value of -1 means that the tensor was directly
55 // watched and not the result of any operation in the tape.
56 using TensorTape = std::unordered_map<int64, int64>;
57
58 // Map from operation-id to tape entry.
59 template <typename BackwardFunction, typename TapeTensor>
60 using OpTape =
61 std::unordered_map<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
62
63 // Operations the tape needs to perform on tensors to do backpropagation. Named
64 // "vspace" because a subset of these are related to a vector space, such as
65 // adding gradients, getting zeroes, etc. Currently cannot be implemented
66 // without using tensorflow python code, hence left unspecified here.
67 //
68 // Gradient is the type returned by gradient functions. In Python TF it's either
69 // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
70 // to allow their size to be computed and they need to be passable to a backward
71 // function and deleted (as the backprop code creates lots of gradients the user
72 // is not interested in).
73 //
74 // BackwardFunction needs to be a closure which stores intermediate activations
75 // from the forward computation and calls a vector-jacobian product function
76 // (also known as adjoint function) to compute, given downstream gradients,
77 // upstream gradients.
78 //
79 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
80 // specialization, which is blocked by quite a few things needing to loop back
81 // into python now.
82 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
83 class VSpace {
84 public:
~VSpace()85 virtual ~VSpace() {}
86
87 // Returns the number of elements in the gradient tensor.
88 virtual int64 NumElements(Gradient* tensor) const = 0;
89
90 // Consumes references to the tensors in the gradient_tensors list and returns
91 // a tensor with the result.
92 virtual Gradient* AggregateGradients(
93 gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
94
95 // Calls the passed-in backward function.
96 //
97 // `unneeded_gradients` contains sorted list of input indices for which a
98 // gradient is not required.
99 virtual Status CallBackwardFunction(
100 const string& op_type, BackwardFunction* backward_function,
101 const std::vector<int64>& unneeded_gradients,
102 gtl::ArraySlice<Gradient*> output_gradients,
103 absl::Span<Gradient*> result) const = 0;
104
105 // Builds a tensor filled with ones with the same shape and dtype as `t`.
106 virtual Status BuildOnesLike(const TapeTensor& t,
107 Gradient** result) const = 0;
108
109 // Looks up the ID of a Gradient.
110 virtual int64 TensorId(Gradient* tensor) const = 0;
111
112 // Converts a Gradient to a TapeTensor.
113 virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0;
114
115 // Marks the following gradient as a result so it's not consumed by backward
116 // functions.
117 virtual void MarkAsResult(Gradient* gradient) const = 0;
118
119 // Deletes the input tensor.
120 virtual void DeleteGradient(Gradient* gradient) const = 0;
121 };
122
123 // Traces the execution of operations, doing eager garbage collection, and
124 // exporting a full trace so other code can do backpropagation. Not thread-safe.
125 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
126 class GradientTape {
127 public:
128 // If `persistent` is true, GradientTape will not eagerly delete backward
129 // functions (and hence the tensors they keep alive). Instead, everything
130 // is deleted in ~GradientTape. Persistent GradientTapes are useful when
131 // users want to compute multiple gradients over the same tape.
GradientTape(bool persistent)132 explicit GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape()133 ~GradientTape() {
134 for (const auto& pair : op_tape_) {
135 pair.second.backward_function_deleter(pair.second.backward_function);
136 }
137 }
138
139 // Returns whether any tensor in a list of tensors is being watched and has
140 // a trainable dtype.
141 bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
142 gtl::ArraySlice<tensorflow::DataType> dtypes) const;
143
144 // Adds this tensor to the list of watched tensors.
145 //
146 // This is a no-op if the tensor is already being watched either from an
147 // earlier call to `GradientTape::Watch` or being an output of an op with
148 // watched inputs.
149 void Watch(int64 tensor_id);
150
151 // Records an operation with inputs `input_tensor_id` and outputs
152 // `output_tensors` on the tape and marks all its outputs as watched if at
153 // least one input of the op is watched and has trainable dtype.
154 //
155 // op_type is used to decide which of the incoming gradients can be left as
156 // nullptr instead of building zeros when build_default_zeros_grads == true.
157 void RecordOperation(
158 const string& op_type, const std::vector<TapeTensor>& output_tensors,
159 gtl::ArraySlice<int64> input_tensor_id,
160 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
161 const std::function<BackwardFunction*()>& backward_function_getter,
162 const std::function<void(BackwardFunction*)>& backward_function_deleter);
163
164 void DeleteTrace(int64 tensor_id);
165
166 // Consumes the internal state of the tape (so cannot be called more than
167 // once) and produces the gradient of the target tensors with respect to the
168 // source tensors. The output gradients are used if not empty and not
169 // null. The result is populated with one tensor per target element.
170 // When running backward functions, builds zeros-like tensors for
171 // incoming grads which are nullptrs, unless `build_default_zeros_grads`
172 // is set to false.
173 Status ComputeGradient(
174 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
175 const gtl::ArraySlice<int64> target_tensor_ids,
176 const gtl::ArraySlice<int64> source_tensor_ids,
177 const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
178 gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
179 bool build_default_zeros_grads = true);
180
181 // Whether the tape is persistent. See ctor for detailed description.
IsPersistent()182 bool IsPersistent() const { return persistent_; }
183
184 private:
185 TensorTape tensor_tape_;
186 OpTape<BackwardFunction, TapeTensor> op_tape_;
187 int64 next_op_id_{0};
188
189 // Map from tensor id to number of remaining usages (i.e. how many entries in
190 // the tape refer to it); to aid in tape garbage collection.
191 std::unordered_map<int64, int64> tensor_usage_;
192
193 // If false, all activations are deleted in the first call to ComputeGradient.
194 // Else, only when this is destructed.
195 bool persistent_;
196 };
197
198 // Describes a callback for special-cased and more efficient jvp computation.
199 //
200 // Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on
201 // that.
202 template <typename Gradient>
203 class ForwardFunction
204 : public std::function<Status(const std::vector<Gradient*>&,
205 std::vector<Gradient*>*, bool)> {
206 public:
207 template <typename lambda_type>
ForwardFunction(lambda_type lambda)208 explicit ForwardFunction(lambda_type lambda)
209 : std::function<Status(const std::vector<Gradient*>&,
210 std::vector<Gradient*>*, bool)>(lambda) {}
211 };
212
213 // Computes Jacobian-vector products using forward-mode automatic
214 // differentiation.
215 //
216 // While GradientTape's RecordOperation is trivial, ForwardAccumulator's
217 // Accumulate runs the gradient computation immediately.
218 //
219 // Keeps references to Tensors watched via Watch and computed in Accumulate
220 // corresponding to output_tensors, and releases these references in its
221 // destructor. However, waiting until the destructor runs loses the memory
222 // efficiency of forward-mode autodiff. Instead, language bindings should call
223 // DeleteGradient as soon as a Tensor which was `Watch`ed or was an output
224 // Tensor passed to Accumulate goes out of scope.
225 //
226 // Not thread-safe.
227 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
228 class ForwardAccumulator {
229 public:
230 // Does not take ownership of `vspace`, which must outlive the
231 // ForwardAccumulator.
ForwardAccumulator(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,bool use_batch)232 explicit ForwardAccumulator(
233 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
234 bool use_batch)
235 : vspace_(vspace), use_batch_(use_batch) {
236 call_state_.emplace(nullptr, false);
237 }
238
~ForwardAccumulator()239 virtual ~ForwardAccumulator() {
240 for (auto accumulated : accumulated_gradients_) {
241 vspace_.DeleteGradient(accumulated.second);
242 }
243 }
244
245 // Tell the forward accumulator to watch tensor_id, with a Tensor tangent
246 // vector `tangent` of matching shape and dtype. Tangents are the "vector" in
247 // "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling
248 // FetchJVP for it would return `tangent`.
249 void Watch(int64 tensor_id, Gradient* tangent);
250
251 // Removes the gradient associated with tensor_id. Should be called when the
252 // Tensor associated with `tensor_id` is deleted.
253 void DeleteGradient(int64 tensor_id);
254
255 // Runs forward autodiff. Should be called whenever a new operation is
256 // available and the accumulator is active.
257 //
258 // Like GradientTape::RecordOperation, this method takes the operation type
259 // `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`,
260 // `input_tensor_id`, and `input_dtypes`; the latter two are somewhat
261 // redundant but taken as arguments to avoid repeatedly fetching these values
262 // between calls to ShouldRecord and Accumulator), and its outputs
263 // (`output_tensors`).
264 //
265 // If provided, a non-null `forward_function` will be used instead of the
266 // backward function (`backward_function_getter` /
267 // `backward_function_deleter`) to compute jvps for this operation. If
268 // `forward_function` is null, a GradientTape is used on the backward function
269 // to compute the jvp, which will waste computation when executing eagerly.
270 //
271 // Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
272 // immediately. It stores the results, which feed into Accumulate for future
273 // operations and may be fetched by calling FetchJVP. ForwardAccumulator
274 // maintains a reference to these JVPs: if an `output_tensors` Tensor is
275 // deleted, `DeleteGradient` should be called as soon as possible to free the
276 // (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor
277 // will release remaining references.
278 //
279 // This method is not thread-safe (and in general ForwardAccumulator is not
280 // thread-safe).
281 Status Accumulate(
282 const string& op_type, const std::vector<TapeTensor>& input_tensors,
283 const std::vector<TapeTensor>& output_tensors,
284 gtl::ArraySlice<int64> input_tensor_id,
285 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
286 const ForwardFunction<Gradient>* forward_function,
287 const std::function<BackwardFunction*()>& backward_function_getter,
288 const std::function<void(BackwardFunction*)>& backward_function_deleter);
289
290 // Returns true if `Accumulate` is active somewhere above on the stack and
291 // there isn't an intervening PushState. This is useful for ordering
292 // ForwardAccumulators, where more deeply nested accumulators should not see
293 // computations from less deeply nested accumulators.
BusyAccumulating()294 bool BusyAccumulating() const { return call_state_.top().accumulating; }
295
296 // Fetches the current Jacobian-vector product associated with `tensor_id`, or
297 // a nullptr if none is available.
298 //
299 // Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its
300 // return value. The caller should increment the reference count before
301 // deleting the ForwardAccumulator or calling DeleteGradient if keeping a
302 // persistent reference to a non-null result.
303 Gradient* FetchJVP(int64 tensor_id);
304
305 // Indicates whether the forward accumulator should run on an operation with
306 // the specified inputs and dtypes.
307 bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
308 gtl::ArraySlice<tensorflow::DataType> dtypes);
309
310 // Temporarily push or pop transient state for this accumulator.
311 //
312 // Allows an accumulator which is currently processing an operation to
313 // temporarily reset its state. Without pushing and popping, accumulators
314 // ignore operations executed as a direct result of their own jvp
315 // computations.
PushState()316 void PushState() { call_state_.emplace(nullptr, false); }
PopState()317 void PopState() { call_state_.pop(); }
318
319 private:
320 // Helper for Accumulate: uses a GradientTape to compute forward gradients
321 // from a backward gradient function. Fills `out_grads` corresponding to
322 // `output_tensors`. `out_grads` must not be null.
323 //
324 // Executes the backward function in order to trace its gradient, which will
325 // waste computation if executing eagerly (when graph building the unneeded
326 // computation is pruned). Temporarily sets `backward_tape` so that
327 // Accumulate will forward op executions to the tape while the backward
328 // function is running; this effectively adds the backward tape to the active
329 // set (but does not require complicated callbacks to the language bindings).
330 Status ForwardpropFromTape(
331 const string& op_type, const std::vector<TapeTensor>& output_tensors,
332 const std::function<BackwardFunction*()>& backward_function_getter,
333 const std::function<void(BackwardFunction*)>& backward_function_deleter,
334 const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads);
335
336 // Maps from tensor IDs to corresponding JVPs.
337 std::unordered_map<int64, Gradient*> accumulated_gradients_;
338 // Not owned; provides operations on Tensors which are currently only
339 // available in language bindings (e.g. Python).
340 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace_;
341
342 // Decides if tangents are vectorized or not
343 bool use_batch_;
344
345 struct AccumulatorCallState {
AccumulatorCallStateAccumulatorCallState346 AccumulatorCallState(
347 GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape,
348 bool accumulating)
349 : backward_tape(backward_tape), accumulating(accumulating) {}
350 // Set temporarily while in the Accumulate method; if backward_tape is not
351 // nullptr then we forward op executions to it so Accumulate can compute a
352 // backward pass on its backward function.
353 //
354 // Not owned by the ForwardAccumulator. The method which sets
355 // `backward_tape` keeps ownership.
356 GradientTape<Gradient, BackwardFunction, TapeTensor>* backward_tape;
357 // While the Accumulate method is running (accumulating is True), any op
358 // executions not forwarded to backward_tape should be ignored.
359 bool accumulating;
360 };
361 // A deque-backed stack, whose element references are not invalidated by
362 // pushes and pops at the back.
363 std::stack<AccumulatorCallState> call_state_;
364 };
365
366 // Template instantiations here
367
IsDtypeTrainable(DataType dtype)368 inline bool IsDtypeTrainable(DataType dtype) {
369 switch (dtype) {
370 case DT_HALF:
371 case DT_BFLOAT16:
372 case DT_FLOAT:
373 case DT_DOUBLE:
374 case DT_COMPLEX64:
375 case DT_COMPLEX128:
376 case DT_RESOURCE:
377 case DT_VARIANT:
378 return true;
379 default:
380 return false;
381 }
382 }
383
384 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)385 bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
386 gtl::ArraySlice<int64> tensor_ids,
387 gtl::ArraySlice<tensorflow::DataType> dtypes) const {
388 CHECK_EQ(tensor_ids.size(), dtypes.size());
389 for (int i = 0; i < tensor_ids.size(); ++i) {
390 if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
391 if (IsDtypeTrainable(dtypes[i])) {
392 return true;
393 }
394 }
395 }
396 return false;
397 }
398
399 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id)400 void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
401 int64 tensor_id) {
402 tensor_tape_.emplace(tensor_id, -1);
403 }
404
405 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
RecordOperation(const string & op_type,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)406 void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
407 const string& op_type, const std::vector<TapeTensor>& output_tensors,
408 gtl::ArraySlice<int64> input_tensor_id,
409 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
410 const std::function<BackwardFunction*()>& backward_function_getter,
411 const std::function<void(BackwardFunction*)>& backward_function_deleter) {
412 if (!ShouldRecord(input_tensor_id, input_dtypes)) {
413 return;
414 }
415 std::vector<int64> ids;
416 ids.reserve(input_tensor_id.size());
417 for (int64 i : input_tensor_id) {
418 tensor_usage_[i]++;
419 ids.push_back(i);
420 }
421 const int64 op_id = next_op_id_++;
422 std::vector<TapeTensor> tensors;
423 tensors.reserve(output_tensors.size());
424 for (const TapeTensor& o : output_tensors) {
425 // Note: the tensor can have already been watched and hence be in the tape,
426 // so we cannot check that we're inserting it here.
427 tensor_tape_[o.GetID()] = op_id;
428 tensor_usage_[o.GetID()] = 1;
429 tensors.push_back(o);
430 }
431 op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
432 op_type, std::move(tensors), std::move(ids), backward_function_getter(),
433 backward_function_deleter};
434 }
435
436 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteTrace(int64 tensor_id)437 void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
438 int64 tensor_id) {
439 auto it = tensor_usage_.find(tensor_id);
440 if (it == tensor_usage_.end()) {
441 return;
442 }
443 it->second--;
444 if (it->second != 0) {
445 return;
446 }
447 tensor_usage_.erase(it);
448 auto tensor_op_it = tensor_tape_.find(tensor_id);
449 if (tensor_op_it == tensor_tape_.end()) {
450 return;
451 }
452 const int64 op_id = tensor_op_it->second;
453 if (op_id == -1) {
454 // Do not delete watched tensors.
455 return;
456 }
457 tensor_tape_.erase(tensor_op_it);
458 auto op_it = op_tape_.find(op_id);
459 CHECK(op_it != op_tape_.end());
460 for (const auto& output : op_it->second.output_tensor_info) {
461 if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
462 // Found a usage for an output, so cannot delete the op.
463 return;
464 }
465 }
466 for (int64 id : op_it->second.input_tensor_id) {
467 DeleteTrace(id);
468 }
469 op_it->second.backward_function_deleter(op_it->second.backward_function);
470 op_tape_.erase(op_it);
471 }
472
473 // Terminology:
474 //
475 // - op: a possibly composite operation, which has an entry in the tape
476 // - target: dy in dx/dy
477 // - source: dx in dx/dy
478 // - tensor: one of the many inputs or outputs of an operation
479 //
480 // Below here we do the gradient algorithm. It works as follows:
481 //
482 // First we filter the tape to just the subset of operations we want to
483 // differentiate. In the process of doing so we count how many times each Tensor
484 // is used as an input to an op (so we know when we're done computing gradients
485 // for that Tensor). We also count, for each tape entry, how many of its output
486 // Tensors need gradients to be computed (Tensors which are not used do not need
487 // any gradients to be computed).
488 //
489 // Finally, we start a backprop stack with a set of tape entries for which we
490 // have all gradients available. This set usually is a subset of the set of
491 // targets (not all since targets which have outputs in the tape will not have
492 // gradients available initially).
493 //
494 // Then we repeatedly pop an entry from the stack, run its backprop, and update
495 // the gradients of its inputs. Once we have computed all gradients for a single
496 // input we can mark this input as done, and this can trigger adding an entry to
497 // the stack if all outputs of that entry are now done.
498 //
499 // When the stack is empty we have gradients for all tensors we're interested
500 // in.
501
502 namespace {
503
504 template <typename BackwardFunction, typename TapeTensor>
505 struct BackpropInitialState {
506 OpTape<BackwardFunction, TapeTensor> op_tape;
507
508 // Map from tensor ID to how many references still exist for this tensor in
509 // the tape.
510 std::unordered_map<int64, int64> tensor_usage_counts;
511
512 // Maps from op ID to how many output tensors of this op still need to have
513 // their gradients computed.
514 std::unordered_map<int64, int64> op_missing_tensor;
515 };
516
517 // If `persistent_tape` is true, op_tape is not changed and none of the
518 // backwards functions are deleted.
519 // If `persistent_tape` is false, op_tape is cleared and backwards functions
520 // not needed for gradient computation are deleted. Backwards functions that
521 // are needed, are copied and returned in BackpropInitialState.
522 template <typename BackwardFunction, typename TapeTensor>
PrepareBackprop(gtl::ArraySlice<int64> target,const TensorTape & tensor_tape,OpTape<BackwardFunction,TapeTensor> * op_tape,const std::unordered_set<int64> & sources_set,bool persistent_tape)523 BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
524 gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
525 OpTape<BackwardFunction, TapeTensor>* op_tape,
526 const std::unordered_set<int64>& sources_set, bool persistent_tape) {
527 std::vector<int64> tensor_stack;
528 tensor_stack.reserve(target.size());
529 for (auto t : target) {
530 tensor_stack.push_back(t);
531 }
532 BackpropInitialState<BackwardFunction, TapeTensor> result;
533 while (!tensor_stack.empty()) {
534 int64 tensor_id = tensor_stack.back();
535 tensor_stack.pop_back();
536 auto op_id_it = tensor_tape.find(tensor_id);
537 if (op_id_it == tensor_tape.end()) {
538 continue;
539 }
540 int64 op_id = op_id_it->second;
541 auto op_it = op_tape->find(op_id);
542 auto result_op_it = result.op_tape.find(op_id);
543 if (op_id == -1 || op_it == op_tape->end() ||
544 result_op_it != result.op_tape.end()) {
545 continue;
546 }
547 CHECK(result.op_tape.emplace(op_id, op_it->second).second);
548 for (auto it : op_it->second.input_tensor_id) {
549 auto count_it = result.tensor_usage_counts.find(it);
550 if (count_it != result.tensor_usage_counts.end()) {
551 count_it->second++;
552 } else {
553 result.tensor_usage_counts[it] = 1;
554 if (tensor_tape.find(it) != tensor_tape.end()) {
555 tensor_stack.push_back(it);
556 }
557 }
558 }
559 if (!persistent_tape) {
560 op_tape->erase(op_it);
561 }
562 }
563 for (auto& pair : result.tensor_usage_counts) {
564 auto it = tensor_tape.find(pair.first);
565 if (it != tensor_tape.end() && it->second != -1) {
566 result.op_missing_tensor[it->second] += 1;
567 }
568 }
569 if (!persistent_tape) {
570 // Call destructors for all unneeded gradient functions and
571 // clear the op_tape. We can clear the tape because ownership of
572 // backward functions that will be used for gradient computation
573 // has been transferred to `result`.
574 for (const auto& op_pair : *op_tape) {
575 op_pair.second.backward_function_deleter(
576 op_pair.second.backward_function);
577 }
578 op_tape->clear();
579 }
580 return result;
581 }
582
583 template <typename BackwardFunction, typename TapeTensor>
InitialStack(const OpTape<BackwardFunction,TapeTensor> & op_tape,const std::unordered_map<int64,int64> & op_missing_tensor)584 std::vector<int64> InitialStack(
585 const OpTape<BackwardFunction, TapeTensor>& op_tape,
586 const std::unordered_map<int64, int64>& op_missing_tensor) {
587 std::vector<int64> result;
588 for (auto& op_entry : op_tape) {
589 if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
590 result.push_back(op_entry.first);
591 }
592 }
593 return result;
594 }
595
596 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
InitialGradients(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,gtl::ArraySlice<int64> target_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,const TensorTape & tensor_tape,const OpTape<BackwardFunction,TapeTensor> & op_tape,std::unordered_map<int64,std::vector<Gradient * >> * result)597 Status InitialGradients(
598 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
599 gtl::ArraySlice<int64> target_tensor_ids,
600 const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
601 gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
602 const OpTape<BackwardFunction, TapeTensor>& op_tape,
603 std::unordered_map<int64, std::vector<Gradient*>>* result) {
604 for (int i = 0, end = target_tensor_ids.size(); i < end; ++i) {
605 const int64 id = target_tensor_ids[i];
606 if (output_gradients.empty() || output_gradients[i] == nullptr) {
607 auto tensor_it = tensor_tape.find(id);
608 if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
609 auto op_it = op_tape.find(tensor_it->second);
610 if (op_it == op_tape.end()) {
611 return errors::Internal(
612 "Internal state of the gradient tape is invalid: "
613 "failed to find operation producing a tensor");
614 }
615 bool found = false;
616 for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
617 if (op_it->second.output_tensor_info[j].GetID() == id) {
618 found = true;
619 Gradient* ones_like = nullptr;
620 TF_RETURN_IF_ERROR(vspace.BuildOnesLike(
621 op_it->second.output_tensor_info[j], &ones_like));
622 (*result)[id].push_back(ones_like);
623 break;
624 }
625 }
626 if (!found) {
627 return errors::Internal(
628 "Internal state of the gradient tape is invalid: "
629 "none of operations outputs match expected tensor");
630 }
631 } else {
632 // This target tensor was not generated by any operation recorded on
633 // the tape, so no gradient needs to be computed from it unless this
634 // target is also a source.
635 auto source_tensor = sources_that_are_targets.find(id);
636 if (source_tensor != sources_that_are_targets.end()) {
637 Gradient* ones_like = nullptr;
638 TF_RETURN_IF_ERROR(
639 vspace.BuildOnesLike(source_tensor->second, &ones_like));
640 (*result)[id].push_back(ones_like);
641 }
642 }
643 } else {
644 (*result)[id].push_back(output_gradients[i]);
645 }
646 }
647 return Status::OK();
648 }
649
650 // TODO(agarwal): use an automatic mechanism for handling None arguments to
651 // gradient functions.
652 //
653 // Some gradient functions can accept None arguments for gradients. The
654 // following maps the operation name to the indices at which the corresponding
655 // gradient function can accept None values. e.g. FusedBatchNorm outputs 5
656 // values and hence receives 5 gradient values during backprop. However the
657 // gradient function uses only the first of those values and ignores the rest.
658 // The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
659 // corresponding to index 0 is used, and the gradient values at indices 1-4 are
660 // ignored (and hence can be None). The backprop algorithm can then leverage
661 // this by not constructing zeros to pass for those indices.
662 std::unordered_map<string, std::unordered_set<int>>*
FunctionsAcceptingNoneForIndicesMap()663 FunctionsAcceptingNoneForIndicesMap() {
664 static auto* const m =
665 new std::unordered_map<string, std::unordered_set<int>>({
666 {"SoftmaxCrossEntropyWithLogits", {1}},
667 {"SparseSoftmaxCrossEntropyWithLogits", {1}},
668 {"FusedBatchNorm", {1, 2, 3, 4}},
669 });
670 return m;
671 }
672
673 } // namespace
674
675 // If over kMinAggregateCount gradients are accumulated and the total
676 // memory consumption is over kMinAggregateBytes, do an early aggregation
677 // so as to release the gradient tensor to save memory.
678 constexpr int kMinAggregateCount = 4;
679 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
680
681 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ComputeGradient(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,const gtl::ArraySlice<int64> target_tensor_ids,const gtl::ArraySlice<int64> source_tensor_ids,const std::unordered_map<int64,TapeTensor> & sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,absl::Span<Gradient * > result,bool build_default_zeros_grads)682 Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
683 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
684 const gtl::ArraySlice<int64> target_tensor_ids,
685 const gtl::ArraySlice<int64> source_tensor_ids,
686 const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
687 gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
688 bool build_default_zeros_grads) {
689 std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
690 source_tensor_ids.end());
691 BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
692 target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
693 std::vector<int64> op_stack =
694 InitialStack(state.op_tape, state.op_missing_tensor);
695 std::unordered_map<int64, std::vector<Gradient*>> gradients;
696 Status s = InitialGradients(vspace, target_tensor_ids,
697 sources_that_are_targets, output_gradients,
698 tensor_tape_, state.op_tape, &gradients);
699 auto cleanup = gtl::MakeCleanup([this, &state]() {
700 if (!persistent_) {
701 // Release all backprop functions
702 for (const auto& pair : state.op_tape) {
703 pair.second.backward_function_deleter(pair.second.backward_function);
704 }
705 }
706 });
707 if (!s.ok()) {
708 return s;
709 }
710
711 std::unordered_map<int64, int64> gradients_size;
712 // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
713 // time, for better CPU backprop performance.
714 VLOG(1) << "Initial stack:";
715 if (VLOG_IS_ON(1)) {
716 for (auto t : op_stack) {
717 VLOG(1) << " " << t;
718 }
719 }
720 while (!op_stack.empty()) {
721 const int64 op = op_stack.back();
722 VLOG(1) << "Popped " << op;
723 op_stack.pop_back();
724 auto op_it = state.op_tape.find(op);
725 if (op_it == state.op_tape.end()) {
726 // It is possible for ops to end up on the stack if they are unrelated to
727 // the target; we should just skip them.
728 continue;
729 }
730 auto trace = std::move(op_it->second);
731 state.op_tape.erase(op_it);
732 std::vector<Gradient*> out_gradients;
733 out_gradients.reserve(trace.output_tensor_info.size());
734 std::vector<int64> unneeded_gradients;
735 for (int i = 0, end = trace.input_tensor_id.size(); i < end; i++) {
736 const auto& in_tensor_id = trace.input_tensor_id[i];
737 if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() &&
738 sources_set.find(in_tensor_id) == sources_set.end()) {
739 unneeded_gradients.push_back(i);
740 }
741 }
742
743 bool any_gradient_nonzero = false;
744 std::vector<int> zero_indices;
745 for (int i = 0, end = trace.output_tensor_info.size(); i < end; ++i) {
746 const int64 id = trace.output_tensor_info[i].GetID();
747 auto grad_it = gradients.find(id);
748 if (grad_it == gradients.end()) {
749 out_gradients.push_back(nullptr);
750 if (build_default_zeros_grads) {
751 auto func_name_it =
752 FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
753 if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
754 func_name_it->second.find(i) == func_name_it->second.end()) {
755 zero_indices.push_back(i);
756 }
757 }
758 } else {
759 any_gradient_nonzero = true;
760 Gradient* new_gradients = nullptr;
761 if (grad_it->second.size() == 1) {
762 new_gradients = grad_it->second.at(0);
763 } else {
764 new_gradients = vspace.AggregateGradients(grad_it->second);
765 }
766 if (sources_set.find(grad_it->first) == sources_set.end()) {
767 gradients.erase(grad_it);
768 } else {
769 grad_it->second.clear();
770 grad_it->second.push_back(new_gradients);
771 vspace.MarkAsResult(new_gradients);
772 }
773 out_gradients.push_back(new_gradients);
774 }
775 }
776 VLOG(1) << "Calling gradient function for '" << trace.op_type << "'";
777 std::vector<Gradient*> in_gradients(trace.input_tensor_id.size());
778 DCHECK(build_default_zeros_grads || zero_indices.empty());
779 if (any_gradient_nonzero) {
780 for (const auto i : zero_indices) {
781 out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
782 }
783 Status s;
784 s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function,
785 unneeded_gradients, out_gradients,
786 absl::MakeSpan(in_gradients));
787 if (!persistent_) {
788 trace.backward_function_deleter(trace.backward_function);
789 }
790 if (!s.ok()) {
791 return s;
792 }
793 } else {
794 if (!persistent_) {
795 trace.backward_function_deleter(trace.backward_function);
796 }
797 for (Gradient* grad : out_gradients) {
798 if (grad != nullptr) {
799 vspace.DeleteGradient(grad);
800 }
801 }
802 }
803 for (int i = 0, end = in_gradients.size(); i < end; ++i) {
804 const int64 id = trace.input_tensor_id[i];
805 if (in_gradients[i] != nullptr) {
806 auto& unaggregated_grads = gradients[id];
807 unaggregated_grads.push_back(in_gradients[i]);
808 if (unaggregated_grads.size() > kMinAggregateCount) {
809 auto size_it = gradients_size.find(id);
810 int64 size;
811 if (size_it == gradients_size.end()) {
812 size = vspace.NumElements(unaggregated_grads[0]);
813 gradients_size.emplace(id, size);
814 } else {
815 size = size_it->second;
816 }
817 if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
818 Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
819 unaggregated_grads.clear();
820 unaggregated_grads.push_back(grad);
821 }
822 }
823 }
824 auto usage_count_it = state.tensor_usage_counts.find(id);
825 if (usage_count_it == state.tensor_usage_counts.end()) {
826 VLOG(1) << "Tensor " << id << " not used";
827 continue;
828 }
829 usage_count_it->second--;
830 if (usage_count_it->second > 0) {
831 VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
832 continue;
833 }
834 auto tape_it = tensor_tape_.find(id);
835 if (tape_it == tensor_tape_.end()) {
836 VLOG(1) << "Tensor " << id
837 << " has no associated op. Deleting gradient";
838 auto grad_it = gradients.find(id);
839 if (grad_it != gradients.end()) {
840 for (auto g : grad_it->second) {
841 vspace.DeleteGradient(g);
842 }
843 gradients.erase(grad_it);
844 }
845 continue;
846 }
847 const int64 op_id = tape_it->second;
848 if (op_id == -1) {
849 VLOG(1) << "Tensor " << id << " is source";
850 continue;
851 }
852 auto missing_it = state.op_missing_tensor.find(op_id);
853 if (missing_it != state.op_missing_tensor.end()) {
854 missing_it->second--;
855 VLOG(1) << "Op " << op_id << " missing " << missing_it->second
856 << " output gradients";
857 if (missing_it->second == 0) {
858 op_stack.insert(op_stack.begin(), op_id);
859 }
860 }
861 }
862 }
863 if (!state.op_tape.empty()) {
864 return tensorflow::errors::Internal("Invalid tape state.");
865 }
866 if (result.size() != source_tensor_ids.size()) {
867 return errors::Internal("Expected result Span to be of size ",
868 source_tensor_ids.size(), " found ", result.size(),
869 " in call to Tape::ComputeGradient.");
870 }
871 std::unordered_set<int64> used_gradient_ids(source_tensor_ids.size());
872 for (int i = 0; i < source_tensor_ids.size(); i++) {
873 int64 tensor_id = source_tensor_ids[i];
874 auto grad_it = gradients.find(tensor_id);
875 if (grad_it == gradients.end()) {
876 result[i] = nullptr;
877 } else {
878 if (grad_it->second.size() > 1) {
879 Gradient* grad = vspace.AggregateGradients(grad_it->second);
880 grad_it->second.clear();
881 grad_it->second.push_back(grad);
882 }
883 result[i] = grad_it->second[0];
884 used_gradient_ids.insert(tensor_id);
885 }
886 }
887 VLOG(1) << "Final gradients size: "
888 << gradients.size() - used_gradient_ids.size();
889 for (const auto& grad_pair : gradients) {
890 if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
891 for (const auto& g : grad_pair.second) {
892 vspace.DeleteGradient(g);
893 }
894 }
895 }
896 return Status::OK();
897 }
898
899 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)900 bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
901 gtl::ArraySlice<int64> tensor_ids,
902 gtl::ArraySlice<tensorflow::DataType> dtypes) {
903 if (call_state_.top().backward_tape != nullptr) {
904 // If we're forwarding Accumulate calls to backward_tape's RecordOperation,
905 // we should also delegate ShouldRecord.
906 return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes);
907 }
908 if (call_state_.top().accumulating) {
909 return false;
910 }
911 for (int i = 0; i < tensor_ids.size(); ++i) {
912 if (accumulated_gradients_.find(tensor_ids[i]) !=
913 accumulated_gradients_.end()) {
914 if (IsDtypeTrainable(dtypes[i])) {
915 return true;
916 }
917 }
918 }
919 return false;
920 }
921
922 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
923 Status
ForwardpropFromTape(const string & op_type,const std::vector<TapeTensor> & output_tensors,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter,const std::vector<Gradient * > & in_grads,absl::Span<Gradient * > out_grads)924 ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
925 const string& op_type, const std::vector<TapeTensor>& output_tensors,
926 const std::function<BackwardFunction*()>& backward_function_getter,
927 const std::function<void(BackwardFunction*)>& backward_function_deleter,
928 const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads) {
929 /* This function is approximately equivalent to this Python code:
930
931 forwardprop_aids = tf.ones_like(output_tensors)
932 with tf.GradientTape() as g:
933 g.watch(forwardprop_aids)
934 grad = backward_function(forwardprop_aids)
935 forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads)
936 accumulated_gradients_[ID(output_tensors)] = forward_grads
937 */
938 std::unique_ptr<GradientTape<Gradient, BackwardFunction, TapeTensor>> tape(
939 new GradientTape<Gradient, BackwardFunction, TapeTensor>(false));
940 AccumulatorCallState& call_state = call_state_.top();
941 call_state.backward_tape = tape.get();
942 auto pop_backward_tape =
943 gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; });
944 std::vector<Gradient*> forwardprop_aids;
945 std::vector<int64> sources;
946 std::unordered_set<int64> sources_set;
947 sources.reserve(output_tensors.size());
948 for (const TapeTensor& output_tensor : output_tensors) {
949 // Ownership of `aid` transferred to CallBackwardFunction below.
950 Gradient* aid;
951 if (output_tensor.GetDType() == tensorflow::DT_VARIANT) {
952 // Note: Needs to be zeros rather than ones since there's currently no
953 // ones_like for variants.
954 aid = output_tensor.ZerosLike();
955 } else {
956 // TODO(allenl): Figure out why using zeros_like everywhere causes issues
957 // for some gradient functions and if there's another way to work around
958 // it (e.g. conds instead of ifs). The value shouldn't really matter.
959 TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid));
960 }
961 if (TF_PREDICT_FALSE(aid == nullptr)) {
962 return tensorflow::errors::Internal(
963 "Failed to create ones tensor for tensor ", output_tensor.GetID(),
964 " with dtype ", output_tensor.GetDType());
965 }
966 forwardprop_aids.push_back(aid);
967 int64 aid_id = vspace_.TensorId(aid);
968 sources.push_back(aid_id);
969 sources_set.insert(aid_id);
970 tape->Watch(aid_id);
971 }
972 std::vector<Gradient*> grad(in_grads.size());
973 auto delete_grad = gtl::MakeCleanup([&grad, this] {
974 for (Gradient* tensor : grad) {
975 this->vspace_.DeleteGradient(tensor);
976 }
977 });
978 {
979 std::vector<int64> unneeded_gradients;
980 std::unique_ptr<BackwardFunction, std::function<void(BackwardFunction*)>>
981 backward_function(backward_function_getter(),
982 backward_function_deleter);
983 TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
984 op_type, backward_function.get(), unneeded_gradients, forwardprop_aids,
985 absl::MakeSpan(grad)));
986 }
987
988 // Stop the tape from recording
989 pop_backward_tape.release()();
990
991 std::vector<int64> targets;
992 std::vector<Gradient*> used_in_grads;
993 // We may end up with slightly fewer elements than we reserve, but grad.size()
994 // should be a reasonably tight upper bound.
995 targets.reserve(grad.size());
996 used_in_grads.reserve(grad.size());
997 std::unordered_map<int64, TapeTensor> sources_that_are_targets;
998 for (int grad_index = 0, end = grad.size(); grad_index < end; ++grad_index) {
999 Gradient* grad_tensor = grad[grad_index];
1000 if (grad_tensor != nullptr) {
1001 int64 tensor_id = vspace_.TensorId(grad_tensor);
1002 targets.push_back(tensor_id);
1003 if (sources_set.find(tensor_id) != sources_set.end()) {
1004 sources_that_are_targets.emplace(
1005 tensor_id, vspace_.TapeTensorFromGradient(grad_tensor));
1006 }
1007 Gradient* in_grad = in_grads[grad_index];
1008 if (in_grad != nullptr) {
1009 // ComputeGradient steals a reference
1010 vspace_.MarkAsResult(in_grad);
1011 }
1012 used_in_grads.push_back(in_grad);
1013 }
1014 }
1015
1016 return tape->ComputeGradient(vspace_, targets, sources,
1017 sources_that_are_targets, used_in_grads,
1018 out_grads);
1019 }
1020
1021 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Accumulate(const string & op_type,const std::vector<TapeTensor> & input_tensors,const std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const ForwardFunction<Gradient> * forward_function,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)1022 Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
1023 const string& op_type, const std::vector<TapeTensor>& input_tensors,
1024 const std::vector<TapeTensor>& output_tensors,
1025 gtl::ArraySlice<int64> input_tensor_id,
1026 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
1027 const ForwardFunction<Gradient>* forward_function,
1028 const std::function<BackwardFunction*()>& backward_function_getter,
1029 const std::function<void(BackwardFunction*)>& backward_function_deleter) {
1030 if (call_state_.top().backward_tape != nullptr) {
1031 // If backward_tape is not null, then this call to Accumulate is the result
1032 // of a still-active call to Accumulate which is running operations. We
1033 // forward these operations to backward_tape so the outer Accumulate call
1034 // can do its work.
1035 //
1036 // Rather than re-entering and delegating Accumulate like this, we could
1037 // instead allow ForwardAccumulator some control over the current tape set
1038 // (so it can deactivate itself and activate its GradientTape). Currently
1039 // that is managed by the language binding and would require relatively
1040 // messy callbacks.
1041 call_state_.top().backward_tape->RecordOperation(
1042 op_type, output_tensors, input_tensor_id, input_dtypes,
1043 backward_function_getter, backward_function_deleter);
1044 return Status::OK();
1045 }
1046 if (!ShouldRecord(input_tensor_id, input_dtypes)) {
1047 return Status::OK();
1048 }
1049
1050 // We may need to allocate zero inputs for trainable dtypes we don't have JVPs
1051 // for. Make sure they get cleaned up.
1052 std::vector<Gradient*> new_zeros;
1053 auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] {
1054 for (Gradient* tensor : new_zeros) {
1055 this->vspace_.DeleteGradient(tensor);
1056 }
1057 });
1058 std::vector<Gradient*> in_grads;
1059 in_grads.reserve(input_tensors.size());
1060 for (int target_index = 0; target_index < input_tensors.size();
1061 ++target_index) {
1062 const auto current_grad =
1063 accumulated_gradients_.find(input_tensors[target_index].GetID());
1064 if (current_grad == accumulated_gradients_.end()) {
1065 if (IsDtypeTrainable(input_tensors[target_index].GetDType())) {
1066 // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike
1067 // GradientTape which uses ones.
1068 Gradient* zero = input_tensors[target_index].ZerosLike();
1069 new_zeros.push_back(zero);
1070 in_grads.push_back(zero);
1071 } else {
1072 in_grads.push_back(nullptr);
1073 }
1074 } else {
1075 in_grads.push_back(current_grad->second);
1076 }
1077 }
1078
1079 // Avoid infinite recursion. Whichever forward function we run, it'll end up
1080 // executing ops, and we don't want to watch those with this accumulator.
1081 call_state_.emplace(nullptr, true);
1082 auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); });
1083
1084 std::vector<Gradient*> forward_grads;
1085 if (forward_function == nullptr) {
1086 // We have no special-cased forward gradient. Fall back to running the
1087 // backward function under a gradient tape.
1088 forward_grads.resize(output_tensors.size());
1089 TF_RETURN_IF_ERROR(ForwardpropFromTape(
1090 op_type, output_tensors, backward_function_getter,
1091 backward_function_deleter, in_grads, absl::MakeSpan(forward_grads)));
1092 } else {
1093 TF_RETURN_IF_ERROR(
1094 (*forward_function)(in_grads, &forward_grads, use_batch_));
1095 }
1096 for (int i = 0; i < forward_grads.size(); ++i) {
1097 if (forward_grads[i] != nullptr) {
1098 int64 tensor_id = output_tensors[i].GetID();
1099 auto existing = accumulated_gradients_.find(tensor_id);
1100 if (existing != accumulated_gradients_.end()) {
1101 // This is a somewhat odd case to be in, since it means we have two
1102 // operations which supposedly both created the same Tensor. It comes up
1103 // in recompute_grad, where the gradients have the same value. However,
1104 // only the original gradient is connected to everything else, so we
1105 // should still use that.
1106 vspace_.DeleteGradient(forward_grads[i]);
1107 } else {
1108 accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
1109 }
1110 }
1111 }
1112 return Status::OK();
1113 }
1114
1115 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id,Gradient * tangent)1116 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Watch(
1117 int64 tensor_id, Gradient* tangent) {
1118 typename std::unordered_map<int64, Gradient*>::iterator existing =
1119 accumulated_gradients_.find(tensor_id);
1120 vspace_.MarkAsResult(tangent);
1121 if (existing == accumulated_gradients_.end()) {
1122 accumulated_gradients_.emplace(tensor_id, tangent);
1123 } else {
1124 std::array<Gradient*, 2> to_aggregate;
1125 to_aggregate[0] = tangent;
1126 to_aggregate[1] = existing->second;
1127 // AggregateGradients steals a reference to each of its arguments. We
1128 // MarkAsResult on `tangent` above so we don't steal a reference to it.
1129 existing->second = vspace_.AggregateGradients(to_aggregate);
1130 }
1131 }
1132
1133 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteGradient(int64 tensor_id)1134 void ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::DeleteGradient(
1135 int64 tensor_id) {
1136 auto existing = accumulated_gradients_.find(tensor_id);
1137 if (existing != accumulated_gradients_.end()) {
1138 vspace_.DeleteGradient(existing->second);
1139 accumulated_gradients_.erase(existing);
1140 }
1141 }
1142
1143 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
FetchJVP(int64 tensor_id)1144 Gradient* ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::FetchJVP(
1145 int64 tensor_id) {
1146 auto lookup = accumulated_gradients_.find(tensor_id);
1147 if (lookup == accumulated_gradients_.end()) {
1148 return nullptr;
1149 } else {
1150 return lookup->second;
1151 }
1152 }
1153
1154 } // namespace eager
1155 } // namespace tensorflow
1156
1157 #endif // TENSORFLOW_C_EAGER_TAPE_H_
1158