1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_ 17 #define TENSORFLOW_C_EAGER_GRADIENTS_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "tensorflow/c/eager/abstract_context.h" 21 #include "tensorflow/c/eager/abstract_tensor_handle.h" 22 #include "tensorflow/c/eager/tape.h" 23 #include "tensorflow/core/common_runtime/eager/attr_builder.h" 24 25 namespace tensorflow { 26 namespace gradients { 27 28 // =============== Experimental C++ API for computing gradients =============== 29 30 // Sample gradient function: 31 // 32 // class AddGradientFunction : public GradientFunction { 33 // public: 34 // Status Compute(Context* ctx, 35 // absl::Span<AbstractTensorHandle* const> grad_inputs, 36 // absl::Span<AbstractTensorHandle*> grad_outputs) override { 37 // grad_outputs[0] = grad_inputs[0]; 38 // grad_outputs[1] = grad_inputs[0]; 39 // grad_outputs[0]->Ref(); 40 // grad_outputs[1]->Ref(); 41 // return Status::OK(); 42 // } 43 // ~AddGradientFunction() override {} 44 // }; 45 // 46 // GradientFunction* AddRegisterer(const ForwardOperation& op) { 47 // // More complex gradient functions can use inputs/attrs etc. from the 48 // // forward `op`. 49 // return new AddGradientFunction; 50 // } 51 // 52 // Status RegisterGradients(GradientRegistry* registry) { 53 // return registry->Register("Add", AddRegisterer); 54 // } 55 class GradientFunction { 56 public: 57 virtual Status Compute(AbstractContext* ctx, 58 absl::Span<AbstractTensorHandle* const> grad_outputs, 59 absl::Span<AbstractTensorHandle*> grad_inputs) = 0; ~GradientFunction()60 virtual ~GradientFunction() {} 61 }; 62 63 // Metadata from the forward operation that is made available to the 64 // gradient registerer to instantiate a GradientFunction. 65 struct ForwardOperation { 66 public: 67 string op_name; 68 std::vector<AbstractTensorHandle*> inputs; 69 std::vector<AbstractTensorHandle*> outputs; 70 std::vector<int64> skip_input_indices; 71 AttrBuilder attrs; 72 }; 73 74 using GradientFunctionFactory = 75 std::function<GradientFunction*(const ForwardOperation& op)>; 76 77 // Map from op name to a `GradientFunctionFactory`. 78 class GradientRegistry { 79 public: 80 Status Register(const string& op, 81 GradientFunctionFactory gradient_function_factory); 82 Status Lookup(const ForwardOperation& op, 83 std::unique_ptr<GradientFunction>* gradient_function) const; 84 85 private: 86 absl::flat_hash_map<string, GradientFunctionFactory> registry_; 87 }; 88 89 // TODO(srbs): Figure out if we can avoid declaring this in the public header. 90 // Wrapper for a tensor output of an operation executing under a tape. 91 // 92 // `GetID` returns a unique id for the wrapped tensor which is used to maintain 93 // a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of 94 // the op that produced it (or -1 if this tensor was watched using 95 // `GradientTape::Watch`.) The op_id is simply a unique index assigned to each 96 // op executed under the tape. A separate map (`tensorflow::eager::OpTape`) 97 // maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`, 98 // inputs and outputs and the gradient function These data structures combined 99 // allow us to trace the data dependencies between operations and hence compute 100 // gradients. 101 // 102 // `ZerosLike` is not expected to be called and returns a nullptr. The creation 103 // of default zeros grads is handled by the `DefaultGradientFunction` registered 104 // for each op. 105 // TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy. 106 // Figure out a way to avoid this. 107 // TODO(srbs): Should ZerosLike check-fail instead of returning nullptr? 108 class TapeTensor { 109 public: 110 explicit TapeTensor(AbstractTensorHandle* handle); 111 TapeTensor(const TapeTensor& other); 112 ~TapeTensor(); 113 114 tensorflow::int64 GetID() const; 115 tensorflow::DataType GetDType() const; 116 117 AbstractTensorHandle* ZerosLike() const; 118 119 AbstractTensorHandle* GetHandle() const; 120 121 private: 122 AbstractTensorHandle* handle_; 123 }; 124 125 // A tracing/immediate-execution agnostic tape. 126 // 127 // Gradient functions defined for this tape must support handling null incoming 128 // gradients. 129 class Tape : protected eager::GradientTape<AbstractTensorHandle, 130 GradientFunction, TapeTensor> { 131 public: 132 using GradientTape<AbstractTensorHandle, GradientFunction, 133 TapeTensor>::GradientTape; 134 // Returns whether the tape is persistent, i.e., whether the tape will hold 135 // onto its internal state after a call to `ComputeGradient`. 136 using GradientTape<AbstractTensorHandle, GradientFunction, 137 TapeTensor>::IsPersistent; 138 139 // Adds this tensor to the list of watched tensors. 140 // 141 // This is a no-op if the tensor is already being watched either from an 142 // earlier call to `GradientTape::Watch` or being an output of an op with 143 // watched inputs. 144 void Watch(const AbstractTensorHandle*); 145 // Records an operation with given inputs and outputs 146 // on the tape and marks all its outputs as watched if at 147 // least one input of the op is watched and has a trainable dtype. 148 // op_name is optional and is used for debugging only. 149 void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs, 150 absl::Span<AbstractTensorHandle* const> outputs, 151 GradientFunction* gradient_function, 152 const string& op_name = ""); 153 // Returns whether any tensor in a list of tensors is being watched and has 154 // a trainable dtype. 155 bool ShouldRecord( 156 absl::Span<const AbstractTensorHandle* const> tensors) const; 157 // Unwatches this tensor on the tape. Mainly used for cleanup when deleting 158 // eager tensors. 159 void DeleteTrace(const AbstractTensorHandle*); 160 161 // Consumes the internal state of the tape (so cannot be called more than 162 // once unless the tape is persistent) and produces the gradient of the target 163 // tensors with respect to the source tensors. The output gradients are used 164 // if not empty and not null. The result is populated with one tensor per 165 // target element. 166 Status ComputeGradient( 167 AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets, 168 absl::Span<AbstractTensorHandle* const> sources, 169 absl::Span<AbstractTensorHandle* const> output_gradients, 170 absl::Span<AbstractTensorHandle*> result); 171 }; 172 173 } // namespace gradients 174 } // namespace tensorflow 175 176 #endif // TENSORFLOW_C_EAGER_GRADIENTS_H_ 177