1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/mnist_gradients_testutil.h"
16
17 #include <memory>
18
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/c/eager/abstract_tensor_handle.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api_unified_experimental.h"
24 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
25 #include "tensorflow/c/eager/gradients.h"
26 #include "tensorflow/c/eager/gradients_internal.h"
27 #include "tensorflow/c/eager/gradients_util.h"
28 #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
29 #include "tensorflow/c/experimental/ops/array_ops.h"
30 #include "tensorflow/c/experimental/ops/math_ops.h"
31 #include "tensorflow/c/experimental/ops/nn_ops.h"
32 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
33
34 namespace tensorflow {
35 namespace gradients {
36 namespace internal {
37
38 using std::vector;
39
40 //===================== Test Models to run =========================
41
42 // Computes
43 // y = inputs[0] + inputs[1]
44 // return grad(y, {inputs[0], inputs[1]})
AddGradModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)45 Status AddGradModel(AbstractContext* ctx,
46 absl::Span<AbstractTensorHandle* const> inputs,
47 absl::Span<AbstractTensorHandle*> outputs,
48 const GradientRegistry& registry) {
49 auto tape = new Tape(/*persistent=*/false);
50 tape->Watch(inputs[0]); // Watch x.
51 tape->Watch(inputs[1]); // Watch y.
52 std::vector<AbstractTensorHandle*> add_outputs(1);
53 AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
54 TF_RETURN_IF_ERROR(
55 ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
56 TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
57 /*sources=*/inputs,
58 /*output_gradients=*/{}, outputs));
59 for (auto add_output : add_outputs) {
60 add_output->Unref();
61 }
62 delete tape;
63 return Status::OK();
64 }
65
66 // Computes
67 // y = inputs[0] * inputs[1]
68 // return grad(y, {inputs[0], inputs[1]})
MatMulGradModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)69 Status MatMulGradModel(AbstractContext* ctx,
70 absl::Span<AbstractTensorHandle* const> inputs,
71 absl::Span<AbstractTensorHandle*> outputs,
72 const GradientRegistry& registry) {
73 auto tape = new Tape(/*persistent=*/false);
74 tape->Watch(inputs[0]); // Watch x.
75 tape->Watch(inputs[1]); // Watch y.
76 vector<AbstractTensorHandle*> mm_outputs(1);
77 AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
78 TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
79 absl::MakeSpan(mm_outputs), "matmul0",
80 /*transpose_a=*/false,
81 /*transpose_b=*/false)); // Compute x*y.
82
83 TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mm_outputs,
84 /*sources=*/inputs,
85 /*output_gradients=*/{}, outputs));
86 for (auto mm_output : mm_outputs) {
87 mm_output->Unref();
88 }
89 delete tape;
90 return Status::OK();
91 }
92
93 // Model to run 2-layer net
MNISTForwardModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)94 Status MNISTForwardModel(AbstractContext* ctx,
95 absl::Span<AbstractTensorHandle* const> inputs,
96 absl::Span<AbstractTensorHandle*> outputs,
97 const GradientRegistry& registry) {
98 /**
99 * We will trace a 2-layer fully connected network for an MNIST model:
100 *
101 * def mnist_forward(X, W1, W2, y_labels):
102 * mm_out_1 = tf.matmul(X,W1)
103 * hidden_layer = tf.nn.relu(mm_out_1)
104 * scores = tf.matmul(hidden_layer,W2)
105 * softmax =
106 * tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
107 * y_labels)
108 * return scores, softmax
109 *
110 * Use this convention for inputs:
111 *
112 * inputs = [X, W1, W2, y_labels]
113 *
114 */
115 AbstractTensorHandle* X = inputs[0];
116 AbstractTensorHandle* W1 = inputs[1];
117 AbstractTensorHandle* W2 = inputs[2];
118 AbstractTensorHandle* y_labels = inputs[3];
119
120 vector<AbstractTensorHandle*> temp_outputs(1);
121
122 TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, absl::MakeSpan(temp_outputs),
123 "matmul0",
124 /*transpose_a=*/false,
125 /*transpose_b=*/false)); // Compute X*W1
126
127 TF_RETURN_IF_ERROR(ops::Relu(ctx, {temp_outputs[0]},
128 absl::MakeSpan(temp_outputs),
129 "relu")); // Compute Relu(X*W1)
130
131 TF_RETURN_IF_ERROR(ops::MatMul(
132 ctx, {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), "matmul1",
133 /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
134
135 AbstractTensorHandle* scores = temp_outputs[0];
136
137 temp_outputs.resize(2);
138 TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
139 ctx, {scores, y_labels}, absl::MakeSpan(temp_outputs),
140 "softmax_loss")); // Compute Softmax(Scores,labels)
141
142 AbstractTensorHandle* loss_vals = temp_outputs[0];
143
144 outputs[0] = scores;
145 outputs[1] = loss_vals;
146 return Status::OK();
147 }
148
MatMulTransposeModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)149 Status MatMulTransposeModel(AbstractContext* ctx,
150 absl::Span<AbstractTensorHandle* const> inputs,
151 absl::Span<AbstractTensorHandle*> outputs,
152 const GradientRegistry& registry) {
153 AbstractTensorHandle* X = inputs[0];
154 AbstractTensorHandle* W1 = inputs[1];
155
156 TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, outputs, "matmul0",
157 /*transpose_a=*/true,
158 /*transpose_b=*/false)); // Compute X*W1
159 return Status::OK();
160 }
161
MNISTGradModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)162 Status MNISTGradModel(AbstractContext* ctx,
163 absl::Span<AbstractTensorHandle* const> inputs,
164 absl::Span<AbstractTensorHandle*> outputs,
165 const GradientRegistry& registry) {
166 AbstractTensorHandle* X = inputs[0];
167 AbstractTensorHandle* W1 = inputs[1];
168 AbstractTensorHandle* W2 = inputs[2];
169 AbstractTensorHandle* y_labels = inputs[3];
170
171 auto tape = new Tape(/*persistent=*/true);
172 tape->Watch(X); // Watch X.
173 tape->Watch(W1); // Watch W1.
174 tape->Watch(W2); // Watch W1.
175 vector<AbstractTensorHandle*> temp_outputs(1);
176 AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
177 TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
178 absl::MakeSpan(temp_outputs), "matmul0",
179 /*transpose_a=*/false,
180 /*transpose_b=*/false)); // Compute X*W1
181
182 AbstractTensorHandle* mm = temp_outputs[0];
183
184 TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
185 absl::MakeSpan(temp_outputs), // Relu(X*W1)
186 "relu0"));
187
188 AbstractTensorHandle* hidden = temp_outputs[0];
189
190 TF_RETURN_IF_ERROR(ops::MatMul(
191 tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
192 /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
193
194 AbstractTensorHandle* scores = temp_outputs[0];
195
196 temp_outputs.resize(2);
197 TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
198 tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
199 "softmaxloss")); // W2*Relu(X*W1)
200
201 AbstractTensorHandle* loss = temp_outputs[0];
202
203 TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/{loss},
204 /*sources=*/{W1, W2},
205 /*output_gradients=*/{},
206 outputs.subspan(0, 2)));
207
208 // Only release 2nd temp output as first holds loss values.
209 temp_outputs[1]->Unref();
210 outputs[2] = loss;
211 delete tape;
212 return Status::OK();
213 }
214
ScalarMulModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)215 Status ScalarMulModel(AbstractContext* ctx,
216 absl::Span<AbstractTensorHandle* const> inputs,
217 absl::Span<AbstractTensorHandle*> outputs,
218 const GradientRegistry& registry) {
219 return ops::Mul(ctx, inputs, outputs,
220 "scalarMul0"); // Compute eta*A
221 }
222
MatMulModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)223 Status MatMulModel(AbstractContext* ctx,
224 absl::Span<AbstractTensorHandle* const> inputs,
225 absl::Span<AbstractTensorHandle*> outputs,
226 const GradientRegistry& registry) {
227 return ops::MatMul(ctx, inputs, outputs, "matmul0",
228 /*transpose_a=*/false,
229 /*transpose_b=*/false); // Compute X*W1
230 }
231
MulModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const GradientRegistry & registry)232 Status MulModel(AbstractContext* ctx,
233 absl::Span<AbstractTensorHandle* const> inputs,
234 absl::Span<AbstractTensorHandle*> outputs,
235 const GradientRegistry& registry) {
236 return ops::Mul(ctx, inputs, outputs,
237 "mul0"); // Compute x*y
238 }
239
240 // ============================= End Models ================================
241
242 } // namespace internal
243 } // namespace gradients
244 } // namespace tensorflow
245