1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
16
17 #include "tensorflow/c/eager/abstract_context.h"
18 #include "tensorflow/c/eager/gradients.h"
19
20 namespace tensorflow {
21 namespace gradients {
TapeOperation(AbstractOperation * parent_op,Tape * tape,const GradientRegistry & registry)22 TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
23 const GradientRegistry& registry)
24 : AbstractOperation(kTape),
25 parent_op_(parent_op),
26 tape_(tape),
27 registry_(registry) {
28 // TODO(b/172003047): Consider making AbstractOperation RefCounted.
29 // parent_op_->Ref();
30 }
Release()31 void TapeOperation::Release() {
32 // TODO(srbs): Change to Unref().
33 delete this;
34 }
~TapeOperation()35 TapeOperation::~TapeOperation() {
36 // TODO(b/172003047): Consider making AbstractOperation RefCounted.
37 // parent_op->Unref();
38 }
Reset(const char * op,const char * raw_device_name)39 Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
40 forward_op_.op_name = op;
41 forward_op_.attrs.Reset(op);
42 forward_op_.inputs.clear();
43 forward_op_.outputs.clear();
44 return parent_op_->Reset(op, raw_device_name);
45 }
Name() const46 const string& TapeOperation::Name() const { return parent_op_->Name(); }
DeviceName() const47 const string& TapeOperation::DeviceName() const {
48 return parent_op_->DeviceName();
49 }
SetDeviceName(const char * name)50 Status TapeOperation::SetDeviceName(const char* name) {
51 return parent_op_->SetDeviceName(name);
52 }
AddInput(AbstractTensorHandle * input)53 Status TapeOperation::AddInput(AbstractTensorHandle* input) {
54 TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
55 forward_op_.inputs.push_back(input);
56 return Status::OK();
57 }
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)58 Status TapeOperation::AddInputList(
59 absl::Span<AbstractTensorHandle* const> inputs) {
60 TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs));
61 for (auto input : inputs) {
62 forward_op_.inputs.push_back(input);
63 }
64 return Status::OK();
65 }
SetAttrString(const char * attr_name,const char * data,size_t length)66 Status TapeOperation::SetAttrString(const char* attr_name, const char* data,
67 size_t length) {
68 forward_op_.attrs.Set(attr_name, StringPiece(data, length));
69 return parent_op_->SetAttrString(attr_name, data, length);
70 }
SetAttrInt(const char * attr_name,int64_t value)71 Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) {
72 forward_op_.attrs.Set(attr_name, static_cast<int64>(value));
73 return parent_op_->SetAttrInt(attr_name, value);
74 }
SetAttrFloat(const char * attr_name,float value)75 Status TapeOperation::SetAttrFloat(const char* attr_name, float value) {
76 forward_op_.attrs.Set(attr_name, value);
77 return parent_op_->SetAttrFloat(attr_name, value);
78 }
SetAttrBool(const char * attr_name,bool value)79 Status TapeOperation::SetAttrBool(const char* attr_name, bool value) {
80 forward_op_.attrs.Set(attr_name, value);
81 return parent_op_->SetAttrBool(attr_name, value);
82 }
SetAttrType(const char * attr_name,DataType value)83 Status TapeOperation::SetAttrType(const char* attr_name, DataType value) {
84 forward_op_.attrs.Set(attr_name, value);
85 return parent_op_->SetAttrType(attr_name, value);
86 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)87 Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
88 const int num_dims) {
89 if (num_dims > TensorShape::MaxDimensions()) {
90 return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
91 num_dims,
92 " dimensions which is over the limit of ",
93 TensorShape::MaxDimensions(), ".");
94 }
95 TensorShapeProto proto;
96 if (num_dims < 0) {
97 proto.set_unknown_rank(true);
98 } else {
99 for (int d = 0; d < num_dims; ++d) {
100 proto.add_dim()->set_size(dims[d]);
101 }
102 }
103
104 forward_op_.attrs.Set(attr_name, proto);
105 return parent_op_->SetAttrShape(attr_name, dims, num_dims);
106 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)107 Status TapeOperation::SetAttrFunction(const char* attr_name,
108 const AbstractOperation* value) {
109 return tensorflow::errors::Unimplemented(
110 "SetAttrFunction has not been implemented yet.");
111 }
SetAttrFunctionName(const char * attr_name,const char * value,size_t length)112 Status TapeOperation::SetAttrFunctionName(const char* attr_name,
113 const char* value, size_t length) {
114 return tensorflow::errors::Unimplemented(
115 "SetAttrFunctionName has not been implemented "
116 "yet.");
117 }
SetAttrTensor(const char * attr_name,AbstractTensorInterface * tensor)118 Status TapeOperation::SetAttrTensor(const char* attr_name,
119 AbstractTensorInterface* tensor) {
120 return tensorflow::errors::Unimplemented(
121 "SetAttrTensor has not been implemented yet.");
122 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)123 Status TapeOperation::SetAttrStringList(const char* attr_name,
124 const void* const* values,
125 const size_t* lengths, int num_values) {
126 std::vector<StringPiece> v(num_values);
127 for (int i = 0; i < num_values; ++i) {
128 v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
129 }
130 forward_op_.attrs.Set(attr_name, v);
131 return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
132 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)133 Status TapeOperation::SetAttrFloatList(const char* attr_name,
134 const float* values, int num_values) {
135 forward_op_.attrs.Set(attr_name,
136 gtl::ArraySlice<const float>(values, num_values));
137 return parent_op_->SetAttrFloatList(attr_name, values, num_values);
138 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)139 Status TapeOperation::SetAttrIntList(const char* attr_name,
140 const int64_t* values, int num_values) {
141 forward_op_.attrs.Set(
142 attr_name, gtl::ArraySlice<const int64>(
143 reinterpret_cast<const int64*>(values), num_values));
144 return parent_op_->SetAttrIntList(attr_name, values, num_values);
145 }
SetAttrTypeList(const char * attr_name,const DataType * values,int num_values)146 Status TapeOperation::SetAttrTypeList(const char* attr_name,
147 const DataType* values, int num_values) {
148 forward_op_.attrs.Set(attr_name,
149 gtl::ArraySlice<const DataType>(values, num_values));
150 return parent_op_->SetAttrTypeList(attr_name, values, num_values);
151 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)152 Status TapeOperation::SetAttrBoolList(const char* attr_name,
153 const unsigned char* values,
154 int num_values) {
155 std::unique_ptr<bool[]> b(new bool[num_values]);
156 for (int i = 0; i < num_values; ++i) {
157 b[i] = values[i];
158 }
159 forward_op_.attrs.Set(attr_name,
160 gtl::ArraySlice<const bool>(b.get(), num_values));
161 return parent_op_->SetAttrBoolList(attr_name, values, num_values);
162 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)163 Status TapeOperation::SetAttrShapeList(const char* attr_name,
164 const int64_t** dims,
165 const int* num_dims, int num_values) {
166 std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
167 for (int i = 0; i < num_values; ++i) {
168 const auto num_dims_i = num_dims[i];
169
170 if (num_dims_i > TensorShape::MaxDimensions()) {
171 return errors::InvalidArgument(
172 strings::StrCat("Value specified for `", attr_name, "` has ",
173 num_dims_i, " dimensions which is over the limit of ",
174 TensorShape::MaxDimensions(), "."));
175 }
176 if (num_dims_i < 0) {
177 proto[i].set_unknown_rank(true);
178 } else {
179 const int64_t* dims_i = dims[i];
180 auto proto_i = &proto[i];
181 for (int d = 0; d < num_dims_i; ++d) {
182 proto_i->add_dim()->set_size(dims_i[d]);
183 }
184 }
185 }
186 forward_op_.attrs.Set(
187 attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
188 return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
189 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)190 Status TapeOperation::SetAttrFunctionList(
191 const char* attr_name, absl::Span<const AbstractOperation*> values) {
192 return tensorflow::errors::Unimplemented(
193 "SetAttrFunctionList has not been "
194 "implemented yet.");
195 }
GetBackingOperation()196 AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)197 Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
198 int* num_retvals) {
199 TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
200 for (int i = 0; i < *num_retvals; i++) {
201 // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
202 forward_op_.outputs.push_back(retvals[i]);
203 }
204 // TODO(b/166669239): This is needed to support AttrBuilder::Get for string
205 // attributes. Number type attrs and DataType attrs work fine without this.
206 // Consider getting rid of this and making the behavior between number types
207 // and string consistent.
208 forward_op_.attrs.BuildNodeDef();
209 // TODO(b/170307493): Populate skip_input_indices here.
210 std::unique_ptr<GradientFunction> backward_fn;
211 TF_RETURN_IF_ERROR(registry_.Lookup(forward_op_, &backward_fn));
212 tape_->RecordOperation(forward_op_.inputs, forward_op_.outputs,
213 backward_fn.release(), parent_op_->Name());
214 return Status::OK();
215 }
216
217 } // namespace gradients
218 } // namespace tensorflow
219