1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_TAPE_H_
16 #define TENSORFLOW_C_EAGER_TAPE_H_
17
18 // Language-agnostic gradient tape. Does not perform backpropagation, just
19 // maintains the data structures required to do so.
20
21 #include <vector>
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/gtl/array_slice.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/gtl/flatset.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace tensorflow {
30 namespace eager {
31
32 // Represents an entry in the tape.
33 template <typename BackwardFunction, typename TapeTensor>
34 struct OpTapeEntry {
35 string op_type;
36 std::vector<TapeTensor> output_tensor_info;
37 std::vector<int64> input_tensor_id;
38
39 // TODO(apassos) consider narrowing down this interface.
40 BackwardFunction* backward_function;
41
42 // Should be called before deleting the backward function. TODO(apassos) use
43 // unique_ptrs to ensure this happens.
44 std::function<void(BackwardFunction*)> backward_function_deleter;
45 };
46
47 // Map from tensor_id to internally-defined operation-id of the operation which
48 // produced this tensor. A value of -1 means that the tensor was directly
49 // watched and not the result of any operation in the tape.
50 using TensorTape = gtl::FlatMap<int64, int64>;
51
52 // Map from operation-id to tape entry.
53 template <typename BackwardFunction, typename TapeTensor>
54 using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
55
56 // Operations the tape needs to perform on tensors to do backpropagation. Named
57 // "vspace" because a subset of these are related to a vector space, such as
58 // adding gradients, getting zeroes, etc. Currently cannot be implemented
59 // without using tensorflow python code, hence left unspecified here.
60 //
61 // Gradient is the type returned by gradient functions. In Python TF it's either
62 // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
63 // to allow their size to be computed and they need to be passable to a backward
64 // function and deleted (as the backprop code creates lots of gradients the user
65 // is not interested in).
66 //
67 // BackwardFunction needs to be a closure which stores intermediate activations
68 // from the forward computation and calls a vector-jacobian product function
69 // (also known as adjoint function) to compute, given downstream gradients,
70 // upstream gradients.
71 //
72 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
73 // specialization, which is blocked by quite a few things needing to loop back
74 // into python now.
75 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
76 class VSpace {
77 public:
~VSpace()78 virtual ~VSpace() {}
79
80 // Returns the number of elements in the gradient tensor.
81 virtual int64 NumElements(Gradient* tensor) const = 0;
82
83 // Consumes references to the tensors in the gradient_tensors list and returns
84 // a tensor with the result.
85 virtual Gradient* AggregateGradients(
86 gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
87
88 // Returns a tensor of the right shape and dtype filled with zeros.
89 virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
90
91 // Returns a Tensor which is filled with ones and like the input.
92 virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
93
94 // Calls the passed-in backward function.
95 virtual Status CallBackwardFunction(
96 BackwardFunction* backward_function,
97 gtl::ArraySlice<Gradient*> output_gradients,
98 std::vector<Gradient*>* result) const = 0;
99
100 // Marks the following gradient as a result so it's not consumed by backward
101 // functions.
102 virtual void MarkAsResult(Gradient* gradient) const = 0;
103
104 // Deletes the input tensor.
105 virtual void DeleteGradient(Gradient* gradient) const = 0;
106 };
107
108 // Traces the execution of operations, doing eager garbage collection, and
109 // exporting a full trace so other code can do backpropagation. Not thread-safe.
110 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
111 class GradientTape {
112 public:
113 // If `persistent` is true, GradientTape will not eagerly delete backward
114 // functions (and hence the tensors they keep alive). Instead, everything
115 // is deleted in ~GradientTape. Persistent GradientTapes are useful when
116 // users want to compute multiple gradients over the same tape.
GradientTape(bool persistent)117 GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape()118 ~GradientTape() {
119 for (const auto& pair : op_tape_) {
120 pair.second.backward_function_deleter(pair.second.backward_function);
121 }
122 }
123
124 bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
125 gtl::ArraySlice<tensorflow::DataType> dtypes);
126
127 void Watch(int64 tensor_id);
128
129 void RecordOperation(
130 const string& op_type, std::vector<TapeTensor>& output_tensors,
131 gtl::ArraySlice<int64> input_tensor_id,
132 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
133 const std::function<BackwardFunction*()>& backward_function_getter,
134 const std::function<void(BackwardFunction*)>& backward_function_deleter);
135
136 void DeleteTrace(int64 tensor_id);
137
138 // Consumes the internal state of the tape (so cannot be called more than
139 // once) and produces the gradient of the target tensors with respect to the
140 // source tensors. The output gradients are used if not empty and not
141 // null. The result is populated with one tensor per target element.
142 Status ComputeGradient(
143 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
144 const gtl::ArraySlice<int64> target_tensor_ids,
145 const gtl::ArraySlice<int64> source_tensor_ids,
146 const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
147 gtl::ArraySlice<Gradient*> output_gradients,
148 std::vector<Gradient*>* result);
149
IsPersistent()150 bool IsPersistent() const { return persistent_; }
151
152 private:
153 TensorTape tensor_tape_;
154 OpTape<BackwardFunction, TapeTensor> op_tape_;
155 int64 next_op_id_{0};
156
157 // Map from tensor id to number of remaining usages (i.e. how many entries in
158 // the tape refer to it); to aid in tape garbage collection.
159 gtl::FlatMap<int64, int64> tensor_usage_;
160
161 // If false, all activations are deleted in the first call to ComputeGradient.
162 // Else, only when this is destructed.
163 bool persistent_;
164 };
165
166 // Template instantiations here
167
IsDtypeTrainable(DataType dtype)168 inline bool IsDtypeTrainable(DataType dtype) {
169 switch (dtype) {
170 case DT_HALF:
171 case DT_BFLOAT16:
172 case DT_FLOAT:
173 case DT_DOUBLE:
174 case DT_COMPLEX64:
175 case DT_COMPLEX128:
176 case DT_RESOURCE:
177 case DT_VARIANT:
178 return true;
179 default:
180 return false;
181 }
182 }
183
184 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ShouldRecord(gtl::ArraySlice<int64> tensor_ids,gtl::ArraySlice<tensorflow::DataType> dtypes)185 bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
186 gtl::ArraySlice<int64> tensor_ids,
187 gtl::ArraySlice<tensorflow::DataType> dtypes) {
188 CHECK_EQ(tensor_ids.size(), dtypes.size());
189 for (int i = 0; i < tensor_ids.size(); ++i) {
190 if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
191 if (IsDtypeTrainable(dtypes[i])) {
192 return true;
193 }
194 }
195 }
196 return false;
197 }
198
199 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Watch(int64 tensor_id)200 void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
201 int64 tensor_id) {
202 tensor_tape_.emplace(tensor_id, -1);
203 }
204
205 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
RecordOperation(const string & op_type,std::vector<TapeTensor> & output_tensors,gtl::ArraySlice<int64> input_tensor_id,gtl::ArraySlice<tensorflow::DataType> input_dtypes,const std::function<BackwardFunction * ()> & backward_function_getter,const std::function<void (BackwardFunction *)> & backward_function_deleter)206 void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
207 const string& op_type, std::vector<TapeTensor>& output_tensors,
208 gtl::ArraySlice<int64> input_tensor_id,
209 gtl::ArraySlice<tensorflow::DataType> input_dtypes,
210 const std::function<BackwardFunction*()>& backward_function_getter,
211 const std::function<void(BackwardFunction*)>& backward_function_deleter) {
212 if (!ShouldRecord(input_tensor_id, input_dtypes)) {
213 return;
214 }
215 std::vector<int64> ids;
216 ids.reserve(input_tensor_id.size());
217 for (int64 i : input_tensor_id) {
218 tensor_usage_[i]++;
219 ids.push_back(i);
220 }
221 const int64 op_id = next_op_id_++;
222 std::vector<TapeTensor> tensors;
223 tensors.reserve(output_tensors.size());
224 for (const TapeTensor& o : output_tensors) {
225 // Note: the tensor can have already been watched and hence be in the tape,
226 // so we cannot check that we're inserting it here.
227 tensor_tape_[o.GetID()] = op_id;
228 tensor_usage_[o.GetID()] = 1;
229 tensors.push_back(o);
230 }
231 op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
232 op_type, std::move(tensors), std::move(ids), backward_function_getter(),
233 backward_function_deleter};
234 }
235
236 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
DeleteTrace(int64 tensor_id)237 void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
238 int64 tensor_id) {
239 auto it = tensor_usage_.find(tensor_id);
240 if (it == tensor_usage_.end()) {
241 return;
242 }
243 it->second--;
244 if (it->second != 0) {
245 return;
246 }
247 tensor_usage_.erase(it);
248 auto tensor_op_it = tensor_tape_.find(tensor_id);
249 if (tensor_op_it == tensor_tape_.end()) {
250 return;
251 }
252 const int64 op_id = tensor_op_it->second;
253 if (op_id == -1) {
254 // Do not delete watched tensors.
255 return;
256 }
257 tensor_tape_.erase(tensor_op_it);
258 auto op_it = op_tape_.find(op_id);
259 CHECK(op_it != op_tape_.end());
260 for (const auto& output : op_it->second.output_tensor_info) {
261 if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
262 // Found a usage for an output, so cannot delete the op.
263 return;
264 }
265 }
266 for (int64 id : op_it->second.input_tensor_id) {
267 DeleteTrace(id);
268 }
269 op_it->second.backward_function_deleter(op_it->second.backward_function);
270 op_tape_.erase(op_it);
271 }
272
273 // Terminology:
274 //
275 // - op: a possibly composite operation, which has an entry in the tape
276 // - target: dy in dx/dy
277 // - source: dx in dx/dy
278 // - tensor: one of the many inputs or outputs of an operation
279 //
280 // Below here we do the gradient algorithm. It works as follows:
281 //
282 // First we filter the tape to just the subset of operations we want to
283 // differentiate. In the process of doing so we count how many times each Tensor
284 // is used as an input to an op (so we know when we're done computing gradients
285 // for that Tensor). We also count, for each tape entry, how many of its output
286 // Tensors need gradients to be computed (Tensors which are not used do not need
287 // any gradients to be computed).
288 //
289 // Finally, we start a backprop stack with a set of tape entries for which we
290 // have all gradients available. This set usually is a subset of the set of
291 // targets (not all since targets which have outputs in the tape will not have
292 // gradients available initially).
293 //
294 // Then we repeatedly pop an entry from the stack, run its backprop, and update
295 // the gradients of its inputs. Once we have computed all gradients for a single
296 // input we can mark this input as done, and this can trigger adding an entry to
297 // the stack if all outputs of that entry are now done.
298 //
299 // When the stack is empty we have gradients for all tensors we're interested
300 // in.
301
302 namespace {
303
304 template <typename BackwardFunction, typename TapeTensor>
305 struct BackpropInitialState {
306 OpTape<BackwardFunction, TapeTensor> op_tape;
307
308 // Map from tensor ID to how many references still exist for this tensor in
309 // the tape.
310 gtl::FlatMap<int64, int64> tensor_usage_counts;
311
312 // Maps from op ID to how many output tensors of this op still need to have
313 // their gradients computed.
314 gtl::FlatMap<int64, int64> op_missing_tensor;
315 };
316
317 // If `persistent_tape` is true, op_tape is not changed and none of the
318 // backwards functions are deleted.
319 // If `persistent_tape` is false, op_tape is cleared and backwards functions
320 // not needed for gradient computation are deleted. Backwards functions that
321 // are needed, are copied and returned in BackpropInitialState.
322 template <typename BackwardFunction, typename TapeTensor>
PrepareBackprop(gtl::ArraySlice<int64> target,const TensorTape & tensor_tape,OpTape<BackwardFunction,TapeTensor> * op_tape,const gtl::FlatSet<int64> & sources_set,bool persistent_tape)323 BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
324 gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
325 OpTape<BackwardFunction, TapeTensor>* op_tape,
326 const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
327 std::vector<int64> tensor_stack;
328 tensor_stack.reserve(target.size());
329 for (auto t : target) {
330 tensor_stack.push_back(t);
331 }
332 BackpropInitialState<BackwardFunction, TapeTensor> result;
333 while (!tensor_stack.empty()) {
334 int64 tensor_id = tensor_stack.back();
335 tensor_stack.pop_back();
336 auto op_id_it = tensor_tape.find(tensor_id);
337 if (op_id_it == tensor_tape.end()) {
338 continue;
339 }
340 int64 op_id = op_id_it->second;
341 auto op_it = op_tape->find(op_id);
342 auto result_op_it = result.op_tape.find(op_id);
343 if (op_id == -1 || op_it == op_tape->end() ||
344 result_op_it != result.op_tape.end()) {
345 continue;
346 }
347 CHECK(result.op_tape.emplace(op_id, op_it->second).second);
348 for (auto it : op_it->second.input_tensor_id) {
349 auto count_it = result.tensor_usage_counts.find(it);
350 if (count_it != result.tensor_usage_counts.end()) {
351 count_it->second++;
352 } else {
353 result.tensor_usage_counts[it] = 1;
354 if (tensor_tape.find(it) != tensor_tape.end()) {
355 tensor_stack.push_back(it);
356 }
357 }
358 }
359 if (!persistent_tape) {
360 op_tape->erase(op_it);
361 }
362 }
363 for (auto& pair : result.tensor_usage_counts) {
364 auto it = tensor_tape.find(pair.first);
365 if (it != tensor_tape.end() && it->second != -1) {
366 result.op_missing_tensor[it->second] += 1;
367 }
368 }
369 if (!persistent_tape) {
370 // Call destructors for all unneeded gradient functions and
371 // clear the op_tape. We can clear the tape because ownership of
372 // backward functions that will be used for gradient computation
373 // has been transferred to `result`.
374 for (const auto& op_pair : *op_tape) {
375 op_pair.second.backward_function_deleter(
376 op_pair.second.backward_function);
377 }
378 op_tape->clear();
379 }
380 return result;
381 }
382
383 template <typename BackwardFunction, typename TapeTensor>
InitialStack(const OpTape<BackwardFunction,TapeTensor> & op_tape,const gtl::FlatMap<int64,int64> & op_missing_tensor)384 std::vector<int64> InitialStack(
385 const OpTape<BackwardFunction, TapeTensor>& op_tape,
386 const gtl::FlatMap<int64, int64>& op_missing_tensor) {
387 std::vector<int64> result;
388 for (auto& op_entry : op_tape) {
389 if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
390 result.push_back(op_entry.first);
391 }
392 }
393 return result;
394 }
395
396 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
InitialGradients(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,gtl::ArraySlice<int64> target_tensor_ids,gtl::FlatMap<int64,TapeTensor> sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,const TensorTape & tensor_tape,const OpTape<BackwardFunction,TapeTensor> & op_tape,gtl::FlatMap<int64,std::vector<Gradient * >> * result)397 Status InitialGradients(
398 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
399 gtl::ArraySlice<int64> target_tensor_ids,
400 gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
401 gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
402 const OpTape<BackwardFunction, TapeTensor>& op_tape,
403 gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
404 for (int i = 0; i < target_tensor_ids.size(); ++i) {
405 const int64 id = target_tensor_ids[i];
406 if (output_gradients.empty() || output_gradients[i] == nullptr) {
407 auto tensor_it = tensor_tape.find(id);
408 if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
409 auto op_it = op_tape.find(tensor_it->second);
410 if (op_it == op_tape.end()) {
411 return errors::Internal(
412 "Internal state of the gradient tape is invalid: "
413 "failed to find operation producing a tensor");
414 }
415 bool found = false;
416 for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
417 if (op_it->second.output_tensor_info[j].GetID() == id) {
418 found = true;
419 (*result)[id].push_back(
420 vspace.Ones(op_it->second.output_tensor_info[j]));
421 break;
422 }
423 }
424 if (!found) {
425 return errors::Internal(
426 "Internal state of the gradient tape is invalid: "
427 "none of operations outputs match expected tensor");
428 }
429 } else {
430 // This target tensor was not generated by any operation recorded on
431 // the tape, so no gradient needs to be computed from it unless this
432 // target is also a source.
433 auto source_tensor = sources_that_are_targets.find(id);
434 if (source_tensor != sources_that_are_targets.end()) {
435 (*result)[id].push_back(vspace.Ones(source_tensor->second));
436 }
437 }
438 } else {
439 (*result)[id].push_back(output_gradients[i]);
440 }
441 }
442 return Status::OK();
443 }
444
445 // TODO(agarwal): use an automatic mechanism for handling None arguments to
446 // gradient functions.
447 //
448 // Some gradient functions can accept None arguments for gradients. The
449 // following maps the operation name to the indices at which the corresponding
450 // gradient function can accept None values. e.g. FusedBatchNorm outputs 5
451 // values and hence receives 5 gradient values during backprop. However the
452 // gradient function uses only the first of those values and ignores the rest.
453 // The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
454 // corresponding to index 0 is used, and the gradient values at indices 1-4 are
455 // ignored (and hence can be None). The backprop algorithm can then leverage
456 // this by not constructing zeros to pass for those indices.
FunctionsAcceptingNoneForIndicesMap()457 gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
458 static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
459 {"SoftmaxCrossEntropyWithLogits", {1}},
460 {"SparseSoftmaxCrossEntropyWithLogits", {1}},
461 {"FusedBatchNorm", {1, 2, 3, 4}},
462 });
463 return m;
464 }
465
466 } // namespace
467
468 // If over kMinAggregateCount gradients are accumulated and the total
469 // memory consumption is over kMinAggregateBytes, do an early aggregation
470 // so as to release the gradient tensor to save memory.
471 constexpr int kMinAggregateCount = 4;
472 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
473
474 template <typename Gradient, typename BackwardFunction, typename TapeTensor>
ComputeGradient(const VSpace<Gradient,BackwardFunction,TapeTensor> & vspace,const gtl::ArraySlice<int64> target_tensor_ids,const gtl::ArraySlice<int64> source_tensor_ids,const gtl::FlatMap<int64,TapeTensor> sources_that_are_targets,gtl::ArraySlice<Gradient * > output_gradients,std::vector<Gradient * > * result)475 Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
476 const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
477 const gtl::ArraySlice<int64> target_tensor_ids,
478 const gtl::ArraySlice<int64> source_tensor_ids,
479 const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
480 gtl::ArraySlice<Gradient*> output_gradients,
481 std::vector<Gradient*>* result) {
482 gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
483 source_tensor_ids.end());
484 BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
485 target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
486 std::vector<int64> op_stack =
487 InitialStack(state.op_tape, state.op_missing_tensor);
488 gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
489 Status s = InitialGradients(vspace, target_tensor_ids,
490 sources_that_are_targets, output_gradients,
491 tensor_tape_, state.op_tape, &gradients);
492 auto cleanup = [this, &state]() {
493 if (!persistent_) {
494 // Release all backprop functions
495 for (const auto& pair : state.op_tape) {
496 pair.second.backward_function_deleter(pair.second.backward_function);
497 }
498 }
499 };
500 if (!s.ok()) {
501 cleanup();
502 return s;
503 }
504 gtl::FlatMap<int64, int64> gradients_size;
505 // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
506 // time, for better CPU backprop performance.
507 VLOG(1) << "Initial stack:";
508 if (VLOG_IS_ON(1)) {
509 for (auto t : op_stack) {
510 VLOG(1) << " " << t;
511 }
512 }
513 while (!op_stack.empty()) {
514 const int64 op = op_stack.back();
515 VLOG(1) << "Popped " << op;
516 op_stack.pop_back();
517 auto op_it = state.op_tape.find(op);
518 if (op_it == state.op_tape.end()) {
519 // It is possible for ops to end up on the stack if they are unrelated to
520 // the target; we should just skip them.
521 continue;
522 }
523 auto trace = std::move(op_it->second);
524 state.op_tape.erase(op_it);
525 std::vector<Gradient*> out_gradients;
526 out_gradients.reserve(trace.output_tensor_info.size());
527 bool any_gradient_nonzero = false;
528 for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
529 const int64 id = trace.output_tensor_info[i].GetID();
530 auto grad_it = gradients.find(id);
531 if (grad_it == gradients.end()) {
532 auto func_name_it =
533 FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
534 if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
535 func_name_it->second.find(i) != func_name_it->second.end()) {
536 out_gradients.push_back(nullptr);
537 } else {
538 out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
539 }
540 } else {
541 any_gradient_nonzero = true;
542 Gradient* new_gradients = nullptr;
543 if (grad_it->second.size() == 1) {
544 new_gradients = grad_it->second.at(0);
545 } else {
546 new_gradients = vspace.AggregateGradients(grad_it->second);
547 }
548 if (sources_set.find(grad_it->first) == sources_set.end()) {
549 gradients.erase(grad_it);
550 } else {
551 grad_it->second.clear();
552 grad_it->second.push_back(new_gradients);
553 vspace.MarkAsResult(new_gradients);
554 }
555 out_gradients.push_back(new_gradients);
556 }
557 }
558 std::vector<Gradient*> in_gradients;
559 if (any_gradient_nonzero) {
560 Status s = vspace.CallBackwardFunction(trace.backward_function,
561 out_gradients, &in_gradients);
562 if (!persistent_) {
563 trace.backward_function_deleter(trace.backward_function);
564 }
565 if (!s.ok()) {
566 cleanup();
567 return s;
568 }
569 } else {
570 in_gradients.resize(trace.input_tensor_id.size());
571 if (!persistent_) {
572 trace.backward_function_deleter(trace.backward_function);
573 }
574 for (Gradient* grad : out_gradients) {
575 if (grad != nullptr) {
576 vspace.DeleteGradient(grad);
577 }
578 }
579 }
580 VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
581 << trace.input_tensor_id.size() << " sources";
582 for (int i = 0; i < in_gradients.size(); ++i) {
583 const int64 id = trace.input_tensor_id[i];
584 if (in_gradients[i] != nullptr) {
585 auto& unaggregated_grads = gradients[id];
586 unaggregated_grads.push_back(in_gradients[i]);
587 if (unaggregated_grads.size() > kMinAggregateCount) {
588 auto size_it = gradients_size.find(id);
589 int64 size;
590 if (size_it == gradients_size.end()) {
591 size = vspace.NumElements(unaggregated_grads[0]);
592 gradients_size.emplace(id, size);
593 } else {
594 size = size_it->second;
595 }
596 if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
597 Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
598 unaggregated_grads.clear();
599 unaggregated_grads.push_back(grad);
600 }
601 }
602 }
603 auto usage_count_it = state.tensor_usage_counts.find(id);
604 if (usage_count_it == state.tensor_usage_counts.end()) {
605 VLOG(1) << "Tensor " << id << " not used";
606 continue;
607 }
608 usage_count_it->second--;
609 if (usage_count_it->second > 0) {
610 VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
611 continue;
612 }
613 auto tape_it = tensor_tape_.find(id);
614 if (tape_it == tensor_tape_.end()) {
615 VLOG(1) << "Tensor " << id
616 << " has no associated op. Deleting gradient";
617 auto grad_it = gradients.find(id);
618 if (grad_it != gradients.end()) {
619 for (auto g : grad_it->second) {
620 vspace.DeleteGradient(g);
621 }
622 gradients.erase(grad_it);
623 }
624 continue;
625 }
626 const int64 op_id = tape_it->second;
627 if (op_id == -1) {
628 VLOG(1) << "Tensor " << id << " is source";
629 continue;
630 }
631 auto missing_it = state.op_missing_tensor.find(op_id);
632 if (missing_it != state.op_missing_tensor.end()) {
633 missing_it->second--;
634 VLOG(1) << "Op " << op_id << " missing " << missing_it->second
635 << " output gradients";
636 if (missing_it->second == 0) {
637 op_stack.push_back(op_id);
638 }
639 }
640 }
641 }
642 if (!state.op_tape.empty()) {
643 return tensorflow::errors::Internal("Invalid tape state.");
644 }
645 result->reserve(source_tensor_ids.size());
646 gtl::FlatSet<int64> used_gradient_ids(source_tensor_ids.size());
647 for (auto is : source_tensor_ids) {
648 auto grad_it = gradients.find(is);
649 if (grad_it == gradients.end()) {
650 result->push_back(nullptr);
651 } else {
652 if (grad_it->second.size() > 1) {
653 Gradient* grad = vspace.AggregateGradients(grad_it->second);
654 grad_it->second.clear();
655 grad_it->second.push_back(grad);
656 }
657 result->push_back(grad_it->second[0]);
658 used_gradient_ids.insert(is);
659 }
660 }
661 VLOG(1) << "Final gradients size: "
662 << gradients.size() - used_gradient_ids.size();
663 for (auto grad_pair : gradients) {
664 if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
665 for (const auto& g : grad_pair.second) {
666 vspace.DeleteGradient(g);
667 }
668 }
669 }
670 return Status::OK();
671 }
672
673 } // namespace eager
674 } // namespace tensorflow
675
676 #endif // TENSORFLOW_C_EAGER_TAPE_H_
677