1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include <memory>
13
14 #include "absl/types/span.h"
15 #include "tensorflow/c/eager/abstract_tensor_handle.h"
16 #include "tensorflow/c/eager/c_api_experimental.h"
17 #include "tensorflow/c/eager/c_api_unified_experimental.h"
18 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
19 #include "tensorflow/c/eager/gradients.h"
20 #include "tensorflow/c/eager/gradients_internal.h"
21 #include "tensorflow/c/eager/gradients_util.h"
22 #include "tensorflow/c/eager/mnist_gradients_testutil.h"
23 #include "tensorflow/c/experimental/gradients/math_grad.h"
24 #include "tensorflow/c/experimental/gradients/nn_grad.h"
25 #include "tensorflow/c/experimental/ops/array_ops.h"
26 #include "tensorflow/c/tf_status_helper.h"
27 #include "tensorflow/c/tf_tensor.h"
28 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/tensor_float_32_utils.h"
31 #include "tensorflow/core/platform/test.h"
32
33 namespace tensorflow {
34 namespace gradients {
35 namespace internal {
36 namespace {
37 using tensorflow::TF_StatusPtr;
38
39 class CppGradients
40 : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
41 protected:
SetUp()42 void SetUp() override {
43 TF_StatusPtr status(TF_NewStatus());
44 TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
45 Status s = StatusFromTF_Status(status.get());
46 CHECK_EQ(errors::OK, s.code()) << s.error_message();
47
48 // Computing numerical gradients with TensorFloat-32 is numerically
49 // unstable. Some forward pass tests also fail with TensorFloat-32 due to
50 // low tolerances
51 enable_tensor_float_32_execution(false);
52 }
53 };
54
RegisterGradients(GradientRegistry * registry)55 Status RegisterGradients(GradientRegistry* registry) {
56 TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
57 TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
58 TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
59 TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
60 TF_RETURN_IF_ERROR(
61 registry->Register("SparseSoftmaxCrossEntropyWithLogits",
62 SparseSoftmaxCrossEntropyWithLogitsRegisterer));
63 return Status::OK();
64 }
65
TEST_P(CppGradients,TestMatMulGrad)66 TEST_P(CppGradients, TestMatMulGrad) {
67 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
68 TF_NewStatus(), TF_DeleteStatus);
69 AbstractContextPtr ctx;
70 {
71 AbstractContext* ctx_raw = nullptr;
72 Status s =
73 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
74 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
75 ctx.reset(ctx_raw);
76 }
77
78 float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
79 int64_t A_dims[] = {2, 2};
80 float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
81 int64_t B_dims[] = {2, 2};
82 int num_dims = 2;
83
84 AbstractTensorHandlePtr A =
85 GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
86 AbstractTensorHandlePtr B =
87 GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
88
89 GradientRegistry registry;
90 Status s = RegisterGradients(®istry);
91 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
92
93 /* Pseudo-code:
94 *
95 * tape.watch(A)
96 * tape.watch(B)
97 * Y = AB
98 * outputs = tape.gradient(Y, [A, B])
99 */
100
101 std::vector<AbstractTensorHandle*> outputs(2);
102 s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
103 absl::MakeSpan(outputs),
104 /*use_function=*/!std::get<2>(GetParam()), registry);
105 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
106
107 TF_Tensor* dA_tensor;
108 s = GetValue(outputs[0], &dA_tensor);
109 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
110
111 float result_data[4] = {0};
112 memcpy(&result_data[0], TF_TensorData(dA_tensor),
113 TF_TensorByteSize(dA_tensor));
114
115 float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
116 float tolerance = 1e-3;
117 for (int j = 0; j < 4; j++) {
118 ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
119 }
120
121 TF_Tensor* dB_tensor;
122 s = GetValue(outputs[1], &dB_tensor);
123 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
124
125 memcpy(&result_data[0], TF_TensorData(dB_tensor),
126 TF_TensorByteSize(dB_tensor));
127
128 float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
129 for (int j = 0; j < 4; j++) {
130 ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
131 }
132
133 outputs[0]->Unref();
134 outputs[1]->Unref();
135 TF_DeleteTensor(dA_tensor);
136 TF_DeleteTensor(dB_tensor);
137 }
138
TEST_P(CppGradients,TestMNISTForward)139 TEST_P(CppGradients, TestMNISTForward) {
140 AbstractContextPtr ctx;
141 {
142 AbstractContext* ctx_raw = nullptr;
143 Status s =
144 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
145 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
146 ctx.reset(ctx_raw);
147 }
148
149 // X = data
150 float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
151 int64_t dims[] = {2, 2};
152 int num_dims = 2;
153 AbstractTensorHandlePtr X =
154 GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
155
156 // W1 = first weights
157 float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
158 AbstractTensorHandlePtr W1 =
159 GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
160
161 // W2 = second weights
162 float W2_vals[] = {.1f, .2f, .3f, -.5f};
163 AbstractTensorHandlePtr W2 =
164 GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
165
166 // y = labels
167 int y_vals[] = {1, 1};
168 int64_t dims_y[] = {2};
169 num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
170 AbstractTensorHandlePtr y =
171 GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
172
173 GradientRegistry registry;
174
175 // Run the Forward Pass
176 std::vector<AbstractTensorHandle*> outputs(2);
177 Status s =
178 RunModel(MNISTForwardModel, ctx.get(),
179 {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
180 /*use_function=*/!std::get<2>(GetParam()), registry);
181 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
182
183 // Verify the Results
184 TF_Tensor* scores_tensor;
185 s = GetValue(outputs[0], &scores_tensor);
186 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
187
188 float result_data[4] = {0};
189 memcpy(&result_data[0], TF_TensorData(scores_tensor),
190 TF_TensorByteSize(scores_tensor));
191
192 float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
193 float tolerance = 1e-3;
194 for (int j = 0; j < 4; j++) {
195 ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
196 }
197
198 TF_Tensor* loss_vals_tensor;
199 s = GetValue(outputs[1], &loss_vals_tensor);
200 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
201
202 memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
203 TF_TensorByteSize(loss_vals_tensor));
204 float expected_losses[2] = {9.6f, 27.2f};
205 for (int j = 0; j < 2; j++) {
206 ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
207 }
208
209 outputs[0]->Unref();
210 outputs[1]->Unref();
211 TF_DeleteTensor(scores_tensor);
212 TF_DeleteTensor(loss_vals_tensor);
213 }
214
TEST_P(CppGradients,TestMNISTForward2)215 TEST_P(CppGradients, TestMNISTForward2) {
216 AbstractContextPtr ctx;
217 {
218 AbstractContext* ctx_raw = nullptr;
219 Status s =
220 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
221 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
222 ctx.reset(ctx_raw);
223 }
224
225 // X = data
226 float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
227 int64_t X_dims[] = {3, 2};
228 int num_dims = 2;
229 AbstractTensorHandlePtr X =
230 GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
231
232 // W1 = first weights
233 float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
234 int64_t dims[] = {2, 2};
235 AbstractTensorHandlePtr W1 =
236 GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
237
238 // W2 = second weights
239 float W2_vals[] = {.1f, .2f, .3f, -.5f};
240 AbstractTensorHandlePtr W2 =
241 GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
242
243 // y = labels
244 int y_vals[] = {1, 1, 1};
245 int64_t y_dims[] = {3};
246 num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
247 AbstractTensorHandlePtr y =
248 GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
249
250 GradientRegistry registry;
251
252 // Run the Forward Pass
253 std::vector<AbstractTensorHandle*> outputs(2);
254 Status s =
255 RunModel(MNISTForwardModel, ctx.get(),
256 {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
257 /*use_function=*/!std::get<2>(GetParam()), registry);
258 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
259
260 // Verify the Results
261 TF_Tensor* scores_tensor;
262 s = GetValue(outputs[0], &scores_tensor);
263 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
264
265 float result_data[6] = {0};
266 memcpy(&result_data[0], TF_TensorData(scores_tensor),
267 TF_TensorByteSize(scores_tensor));
268
269 float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
270 float tolerance = 1e-3;
271 for (int j = 0; j < 6; j++) {
272 ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
273 }
274
275 TF_Tensor* loss_vals_tensor;
276 s = GetValue(outputs[1], &loss_vals_tensor);
277 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
278
279 memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
280 TF_TensorByteSize(loss_vals_tensor));
281 float expected_losses[3] = {9.6f, 27.2f, 44.8f};
282 for (int j = 0; j < 3; j++) {
283 ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
284 }
285
286 outputs[0]->Unref();
287 outputs[1]->Unref();
288 TF_DeleteTensor(scores_tensor);
289 TF_DeleteTensor(loss_vals_tensor);
290 }
291
TEST_P(CppGradients,TestMatMulTranspose)292 TEST_P(CppGradients, TestMatMulTranspose) {
293 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
294 TF_NewStatus(), TF_DeleteStatus);
295
296 AbstractContextPtr ctx;
297 {
298 AbstractContext* ctx_raw = nullptr;
299 Status s =
300 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
301 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
302 ctx.reset(ctx_raw);
303 }
304
305 // X = data
306 float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
307 int64_t X_dims[] = {2, 3};
308 int num_dims = 2;
309 AbstractTensorHandlePtr X =
310 GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
311
312 // W1 = first weights
313 float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
314 int64_t dims[] = {2, 2};
315 AbstractTensorHandlePtr W1 =
316 GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
317
318 GradientRegistry registry;
319
320 // Run the MatMul Op
321 std::vector<AbstractTensorHandle*> outputs(1);
322
323 Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
324 absl::MakeSpan(outputs),
325 /*use_function=*/!std::get<2>(GetParam()), registry);
326
327 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
328
329 // Verify the Results
330 TF_Tensor* scores_tensor;
331 s = GetValue(outputs[0], &scores_tensor);
332 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
333
334 float result_data[6] = {0};
335 memcpy(&result_data[0], TF_TensorData(scores_tensor),
336 TF_TensorByteSize(scores_tensor));
337
338 float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
339 float tolerance = 1e-3;
340 for (int j = 0; j < 6; j++) {
341 ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
342 }
343 }
344
TEST_P(CppGradients,TestMNISTGrad)345 TEST_P(CppGradients, TestMNISTGrad) {
346 bool use_function = !std::get<2>(GetParam());
347 if (use_function) {
348 // TODO(b/168850692): Enable this.
349 GTEST_SKIP() << "Can't take gradient of "
350 "SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
351 }
352 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
353 TF_NewStatus(), TF_DeleteStatus);
354 AbstractContextPtr ctx;
355 {
356 AbstractContext* ctx_raw = nullptr;
357 Status s =
358 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
359 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
360 ctx.reset(ctx_raw);
361 }
362
363 // X = data
364 float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
365 int64_t X_dims[] = {2, 2};
366 int num_dims = 2;
367 AbstractTensorHandlePtr X =
368 GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
369
370 // W1 = first weights
371 float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
372 int64_t dims[] = {2, 2};
373 AbstractTensorHandlePtr W1 =
374 GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
375
376 // W2 = second weights
377 float W2_vals[] = {.1f, .2f, .3f, -.5f};
378 AbstractTensorHandlePtr W2 =
379 GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
380
381 // y = labels
382 int y_vals[] = {1, 1};
383 int64_t y_dims[] = {2};
384 num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
385 AbstractTensorHandlePtr y =
386 GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
387
388 // Register Grads
389 GradientRegistry registry;
390 Status s = RegisterGradients(®istry);
391 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
392
393 /* Pseudo-code:
394 *
395 *
396 * tape.watch(W1)
397 * tape.watch(W2)
398 * mm = X*W1
399 * hidden = Relu(mm)
400 * scores = W2*hidden
401 * loss = SoftmaxLoss(scores, y)
402 * outputs = tape.gradient(loss, [A, B])
403 *
404 */
405
406 std::vector<AbstractTensorHandle*> outputs(3);
407 s = RunModel(MNISTGradModel, ctx.get(),
408 {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
409 /*use_function=*/!std::get<2>(GetParam()), registry);
410 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
411
412 float tolerance = 1e-3;
413 TF_Tensor* dW1_tensor;
414 s = GetValue(outputs[0], &dW1_tensor);
415 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
416
417 float result_data[4] = {0};
418 memcpy(&result_data[0], TF_TensorData(dW1_tensor),
419 TF_TensorByteSize(dW1_tensor));
420
421 float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
422 for (int j = 0; j < 4; j++) {
423 ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
424 }
425
426 TF_Tensor* dW2_tensor;
427 s = GetValue(outputs[1], &dW2_tensor);
428 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
429
430 memcpy(&result_data[0], TF_TensorData(dW2_tensor),
431 TF_TensorByteSize(dW2_tensor));
432
433 float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss
434 for (int j = 0; j < 4; j++) {
435 ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
436 }
437
438 outputs[0]->Unref();
439 outputs[1]->Unref();
440 outputs[2]->Unref();
441 TF_DeleteTensor(dW1_tensor);
442 TF_DeleteTensor(dW2_tensor);
443 }
444
TEST_P(CppGradients,TestScalarMul)445 TEST_P(CppGradients, TestScalarMul) {
446 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
447 TF_NewStatus(), TF_DeleteStatus);
448
449 AbstractContextPtr ctx;
450 {
451 AbstractContext* ctx_raw = nullptr;
452 Status s =
453 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
454 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
455 ctx.reset(ctx_raw);
456 }
457
458 AbstractTensorHandlePtr eta;
459 {
460 AbstractTensorHandle* x_raw = nullptr;
461 Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
462 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
463 eta.reset(x_raw);
464 }
465
466 float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
467 int64_t A_dims[] = {2, 2};
468 int num_dims = 2;
469
470 AbstractTensorHandlePtr A =
471 GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
472
473 GradientRegistry registry;
474 std::vector<AbstractTensorHandle*> outputs(1);
475 Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
476 absl::MakeSpan(outputs),
477 /*use_function=*/!std::get<2>(GetParam()), registry);
478 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
479
480 TF_Tensor* dA_tensor;
481 s = GetValue(outputs[0], &dA_tensor);
482 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
483
484 float result_data[4] = {0};
485 memcpy(&result_data[0], TF_TensorData(dA_tensor),
486 TF_TensorByteSize(dA_tensor));
487
488 float tolerance = 1e-3;
489 float eta_val = 1.5f;
490 for (int j = 0; j < 4; j++) {
491 ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
492 }
493
494 outputs[0]->Unref();
495 TF_DeleteTensor(dA_tensor);
496 }
497
TEST_P(CppGradients,TestMNIST_Training)498 TEST_P(CppGradients, TestMNIST_Training) {
499 bool use_function = !std::get<2>(GetParam());
500 if (use_function) {
501 // TODO(b/168850692): Enable this.
502 GTEST_SKIP() << "Can't take gradient of "
503 "SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
504 }
505 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
506 TF_NewStatus(), TF_DeleteStatus);
507
508 AbstractContextPtr ctx;
509 {
510 AbstractContext* ctx_raw = nullptr;
511 Status s =
512 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
513 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
514 ctx.reset(ctx_raw);
515 }
516
517 // X = data
518 float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
519 int64_t X_dims[] = {2, 2};
520 int num_dims = 2;
521 AbstractTensorHandlePtr X =
522 GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
523
524 // TODO(amturati): use random initializer for weights instead of
525 // constant values.
526
527 // W1 = first weights
528 float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
529 int64_t dims[] = {2, 2};
530 AbstractTensorHandlePtr W1 =
531 GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
532
533 // W2 = second weights
534 float W2_vals[] = {.1f, .2f, .3f, -.5f};
535 AbstractTensorHandlePtr W2 =
536 GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
537
538 // y = labels
539 int y_vals[] = {1, 1};
540 int64_t y_dims[] = {2};
541 num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
542 AbstractTensorHandlePtr y =
543 GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
544
545 // Register Grads
546 GradientRegistry registry;
547 Status s = RegisterGradients(®istry);
548 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
549
550 // Prepare for training
551 std::vector<AbstractTensorHandle*> weights;
552 weights.push_back(W1.get());
553 weights.push_back(W2.get());
554
555 // Set learning rate to be 1e-1
556 AbstractTensorHandle* learning_rate = nullptr;
557 s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
558 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
559
560 // Train
561 int num_iters = 10;
562 std::vector<AbstractTensorHandle*> mnist_outputs(3);
563 std::vector<AbstractTensorHandle*> grads(2);
564 for (int i = 0; i < num_iters; i++) {
565 // Run Forward Pass
566 s = RunModel(MNISTGradModel, ctx.get(),
567 {X.get(), weights[0], weights[1], y.get()},
568 absl::MakeSpan(mnist_outputs),
569 /*use_function=*/!std::get<2>(GetParam()), registry);
570 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
571
572 // Fill grads
573 grads[0] = mnist_outputs[0];
574 grads[1] = mnist_outputs[1];
575
576 // Gradient Update
577 s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
578 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
579 }
580
581 grads[0]->Unref(); // release W1_grad
582 grads[1]->Unref(); // release W2_grad
583 mnist_outputs[2]->Unref(); // release loss
584 }
585
586 #ifdef PLATFORM_GOOGLE
587 INSTANTIATE_TEST_SUITE_P(
588 UnifiedCAPI, CppGradients,
589 ::testing::Combine(::testing::Values("graphdef", "mlir"),
590 /*tfrt*/ ::testing::Values(false),
591 /*executing_eagerly*/ ::testing::Values(true, false)));
592 #else
593 INSTANTIATE_TEST_SUITE_P(
594 UnifiedCAPI, CppGradients,
595 ::testing::Combine(::testing::Values("graphdef", "mlir"),
596 /*tfrt*/ ::testing::Values(false),
597 /*executing_eagerly*/ ::testing::Values(true, false)));
598 #endif
599 } // namespace
600 } // namespace internal
601 } // namespace gradients
602 } // namespace tensorflow
603