• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
16 
17 #include "tensorflow/c/eager/abstract_context.h"
18 #include "tensorflow/c/eager/gradients.h"
19 
20 namespace tensorflow {
21 namespace gradients {
TapeOperation(AbstractOperation * parent_op,Tape * tape,const GradientRegistry & registry)22 TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
23                              const GradientRegistry& registry)
24     : AbstractOperation(kTape),
25       parent_op_(parent_op),
26       tape_(tape),
27       registry_(registry) {
28   // TODO(b/172003047): Consider making AbstractOperation RefCounted.
29   // parent_op_->Ref();
30 }
Release()31 void TapeOperation::Release() {
32   // TODO(srbs): Change to Unref().
33   delete this;
34 }
~TapeOperation()35 TapeOperation::~TapeOperation() {
36   // TODO(b/172003047): Consider making AbstractOperation RefCounted.
37   // parent_op->Unref();
38 }
Reset(const char * op,const char * raw_device_name)39 Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
40   forward_op_.op_name = op;
41   forward_op_.attrs.Reset(op);
42   forward_op_.inputs.clear();
43   forward_op_.outputs.clear();
44   return parent_op_->Reset(op, raw_device_name);
45 }
Name() const46 const string& TapeOperation::Name() const { return parent_op_->Name(); }
DeviceName() const47 const string& TapeOperation::DeviceName() const {
48   return parent_op_->DeviceName();
49 }
SetDeviceName(const char * name)50 Status TapeOperation::SetDeviceName(const char* name) {
51   return parent_op_->SetDeviceName(name);
52 }
AddInput(AbstractTensorHandle * input)53 Status TapeOperation::AddInput(AbstractTensorHandle* input) {
54   TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
55   forward_op_.inputs.push_back(input);
56   return Status::OK();
57 }
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)58 Status TapeOperation::AddInputList(
59     absl::Span<AbstractTensorHandle* const> inputs) {
60   TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs));
61   for (auto input : inputs) {
62     forward_op_.inputs.push_back(input);
63   }
64   return Status::OK();
65 }
SetAttrString(const char * attr_name,const char * data,size_t length)66 Status TapeOperation::SetAttrString(const char* attr_name, const char* data,
67                                     size_t length) {
68   forward_op_.attrs.Set(attr_name, StringPiece(data, length));
69   return parent_op_->SetAttrString(attr_name, data, length);
70 }
SetAttrInt(const char * attr_name,int64_t value)71 Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) {
72   forward_op_.attrs.Set(attr_name, static_cast<int64>(value));
73   return parent_op_->SetAttrInt(attr_name, value);
74 }
SetAttrFloat(const char * attr_name,float value)75 Status TapeOperation::SetAttrFloat(const char* attr_name, float value) {
76   forward_op_.attrs.Set(attr_name, value);
77   return parent_op_->SetAttrFloat(attr_name, value);
78 }
SetAttrBool(const char * attr_name,bool value)79 Status TapeOperation::SetAttrBool(const char* attr_name, bool value) {
80   forward_op_.attrs.Set(attr_name, value);
81   return parent_op_->SetAttrBool(attr_name, value);
82 }
SetAttrType(const char * attr_name,DataType value)83 Status TapeOperation::SetAttrType(const char* attr_name, DataType value) {
84   forward_op_.attrs.Set(attr_name, value);
85   return parent_op_->SetAttrType(attr_name, value);
86 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)87 Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
88                                    const int num_dims) {
89   if (num_dims > TensorShape::MaxDimensions()) {
90     return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
91                                    num_dims,
92                                    " dimensions which is over the limit of ",
93                                    TensorShape::MaxDimensions(), ".");
94   }
95   TensorShapeProto proto;
96   if (num_dims < 0) {
97     proto.set_unknown_rank(true);
98   } else {
99     for (int d = 0; d < num_dims; ++d) {
100       proto.add_dim()->set_size(dims[d]);
101     }
102   }
103 
104   forward_op_.attrs.Set(attr_name, proto);
105   return parent_op_->SetAttrShape(attr_name, dims, num_dims);
106 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)107 Status TapeOperation::SetAttrFunction(const char* attr_name,
108                                       const AbstractOperation* value) {
109   return tensorflow::errors::Unimplemented(
110       "SetAttrFunction has not been implemented yet.");
111 }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)112 Status TapeOperation::SetAttrFunctionName(const char* attr_name,
113                                           const char* value, size_t length) {
114   return tensorflow::errors::Unimplemented(
115       "SetAttrFunctionName has not been implemented "
116       "yet.");
117 }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)118 Status TapeOperation::SetAttrTensor(const char* attr_name,
119                                     AbstractTensorInterface* tensor) {
120   return tensorflow::errors::Unimplemented(
121       "SetAttrTensor has not been implemented yet.");
122 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)123 Status TapeOperation::SetAttrStringList(const char* attr_name,
124                                         const void* const* values,
125                                         const size_t* lengths, int num_values) {
126   std::vector<StringPiece> v(num_values);
127   for (int i = 0; i < num_values; ++i) {
128     v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
129   }
130   forward_op_.attrs.Set(attr_name, v);
131   return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
132 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)133 Status TapeOperation::SetAttrFloatList(const char* attr_name,
134                                        const float* values, int num_values) {
135   forward_op_.attrs.Set(attr_name,
136                         gtl::ArraySlice<const float>(values, num_values));
137   return parent_op_->SetAttrFloatList(attr_name, values, num_values);
138 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)139 Status TapeOperation::SetAttrIntList(const char* attr_name,
140                                      const int64_t* values, int num_values) {
141   forward_op_.attrs.Set(
142       attr_name, gtl::ArraySlice<const int64>(
143                      reinterpret_cast<const int64*>(values), num_values));
144   return parent_op_->SetAttrIntList(attr_name, values, num_values);
145 }
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)146 Status TapeOperation::SetAttrTypeList(const char* attr_name,
147                                       const DataType* values, int num_values) {
148   forward_op_.attrs.Set(attr_name,
149                         gtl::ArraySlice<const DataType>(values, num_values));
150   return parent_op_->SetAttrTypeList(attr_name, values, num_values);
151 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)152 Status TapeOperation::SetAttrBoolList(const char* attr_name,
153                                       const unsigned char* values,
154                                       int num_values) {
155   std::unique_ptr<bool[]> b(new bool[num_values]);
156   for (int i = 0; i < num_values; ++i) {
157     b[i] = values[i];
158   }
159   forward_op_.attrs.Set(attr_name,
160                         gtl::ArraySlice<const bool>(b.get(), num_values));
161   return parent_op_->SetAttrBoolList(attr_name, values, num_values);
162 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)163 Status TapeOperation::SetAttrShapeList(const char* attr_name,
164                                        const int64_t** dims,
165                                        const int* num_dims, int num_values) {
166   std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
167   for (int i = 0; i < num_values; ++i) {
168     const auto num_dims_i = num_dims[i];
169 
170     if (num_dims_i > TensorShape::MaxDimensions()) {
171       return errors::InvalidArgument(
172           strings::StrCat("Value specified for `", attr_name, "` has ",
173                           num_dims_i, " dimensions which is over the limit of ",
174                           TensorShape::MaxDimensions(), "."));
175     }
176     if (num_dims_i < 0) {
177       proto[i].set_unknown_rank(true);
178     } else {
179       const int64_t* dims_i = dims[i];
180       auto proto_i = &proto[i];
181       for (int d = 0; d < num_dims_i; ++d) {
182         proto_i->add_dim()->set_size(dims_i[d]);
183       }
184     }
185   }
186   forward_op_.attrs.Set(
187       attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
188   return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
189 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)190 Status TapeOperation::SetAttrFunctionList(
191     const char* attr_name, absl::Span<const AbstractOperation*> values) {
192   return tensorflow::errors::Unimplemented(
193       "SetAttrFunctionList has not been "
194       "implemented yet.");
195 }
GetBackingOperation()196 AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)197 Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
198                               int* num_retvals) {
199   TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
200   for (int i = 0; i < *num_retvals; i++) {
201     // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
202     forward_op_.outputs.push_back(retvals[i]);
203   }
204   // TODO(b/166669239): This is needed to support AttrBuilder::Get for string
205   // attributes. Number type attrs and DataType attrs work fine without this.
206   // Consider getting rid of this and making the behavior between number types
207   // and string consistent.
208   forward_op_.attrs.BuildNodeDef();
209   // TODO(b/170307493): Populate skip_input_indices here.
210   std::unique_ptr<GradientFunction> backward_fn;
211   TF_RETURN_IF_ERROR(registry_.Lookup(forward_op_, &backward_fn));
212   tape_->RecordOperation(forward_op_.inputs, forward_op_.outputs,
213                          backward_fn.release(), parent_op_->Name());
214   return Status::OK();
215 }
216 
217 }  // namespace gradients
218 }  // namespace tensorflow
219