1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/nn_grad.h"
16
17 #include "absl/types/span.h"
18 #include "tensorflow/c/eager/abstract_tensor_handle.h"
19 #include "tensorflow/c/eager/immediate_execution_context.h"
20 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
21 #include "tensorflow/c/experimental/ops/array_ops.h"
22 #include "tensorflow/c/experimental/ops/math_ops.h"
23 #include "tensorflow/c/experimental/ops/nn_ops.h"
24 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
25 #include "tensorflow/core/platform/errors.h"
26
27 using std::vector;
28 using tensorflow::ops::BiasAddGrad;
29 using tensorflow::ops::Mul;
30 using tensorflow::ops::ReluGrad;
31
32 namespace tensorflow {
33 namespace gradients {
34 namespace {
35
36 class ReluGradientFunction : public GradientFunction {
37 public:
ReluGradientFunction(vector<AbstractTensorHandle * > f_outputs)38 explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
39 : forward_outputs_(f_outputs) {
40 for (auto output : forward_outputs_) {
41 if (output) {
42 output->Ref();
43 }
44 }
45 }
46
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)47 Status Compute(AbstractContext* ctx,
48 absl::Span<AbstractTensorHandle* const> grad_outputs,
49 absl::Span<AbstractTensorHandle*> grad_inputs) override {
50 AbstractTensorHandle* upstream_grad = grad_outputs[0];
51 AbstractTensorHandle* activations = forward_outputs_[0];
52
53 // Calculate Grad
54 std::string name = "relu_grad";
55 TF_RETURN_IF_ERROR(ReluGrad(ctx, upstream_grad, activations,
56 &grad_inputs[0], name.c_str()));
57 return OkStatus();
58 }
~ReluGradientFunction()59 ~ReluGradientFunction() override {
60 for (auto output : forward_outputs_) {
61 if (output) {
62 output->Unref();
63 }
64 }
65 }
66
67 private:
68 // TODO(b/174778737): Only hold needed outputs.
69 vector<AbstractTensorHandle*> forward_outputs_;
70 };
71
BroadcastMul(AbstractContext * ctx,AbstractTensorHandle * vec,AbstractTensorHandle * mat,absl::Span<AbstractTensorHandle * > outputs)72 Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec,
73 AbstractTensorHandle* mat,
74 absl::Span<AbstractTensorHandle*> outputs) {
75 if (!isa<ImmediateExecutionContext>(ctx)) {
76 // TODO(b/168850692): Fix this.
77 return errors::Unimplemented(
78 "BroadcastMul is not supported in tracing mode yet.");
79 }
80 auto imm_ctx = dyn_cast<ImmediateExecutionContext>(ctx);
81 AbstractTensorPtr minus_1(imm_ctx->CreateInt32Scalar(-1));
82 ImmediateTensorHandlePtr dim(imm_ctx->CreateLocalHandle(minus_1.get()));
83 AbstractTensorHandle* expand_dims_outputs;
84 TF_RETURN_IF_ERROR(
85 ops::ExpandDims(ctx, vec, dim.get(), &expand_dims_outputs, "ExpandDims"));
86 TF_RETURN_IF_ERROR(
87 ops::Mul(ctx, expand_dims_outputs, mat, &outputs[0], "Mul"));
88 expand_dims_outputs->Unref();
89 return OkStatus();
90 }
91
92 class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
93 : public GradientFunction {
94 public:
SparseSoftmaxCrossEntropyWithLogitsGradientFunction(vector<AbstractTensorHandle * > f_outputs)95 explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction(
96 vector<AbstractTensorHandle*> f_outputs)
97 : forward_outputs_(f_outputs) {}
98
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)99 Status Compute(AbstractContext* ctx,
100 absl::Span<AbstractTensorHandle* const> grad_outputs,
101 absl::Span<AbstractTensorHandle*> grad_inputs) override {
102 // Grad for Softmax Input
103 TF_RETURN_IF_ERROR(BroadcastMul(
104 ctx, grad_outputs[0], forward_outputs_[1],
105 grad_inputs.subspan(0, 1))); // upstream_grad * local softmax grad
106
107 // Grad for labels is null
108 grad_inputs[1] = nullptr;
109 return OkStatus();
110 }
~SparseSoftmaxCrossEntropyWithLogitsGradientFunction()111 ~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {}
112
113 private:
114 vector<AbstractTensorHandle*> forward_outputs_;
115 };
116
117 // TODO(vnvo2409): Add python test
118 class BiasAddGradientFunction : public GradientFunction {
119 public:
BiasAddGradientFunction(AttrBuilder f_attrs)120 explicit BiasAddGradientFunction(AttrBuilder f_attrs)
121 : forward_attrs_(f_attrs) {}
122
Compute(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> grad_outputs,absl::Span<AbstractTensorHandle * > grad_inputs)123 Status Compute(AbstractContext* ctx,
124 absl::Span<AbstractTensorHandle* const> grad_outputs,
125 absl::Span<AbstractTensorHandle*> grad_inputs) override {
126 /* Given upstream grad U and a BiasAdd: A + bias, the gradients are:
127 *
128 * dA = U
129 * dbias = reduceSum(U, dims = channel_dim)
130 */
131
132 AbstractTensorHandle* upstream_grad = grad_outputs[0];
133 DCHECK(upstream_grad);
134
135 // Recover data format from forward pass for gradient.
136 std::string data_format;
137 TF_RETURN_IF_ERROR(forward_attrs_.Get("data_format", &data_format));
138
139 // Grad for A
140 grad_inputs[0] = upstream_grad;
141 grad_inputs[0]->Ref();
142
143 // Grad for bias
144 std::string name = "bias_add_grad";
145 TF_RETURN_IF_ERROR(BiasAddGrad(ctx, upstream_grad, &grad_inputs[1],
146 data_format.c_str(), name.c_str()));
147
148 return OkStatus();
149 }
~BiasAddGradientFunction()150 ~BiasAddGradientFunction() override {}
151
152 private:
153 AttrBuilder forward_attrs_;
154 };
155
156 } // namespace
157
ReluRegisterer(const ForwardOperation & op)158 GradientFunction* ReluRegisterer(const ForwardOperation& op) {
159 return new ReluGradientFunction(op.outputs);
160 }
161
SparseSoftmaxCrossEntropyWithLogitsRegisterer(const ForwardOperation & op)162 GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
163 const ForwardOperation& op) {
164 return new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
165 }
166
BiasAddRegisterer(const ForwardOperation & op)167 GradientFunction* BiasAddRegisterer(const ForwardOperation& op) {
168 return new BiasAddGradientFunction(op.attrs);
169 }
170
171 } // namespace gradients
172 } // namespace tensorflow
173