• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/math_grad.h"
16 
17 #include "tensorflow/c/eager/abstract_tensor_handle.h"
18 #include "tensorflow/c/eager/gradients.h"
19 #include "tensorflow/c/experimental/ops/array_ops.h"
20 #include "tensorflow/c/experimental/ops/math_ops.h"
21 #include "tensorflow/c/experimental/ops/nn_ops.h"
22 
23 using std::vector;
24 using tensorflow::ops::AddV2;
25 using tensorflow::ops::Div;
26 using tensorflow::ops::DivNoNan;
27 using tensorflow::ops::MatMul;
28 using tensorflow::ops::Mul;
29 using tensorflow::ops::Neg;
30 using tensorflow::ops::OnesLike;
31 using tensorflow::ops::SqrtGrad;
32 
33 namespace tensorflow {
34 namespace gradients {
35 namespace {
36 
SafeConj(AbstractContext * ctx,AbstractTensorHandle * const input,AbstractTensorHandle ** output,const char * name)37 static Status SafeConj(AbstractContext* ctx, AbstractTensorHandle* const input,
38                        AbstractTensorHandle** output, const char* name) {
39   auto dtype = input->DataType();
40   if (DataTypeIsFloating(BaseType(dtype)) ||
41       DataTypeIsInteger(BaseType(dtype))) {
42     return tensorflow::ops::Identity(ctx, input, output, name);
43   } else if (!DataTypeIsComplex(BaseType(dtype)) &&
44              BaseType(dtype) != DT_VARIANT) {
45     return errors::InvalidArgument(
46         "Expected numeric or variant tensor, got dtype ", dtype);
47   }
48   return tensorflow::ops::Conj(ctx, input, output, name);
49 }
50 
51 class AddGradientFunction : public GradientFunction {
52  public:
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)53   Status Compute(AbstractContext* ctx,
54                  absl::Span<AbstractTensorHandle* const> grad_outputs,
55                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
56     // TODO(b/161805092): Support broadcasting.
57 
58     DCHECK(grad_outputs[0]);
59     grad_inputs[0] = grad_outputs[0];
60     grad_inputs[1] = grad_outputs[0];
61 
62     grad_inputs[0]->Ref();
63     grad_inputs[1]->Ref();
64     return OkStatus();
65   }
~AddGradientFunction()66   ~AddGradientFunction() override {}
67 };
68 
69 class ExpGradientFunction : public GradientFunction {
70  public:
ExpGradientFunction(AbstractTensorHandle * exp)71   explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
72     exp->Ref();
73   }
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)74   Status Compute(AbstractContext* ctx,
75                  absl::Span<AbstractTensorHandle* const> grad_outputs,
76                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
77     AbstractTensorHandle* conj_output;
78     std::string name = "Conj_Exp_Grad";
79     TF_RETURN_IF_ERROR(SafeConj(ctx, exp_.get(), &conj_output, name.c_str()));
80     AbstractTensorHandlePtr conj_output_releaser(conj_output);
81 
82     name = "Mul_Exp_Grad";
83     TF_RETURN_IF_ERROR(
84         Mul(ctx, conj_output, grad_outputs[0], &grad_inputs[0], name.c_str()));
85     return OkStatus();
86   }
~ExpGradientFunction()87   ~ExpGradientFunction() override {}
88 
89  private:
90   AbstractTensorHandlePtr exp_;
91 };
92 
93 class SqrtGradientFunction : public GradientFunction {
94  public:
SqrtGradientFunction(AbstractTensorHandle * sqrt)95   explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
96     sqrt->Ref();
97   }
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)98   Status Compute(AbstractContext* ctx,
99                  absl::Span<AbstractTensorHandle* const> grad_outputs,
100                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
101     std::string name = "Sqrt_Grad";
102     TF_RETURN_IF_ERROR(SqrtGrad(ctx, sqrt_.get(), grad_outputs[0],
103                                 &grad_inputs[0], name.c_str()));
104     return OkStatus();
105   }
~SqrtGradientFunction()106   ~SqrtGradientFunction() override {}
107 
108  private:
109   AbstractTensorHandlePtr sqrt_;
110 };
111 
112 class MatMulGradientFunction : public GradientFunction {
113  public:
MatMulGradientFunction(vector<AbstractTensorHandle * > f_inputs,AttrBuilder f_attrs)114   explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
115                                   AttrBuilder f_attrs)
116       : forward_inputs_(f_inputs), forward_attrs_(f_attrs) {
117     for (auto input : forward_inputs_) {
118       if (input) {
119         input->Ref();
120       }
121     }
122   }
123 
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)124   Status Compute(AbstractContext* ctx,
125                  absl::Span<AbstractTensorHandle* const> grad_outputs,
126                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
127     /* Given upstream grad U and a matmul op A*B, the gradients are:
128      *
129      *    dA = U * B.T
130      *    dB = A.T * U
131      *
132      *    where A.T means `transpose(A)`
133      */
134     AbstractTensorHandle* upstream_grad = grad_outputs[0];
135 
136     // Get transpose attrs
137     bool t_a;
138     TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_a", &t_a));
139 
140     bool t_b;
141     TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_b", &t_b));
142 
143     // Conj each input
144     AbstractTensorHandle* conj_output;
145     std::string name = "Conj_A_MatMul_Grad";
146     TF_RETURN_IF_ERROR(
147         SafeConj(ctx, forward_inputs_[0], &conj_output, name.c_str()));
148 
149     AbstractTensorHandlePtr A(conj_output);
150 
151     name = "Conj_B_MatMul_Grad";
152     TF_RETURN_IF_ERROR(
153         SafeConj(ctx, forward_inputs_[1], &conj_output, name.c_str()));
154 
155     AbstractTensorHandlePtr B(conj_output);
156 
157     // Calc Grad
158     AbstractTensorHandle* matmul_A_output;
159     AbstractTensorHandle* matmul_B_output;
160     std::string name_grad_A = "MatMul_Grad_A";
161     std::string name_grad_B = "MatMul_Grad_B";
162     if (!t_a && !t_b) {
163       TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, B.get(), &matmul_A_output,
164                                 /*transpose_a = */ false,
165                                 /*transpose_b = */ true, name_grad_A.c_str()));
166 
167       TF_RETURN_IF_ERROR(MatMul(ctx, A.get(), upstream_grad, &matmul_B_output,
168                                 /*transpose_a = */ true,
169                                 /*transpose_b = */ false, name_grad_B.c_str()));
170     } else if (!t_a && t_b) {
171       TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, B.get(), &matmul_A_output,
172                                 /*transpose_a = */ false,
173                                 /*transpose_b = */ false, name_grad_A.c_str()));
174 
175       TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, A.get(), &matmul_B_output,
176                                 /*transpose_a = */ true,
177                                 /*transpose_b = */ false, name_grad_B.c_str()));
178 
179     } else if (t_a && !t_b) {
180       TF_RETURN_IF_ERROR(MatMul(ctx, B.get(), upstream_grad, &matmul_A_output,
181                                 /*transpose_a = */ false,
182                                 /*transpose_b = */ true, name_grad_A.c_str()));
183 
184       TF_RETURN_IF_ERROR(MatMul(ctx, A.get(), upstream_grad, &matmul_B_output,
185                                 /*transpose_a = */ false,
186                                 /*transpose_b = */ false, name_grad_B.c_str()));
187     } else {  // t_a && t_b
188       TF_RETURN_IF_ERROR(MatMul(ctx, B.get(), upstream_grad, &matmul_A_output,
189                                 /*transpose_a = */ true,
190                                 /*transpose_b = */ true, name_grad_A.c_str()));
191 
192       TF_RETURN_IF_ERROR(MatMul(ctx, upstream_grad, A.get(), &matmul_B_output,
193                                 /*transpose_a = */ true,
194                                 /*transpose_b = */ true, name_grad_B.c_str()));
195     }
196 
197     // Gradient for A
198     grad_inputs[0] = matmul_A_output;
199 
200     // Gradient for B
201     grad_inputs[1] = matmul_B_output;
202     return OkStatus();
203   }
~MatMulGradientFunction()204   ~MatMulGradientFunction() override {
205     for (auto input : forward_inputs_) {
206       if (input) {
207         input->Unref();
208       }
209     }
210   }
211 
212  private:
213   // TODO(b/174778737): Only hold needed inputs.
214   vector<AbstractTensorHandle*> forward_inputs_;
215   AttrBuilder forward_attrs_;
216 };
217 
218 class NegGradientFunction : public GradientFunction {
219  public:
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)220   Status Compute(AbstractContext* ctx,
221                  absl::Span<AbstractTensorHandle* const> grad_outputs,
222                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
223     /* Given upstream grad U and a Neg op Y = -X, the gradients are:
224      *
225      *    dX =  -U
226      *
227      */
228 
229     std::string name = "Neg_Grad";
230     TF_RETURN_IF_ERROR(
231         ops::Neg(ctx, grad_outputs[0], &grad_inputs[0], name.c_str()));
232     return OkStatus();
233   }
~NegGradientFunction()234   ~NegGradientFunction() override {}
235 };
236 
237 class SubGradientFunction : public GradientFunction {
238  public:
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)239   Status Compute(AbstractContext* ctx,
240                  absl::Span<AbstractTensorHandle* const> grad_outputs,
241                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
242     /* Given upstream grad U and a Sub op A-B, the gradients are:
243      *
244      *    dA =  U
245      *    dB = -U
246      *
247      */
248 
249     // Grad for A
250     DCHECK(grad_outputs[0]);
251     grad_inputs[0] = grad_outputs[0];
252     grad_inputs[0]->Ref();
253 
254     // Grad for B
255     // negate the upstream grad
256     std::string name = "Neg_Sub_Grad_B";
257     TF_RETURN_IF_ERROR(
258         ops::Neg(ctx, grad_outputs[0], &grad_inputs[1], name.c_str()));
259 
260     return OkStatus();
261   }
~SubGradientFunction()262   ~SubGradientFunction() override {}
263 };
264 
265 class MulGradientFunction : public GradientFunction {
266  public:
MulGradientFunction(vector<AbstractTensorHandle * > f_inputs)267   explicit MulGradientFunction(vector<AbstractTensorHandle*> f_inputs)
268       : forward_inputs_(f_inputs) {
269     for (auto input : forward_inputs_) {
270       if (input) {
271         input->Ref();
272       }
273     }
274   }
275 
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)276   Status Compute(AbstractContext* ctx,
277                  absl::Span<AbstractTensorHandle* const> grad_outputs,
278                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
279     /* Given upstream grad U and a mul op A*B, the gradients are:
280      *
281      *    dA = U * B
282      *    dB = A * U
283      *
284      */
285 
286     AbstractTensorHandle* upstream_grad = grad_outputs[0];
287 
288     // Gradient for A
289     std::string name = "Mul_Grad_A";
290     TF_RETURN_IF_ERROR(Mul(ctx, upstream_grad, forward_inputs_[1],
291                            &grad_inputs[0], name.c_str()));
292 
293     // Gradient for B
294     name = "Mul_Grad_B";
295     TF_RETURN_IF_ERROR(Mul(ctx, forward_inputs_[0], upstream_grad,
296                            &grad_inputs[1], name.c_str()));
297     return OkStatus();
298   }
~MulGradientFunction()299   ~MulGradientFunction() override {
300     for (auto input : forward_inputs_) {
301       if (input) {
302         input->Unref();
303       }
304     }
305   }
306 
307  private:
308   // TODO(b/174778737): Only hold needed inputs.
309   vector<AbstractTensorHandle*> forward_inputs_;
310 };
311 
312 class Log1pGradientFunction : public GradientFunction {
313  public:
Log1pGradientFunction(vector<AbstractTensorHandle * > f_inputs)314   explicit Log1pGradientFunction(vector<AbstractTensorHandle*> f_inputs)
315       : forward_inputs_(f_inputs) {
316     for (auto input : forward_inputs_) {
317       if (input) {
318         input->Ref();
319       }
320     }
321   }
322 
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)323   Status Compute(AbstractContext* ctx,
324                  absl::Span<AbstractTensorHandle* const> grad_outputs,
325                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
326     // TODO(vnvo2409): Add control dependency
327     /* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are:
328      *
329      *    dX = U / (1 + X)
330      *
331      */
332 
333     AbstractTensorHandle* upstream_grad = grad_outputs[0];
334     AbstractTensorHandle* X = forward_inputs_[0];
335 
336     AbstractTensorHandle* temp_output;
337 
338     // Calculate conjugate of X
339     std::string name = "Conj_Log1p_Grad_X";
340     TF_RETURN_IF_ERROR(SafeConj(ctx, X, &temp_output, name.c_str()));
341 
342     AbstractTensorHandlePtr Conj_X(temp_output);
343 
344     // Creates Ones
345     name = "OnesLike_Log1p_Grad_X";
346     TF_RETURN_IF_ERROR(OnesLike(ctx, Conj_X.get(), &temp_output, name.c_str()));
347 
348     AbstractTensorHandlePtr Ones_X(temp_output);
349 
350     name = "Add_Log1p_Grad_X";
351     // Calculate 1 + Conj(X)
352     TF_RETURN_IF_ERROR(
353         AddV2(ctx, Ones_X.get(), Conj_X.get(), &temp_output, name.c_str()));
354 
355     AbstractTensorHandlePtr Conj_XP1(temp_output);
356 
357     name = "Div_Log1p_Grad_X";
358     // Calculate U / (1 + Conj(X))
359     TF_RETURN_IF_ERROR(
360         Div(ctx, upstream_grad, Conj_XP1.get(), &grad_inputs[0], name.c_str()));
361 
362     return OkStatus();
363   }
~Log1pGradientFunction()364   ~Log1pGradientFunction() override {
365     for (auto input : forward_inputs_) {
366       if (input) {
367         input->Unref();
368       }
369     }
370   }
371 
372  private:
373   // TODO(b/174778737): Only hold needed inputs.
374   vector<AbstractTensorHandle*> forward_inputs_;
375 };
376 
377 class DivNoNanGradientFunction : public GradientFunction {
378  public:
DivNoNanGradientFunction(vector<AbstractTensorHandle * > f_inputs,vector<AbstractTensorHandle * > f_outputs)379   explicit DivNoNanGradientFunction(vector<AbstractTensorHandle*> f_inputs,
380                                     vector<AbstractTensorHandle*> f_outputs)
381       : forward_inputs_(f_inputs), forward_outputs_(f_outputs) {
382     for (auto input : forward_inputs_) {
383       if (input) {
384         input->Ref();
385       }
386     }
387     for (auto output : forward_outputs_) {
388       if (output) {
389         output->Ref();
390       }
391     }
392   }
393 
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)394   Status Compute(AbstractContext* ctx,
395                  absl::Span<AbstractTensorHandle* const> grad_outputs,
396                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
397     // TODO(vnvo2409): Add shape broadcasting
398     /* Given upstream grad U and a Div op: Z = X/Y, the gradients are:
399      *
400      *    dX = U / Y
401      *    dY = -U*X / Y^2 = (X/Y) * -U / Y = -U*Z / Y
402      *
403      */
404 
405     AbstractTensorHandle* upstream_grad = grad_outputs[0];
406     AbstractTensorHandle* Y = forward_inputs_[1];
407     AbstractTensorHandle* Z = forward_outputs_[0];
408 
409     // Calculate dX =  U / Y
410     std::string name = "Div_Grad_X";
411     TF_RETURN_IF_ERROR(
412         DivNoNan(ctx, upstream_grad, Y, &grad_inputs[0], name.c_str()));
413 
414     AbstractTensorHandle* temp_output;
415     // Calculate dY = -U*Z / Y
416     name = "Neg_Div_Grad_Y";
417     TF_RETURN_IF_ERROR(Neg(ctx, upstream_grad, &temp_output,
418                            name.c_str()));  // -U
419     AbstractTensorHandlePtr MinusU(temp_output);
420 
421     name = "Mul_Div_Grad_Y";
422     TF_RETURN_IF_ERROR(Mul(ctx, MinusU.get(), Z, &temp_output,
423                            name.c_str()));  // -U*Z
424     AbstractTensorHandlePtr UZ(temp_output);
425 
426     name = "Div_Grad_Y";
427     TF_RETURN_IF_ERROR(DivNoNan(ctx, UZ.get(), Y, &grad_inputs[1],
428                                 name.c_str()));  // -U*Z / Y
429 
430     return OkStatus();
431   }
~DivNoNanGradientFunction()432   ~DivNoNanGradientFunction() override {
433     for (auto input : forward_inputs_) {
434       if (input) {
435         input->Unref();
436       }
437     }
438     for (auto output : forward_outputs_) {
439       if (output) {
440         output->Unref();
441       }
442     }
443   }
444 
445  private:
446   // TODO(b/174778737): Only hold needed inputs and outputs.
447   vector<AbstractTensorHandle*> forward_inputs_;
448   vector<AbstractTensorHandle*> forward_outputs_;
449 };
450 
451 }  // namespace
452 
AddRegisterer(const ForwardOperation & op)453 GradientFunction* AddRegisterer(const ForwardOperation& op) {
454   return new AddGradientFunction;
455 }
456 
ExpRegisterer(const ForwardOperation & op)457 GradientFunction* ExpRegisterer(const ForwardOperation& op) {
458   return new ExpGradientFunction(op.outputs[0]);
459 }
460 
MatMulRegisterer(const ForwardOperation & op)461 GradientFunction* MatMulRegisterer(const ForwardOperation& op) {
462   return new MatMulGradientFunction(op.inputs, op.attrs);
463 }
464 
SqrtRegisterer(const ForwardOperation & op)465 GradientFunction* SqrtRegisterer(const ForwardOperation& op) {
466   return new SqrtGradientFunction(op.outputs[0]);
467 }
468 
NegRegisterer(const ForwardOperation & op)469 GradientFunction* NegRegisterer(const ForwardOperation& op) {
470   return new NegGradientFunction;
471 }
472 
SubRegisterer(const ForwardOperation & op)473 GradientFunction* SubRegisterer(const ForwardOperation& op) {
474   return new SubGradientFunction;
475 }
476 
MulRegisterer(const ForwardOperation & op)477 GradientFunction* MulRegisterer(const ForwardOperation& op) {
478   return new MulGradientFunction(op.inputs);
479 }
480 
Log1pRegisterer(const ForwardOperation & op)481 GradientFunction* Log1pRegisterer(const ForwardOperation& op) {
482   return new Log1pGradientFunction(op.inputs);
483 }
484 
DivNoNanRegisterer(const ForwardOperation & op)485 GradientFunction* DivNoNanRegisterer(const ForwardOperation& op) {
486   return new DivNoNanGradientFunction(op.inputs, op.outputs);
487 }
488 
489 }  // namespace gradients
490 }  // namespace tensorflow
491