• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/nn_grad.h"
16 
17 #include "absl/types/span.h"
18 #include "tensorflow/c/eager/abstract_tensor_handle.h"
19 #include "tensorflow/c/eager/immediate_execution_context.h"
20 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
21 #include "tensorflow/c/experimental/ops/array_ops.h"
22 #include "tensorflow/c/experimental/ops/math_ops.h"
23 #include "tensorflow/c/experimental/ops/nn_ops.h"
24 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
25 #include "tensorflow/core/platform/errors.h"
26 
27 using std::vector;
28 using tensorflow::ops::BiasAddGrad;
29 using tensorflow::ops::Mul;
30 using tensorflow::ops::ReluGrad;
31 
32 namespace tensorflow {
33 namespace gradients {
34 namespace {
35 
36 class ReluGradientFunction : public GradientFunction {
37  public:
ReluGradientFunction(vector<AbstractTensorHandle * > f_outputs)38   explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
39       : forward_outputs_(f_outputs) {
40     for (auto output : forward_outputs_) {
41       if (output) {
42         output->Ref();
43       }
44     }
45   }
46 
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)47   Status Compute(AbstractContext* ctx,
48                  absl::Span<AbstractTensorHandle* const> grad_outputs,
49                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
50     AbstractTensorHandle* upstream_grad = grad_outputs[0];
51     AbstractTensorHandle* activations = forward_outputs_[0];
52 
53     // Calculate Grad
54     std::string name = "relu_grad";
55     TF_RETURN_IF_ERROR(ReluGrad(ctx, upstream_grad, activations,
56                                 &grad_inputs[0], name.c_str()));
57     return OkStatus();
58   }
~ReluGradientFunction()59   ~ReluGradientFunction() override {
60     for (auto output : forward_outputs_) {
61       if (output) {
62         output->Unref();
63       }
64     }
65   }
66 
67  private:
68   // TODO(b/174778737): Only hold needed outputs.
69   vector<AbstractTensorHandle*> forward_outputs_;
70 };
71 
BroadcastMul(AbstractContext * ctx,AbstractTensorHandle * vec,AbstractTensorHandle * mat,absl::Span<AbstractTensorHandle * > outputs)72 Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec,
73                     AbstractTensorHandle* mat,
74                     absl::Span<AbstractTensorHandle*> outputs) {
75   if (!isa<ImmediateExecutionContext>(ctx)) {
76     // TODO(b/168850692): Fix this.
77     return errors::Unimplemented(
78         "BroadcastMul is not supported in tracing mode yet.");
79   }
80   auto imm_ctx = dyn_cast<ImmediateExecutionContext>(ctx);
81   AbstractTensorPtr minus_1(imm_ctx->CreateInt32Scalar(-1));
82   ImmediateTensorHandlePtr dim(imm_ctx->CreateLocalHandle(minus_1.get()));
83   AbstractTensorHandle* expand_dims_outputs;
84   TF_RETURN_IF_ERROR(
85       ops::ExpandDims(ctx, vec, dim.get(), &expand_dims_outputs, "ExpandDims"));
86   TF_RETURN_IF_ERROR(
87       ops::Mul(ctx, expand_dims_outputs, mat, &outputs[0], "Mul"));
88   expand_dims_outputs->Unref();
89   return OkStatus();
90 }
91 
92 class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
93     : public GradientFunction {
94  public:
SparseSoftmaxCrossEntropyWithLogitsGradientFunction(vector<AbstractTensorHandle * > f_outputs)95   explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction(
96       vector<AbstractTensorHandle*> f_outputs)
97       : forward_outputs_(f_outputs) {}
98 
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)99   Status Compute(AbstractContext* ctx,
100                  absl::Span<AbstractTensorHandle* const> grad_outputs,
101                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
102     // Grad for Softmax Input
103     TF_RETURN_IF_ERROR(BroadcastMul(
104         ctx, grad_outputs[0], forward_outputs_[1],
105         grad_inputs.subspan(0, 1)));  // upstream_grad * local softmax grad
106 
107     // Grad for labels is null
108     grad_inputs[1] = nullptr;
109     return OkStatus();
110   }
~SparseSoftmaxCrossEntropyWithLogitsGradientFunction()111   ~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {}
112 
113  private:
114   vector<AbstractTensorHandle*> forward_outputs_;
115 };
116 
117 // TODO(vnvo2409): Add python test
118 class BiasAddGradientFunction : public GradientFunction {
119  public:
BiasAddGradientFunction(AttrBuilder f_attrs)120   explicit BiasAddGradientFunction(AttrBuilder f_attrs)
121       : forward_attrs_(f_attrs) {}
122 
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)123   Status Compute(AbstractContext* ctx,
124                  absl::Span<AbstractTensorHandle* const> grad_outputs,
125                  absl::Span<AbstractTensorHandle*> grad_inputs) override {
126     /* Given upstream grad U and a BiasAdd: A + bias, the gradients are:
127      *
128      *    dA = U
129      *    dbias = reduceSum(U, dims = channel_dim)
130      */
131 
132     AbstractTensorHandle* upstream_grad = grad_outputs[0];
133     DCHECK(upstream_grad);
134 
135     // Recover data format from forward pass for gradient.
136     std::string data_format;
137     TF_RETURN_IF_ERROR(forward_attrs_.Get("data_format", &data_format));
138 
139     // Grad for A
140     grad_inputs[0] = upstream_grad;
141     grad_inputs[0]->Ref();
142 
143     // Grad for bias
144     std::string name = "bias_add_grad";
145     TF_RETURN_IF_ERROR(BiasAddGrad(ctx, upstream_grad, &grad_inputs[1],
146                                    data_format.c_str(), name.c_str()));
147 
148     return OkStatus();
149   }
~BiasAddGradientFunction()150   ~BiasAddGradientFunction() override {}
151 
152  private:
153   AttrBuilder forward_attrs_;
154 };
155 
156 }  // namespace
157 
ReluRegisterer(const ForwardOperation & op)158 GradientFunction* ReluRegisterer(const ForwardOperation& op) {
159   return new ReluGradientFunction(op.outputs);
160 }
161 
SparseSoftmaxCrossEntropyWithLogitsRegisterer(const ForwardOperation & op)162 GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
163     const ForwardOperation& op) {
164   return new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
165 }
166 
BiasAddRegisterer(const ForwardOperation & op)167 GradientFunction* BiasAddRegisterer(const ForwardOperation& op) {
168   return new BiasAddGradientFunction(op.attrs);
169 }
170 
171 }  // namespace gradients
172 }  // namespace tensorflow
173