1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/math_grad.h"
16
17 #include "tensorflow/c/eager/abstract_tensor_handle.h"
18 #include "tensorflow/c/eager/gradients.h"
19 #include "tensorflow/c/experimental/ops/array_ops.h"
20 #include "tensorflow/c/experimental/ops/math_ops.h"
21 #include "tensorflow/c/experimental/ops/nn_ops.h"
22
23 using std::vector;
24 using tensorflow::ops::AddV2;
25 using tensorflow::ops::Div;
26 using tensorflow::ops::DivNoNan;
27 using tensorflow::ops::MatMul;
28 using tensorflow::ops::Mul;
29 using tensorflow::ops::Neg;
30 using tensorflow::ops::OnesLike;
31 using tensorflow::ops::SqrtGrad;
32
33 namespace tensorflow {
34 namespace gradients {
35 namespace {
36
SafeConj(AbstractContext * ctx,AbstractTensorHandle * const input,AbstractTensorHandle ** output,const char * name)37 static Status SafeConj(AbstractContext* ctx, AbstractTensorHandle* const input,
38 AbstractTensorHandle** output, const char* name) {
39 auto dtype = input->DataType();
40 if (DataTypeIsFloating(BaseType(dtype)) ||
41 DataTypeIsInteger(BaseType(dtype))) {
42 return tensorflow::ops::Identity(ctx, input, output, name);
43 } else if (!DataTypeIsComplex(BaseType(dtype)) &&
44 BaseType(dtype) != DT_VARIANT) {
45 return errors::InvalidArgument(
46 "Expected numeric or variant tensor, got dtype ", dtype);
47 }
48 return tensorflow::ops::Conj(ctx, input, output, name);
49 }
50
51 class AddGradientFunction : public GradientFunction {
52 public:
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)53 Status Compute(AbstractContext* ctx,
54 absl::Span<AbstractTensorHandle* const> grad_outputs,
55 absl::Span<AbstractTensorHandle*> grad_inputs) override {
56 // TODO(b/161805092): Support broadcasting.
57
58 DCHECK(grad_outputs[0]);
59 grad_inputs[0] = grad_outputs[0];
60 grad_inputs[1] = grad_outputs[0];
61
62 grad_inputs[0]->Ref();
63 grad_inputs[1]->Ref();
64 return OkStatus();
65 }
~AddGradientFunction()66 ~AddGradientFunction() override {}
67 };
68
69 class ExpGradientFunction : public GradientFunction {
70 public:
ExpGradientFunction(AbstractTensorHandle * exp)71 explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
72 exp->Ref();
73 }
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)74 Status Compute(AbstractContext* ctx,
75 absl::Span<AbstractTensorHandle* const> grad_outputs,
76 absl::Span<AbstractTensorHandle*> grad_inputs) override {
77 AbstractTensorHandle* conj_output;
78 std::string name = "Conj_Exp_Grad";
79 TF_RETURN_IF_ERROR(SafeConj(ctx, exp_.get(), &conj_output, name.c_str()));
80 AbstractTensorHandlePtr conj_output_releaser(conj_output);
81
82 name = "Mul_Exp_Grad";
83 TF_RETURN_IF_ERROR(
84 Mul(ctx, conj_output, grad_outputs[0], &grad_inputs[0], name.c_str()));
85 return OkStatus();
86 }
~ExpGradientFunction()87 ~ExpGradientFunction() override {}
88
89 private:
90 AbstractTensorHandlePtr exp_;
91 };
92
93 class SqrtGradientFunction : public GradientFunction {
94 public:
SqrtGradientFunction(AbstractTensorHandle * sqrt)95 explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
96 sqrt->Ref();
97 }
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)98 Status Compute(AbstractContext* ctx,
99 absl::Span<AbstractTensorHandle* const> grad_outputs,
100 absl::Span<AbstractTensorHandle*> grad_inputs) override {
101 std::string name = "Sqrt_Grad";
102 TF_RETURN_IF_ERROR(SqrtGrad(ctx, sqrt_.get(), grad_outputs[0],
103 &grad_inputs[0], name.c_str()));
104 return OkStatus();
105 }
~SqrtGradientFunction()106 ~SqrtGradientFunction() override {}
107
108 private:
109 AbstractTensorHandlePtr sqrt_;
110 };
111
112 class MatMulGradientFunction : public GradientFunction {
113 public:
MatMulGradientFunction(vector<AbstractTensorHandle * > f_inputs,AttrBuilder f_attrs)114 explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
115 AttrBuilder f_attrs)
116 : forward_inputs_(f_inputs), forward_attrs_(f_attrs) {
117 for (auto input : forward_inputs_) {
118 if (input) {
119 input->Ref();
120 }
121 }
122 }
123
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)124 Status Compute(AbstractContext* ctx,
125 absl::Span<AbstractTensorHandle* const> grad_outputs,
126 absl::Span<AbstractTensorHandle*> grad_inputs) override {
127 /* Given upstream grad U and a matmul op A*B, the gradients are:
128 *
129 * dA = U * B.T
130 * dB = A.T * U
131 *
132 * where A.T means `transpose(A)`
133 */
134 AbstractTensorHandle* upstream_grad = grad_outputs[0];
135
136 // Get transpose attrs
137 bool t_a;
138 TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_a", &t_a));
139
140 bool t_b;
141 TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_b", &t_b));
142
143 // Conj each input
144 AbstractTensorHandle* conj_output;
145 std::string name = "Conj_A_MatMul_Grad";
146 TF_RETURN_IF_ERROR(
147 SafeConj(ctx, forward_inputs_[0], &conj_output, name.c_str()));
148
149 AbstractTensorHandlePtr A(conj_output);
150
151 name = "Conj_B_MatMul_Grad";
152 TF_RETURN_IF_ERROR(
153 SafeConj(ctx, forward_inputs_[1], &conj_output, name.c_str()));
154
155 AbstractTensorHandlePtr B(conj_output);
156
157 // Calc Grad
158 AbstractTensorHandle* matmul_A_output;
159 AbstractTensorHandle* matmul_B_output;
160 std::string name_grad_A = "MatMul_Grad_A";
161 std::string name_grad_B = "MatMul_Grad_B";
162 if (!t_a && !t_b) {
163 TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, B.get(), &matmul_A_output,
164 /*transpose_a = */ false,
165 /*transpose_b = */ true, name_grad_A.c_str()));
166
167 TF_RETURN_IF_ERROR(MatMul(ctx, A.get(), upstream_grad, &matmul_B_output,
168 /*transpose_a = */ true,
169 /*transpose_b = */ false, name_grad_B.c_str()));
170 } else if (!t_a && t_b) {
171 TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, B.get(), &matmul_A_output,
172 /*transpose_a = */ false,
173 /*transpose_b = */ false, name_grad_A.c_str()));
174
175 TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, A.get(), &matmul_B_output,
176 /*transpose_a = */ true,
177 /*transpose_b = */ false, name_grad_B.c_str()));
178
179 } else if (t_a && !t_b) {
180 TF_RETURN_IF_ERROR(MatMul(ctx, B.get(), upstream_grad, &matmul_A_output,
181 /*transpose_a = */ false,
182 /*transpose_b = */ true, name_grad_A.c_str()));
183
184 TF_RETURN_IF_ERROR(MatMul(ctx, A.get(), upstream_grad, &matmul_B_output,
185 /*transpose_a = */ false,
186 /*transpose_b = */ false, name_grad_B.c_str()));
187 } else { // t_a && t_b
188 TF_RETURN_IF_ERROR(MatMul(ctx, B.get(), upstream_grad, &matmul_A_output,
189 /*transpose_a = */ true,
190 /*transpose_b = */ true, name_grad_A.c_str()));
191
192 TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, A.get(), &matmul_B_output,
193 /*transpose_a = */ true,
194 /*transpose_b = */ true, name_grad_B.c_str()));
195 }
196
197 // Gradient for A
198 grad_inputs[0] = matmul_A_output;
199
200 // Gradient for B
201 grad_inputs[1] = matmul_B_output;
202 return OkStatus();
203 }
~MatMulGradientFunction()204 ~MatMulGradientFunction() override {
205 for (auto input : forward_inputs_) {
206 if (input) {
207 input->Unref();
208 }
209 }
210 }
211
212 private:
213 // TODO(b/174778737): Only hold needed inputs.
214 vector<AbstractTensorHandle*> forward_inputs_;
215 AttrBuilder forward_attrs_;
216 };
217
218 class NegGradientFunction : public GradientFunction {
219 public:
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)220 Status Compute(AbstractContext* ctx,
221 absl::Span<AbstractTensorHandle* const> grad_outputs,
222 absl::Span<AbstractTensorHandle*> grad_inputs) override {
223 /* Given upstream grad U and a Neg op Y = -X, the gradients are:
224 *
225 * dX = -U
226 *
227 */
228
229 std::string name = "Neg_Grad";
230 TF_RETURN_IF_ERROR(
231 ops::Neg(ctx, grad_outputs[0], &grad_inputs[0], name.c_str()));
232 return OkStatus();
233 }
~NegGradientFunction()234 ~NegGradientFunction() override {}
235 };
236
237 class SubGradientFunction : public GradientFunction {
238 public:
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)239 Status Compute(AbstractContext* ctx,
240 absl::Span<AbstractTensorHandle* const> grad_outputs,
241 absl::Span<AbstractTensorHandle*> grad_inputs) override {
242 /* Given upstream grad U and a Sub op A-B, the gradients are:
243 *
244 * dA = U
245 * dB = -U
246 *
247 */
248
249 // Grad for A
250 DCHECK(grad_outputs[0]);
251 grad_inputs[0] = grad_outputs[0];
252 grad_inputs[0]->Ref();
253
254 // Grad for B
255 // negate the upstream grad
256 std::string name = "Neg_Sub_Grad_B";
257 TF_RETURN_IF_ERROR(
258 ops::Neg(ctx, grad_outputs[0], &grad_inputs[1], name.c_str()));
259
260 return OkStatus();
261 }
~SubGradientFunction()262 ~SubGradientFunction() override {}
263 };
264
265 class MulGradientFunction : public GradientFunction {
266 public:
MulGradientFunction(vector<AbstractTensorHandle * > f_inputs)267 explicit MulGradientFunction(vector<AbstractTensorHandle*> f_inputs)
268 : forward_inputs_(f_inputs) {
269 for (auto input : forward_inputs_) {
270 if (input) {
271 input->Ref();
272 }
273 }
274 }
275
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)276 Status Compute(AbstractContext* ctx,
277 absl::Span<AbstractTensorHandle* const> grad_outputs,
278 absl::Span<AbstractTensorHandle*> grad_inputs) override {
279 /* Given upstream grad U and a mul op A*B, the gradients are:
280 *
281 * dA = U * B
282 * dB = A * U
283 *
284 */
285
286 AbstractTensorHandle* upstream_grad = grad_outputs[0];
287
288 // Gradient for A
289 std::string name = "Mul_Grad_A";
290 TF_RETURN_IF_ERROR(Mul(ctx, upstream_grad, forward_inputs_[1],
291 &grad_inputs[0], name.c_str()));
292
293 // Gradient for B
294 name = "Mul_Grad_B";
295 TF_RETURN_IF_ERROR(Mul(ctx, forward_inputs_[0], upstream_grad,
296 &grad_inputs[1], name.c_str()));
297 return OkStatus();
298 }
~MulGradientFunction()299 ~MulGradientFunction() override {
300 for (auto input : forward_inputs_) {
301 if (input) {
302 input->Unref();
303 }
304 }
305 }
306
307 private:
308 // TODO(b/174778737): Only hold needed inputs.
309 vector<AbstractTensorHandle*> forward_inputs_;
310 };
311
312 class Log1pGradientFunction : public GradientFunction {
313 public:
Log1pGradientFunction(vector<AbstractTensorHandle * > f_inputs)314 explicit Log1pGradientFunction(vector<AbstractTensorHandle*> f_inputs)
315 : forward_inputs_(f_inputs) {
316 for (auto input : forward_inputs_) {
317 if (input) {
318 input->Ref();
319 }
320 }
321 }
322
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)323 Status Compute(AbstractContext* ctx,
324 absl::Span<AbstractTensorHandle* const> grad_outputs,
325 absl::Span<AbstractTensorHandle*> grad_inputs) override {
326 // TODO(vnvo2409): Add control dependency
327 /* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are:
328 *
329 * dX = U / (1 + X)
330 *
331 */
332
333 AbstractTensorHandle* upstream_grad = grad_outputs[0];
334 AbstractTensorHandle* X = forward_inputs_[0];
335
336 AbstractTensorHandle* temp_output;
337
338 // Calculate conjugate of X
339 std::string name = "Conj_Log1p_Grad_X";
340 TF_RETURN_IF_ERROR(SafeConj(ctx, X, &temp_output, name.c_str()));
341
342 AbstractTensorHandlePtr Conj_X(temp_output);
343
344 // Creates Ones
345 name = "OnesLike_Log1p_Grad_X";
346 TF_RETURN_IF_ERROR(OnesLike(ctx, Conj_X.get(), &temp_output, name.c_str()));
347
348 AbstractTensorHandlePtr Ones_X(temp_output);
349
350 name = "Add_Log1p_Grad_X";
351 // Calculate 1 + Conj(X)
352 TF_RETURN_IF_ERROR(
353 AddV2(ctx, Ones_X.get(), Conj_X.get(), &temp_output, name.c_str()));
354
355 AbstractTensorHandlePtr Conj_XP1(temp_output);
356
357 name = "Div_Log1p_Grad_X";
358 // Calculate U / (1 + Conj(X))
359 TF_RETURN_IF_ERROR(
360 Div(ctx, upstream_grad, Conj_XP1.get(), &grad_inputs[0], name.c_str()));
361
362 return OkStatus();
363 }
~Log1pGradientFunction()364 ~Log1pGradientFunction() override {
365 for (auto input : forward_inputs_) {
366 if (input) {
367 input->Unref();
368 }
369 }
370 }
371
372 private:
373 // TODO(b/174778737): Only hold needed inputs.
374 vector<AbstractTensorHandle*> forward_inputs_;
375 };
376
377 class DivNoNanGradientFunction : public GradientFunction {
378 public:
DivNoNanGradientFunction(vector<AbstractTensorHandle * > f_inputs,vector<AbstractTensorHandle * > f_outputs)379 explicit DivNoNanGradientFunction(vector<AbstractTensorHandle*> f_inputs,
380 vector<AbstractTensorHandle*> f_outputs)
381 : forward_inputs_(f_inputs), forward_outputs_(f_outputs) {
382 for (auto input : forward_inputs_) {
383 if (input) {
384 input->Ref();
385 }
386 }
387 for (auto output : forward_outputs_) {
388 if (output) {
389 output->Ref();
390 }
391 }
392 }
393
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)394 Status Compute(AbstractContext* ctx,
395 absl::Span<AbstractTensorHandle* const> grad_outputs,
396 absl::Span<AbstractTensorHandle*> grad_inputs) override {
397 // TODO(vnvo2409): Add shape broadcasting
398 /* Given upstream grad U and a Div op: Z = X/Y, the gradients are:
399 *
400 * dX = U / Y
401 * dY = -U*X / Y^2 = (X/Y) * -U / Y = -U*Z / Y
402 *
403 */
404
405 AbstractTensorHandle* upstream_grad = grad_outputs[0];
406 AbstractTensorHandle* Y = forward_inputs_[1];
407 AbstractTensorHandle* Z = forward_outputs_[0];
408
409 // Calculate dX = U / Y
410 std::string name = "Div_Grad_X";
411 TF_RETURN_IF_ERROR(
412 DivNoNan(ctx, upstream_grad, Y, &grad_inputs[0], name.c_str()));
413
414 AbstractTensorHandle* temp_output;
415 // Calculate dY = -U*Z / Y
416 name = "Neg_Div_Grad_Y";
417 TF_RETURN_IF_ERROR(Neg(ctx, upstream_grad, &temp_output,
418 name.c_str())); // -U
419 AbstractTensorHandlePtr MinusU(temp_output);
420
421 name = "Mul_Div_Grad_Y";
422 TF_RETURN_IF_ERROR(Mul(ctx, MinusU.get(), Z, &temp_output,
423 name.c_str())); // -U*Z
424 AbstractTensorHandlePtr UZ(temp_output);
425
426 name = "Div_Grad_Y";
427 TF_RETURN_IF_ERROR(DivNoNan(ctx, UZ.get(), Y, &grad_inputs[1],
428 name.c_str())); // -U*Z / Y
429
430 return OkStatus();
431 }
~DivNoNanGradientFunction()432 ~DivNoNanGradientFunction() override {
433 for (auto input : forward_inputs_) {
434 if (input) {
435 input->Unref();
436 }
437 }
438 for (auto output : forward_outputs_) {
439 if (output) {
440 output->Unref();
441 }
442 }
443 }
444
445 private:
446 // TODO(b/174778737): Only hold needed inputs and outputs.
447 vector<AbstractTensorHandle*> forward_inputs_;
448 vector<AbstractTensorHandle*> forward_outputs_;
449 };
450
451 } // namespace
452
AddRegisterer(const ForwardOperation & op)453 GradientFunction* AddRegisterer(const ForwardOperation& op) {
454 return new AddGradientFunction;
455 }
456
ExpRegisterer(const ForwardOperation & op)457 GradientFunction* ExpRegisterer(const ForwardOperation& op) {
458 return new ExpGradientFunction(op.outputs[0]);
459 }
460
MatMulRegisterer(const ForwardOperation & op)461 GradientFunction* MatMulRegisterer(const ForwardOperation& op) {
462 return new MatMulGradientFunction(op.inputs, op.attrs);
463 }
464
SqrtRegisterer(const ForwardOperation & op)465 GradientFunction* SqrtRegisterer(const ForwardOperation& op) {
466 return new SqrtGradientFunction(op.outputs[0]);
467 }
468
NegRegisterer(const ForwardOperation & op)469 GradientFunction* NegRegisterer(const ForwardOperation& op) {
470 return new NegGradientFunction;
471 }
472
SubRegisterer(const ForwardOperation & op)473 GradientFunction* SubRegisterer(const ForwardOperation& op) {
474 return new SubGradientFunction;
475 }
476
MulRegisterer(const ForwardOperation & op)477 GradientFunction* MulRegisterer(const ForwardOperation& op) {
478 return new MulGradientFunction(op.inputs);
479 }
480
Log1pRegisterer(const ForwardOperation & op)481 GradientFunction* Log1pRegisterer(const ForwardOperation& op) {
482 return new Log1pGradientFunction(op.inputs);
483 }
484
DivNoNanRegisterer(const ForwardOperation & op)485 GradientFunction* DivNoNanRegisterer(const ForwardOperation& op) {
486 return new DivNoNanGradientFunction(op.inputs, op.outputs);
487 }
488
489 } // namespace gradients
490 } // namespace tensorflow
491