1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/gradients.h"
16
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/c/eager/abstract_tensor_handle.h"
19 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
20 #include "tensorflow/c/eager/gradients_internal.h"
21 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
22 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
23 #include "tensorflow/core/platform/errors.h"
24
25 namespace tensorflow {
26 namespace gradients {
27 namespace {
28
29 // TODO(b/172558015): Using the pointer address as the identifier for the tensor
30 // may lead to collisions. Introduce another way to get a unique id for this
31 // tensor.
ToId(const AbstractTensorHandle * t)32 int64 ToId(const AbstractTensorHandle* t) {
33 return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
34 }
35
ZerosLike(AbstractContext * ctx,AbstractTensorHandle * t,AbstractTensorHandle ** result)36 Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
37 AbstractTensorHandle** result) {
38 AbstractOperationPtr op(ctx->CreateOperation());
39 TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
40 if (isa<tracing::TracingOperation>(op.get())) {
41 TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
42 absl::StrCat("ZerosLike", ToId(t)).c_str()));
43 }
44 TF_RETURN_IF_ERROR(op->AddInput(t));
45 int num_outputs = 1;
46 std::vector<AbstractTensorHandle*> outputs(num_outputs);
47 TF_RETURN_IF_ERROR(
48 op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
49 *result = outputs[0];
50 return Status::OK();
51 }
52 } // namespace
53
Register(const string & op_name,GradientFunctionFactory gradient_function_factory)54 Status GradientRegistry::Register(
55 const string& op_name, GradientFunctionFactory gradient_function_factory) {
56 auto iter = registry_.find(op_name);
57 if (iter != registry_.end()) {
58 const string error_msg = "Gradient already exists for op: " + op_name + ".";
59 return errors::AlreadyExists(error_msg);
60 }
61 registry_.insert({op_name, gradient_function_factory});
62 return Status::OK();
63 }
Lookup(const ForwardOperation & op,std::unique_ptr<GradientFunction> * gradient_function) const64 Status GradientRegistry::Lookup(
65 const ForwardOperation& op,
66 std::unique_ptr<GradientFunction>* gradient_function) const {
67 auto iter = registry_.find(op.op_name);
68 if (iter == registry_.end()) {
69 const string error_msg = "No gradient defined for op: " + op.op_name + ".";
70 return errors::NotFound(error_msg);
71 }
72 gradient_function->reset(iter->second(op));
73 return Status::OK();
74 }
75
TapeTensor(AbstractTensorHandle * handle)76 TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
77 handle_->Ref();
78 }
TapeTensor(const TapeTensor & other)79 TapeTensor::TapeTensor(const TapeTensor& other) {
80 handle_ = other.handle_;
81 handle_->Ref();
82 }
~TapeTensor()83 TapeTensor::~TapeTensor() { handle_->Unref(); }
84
GetID() const85 tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
86
GetDType() const87 tensorflow::DataType TapeTensor::GetDType() const {
88 return handle_->DataType();
89 }
GetHandle() const90 AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
91
ZerosLike() const92 AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
93
94 class TapeVSpace
95 : public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
96 public:
TapeVSpace(AbstractContext * ctx)97 explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace()98 ~TapeVSpace() override {}
99
100 // Returns the number of elements in the gradient tensor.
101 int64 NumElements(AbstractTensorHandle* tensor) const override;
102
103 // Consumes references to the tensors in the gradient_tensors list and returns
104 // a tensor with the result.
105 AbstractTensorHandle* AggregateGradients(
106 gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
107
108 // Calls the passed-in backward function.
109 // op_type is the op's name provided in RecordOperation.
110 Status CallBackwardFunction(
111 const string& op_type, GradientFunction* gradient_function,
112 const std::vector<int64>& unneeded_gradients,
113 gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
114 absl::Span<AbstractTensorHandle*> result) const override;
115
116 // Builds a tensor filled with ones with the same shape and dtype as `t`.
117 Status BuildOnesLike(const TapeTensor& t,
118 AbstractTensorHandle** result) const override;
119
120 // Looks up the ID of a Gradient.
121 int64 TensorId(AbstractTensorHandle* tensor) const override;
122
123 // Converts a Gradient to a TapeTensor.
124 TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
125
126 void MarkAsResult(AbstractTensorHandle* gradient) const override;
127
128 void DeleteGradient(AbstractTensorHandle* gradient) const override;
129
130 private:
131 // The context where the aggregation op `Add` is to be created.
132 AbstractContext* ctx_;
133 };
134
135 // Returns the number of elements in the gradient tensor.
NumElements(AbstractTensorHandle * tensor) const136 int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
137 // TODO(srbs): It seems like this is used only for performance optimization
138 // and not for correctness. The only downside of keeping this 1 seems to be
139 // that the gradient accumulation is unbounded and we will never
140 // aggressively aggregate accumulated gradients to recover memory.
141 // Revisit and fix.
142 return 1;
143 }
144
145 // Consumes references to the tensors in the gradient_tensors list and returns
146 // a tensor with the result.
AggregateGradients(gtl::ArraySlice<AbstractTensorHandle * > gradient_tensors) const147 AbstractTensorHandle* TapeVSpace::AggregateGradients(
148 gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
149 if (gradient_tensors.size() == 1) {
150 return gradient_tensors[0];
151 }
152
153 AbstractOperationPtr op(ctx_->CreateOperation());
154 Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
155 if (!s.ok()) {
156 return nullptr;
157 }
158 s = op->AddInputList(gradient_tensors);
159 if (!s.ok()) {
160 return nullptr;
161 }
162
163 int num_outputs = 1;
164 std::vector<AbstractTensorHandle*> outputs(num_outputs);
165 s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
166 if (!s.ok()) {
167 return nullptr;
168 }
169 return outputs[0];
170 }
171
172 // Calls the passed-in backward function.
173 // op_type is the op's name provided in RecordOperation.
CallBackwardFunction(const string & op_type,GradientFunction * gradient_function,const std::vector<int64> & unneeded_gradients,gtl::ArraySlice<AbstractTensorHandle * > output_gradients,absl::Span<AbstractTensorHandle * > result) const174 Status TapeVSpace::CallBackwardFunction(
175 const string& op_type, GradientFunction* gradient_function,
176 const std::vector<int64>& unneeded_gradients,
177 gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
178 absl::Span<AbstractTensorHandle*> result) const {
179 if (gradient_function == nullptr) {
180 return errors::InvalidArgument(
181 "Provided null gradient_function for '", op_type, "'.\n",
182 "If the intent is to treat this op as non-differentiable consider "
183 "using RegisterNotDifferentiable or "
184 "NotDifferentiableGradientFunction.");
185 }
186 return gradient_function->Compute(ctx_, output_gradients, result);
187 }
188
BuildOnesLike(const TapeTensor & t,AbstractTensorHandle ** result) const189 Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
190 AbstractTensorHandle** result) const {
191 AbstractOperationPtr op(ctx_->CreateOperation());
192 TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
193 if (isa<tracing::TracingOperation>(op.get())) {
194 TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
195 absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
196 }
197 TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
198 int num_outputs = 1;
199 std::vector<AbstractTensorHandle*> outputs(num_outputs);
200 TF_RETURN_IF_ERROR(
201 op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
202 *result = outputs[0];
203 return Status::OK();
204 }
205
206 // Looks up the ID of a Gradient.
TensorId(AbstractTensorHandle * tensor) const207 int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
208 return ToId(tensor);
209 }
210
211 // Converts a Gradient to a TapeTensor.
TapeTensorFromGradient(AbstractTensorHandle * g) const212 TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
213 return TapeTensor(g);
214 }
215
MarkAsResult(AbstractTensorHandle * gradient) const216 void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
217
DeleteGradient(AbstractTensorHandle * gradient) const218 void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
219 gradient->Unref();
220 }
221
Watch(const AbstractTensorHandle * t)222 void Tape::Watch(const AbstractTensorHandle* t) {
223 GradientTape::Watch(ToId(t));
224 }
RecordOperation(absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * const> outputs,GradientFunction * gradient_function,const string & op_name)225 void Tape::RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
226 absl::Span<AbstractTensorHandle* const> outputs,
227 GradientFunction* gradient_function,
228 const string& op_name) {
229 std::vector<int64> input_ids(inputs.size());
230 std::vector<tensorflow::DataType> input_dtypes(inputs.size());
231 for (int i = 0; i < inputs.size(); i++) {
232 input_ids[i] = ToId(inputs[i]);
233 input_dtypes[i] = inputs[i]->DataType();
234 }
235 std::vector<TapeTensor> tape_tensors;
236 for (auto t : outputs) {
237 tape_tensors.push_back(TapeTensor(t));
238 }
239 GradientTape::RecordOperation(
240 op_name, tape_tensors, input_ids, input_dtypes,
241 [gradient_function]() -> GradientFunction* { return gradient_function; },
242 [](GradientFunction* ptr) {
243 if (ptr) {
244 delete ptr;
245 }
246 });
247 }
ShouldRecord(absl::Span<const AbstractTensorHandle * const> tensors) const248 bool Tape::ShouldRecord(
249 absl::Span<const AbstractTensorHandle* const> tensors) const {
250 std::vector<int64> tensor_ids(tensors.size());
251 std::vector<tensorflow::DataType> tensor_dtypes(tensors.size());
252 for (int i = 0; i < tensors.size(); i++) {
253 tensor_ids[i] = ToId(tensors[i]);
254 tensor_dtypes[i] = tensors[i]->DataType();
255 }
256 return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes);
257 }
DeleteTrace(const AbstractTensorHandle * t)258 void Tape::DeleteTrace(const AbstractTensorHandle* t) {
259 GradientTape::DeleteTrace(ToId(t));
260 }
261
MakeTensorIDList(absl::Span<AbstractTensorHandle * const> tensors)262 std::vector<int64> MakeTensorIDList(
263 absl::Span<AbstractTensorHandle* const> tensors) {
264 std::vector<int64> ids(tensors.size());
265 for (int i = 0; i < tensors.size(); i++) {
266 ids[i] = ToId(tensors[i]);
267 }
268 return ids;
269 }
270
ComputeGradient(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> targets,absl::Span<AbstractTensorHandle * const> sources,absl::Span<AbstractTensorHandle * const> output_gradients,absl::Span<AbstractTensorHandle * > result)271 Status Tape::ComputeGradient(
272 AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
273 absl::Span<AbstractTensorHandle* const> sources,
274 absl::Span<AbstractTensorHandle* const> output_gradients,
275 absl::Span<AbstractTensorHandle*> result) {
276 TapeVSpace vspace(ctx);
277 std::vector<int64> target_tensor_ids = MakeTensorIDList(targets);
278 std::vector<int64> source_tensor_ids = MakeTensorIDList(sources);
279 tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(
280 source_tensor_ids.begin(), source_tensor_ids.end());
281 std::unordered_map<int64, TapeTensor> sources_that_are_targets;
282 for (int i = 0; i < target_tensor_ids.size(); ++i) {
283 int64_t target_id = target_tensor_ids[i];
284 if (sources_set.find(target_id) != sources_set.end()) {
285 auto tensor = targets[i];
286 sources_that_are_targets.insert(
287 std::make_pair(target_id, TapeTensor(tensor)));
288 }
289 }
290
291 TF_RETURN_IF_ERROR(GradientTape::ComputeGradient(
292 vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets,
293 output_gradients, result, /*build_default_zeros_grads*/ false));
294 return Status::OK();
295 }
296
297 // Helper functions which delegate to `AbstractOperation`, update
298 // the state of the ForwardOperation and call the tape as appropriate.
299 // These APIs are mainly to facilitate testing and are subject to change.
300 namespace internal {
Reset(AbstractOperation * op_,const char * op,const char * raw_device_name,ForwardOperation * forward_op_)301 Status Reset(AbstractOperation* op_, const char* op,
302 const char* raw_device_name, ForwardOperation* forward_op_) {
303 forward_op_->op_name = op;
304 forward_op_->attrs.Reset(op);
305 return op_->Reset(op, raw_device_name);
306 }
AddInput(AbstractOperation * op_,AbstractTensorHandle * input,ForwardOperation * forward_op_)307 Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
308 ForwardOperation* forward_op_) {
309 TF_RETURN_IF_ERROR(op_->AddInput(input));
310 forward_op_->inputs.push_back(input);
311 return Status::OK();
312 }
AddInputList(AbstractOperation * op_,absl::Span<AbstractTensorHandle * const> inputs,ForwardOperation * forward_op_)313 Status AddInputList(AbstractOperation* op_,
314 absl::Span<AbstractTensorHandle* const> inputs,
315 ForwardOperation* forward_op_) {
316 TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
317 for (auto input : inputs) {
318 forward_op_->inputs.push_back(input);
319 }
320 return Status::OK();
321 }
322
SetAttrString(AbstractOperation * op_,const char * attr_name,const char * data,size_t length,ForwardOperation * forward_op_)323 Status SetAttrString(AbstractOperation* op_, const char* attr_name,
324 const char* data, size_t length,
325 ForwardOperation* forward_op_) {
326 forward_op_->attrs.Set(attr_name, StringPiece(data, length));
327 return op_->SetAttrString(attr_name, data, length);
328 }
SetAttrInt(AbstractOperation * op_,const char * attr_name,int64_t value,ForwardOperation * forward_op_)329 Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
330 ForwardOperation* forward_op_) {
331 forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
332 return op_->SetAttrInt(attr_name, value);
333 }
SetAttrFloat(AbstractOperation * op_,const char * attr_name,float value,ForwardOperation * forward_op_)334 Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
335 ForwardOperation* forward_op_) {
336 forward_op_->attrs.Set(attr_name, value);
337 return op_->SetAttrFloat(attr_name, value);
338 }
SetAttrBool(AbstractOperation * op_,const char * attr_name,bool value,ForwardOperation * forward_op_)339 Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
340 ForwardOperation* forward_op_) {
341 forward_op_->attrs.Set(attr_name, value);
342 return op_->SetAttrBool(attr_name, value);
343 }
SetAttrType(AbstractOperation * op_,const char * attr_name,DataType value,ForwardOperation * forward_op_)344 Status SetAttrType(AbstractOperation* op_, const char* attr_name,
345 DataType value, ForwardOperation* forward_op_) {
346 forward_op_->attrs.Set(attr_name, value);
347 return op_->SetAttrType(attr_name, value);
348 }
SetAttrShape(AbstractOperation * op_,const char * attr_name,const int64_t * dims,const int num_dims,ForwardOperation * forward_op_)349 Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
350 const int64_t* dims, const int num_dims,
351 ForwardOperation* forward_op_) {
352 if (num_dims > TensorShape::MaxDimensions()) {
353 return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
354 num_dims,
355 " dimensions which is over the limit of ",
356 TensorShape::MaxDimensions(), ".");
357 }
358 TensorShapeProto proto;
359 if (num_dims < 0) {
360 proto.set_unknown_rank(true);
361 } else {
362 for (int d = 0; d < num_dims; ++d) {
363 proto.add_dim()->set_size(dims[d]);
364 }
365 }
366
367 forward_op_->attrs.Set(attr_name, proto);
368 return op_->SetAttrShape(attr_name, dims, num_dims);
369 }
SetAttrFunction(AbstractOperation * op_,const char * attr_name,const AbstractOperation * value,ForwardOperation * forward_op_)370 Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
371 const AbstractOperation* value,
372 ForwardOperation* forward_op_) {
373 return tensorflow::errors::Unimplemented(
374 "SetAttrFunction has not been implemented yet.");
375 }
SetAttrFunctionName(AbstractOperation * op_,const char * attr_name,const char * value,size_t length,ForwardOperation * forward_op_)376 Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
377 const char* value, size_t length,
378 ForwardOperation* forward_op_) {
379 return tensorflow::errors::Unimplemented(
380 "SetAttrFunctionName has not been implemented "
381 "yet.");
382 }
SetAttrTensor(AbstractOperation * op_,const char * attr_name,AbstractTensorInterface * tensor,ForwardOperation * forward_op_)383 Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
384 AbstractTensorInterface* tensor,
385 ForwardOperation* forward_op_) {
386 return tensorflow::errors::Unimplemented(
387 "SetAttrTensor has not been implemented yet.");
388 }
SetAttrStringList(AbstractOperation * op_,const char * attr_name,const void * const * values,const size_t * lengths,int num_values,ForwardOperation * forward_op_)389 Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
390 const void* const* values, const size_t* lengths,
391 int num_values, ForwardOperation* forward_op_) {
392 std::vector<StringPiece> v(num_values);
393 for (int i = 0; i < num_values; ++i) {
394 v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
395 }
396 forward_op_->attrs.Set(attr_name, v);
397 return op_->SetAttrStringList(attr_name, values, lengths, num_values);
398 }
SetAttrFloatList(AbstractOperation * op_,const char * attr_name,const float * values,int num_values,ForwardOperation * forward_op_)399 Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
400 const float* values, int num_values,
401 ForwardOperation* forward_op_) {
402 forward_op_->attrs.Set(attr_name,
403 gtl::ArraySlice<const float>(values, num_values));
404 return op_->SetAttrFloatList(attr_name, values, num_values);
405 }
SetAttrIntList(AbstractOperation * op_,const char * attr_name,const int64_t * values,int num_values,ForwardOperation * forward_op_)406 Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
407 const int64_t* values, int num_values,
408 ForwardOperation* forward_op_) {
409 forward_op_->attrs.Set(
410 attr_name, gtl::ArraySlice<const int64>(
411 reinterpret_cast<const int64*>(values), num_values));
412 return op_->SetAttrIntList(attr_name, values, num_values);
413 }
SetAttrTypeList(AbstractOperation * op_,const char * attr_name,const DataType * values,int num_values,ForwardOperation * forward_op_)414 Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
415 const DataType* values, int num_values,
416 ForwardOperation* forward_op_) {
417 forward_op_->attrs.Set(attr_name,
418 gtl::ArraySlice<const DataType>(values, num_values));
419 return op_->SetAttrTypeList(attr_name, values, num_values);
420 }
SetAttrBoolList(AbstractOperation * op_,const char * attr_name,const unsigned char * values,int num_values,ForwardOperation * forward_op_)421 Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
422 const unsigned char* values, int num_values,
423 ForwardOperation* forward_op_) {
424 std::unique_ptr<bool[]> b(new bool[num_values]);
425 for (int i = 0; i < num_values; ++i) {
426 b[i] = values[i];
427 }
428 forward_op_->attrs.Set(attr_name,
429 gtl::ArraySlice<const bool>(b.get(), num_values));
430 return op_->SetAttrBoolList(attr_name, values, num_values);
431 }
SetAttrShapeList(AbstractOperation * op_,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,ForwardOperation * forward_op_)432 Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
433 const int64_t** dims, const int* num_dims,
434 int num_values, ForwardOperation* forward_op_) {
435 std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
436 for (int i = 0; i < num_values; ++i) {
437 const auto num_dims_i = num_dims[i];
438
439 if (num_dims_i > TensorShape::MaxDimensions()) {
440 return errors::InvalidArgument(
441 strings::StrCat("Value specified for `", attr_name, "` has ",
442 num_dims_i, " dimensions which is over the limit of ",
443 TensorShape::MaxDimensions(), "."));
444 }
445 if (num_dims_i < 0) {
446 proto[i].set_unknown_rank(true);
447 } else {
448 const int64_t* dims_i = dims[i];
449 auto proto_i = &proto[i];
450 for (int d = 0; d < num_dims_i; ++d) {
451 proto_i->add_dim()->set_size(dims_i[d]);
452 }
453 }
454 }
455 forward_op_->attrs.Set(
456 attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
457 return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
458 }
SetAttrFunctionList(AbstractOperation * op_,const char * attr_name,absl::Span<const AbstractOperation * > values,ForwardOperation * forward_op_)459 Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
460 absl::Span<const AbstractOperation*> values,
461 ForwardOperation* forward_op_) {
462 return tensorflow::errors::Unimplemented(
463 "SetAttrFunctionList has not been "
464 "implemented yet.");
465 }
Execute(AbstractOperation * op_,AbstractContext * ctx,absl::Span<AbstractTensorHandle * > retvals,int * num_retvals,ForwardOperation * forward_op_,Tape * tape,const GradientRegistry & registry)466 Status Execute(AbstractOperation* op_, AbstractContext* ctx,
467 absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
468 ForwardOperation* forward_op_, Tape* tape,
469 const GradientRegistry& registry) {
470 TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
471 for (int i = 0; i < *num_retvals; i++) {
472 // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
473 forward_op_->outputs.push_back(retvals[i]);
474 }
475 // TODO(b/166669239): This is needed to support AttrBuilder::Get for string
476 // attributes. Number type attrs and DataType attrs work fine without this.
477 // Consider getting rid of this and making the behavior between number types
478 // and string consistent.
479 forward_op_->attrs.BuildNodeDef();
480 std::unique_ptr<GradientFunction> gradient_fn;
481 TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn));
482 tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(),
483 op_->Name());
484 return Status::OK();
485 }
486 } // namespace internal
487
488 } // namespace gradients
489 } // namespace tensorflow
490