• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
17 #define TENSORFLOW_C_EAGER_GRADIENTS_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/c/eager/abstract_context.h"
21 #include "tensorflow/c/eager/abstract_tensor_handle.h"
22 #include "tensorflow/c/eager/tape.h"
23 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
24 
25 namespace tensorflow {
26 namespace gradients {
27 
28 // =============== Experimental C++ API for computing gradients ===============
29 
30 // Sample gradient function:
31 //
32 // class AddGradientFunction : public GradientFunction {
33 //  public:
34 //   Status Compute(Context* ctx,
35 //                  absl::Span<AbstractTensorHandle* const> grad_inputs,
36 //                  absl::Span<AbstractTensorHandle*> grad_outputs) override {
37 //     grad_outputs[0] = grad_inputs[0];
38 //     grad_outputs[1] = grad_inputs[0];
39 //     grad_outputs[0]->Ref();
40 //     grad_outputs[1]->Ref();
41 //     return Status::OK();
42 //   }
43 //   ~AddGradientFunction() override {}
44 // };
45 //
46 // GradientFunction* AddRegisterer(const ForwardOperation& op) {
47 //   // More complex gradient functions can use inputs/attrs etc. from the
48 //   // forward `op`.
49 //   return new AddGradientFunction;
50 // }
51 //
52 // Status RegisterGradients(GradientRegistry* registry) {
53 //   return registry->Register("Add", AddRegisterer);
54 // }
55 class GradientFunction {
56  public:
57   virtual Status Compute(AbstractContext* ctx,
58                          absl::Span<AbstractTensorHandle* const> grad_outputs,
59                          absl::Span<AbstractTensorHandle*> grad_inputs) = 0;
~GradientFunction()60   virtual ~GradientFunction() {}
61 };
62 
63 // Metadata from the forward operation that is made available to the
64 // gradient registerer to instantiate a GradientFunction.
65 struct ForwardOperation {
66  public:
67   string op_name;
68   std::vector<AbstractTensorHandle*> inputs;
69   std::vector<AbstractTensorHandle*> outputs;
70   std::vector<int64> skip_input_indices;
71   AttrBuilder attrs;
72 };
73 
74 using GradientFunctionFactory =
75     std::function<GradientFunction*(const ForwardOperation& op)>;
76 
77 // Map from op name to a `GradientFunctionFactory`.
78 class GradientRegistry {
79  public:
80   Status Register(const string& op,
81                   GradientFunctionFactory gradient_function_factory);
82   Status Lookup(const ForwardOperation& op,
83                 std::unique_ptr<GradientFunction>* gradient_function) const;
84 
85  private:
86   absl::flat_hash_map<string, GradientFunctionFactory> registry_;
87 };
88 
89 // TODO(srbs): Figure out if we can avoid declaring this in the public header.
90 // Wrapper for a tensor output of an operation executing under a tape.
91 //
92 // `GetID` returns a unique id for the wrapped tensor which is used to maintain
93 // a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
94 // the op that produced it (or -1 if this tensor was watched using
95 // `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
96 // op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
97 // maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
98 // inputs and outputs and the gradient function These data structures combined
99 // allow us to trace the data dependencies between operations and hence compute
100 // gradients.
101 //
102 // `ZerosLike` is not expected to be called and returns a nullptr. The creation
103 // of default zeros grads is handled by the `DefaultGradientFunction` registered
104 // for each op.
105 // TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
106 // Figure out a way to avoid this.
107 // TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
108 class TapeTensor {
109  public:
110   explicit TapeTensor(AbstractTensorHandle* handle);
111   TapeTensor(const TapeTensor& other);
112   ~TapeTensor();
113 
114   tensorflow::int64 GetID() const;
115   tensorflow::DataType GetDType() const;
116 
117   AbstractTensorHandle* ZerosLike() const;
118 
119   AbstractTensorHandle* GetHandle() const;
120 
121  private:
122   AbstractTensorHandle* handle_;
123 };
124 
125 // A tracing/immediate-execution agnostic tape.
126 //
127 // Gradient functions defined for this tape must support handling null incoming
128 // gradients.
129 class Tape : protected eager::GradientTape<AbstractTensorHandle,
130                                            GradientFunction, TapeTensor> {
131  public:
132   using GradientTape<AbstractTensorHandle, GradientFunction,
133                      TapeTensor>::GradientTape;
134   // Returns whether the tape is persistent, i.e., whether the tape will hold
135   // onto its internal state after a call to `ComputeGradient`.
136   using GradientTape<AbstractTensorHandle, GradientFunction,
137                      TapeTensor>::IsPersistent;
138 
139   // Adds this tensor to the list of watched tensors.
140   //
141   // This is a no-op if the tensor is already being watched either from an
142   // earlier call to `GradientTape::Watch` or being an output of an op with
143   // watched inputs.
144   void Watch(const AbstractTensorHandle*);
145   // Records an operation with given inputs and outputs
146   // on the tape and marks all its outputs as watched if at
147   // least one input of the op is watched and has a trainable dtype.
148   // op_name is optional and is used for debugging only.
149   void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
150                        absl::Span<AbstractTensorHandle* const> outputs,
151                        GradientFunction* gradient_function,
152                        const string& op_name = "");
153   // Returns whether any tensor in a list of tensors is being watched and has
154   // a trainable dtype.
155   bool ShouldRecord(
156       absl::Span<const AbstractTensorHandle* const> tensors) const;
157   // Unwatches this tensor on the tape. Mainly used for cleanup when deleting
158   // eager tensors.
159   void DeleteTrace(const AbstractTensorHandle*);
160 
161   // Consumes the internal state of the tape (so cannot be called more than
162   // once unless the tape is persistent) and produces the gradient of the target
163   // tensors with respect to the source tensors. The output gradients are used
164   // if not empty and not null. The result is populated with one tensor per
165   // target element.
166   Status ComputeGradient(
167       AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
168       absl::Span<AbstractTensorHandle* const> sources,
169       absl::Span<AbstractTensorHandle* const> output_gradients,
170       absl::Span<AbstractTensorHandle*> result);
171 };
172 
173 }  // namespace gradients
174 }  // namespace tensorflow
175 
176 #endif  // TENSORFLOW_C_EAGER_GRADIENTS_H_
177