1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/array_grad.h"
16
17 #include "tensorflow/c/eager/c_api_test_util.h"
18 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
19 #include "tensorflow/c/eager/unified_api_testutil.h"
20 #include "tensorflow/c/experimental/gradients/grad_test_helper.h"
21 #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
22 #include "tensorflow/c/experimental/ops/array_ops.h"
23 #include "tensorflow/c/tf_status_helper.h"
24 #include "tensorflow/core/platform/tensor_float_32_utils.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace gradients {
29 namespace internal {
30 namespace {
31
32 using tensorflow::TF_StatusPtr;
33
IdentityNModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs)34 Status IdentityNModel(AbstractContext* ctx,
35 absl::Span<AbstractTensorHandle* const> inputs,
36 absl::Span<AbstractTensorHandle*> outputs) {
37 std::vector<AbstractTensorHandle*> temp_outputs(2);
38 TF_RETURN_IF_ERROR(
39 ops::IdentityN(ctx, inputs, absl::MakeSpan(temp_outputs), "IdentityN"));
40 // Although, `ops::IdentityN` returns 2 tensors, the first tensor isn't needed
41 // for computing gradient so we could safely drop it.
42 outputs[0] = temp_outputs[1];
43 temp_outputs[0]->Unref();
44 return OkStatus();
45 }
46
47 class CppGradients
48 : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
49 protected:
SetUp()50 void SetUp() override {
51 TF_StatusPtr status(TF_NewStatus());
52 TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
53 status_ = StatusFromTF_Status(status.get());
54 ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
55
56 {
57 AbstractContext* ctx_raw = nullptr;
58 status_ =
59 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
60 ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
61 immediate_execution_ctx_.reset(ctx_raw);
62 }
63
64 // Computing numerical gradients with TensorFloat-32 is numerically
65 // unstable. Some forward pass tests also fail with TensorFloat-32 due to
66 // low tolerances
67 enable_tensor_float_32_execution(false);
68 }
69
70 AbstractContextPtr immediate_execution_ctx_;
71 GradientRegistry registry_;
72 Status status_;
73
74 public:
UseMlir() const75 bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
UseFunction() const76 bool UseFunction() const { return std::get<2>(GetParam()); }
77 };
78
TEST_P(CppGradients,TestIdentityNGrad)79 TEST_P(CppGradients, TestIdentityNGrad) {
80 // This test is interesting because the current implementation of GradientTape
81 // would return [0, 1] whereas we use build_default_zeros_grads=false here
82 // so we get back [nullptr, 1].
83
84 AbstractTensorHandlePtr x1;
85 {
86 AbstractTensorHandle* x1_raw = nullptr;
87 status_ = TestScalarTensorHandle<float, TF_FLOAT>(
88 immediate_execution_ctx_.get(), 1.0f, &x1_raw);
89 ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
90 x1.reset(x1_raw);
91 }
92
93 AbstractTensorHandlePtr x2;
94 {
95 AbstractTensorHandle* x2_raw = nullptr;
96 status_ = TestScalarTensorHandle<float, TF_FLOAT>(
97 immediate_execution_ctx_.get(), 1.0f, &x2_raw);
98 ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
99 x2.reset(x2_raw);
100 }
101
102 status_ = registry_.Register("IdentityN", IdentityNRegisterer);
103 ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
104 auto IdentityNGradModel = BuildGradModel(IdentityNModel, registry_);
105
106 std::vector<AbstractTensorHandle*> outputs(2);
107 status_ =
108 RunModel(IdentityNGradModel, immediate_execution_ctx_.get(),
109 {x1.get(), x2.get()}, absl::MakeSpan(outputs), UseFunction());
110 ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
111 EXPECT_EQ(outputs[0], nullptr);
112 ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {1.0f}, /*dims*/ {},
113 /*abs_error*/ 0));
114 outputs[1]->Unref();
115 }
116
117 #ifdef PLATFORM_GOOGLE
118 INSTANTIATE_TEST_SUITE_P(
119 UnifiedCAPI, CppGradients,
120 ::testing::Combine(::testing::Values("graphdef", "mlir"),
121 /*tfrt*/ ::testing::Values(false),
122 /*use_function*/ ::testing::Values(true, false)));
123 #else
124 INSTANTIATE_TEST_SUITE_P(
125 UnifiedCAPI, CppGradients,
126 ::testing::Combine(::testing::Values("graphdef", "mlir"),
127 /*tfrt*/ ::testing::Values(false),
128 /*use_function*/ ::testing::Values(true, false)));
129 #endif
130 } // namespace
131 } // namespace internal
132 } // namespace gradients
133 } // namespace tensorflow
134