• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include <memory>
13 
14 #include "absl/types/span.h"
15 #include "tensorflow/c/eager/abstract_tensor_handle.h"
16 #include "tensorflow/c/eager/c_api_experimental.h"
17 #include "tensorflow/c/eager/c_api_unified_experimental.h"
18 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
19 #include "tensorflow/c/eager/gradients.h"
20 #include "tensorflow/c/eager/gradients_internal.h"
21 #include "tensorflow/c/eager/gradients_util.h"
22 #include "tensorflow/c/eager/mnist_gradients_testutil.h"
23 #include "tensorflow/c/experimental/gradients/math_grad.h"
24 #include "tensorflow/c/experimental/gradients/nn_grad.h"
25 #include "tensorflow/c/experimental/ops/array_ops.h"
26 #include "tensorflow/c/tf_status_helper.h"
27 #include "tensorflow/c/tf_tensor.h"
28 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/tensor_float_32_utils.h"
31 #include "tensorflow/core/platform/test.h"
32 
33 namespace tensorflow {
34 namespace gradients {
35 namespace internal {
36 namespace {
37 using tensorflow::TF_StatusPtr;
38 
39 class CppGradients
40     : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
41  protected:
SetUp()42   void SetUp() override {
43     TF_StatusPtr status(TF_NewStatus());
44     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
45     Status s = StatusFromTF_Status(status.get());
46     CHECK_EQ(errors::OK, s.code()) << s.error_message();
47 
48     // Computing numerical gradients with TensorFloat-32 is numerically
49     // unstable. Some forward pass tests also fail with TensorFloat-32 due to
50     // low tolerances
51     enable_tensor_float_32_execution(false);
52   }
53 };
54 
RegisterGradients(GradientRegistry * registry)55 Status RegisterGradients(GradientRegistry* registry) {
56   TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
57   TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
58   TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
59   TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
60   TF_RETURN_IF_ERROR(
61       registry->Register("SparseSoftmaxCrossEntropyWithLogits",
62                          SparseSoftmaxCrossEntropyWithLogitsRegisterer));
63   return Status::OK();
64 }
65 
TEST_P(CppGradients,TestMatMulGrad)66 TEST_P(CppGradients, TestMatMulGrad) {
67   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
68       TF_NewStatus(), TF_DeleteStatus);
69   AbstractContextPtr ctx;
70   {
71     AbstractContext* ctx_raw = nullptr;
72     Status s =
73         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
74     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
75     ctx.reset(ctx_raw);
76   }
77 
78   float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
79   int64_t A_dims[] = {2, 2};
80   float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
81   int64_t B_dims[] = {2, 2};
82   int num_dims = 2;
83 
84   AbstractTensorHandlePtr A =
85       GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
86   AbstractTensorHandlePtr B =
87       GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
88 
89   GradientRegistry registry;
90   Status s = RegisterGradients(&registry);
91   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
92 
93   /* Pseudo-code:
94    *
95    * tape.watch(A)
96    * tape.watch(B)
97    * Y = AB
98    * outputs = tape.gradient(Y, [A, B])
99    */
100 
101   std::vector<AbstractTensorHandle*> outputs(2);
102   s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
103                absl::MakeSpan(outputs),
104                /*use_function=*/!std::get<2>(GetParam()), registry);
105   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
106 
107   TF_Tensor* dA_tensor;
108   s = GetValue(outputs[0], &dA_tensor);
109   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
110 
111   float result_data[4] = {0};
112   memcpy(&result_data[0], TF_TensorData(dA_tensor),
113          TF_TensorByteSize(dA_tensor));
114 
115   float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
116   float tolerance = 1e-3;
117   for (int j = 0; j < 4; j++) {
118     ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
119   }
120 
121   TF_Tensor* dB_tensor;
122   s = GetValue(outputs[1], &dB_tensor);
123   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
124 
125   memcpy(&result_data[0], TF_TensorData(dB_tensor),
126          TF_TensorByteSize(dB_tensor));
127 
128   float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
129   for (int j = 0; j < 4; j++) {
130     ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
131   }
132 
133   outputs[0]->Unref();
134   outputs[1]->Unref();
135   TF_DeleteTensor(dA_tensor);
136   TF_DeleteTensor(dB_tensor);
137 }
138 
TEST_P(CppGradients,TestMNISTForward)139 TEST_P(CppGradients, TestMNISTForward) {
140   AbstractContextPtr ctx;
141   {
142     AbstractContext* ctx_raw = nullptr;
143     Status s =
144         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
145     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
146     ctx.reset(ctx_raw);
147   }
148 
149   // X = data
150   float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
151   int64_t dims[] = {2, 2};
152   int num_dims = 2;
153   AbstractTensorHandlePtr X =
154       GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
155 
156   // W1 = first weights
157   float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
158   AbstractTensorHandlePtr W1 =
159       GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
160 
161   // W2 = second weights
162   float W2_vals[] = {.1f, .2f, .3f, -.5f};
163   AbstractTensorHandlePtr W2 =
164       GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
165 
166   // y = labels
167   int y_vals[] = {1, 1};
168   int64_t dims_y[] = {2};
169   num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
170   AbstractTensorHandlePtr y =
171       GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
172 
173   GradientRegistry registry;
174 
175   // Run the Forward Pass
176   std::vector<AbstractTensorHandle*> outputs(2);
177   Status s =
178       RunModel(MNISTForwardModel, ctx.get(),
179                {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
180                /*use_function=*/!std::get<2>(GetParam()), registry);
181   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
182 
183   // Verify the Results
184   TF_Tensor* scores_tensor;
185   s = GetValue(outputs[0], &scores_tensor);
186   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
187 
188   float result_data[4] = {0};
189   memcpy(&result_data[0], TF_TensorData(scores_tensor),
190          TF_TensorByteSize(scores_tensor));
191 
192   float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
193   float tolerance = 1e-3;
194   for (int j = 0; j < 4; j++) {
195     ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
196   }
197 
198   TF_Tensor* loss_vals_tensor;
199   s = GetValue(outputs[1], &loss_vals_tensor);
200   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
201 
202   memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
203          TF_TensorByteSize(loss_vals_tensor));
204   float expected_losses[2] = {9.6f, 27.2f};
205   for (int j = 0; j < 2; j++) {
206     ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
207   }
208 
209   outputs[0]->Unref();
210   outputs[1]->Unref();
211   TF_DeleteTensor(scores_tensor);
212   TF_DeleteTensor(loss_vals_tensor);
213 }
214 
TEST_P(CppGradients,TestMNISTForward2)215 TEST_P(CppGradients, TestMNISTForward2) {
216   AbstractContextPtr ctx;
217   {
218     AbstractContext* ctx_raw = nullptr;
219     Status s =
220         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
221     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
222     ctx.reset(ctx_raw);
223   }
224 
225   // X = data
226   float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
227   int64_t X_dims[] = {3, 2};
228   int num_dims = 2;
229   AbstractTensorHandlePtr X =
230       GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
231 
232   // W1 = first weights
233   float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
234   int64_t dims[] = {2, 2};
235   AbstractTensorHandlePtr W1 =
236       GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
237 
238   // W2 = second weights
239   float W2_vals[] = {.1f, .2f, .3f, -.5f};
240   AbstractTensorHandlePtr W2 =
241       GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
242 
243   // y = labels
244   int y_vals[] = {1, 1, 1};
245   int64_t y_dims[] = {3};
246   num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
247   AbstractTensorHandlePtr y =
248       GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
249 
250   GradientRegistry registry;
251 
252   // Run the Forward Pass
253   std::vector<AbstractTensorHandle*> outputs(2);
254   Status s =
255       RunModel(MNISTForwardModel, ctx.get(),
256                {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
257                /*use_function=*/!std::get<2>(GetParam()), registry);
258   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
259 
260   // Verify the Results
261   TF_Tensor* scores_tensor;
262   s = GetValue(outputs[0], &scores_tensor);
263   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
264 
265   float result_data[6] = {0};
266   memcpy(&result_data[0], TF_TensorData(scores_tensor),
267          TF_TensorByteSize(scores_tensor));
268 
269   float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
270   float tolerance = 1e-3;
271   for (int j = 0; j < 6; j++) {
272     ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
273   }
274 
275   TF_Tensor* loss_vals_tensor;
276   s = GetValue(outputs[1], &loss_vals_tensor);
277   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
278 
279   memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
280          TF_TensorByteSize(loss_vals_tensor));
281   float expected_losses[3] = {9.6f, 27.2f, 44.8f};
282   for (int j = 0; j < 3; j++) {
283     ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
284   }
285 
286   outputs[0]->Unref();
287   outputs[1]->Unref();
288   TF_DeleteTensor(scores_tensor);
289   TF_DeleteTensor(loss_vals_tensor);
290 }
291 
TEST_P(CppGradients,TestMatMulTranspose)292 TEST_P(CppGradients, TestMatMulTranspose) {
293   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
294       TF_NewStatus(), TF_DeleteStatus);
295 
296   AbstractContextPtr ctx;
297   {
298     AbstractContext* ctx_raw = nullptr;
299     Status s =
300         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
301     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
302     ctx.reset(ctx_raw);
303   }
304 
305   // X = data
306   float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
307   int64_t X_dims[] = {2, 3};
308   int num_dims = 2;
309   AbstractTensorHandlePtr X =
310       GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
311 
312   // W1 = first weights
313   float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
314   int64_t dims[] = {2, 2};
315   AbstractTensorHandlePtr W1 =
316       GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
317 
318   GradientRegistry registry;
319 
320   // Run the MatMul Op
321   std::vector<AbstractTensorHandle*> outputs(1);
322 
323   Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
324                       absl::MakeSpan(outputs),
325                       /*use_function=*/!std::get<2>(GetParam()), registry);
326 
327   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
328 
329   // Verify the Results
330   TF_Tensor* scores_tensor;
331   s = GetValue(outputs[0], &scores_tensor);
332   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
333 
334   float result_data[6] = {0};
335   memcpy(&result_data[0], TF_TensorData(scores_tensor),
336          TF_TensorByteSize(scores_tensor));
337 
338   float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
339   float tolerance = 1e-3;
340   for (int j = 0; j < 6; j++) {
341     ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
342   }
343 }
344 
TEST_P(CppGradients,TestMNISTGrad)345 TEST_P(CppGradients, TestMNISTGrad) {
346   bool use_function = !std::get<2>(GetParam());
347   if (use_function) {
348     // TODO(b/168850692): Enable this.
349     GTEST_SKIP() << "Can't take gradient of "
350                     "SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
351   }
352   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
353       TF_NewStatus(), TF_DeleteStatus);
354   AbstractContextPtr ctx;
355   {
356     AbstractContext* ctx_raw = nullptr;
357     Status s =
358         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
359     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
360     ctx.reset(ctx_raw);
361   }
362 
363   // X = data
364   float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
365   int64_t X_dims[] = {2, 2};
366   int num_dims = 2;
367   AbstractTensorHandlePtr X =
368       GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
369 
370   // W1 = first weights
371   float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
372   int64_t dims[] = {2, 2};
373   AbstractTensorHandlePtr W1 =
374       GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
375 
376   // W2 = second weights
377   float W2_vals[] = {.1f, .2f, .3f, -.5f};
378   AbstractTensorHandlePtr W2 =
379       GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
380 
381   // y = labels
382   int y_vals[] = {1, 1};
383   int64_t y_dims[] = {2};
384   num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
385   AbstractTensorHandlePtr y =
386       GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
387 
388   // Register Grads
389   GradientRegistry registry;
390   Status s = RegisterGradients(&registry);
391   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
392 
393   /* Pseudo-code:
394    *
395    *
396    * tape.watch(W1)
397    * tape.watch(W2)
398    * mm = X*W1
399    * hidden = Relu(mm)
400    * scores = W2*hidden
401    * loss = SoftmaxLoss(scores, y)
402    * outputs = tape.gradient(loss, [A, B])
403    *
404    */
405 
406   std::vector<AbstractTensorHandle*> outputs(3);
407   s = RunModel(MNISTGradModel, ctx.get(),
408                {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
409                /*use_function=*/!std::get<2>(GetParam()), registry);
410   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
411 
412   float tolerance = 1e-3;
413   TF_Tensor* dW1_tensor;
414   s = GetValue(outputs[0], &dW1_tensor);
415   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
416 
417   float result_data[4] = {0};
418   memcpy(&result_data[0], TF_TensorData(dW1_tensor),
419          TF_TensorByteSize(dW1_tensor));
420 
421   float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
422   for (int j = 0; j < 4; j++) {
423     ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
424   }
425 
426   TF_Tensor* dW2_tensor;
427   s = GetValue(outputs[1], &dW2_tensor);
428   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
429 
430   memcpy(&result_data[0], TF_TensorData(dW2_tensor),
431          TF_TensorByteSize(dW2_tensor));
432 
433   float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f};  // dLoss
434   for (int j = 0; j < 4; j++) {
435     ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
436   }
437 
438   outputs[0]->Unref();
439   outputs[1]->Unref();
440   outputs[2]->Unref();
441   TF_DeleteTensor(dW1_tensor);
442   TF_DeleteTensor(dW2_tensor);
443 }
444 
TEST_P(CppGradients,TestScalarMul)445 TEST_P(CppGradients, TestScalarMul) {
446   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
447       TF_NewStatus(), TF_DeleteStatus);
448 
449   AbstractContextPtr ctx;
450   {
451     AbstractContext* ctx_raw = nullptr;
452     Status s =
453         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
454     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
455     ctx.reset(ctx_raw);
456   }
457 
458   AbstractTensorHandlePtr eta;
459   {
460     AbstractTensorHandle* x_raw = nullptr;
461     Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
462     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
463     eta.reset(x_raw);
464   }
465 
466   float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
467   int64_t A_dims[] = {2, 2};
468   int num_dims = 2;
469 
470   AbstractTensorHandlePtr A =
471       GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
472 
473   GradientRegistry registry;
474   std::vector<AbstractTensorHandle*> outputs(1);
475   Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
476                       absl::MakeSpan(outputs),
477                       /*use_function=*/!std::get<2>(GetParam()), registry);
478   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
479 
480   TF_Tensor* dA_tensor;
481   s = GetValue(outputs[0], &dA_tensor);
482   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
483 
484   float result_data[4] = {0};
485   memcpy(&result_data[0], TF_TensorData(dA_tensor),
486          TF_TensorByteSize(dA_tensor));
487 
488   float tolerance = 1e-3;
489   float eta_val = 1.5f;
490   for (int j = 0; j < 4; j++) {
491     ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
492   }
493 
494   outputs[0]->Unref();
495   TF_DeleteTensor(dA_tensor);
496 }
497 
TEST_P(CppGradients,TestMNIST_Training)498 TEST_P(CppGradients, TestMNIST_Training) {
499   bool use_function = !std::get<2>(GetParam());
500   if (use_function) {
501     // TODO(b/168850692): Enable this.
502     GTEST_SKIP() << "Can't take gradient of "
503                     "SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
504   }
505   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
506       TF_NewStatus(), TF_DeleteStatus);
507 
508   AbstractContextPtr ctx;
509   {
510     AbstractContext* ctx_raw = nullptr;
511     Status s =
512         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
513     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
514     ctx.reset(ctx_raw);
515   }
516 
517   // X = data
518   float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
519   int64_t X_dims[] = {2, 2};
520   int num_dims = 2;
521   AbstractTensorHandlePtr X =
522       GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
523 
524   // TODO(amturati): use random initializer for weights instead of
525   // constant values.
526 
527   // W1 = first weights
528   float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
529   int64_t dims[] = {2, 2};
530   AbstractTensorHandlePtr W1 =
531       GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
532 
533   // W2 = second weights
534   float W2_vals[] = {.1f, .2f, .3f, -.5f};
535   AbstractTensorHandlePtr W2 =
536       GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
537 
538   // y = labels
539   int y_vals[] = {1, 1};
540   int64_t y_dims[] = {2};
541   num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
542   AbstractTensorHandlePtr y =
543       GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
544 
545   // Register Grads
546   GradientRegistry registry;
547   Status s = RegisterGradients(&registry);
548   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
549 
550   // Prepare for training
551   std::vector<AbstractTensorHandle*> weights;
552   weights.push_back(W1.get());
553   weights.push_back(W2.get());
554 
555   // Set learning rate to be 1e-1
556   AbstractTensorHandle* learning_rate = nullptr;
557   s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
558   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
559 
560   // Train
561   int num_iters = 10;
562   std::vector<AbstractTensorHandle*> mnist_outputs(3);
563   std::vector<AbstractTensorHandle*> grads(2);
564   for (int i = 0; i < num_iters; i++) {
565     // Run Forward Pass
566     s = RunModel(MNISTGradModel, ctx.get(),
567                  {X.get(), weights[0], weights[1], y.get()},
568                  absl::MakeSpan(mnist_outputs),
569                  /*use_function=*/!std::get<2>(GetParam()), registry);
570     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
571 
572     // Fill grads
573     grads[0] = mnist_outputs[0];
574     grads[1] = mnist_outputs[1];
575 
576     // Gradient Update
577     s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
578     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
579   }
580 
581   grads[0]->Unref();          // release W1_grad
582   grads[1]->Unref();          // release W2_grad
583   mnist_outputs[2]->Unref();  // release loss
584 }
585 
586 #ifdef PLATFORM_GOOGLE
587 INSTANTIATE_TEST_SUITE_P(
588     UnifiedCAPI, CppGradients,
589     ::testing::Combine(::testing::Values("graphdef", "mlir"),
590                        /*tfrt*/ ::testing::Values(false),
591                        /*executing_eagerly*/ ::testing::Values(true, false)));
592 #else
593 INSTANTIATE_TEST_SUITE_P(
594     UnifiedCAPI, CppGradients,
595     ::testing::Combine(::testing::Values("graphdef", "mlir"),
596                        /*tfrt*/ ::testing::Values(false),
597                        /*executing_eagerly*/ ::testing::Values(true, false)));
598 #endif
599 }  // namespace
600 }  // namespace internal
601 }  // namespace gradients
602 }  // namespace tensorflow
603