• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/gradients.h"
16 
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/c/eager/abstract_tensor_handle.h"
19 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
20 #include "tensorflow/c/eager/gradients_internal.h"
21 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
22 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
23 #include "tensorflow/core/platform/errors.h"
24 
25 namespace tensorflow {
26 namespace gradients {
27 namespace {
28 
29 // TODO(b/172558015): Using the pointer address as the identifier for the tensor
30 // may lead to collisions. Introduce another way to get a unique id for this
31 // tensor.
ToId(const AbstractTensorHandle * t)32 int64 ToId(const AbstractTensorHandle* t) {
33   return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
34 }
35 
ZerosLike(AbstractContext * ctx,AbstractTensorHandle * t,AbstractTensorHandle ** result)36 Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
37                  AbstractTensorHandle** result) {
38   AbstractOperationPtr op(ctx->CreateOperation());
39   TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
40   if (isa<tracing::TracingOperation>(op.get())) {
41     TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
42         absl::StrCat("ZerosLike", ToId(t)).c_str()));
43   }
44   TF_RETURN_IF_ERROR(op->AddInput(t));
45   int num_outputs = 1;
46   std::vector<AbstractTensorHandle*> outputs(num_outputs);
47   TF_RETURN_IF_ERROR(
48       op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
49   *result = outputs[0];
50   return Status::OK();
51 }
52 }  // namespace
53 
Register(const string & op_name,GradientFunctionFactory gradient_function_factory)54 Status GradientRegistry::Register(
55     const string& op_name, GradientFunctionFactory gradient_function_factory) {
56   auto iter = registry_.find(op_name);
57   if (iter != registry_.end()) {
58     const string error_msg = "Gradient already exists for op: " + op_name + ".";
59     return errors::AlreadyExists(error_msg);
60   }
61   registry_.insert({op_name, gradient_function_factory});
62   return Status::OK();
63 }
Lookup(const ForwardOperation & op,std::unique_ptr<GradientFunction> * gradient_function) const64 Status GradientRegistry::Lookup(
65     const ForwardOperation& op,
66     std::unique_ptr<GradientFunction>* gradient_function) const {
67   auto iter = registry_.find(op.op_name);
68   if (iter == registry_.end()) {
69     const string error_msg = "No gradient defined for op: " + op.op_name + ".";
70     return errors::NotFound(error_msg);
71   }
72   gradient_function->reset(iter->second(op));
73   return Status::OK();
74 }
75 
TapeTensor(AbstractTensorHandle * handle)76 TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
77   handle_->Ref();
78 }
TapeTensor(const TapeTensor & other)79 TapeTensor::TapeTensor(const TapeTensor& other) {
80   handle_ = other.handle_;
81   handle_->Ref();
82 }
~TapeTensor()83 TapeTensor::~TapeTensor() { handle_->Unref(); }
84 
GetID() const85 tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
86 
GetDType() const87 tensorflow::DataType TapeTensor::GetDType() const {
88   return handle_->DataType();
89 }
GetHandle() const90 AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
91 
ZerosLike() const92 AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
93 
94 class TapeVSpace
95     : public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
96  public:
TapeVSpace(AbstractContext * ctx)97   explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace()98   ~TapeVSpace() override {}
99 
100   // Returns the number of elements in the gradient tensor.
101   int64 NumElements(AbstractTensorHandle* tensor) const override;
102 
103   // Consumes references to the tensors in the gradient_tensors list and returns
104   // a tensor with the result.
105   AbstractTensorHandle* AggregateGradients(
106       gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
107 
108   // Calls the passed-in backward function.
109   // op_type is the op's name provided in RecordOperation.
110   Status CallBackwardFunction(
111       const string& op_type, GradientFunction* gradient_function,
112       const std::vector<int64>& unneeded_gradients,
113       gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
114       absl::Span<AbstractTensorHandle*> result) const override;
115 
116   // Builds a tensor filled with ones with the same shape and dtype as `t`.
117   Status BuildOnesLike(const TapeTensor& t,
118                        AbstractTensorHandle** result) const override;
119 
120   // Looks up the ID of a Gradient.
121   int64 TensorId(AbstractTensorHandle* tensor) const override;
122 
123   // Converts a Gradient to a TapeTensor.
124   TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
125 
126   void MarkAsResult(AbstractTensorHandle* gradient) const override;
127 
128   void DeleteGradient(AbstractTensorHandle* gradient) const override;
129 
130  private:
131   // The context where the aggregation op `Add` is to be created.
132   AbstractContext* ctx_;
133 };
134 
135 // Returns the number of elements in the gradient tensor.
NumElements(AbstractTensorHandle * tensor) const136 int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
137   // TODO(srbs): It seems like this is used only for performance optimization
138   // and not for correctness. The only downside of keeping this 1 seems to be
139   // that the gradient accumulation is unbounded and we will never
140   // aggressively aggregate accumulated gradients to recover memory.
141   // Revisit and fix.
142   return 1;
143 }
144 
145 // Consumes references to the tensors in the gradient_tensors list and returns
146 // a tensor with the result.
AggregateGradients(gtl::ArraySlice<AbstractTensorHandle * > gradient_tensors) const147 AbstractTensorHandle* TapeVSpace::AggregateGradients(
148     gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
149   if (gradient_tensors.size() == 1) {
150     return gradient_tensors[0];
151   }
152 
153   AbstractOperationPtr op(ctx_->CreateOperation());
154   Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
155   if (!s.ok()) {
156     return nullptr;
157   }
158   s = op->AddInputList(gradient_tensors);
159   if (!s.ok()) {
160     return nullptr;
161   }
162 
163   int num_outputs = 1;
164   std::vector<AbstractTensorHandle*> outputs(num_outputs);
165   s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
166   if (!s.ok()) {
167     return nullptr;
168   }
169   return outputs[0];
170 }
171 
172 // Calls the passed-in backward function.
173 // op_type is the op's name provided in RecordOperation.
CallBackwardFunction(const string & op_type,GradientFunction * gradient_function,const std::vector<int64> & unneeded_gradients,gtl::ArraySlice<AbstractTensorHandle * > output_gradients,absl::Span<AbstractTensorHandle * > result) const174 Status TapeVSpace::CallBackwardFunction(
175     const string& op_type, GradientFunction* gradient_function,
176     const std::vector<int64>& unneeded_gradients,
177     gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
178     absl::Span<AbstractTensorHandle*> result) const {
179   if (gradient_function == nullptr) {
180     return errors::InvalidArgument(
181         "Provided null gradient_function for '", op_type, "'.\n",
182         "If the intent is to treat this op as non-differentiable consider "
183         "using RegisterNotDifferentiable or "
184         "NotDifferentiableGradientFunction.");
185   }
186   return gradient_function->Compute(ctx_, output_gradients, result);
187 }
188 
BuildOnesLike(const TapeTensor & t,AbstractTensorHandle ** result) const189 Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
190                                  AbstractTensorHandle** result) const {
191   AbstractOperationPtr op(ctx_->CreateOperation());
192   TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
193   if (isa<tracing::TracingOperation>(op.get())) {
194     TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
195         absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
196   }
197   TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
198   int num_outputs = 1;
199   std::vector<AbstractTensorHandle*> outputs(num_outputs);
200   TF_RETURN_IF_ERROR(
201       op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
202   *result = outputs[0];
203   return Status::OK();
204 }
205 
206 // Looks up the ID of a Gradient.
TensorId(AbstractTensorHandle * tensor) const207 int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
208   return ToId(tensor);
209 }
210 
211 // Converts a Gradient to a TapeTensor.
TapeTensorFromGradient(AbstractTensorHandle * g) const212 TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
213   return TapeTensor(g);
214 }
215 
MarkAsResult(AbstractTensorHandle * gradient) const216 void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
217 
DeleteGradient(AbstractTensorHandle * gradient) const218 void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
219   gradient->Unref();
220 }
221 
Watch(const AbstractTensorHandle * t)222 void Tape::Watch(const AbstractTensorHandle* t) {
223   GradientTape::Watch(ToId(t));
224 }
RecordOperation(absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * const> outputs,GradientFunction * gradient_function,const string & op_name)225 void Tape::RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
226                            absl::Span<AbstractTensorHandle* const> outputs,
227                            GradientFunction* gradient_function,
228                            const string& op_name) {
229   std::vector<int64> input_ids(inputs.size());
230   std::vector<tensorflow::DataType> input_dtypes(inputs.size());
231   for (int i = 0; i < inputs.size(); i++) {
232     input_ids[i] = ToId(inputs[i]);
233     input_dtypes[i] = inputs[i]->DataType();
234   }
235   std::vector<TapeTensor> tape_tensors;
236   for (auto t : outputs) {
237     tape_tensors.push_back(TapeTensor(t));
238   }
239   GradientTape::RecordOperation(
240       op_name, tape_tensors, input_ids, input_dtypes,
241       [gradient_function]() -> GradientFunction* { return gradient_function; },
242       [](GradientFunction* ptr) {
243         if (ptr) {
244           delete ptr;
245         }
246       });
247 }
ShouldRecord(absl::Span<const AbstractTensorHandle * const> tensors) const248 bool Tape::ShouldRecord(
249     absl::Span<const AbstractTensorHandle* const> tensors) const {
250   std::vector<int64> tensor_ids(tensors.size());
251   std::vector<tensorflow::DataType> tensor_dtypes(tensors.size());
252   for (int i = 0; i < tensors.size(); i++) {
253     tensor_ids[i] = ToId(tensors[i]);
254     tensor_dtypes[i] = tensors[i]->DataType();
255   }
256   return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes);
257 }
DeleteTrace(const AbstractTensorHandle * t)258 void Tape::DeleteTrace(const AbstractTensorHandle* t) {
259   GradientTape::DeleteTrace(ToId(t));
260 }
261 
MakeTensorIDList(absl::Span<AbstractTensorHandle * const> tensors)262 std::vector<int64> MakeTensorIDList(
263     absl::Span<AbstractTensorHandle* const> tensors) {
264   std::vector<int64> ids(tensors.size());
265   for (int i = 0; i < tensors.size(); i++) {
266     ids[i] = ToId(tensors[i]);
267   }
268   return ids;
269 }
270 
ComputeGradient(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> targets,absl::Span<AbstractTensorHandle * const> sources,absl::Span<AbstractTensorHandle * const> output_gradients,absl::Span<AbstractTensorHandle * > result)271 Status Tape::ComputeGradient(
272     AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
273     absl::Span<AbstractTensorHandle* const> sources,
274     absl::Span<AbstractTensorHandle* const> output_gradients,
275     absl::Span<AbstractTensorHandle*> result) {
276   TapeVSpace vspace(ctx);
277   std::vector<int64> target_tensor_ids = MakeTensorIDList(targets);
278   std::vector<int64> source_tensor_ids = MakeTensorIDList(sources);
279   tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(
280       source_tensor_ids.begin(), source_tensor_ids.end());
281   std::unordered_map<int64, TapeTensor> sources_that_are_targets;
282   for (int i = 0; i < target_tensor_ids.size(); ++i) {
283     int64_t target_id = target_tensor_ids[i];
284     if (sources_set.find(target_id) != sources_set.end()) {
285       auto tensor = targets[i];
286       sources_that_are_targets.insert(
287           std::make_pair(target_id, TapeTensor(tensor)));
288     }
289   }
290 
291   TF_RETURN_IF_ERROR(GradientTape::ComputeGradient(
292       vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets,
293       output_gradients, result, /*build_default_zeros_grads*/ false));
294   return Status::OK();
295 }
296 
297 // Helper functions which delegate to `AbstractOperation`, update
298 // the state of the ForwardOperation and call the tape as appropriate.
299 // These APIs are mainly to facilitate testing and are subject to change.
300 namespace internal {
Reset(AbstractOperation * op_,const char * op,const char * raw_device_name,ForwardOperation * forward_op_)301 Status Reset(AbstractOperation* op_, const char* op,
302              const char* raw_device_name, ForwardOperation* forward_op_) {
303   forward_op_->op_name = op;
304   forward_op_->attrs.Reset(op);
305   return op_->Reset(op, raw_device_name);
306 }
AddInput(AbstractOperation * op_,AbstractTensorHandle * input,ForwardOperation * forward_op_)307 Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
308                 ForwardOperation* forward_op_) {
309   TF_RETURN_IF_ERROR(op_->AddInput(input));
310   forward_op_->inputs.push_back(input);
311   return Status::OK();
312 }
AddInputList(AbstractOperation * op_,absl::Span<AbstractTensorHandle * const> inputs,ForwardOperation * forward_op_)313 Status AddInputList(AbstractOperation* op_,
314                     absl::Span<AbstractTensorHandle* const> inputs,
315                     ForwardOperation* forward_op_) {
316   TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
317   for (auto input : inputs) {
318     forward_op_->inputs.push_back(input);
319   }
320   return Status::OK();
321 }
322 
SetAttrString(AbstractOperation * op_,const char * attr_name,const char * data,size_t length,ForwardOperation * forward_op_)323 Status SetAttrString(AbstractOperation* op_, const char* attr_name,
324                      const char* data, size_t length,
325                      ForwardOperation* forward_op_) {
326   forward_op_->attrs.Set(attr_name, StringPiece(data, length));
327   return op_->SetAttrString(attr_name, data, length);
328 }
SetAttrInt(AbstractOperation * op_,const char * attr_name,int64_t value,ForwardOperation * forward_op_)329 Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
330                   ForwardOperation* forward_op_) {
331   forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
332   return op_->SetAttrInt(attr_name, value);
333 }
SetAttrFloat(AbstractOperation * op_,const char * attr_name,float value,ForwardOperation * forward_op_)334 Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
335                     ForwardOperation* forward_op_) {
336   forward_op_->attrs.Set(attr_name, value);
337   return op_->SetAttrFloat(attr_name, value);
338 }
SetAttrBool(AbstractOperation * op_,const char * attr_name,bool value,ForwardOperation * forward_op_)339 Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
340                    ForwardOperation* forward_op_) {
341   forward_op_->attrs.Set(attr_name, value);
342   return op_->SetAttrBool(attr_name, value);
343 }
SetAttrType(AbstractOperation * op_,const char * attr_name,DataType value,ForwardOperation * forward_op_)344 Status SetAttrType(AbstractOperation* op_, const char* attr_name,
345                    DataType value, ForwardOperation* forward_op_) {
346   forward_op_->attrs.Set(attr_name, value);
347   return op_->SetAttrType(attr_name, value);
348 }
SetAttrShape(AbstractOperation * op_,const char * attr_name,const int64_t * dims,const int num_dims,ForwardOperation * forward_op_)349 Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
350                     const int64_t* dims, const int num_dims,
351                     ForwardOperation* forward_op_) {
352   if (num_dims > TensorShape::MaxDimensions()) {
353     return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
354                                    num_dims,
355                                    " dimensions which is over the limit of ",
356                                    TensorShape::MaxDimensions(), ".");
357   }
358   TensorShapeProto proto;
359   if (num_dims < 0) {
360     proto.set_unknown_rank(true);
361   } else {
362     for (int d = 0; d < num_dims; ++d) {
363       proto.add_dim()->set_size(dims[d]);
364     }
365   }
366 
367   forward_op_->attrs.Set(attr_name, proto);
368   return op_->SetAttrShape(attr_name, dims, num_dims);
369 }
SetAttrFunction(AbstractOperation * op_,const char * attr_name,const AbstractOperation * value,ForwardOperation * forward_op_)370 Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
371                        const AbstractOperation* value,
372                        ForwardOperation* forward_op_) {
373   return tensorflow::errors::Unimplemented(
374       "SetAttrFunction has not been implemented yet.");
375 }
SetAttrFunctionName(AbstractOperation * op_,const char * attr_name,const char * value,size_t length,ForwardOperation * forward_op_)376 Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
377                            const char* value, size_t length,
378                            ForwardOperation* forward_op_) {
379   return tensorflow::errors::Unimplemented(
380       "SetAttrFunctionName has not been implemented "
381       "yet.");
382 }
SetAttrTensor(AbstractOperation * op_,const char * attr_name,AbstractTensorInterface * tensor,ForwardOperation * forward_op_)383 Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
384                      AbstractTensorInterface* tensor,
385                      ForwardOperation* forward_op_) {
386   return tensorflow::errors::Unimplemented(
387       "SetAttrTensor has not been implemented yet.");
388 }
SetAttrStringList(AbstractOperation * op_,const char * attr_name,const void * const * values,const size_t * lengths,int num_values,ForwardOperation * forward_op_)389 Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
390                          const void* const* values, const size_t* lengths,
391                          int num_values, ForwardOperation* forward_op_) {
392   std::vector<StringPiece> v(num_values);
393   for (int i = 0; i < num_values; ++i) {
394     v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
395   }
396   forward_op_->attrs.Set(attr_name, v);
397   return op_->SetAttrStringList(attr_name, values, lengths, num_values);
398 }
SetAttrFloatList(AbstractOperation * op_,const char * attr_name,const float * values,int num_values,ForwardOperation * forward_op_)399 Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
400                         const float* values, int num_values,
401                         ForwardOperation* forward_op_) {
402   forward_op_->attrs.Set(attr_name,
403                          gtl::ArraySlice<const float>(values, num_values));
404   return op_->SetAttrFloatList(attr_name, values, num_values);
405 }
SetAttrIntList(AbstractOperation * op_,const char * attr_name,const int64_t * values,int num_values,ForwardOperation * forward_op_)406 Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
407                       const int64_t* values, int num_values,
408                       ForwardOperation* forward_op_) {
409   forward_op_->attrs.Set(
410       attr_name, gtl::ArraySlice<const int64>(
411                      reinterpret_cast<const int64*>(values), num_values));
412   return op_->SetAttrIntList(attr_name, values, num_values);
413 }
SetAttrTypeList(AbstractOperation * op_,const char * attr_name,const DataType * values,int num_values,ForwardOperation * forward_op_)414 Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
415                        const DataType* values, int num_values,
416                        ForwardOperation* forward_op_) {
417   forward_op_->attrs.Set(attr_name,
418                          gtl::ArraySlice<const DataType>(values, num_values));
419   return op_->SetAttrTypeList(attr_name, values, num_values);
420 }
SetAttrBoolList(AbstractOperation * op_,const char * attr_name,const unsigned char * values,int num_values,ForwardOperation * forward_op_)421 Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
422                        const unsigned char* values, int num_values,
423                        ForwardOperation* forward_op_) {
424   std::unique_ptr<bool[]> b(new bool[num_values]);
425   for (int i = 0; i < num_values; ++i) {
426     b[i] = values[i];
427   }
428   forward_op_->attrs.Set(attr_name,
429                          gtl::ArraySlice<const bool>(b.get(), num_values));
430   return op_->SetAttrBoolList(attr_name, values, num_values);
431 }
SetAttrShapeList(AbstractOperation * op_,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,ForwardOperation * forward_op_)432 Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
433                         const int64_t** dims, const int* num_dims,
434                         int num_values, ForwardOperation* forward_op_) {
435   std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
436   for (int i = 0; i < num_values; ++i) {
437     const auto num_dims_i = num_dims[i];
438 
439     if (num_dims_i > TensorShape::MaxDimensions()) {
440       return errors::InvalidArgument(
441           strings::StrCat("Value specified for `", attr_name, "` has ",
442                           num_dims_i, " dimensions which is over the limit of ",
443                           TensorShape::MaxDimensions(), "."));
444     }
445     if (num_dims_i < 0) {
446       proto[i].set_unknown_rank(true);
447     } else {
448       const int64_t* dims_i = dims[i];
449       auto proto_i = &proto[i];
450       for (int d = 0; d < num_dims_i; ++d) {
451         proto_i->add_dim()->set_size(dims_i[d]);
452       }
453     }
454   }
455   forward_op_->attrs.Set(
456       attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
457   return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
458 }
SetAttrFunctionList(AbstractOperation * op_,const char * attr_name,absl::Span<const AbstractOperation * > values,ForwardOperation * forward_op_)459 Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
460                            absl::Span<const AbstractOperation*> values,
461                            ForwardOperation* forward_op_) {
462   return tensorflow::errors::Unimplemented(
463       "SetAttrFunctionList has not been "
464       "implemented yet.");
465 }
Execute(AbstractOperation * op_,AbstractContext * ctx,absl::Span<AbstractTensorHandle * > retvals,int * num_retvals,ForwardOperation * forward_op_,Tape * tape,const GradientRegistry & registry)466 Status Execute(AbstractOperation* op_, AbstractContext* ctx,
467                absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
468                ForwardOperation* forward_op_, Tape* tape,
469                const GradientRegistry& registry) {
470   TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
471   for (int i = 0; i < *num_retvals; i++) {
472     // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
473     forward_op_->outputs.push_back(retvals[i]);
474   }
475   // TODO(b/166669239): This is needed to support AttrBuilder::Get for string
476   // attributes. Number type attrs and DataType attrs work fine without this.
477   // Consider getting rid of this and making the behavior between number types
478   // and string consistent.
479   forward_op_->attrs.BuildNodeDef();
480   std::unique_ptr<GradientFunction> gradient_fn;
481   TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn));
482   tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(),
483                         op_->Name());
484   return Status::OK();
485 }
486 }  // namespace internal
487 
488 }  // namespace gradients
489 }  // namespace tensorflow
490