• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/array_grad.h"
16 
17 #include "tensorflow/c/eager/c_api_test_util.h"
18 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
19 #include "tensorflow/c/eager/unified_api_testutil.h"
20 #include "tensorflow/c/experimental/gradients/grad_test_helper.h"
21 #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
22 #include "tensorflow/c/experimental/ops/array_ops.h"
23 #include "tensorflow/c/tf_status_helper.h"
24 #include "tensorflow/core/platform/tensor_float_32_utils.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace gradients {
29 namespace internal {
30 namespace {
31 
32 using tensorflow::TF_StatusPtr;
33 
IdentityNModel(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs)34 Status IdentityNModel(AbstractContext* ctx,
35                       absl::Span<AbstractTensorHandle* const> inputs,
36                       absl::Span<AbstractTensorHandle*> outputs) {
37   std::vector<AbstractTensorHandle*> temp_outputs(2);
38   TF_RETURN_IF_ERROR(
39       ops::IdentityN(ctx, inputs, absl::MakeSpan(temp_outputs), "IdentityN"));
40   // Although, `ops::IdentityN` returns 2 tensors, the first tensor isn't needed
41   // for computing gradient so we could safely drop it.
42   outputs[0] = temp_outputs[1];
43   temp_outputs[0]->Unref();
44   return OkStatus();
45 }
46 
47 class CppGradients
48     : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
49  protected:
SetUp()50   void SetUp() override {
51     TF_StatusPtr status(TF_NewStatus());
52     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
53     status_ = StatusFromTF_Status(status.get());
54     ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
55 
56     {
57       AbstractContext* ctx_raw = nullptr;
58       status_ =
59           BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
60       ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
61       immediate_execution_ctx_.reset(ctx_raw);
62     }
63 
64     // Computing numerical gradients with TensorFloat-32 is numerically
65     // unstable. Some forward pass tests also fail with TensorFloat-32 due to
66     // low tolerances
67     enable_tensor_float_32_execution(false);
68   }
69 
70   AbstractContextPtr immediate_execution_ctx_;
71   GradientRegistry registry_;
72   Status status_;
73 
74  public:
UseMlir() const75   bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
UseFunction() const76   bool UseFunction() const { return std::get<2>(GetParam()); }
77 };
78 
TEST_P(CppGradients,TestIdentityNGrad)79 TEST_P(CppGradients, TestIdentityNGrad) {
80   // This test is interesting because the current implementation of GradientTape
81   // would return [0, 1] whereas we use build_default_zeros_grads=false here
82   // so we get back [nullptr, 1].
83 
84   AbstractTensorHandlePtr x1;
85   {
86     AbstractTensorHandle* x1_raw = nullptr;
87     status_ = TestScalarTensorHandle<float, TF_FLOAT>(
88         immediate_execution_ctx_.get(), 1.0f, &x1_raw);
89     ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
90     x1.reset(x1_raw);
91   }
92 
93   AbstractTensorHandlePtr x2;
94   {
95     AbstractTensorHandle* x2_raw = nullptr;
96     status_ = TestScalarTensorHandle<float, TF_FLOAT>(
97         immediate_execution_ctx_.get(), 1.0f, &x2_raw);
98     ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
99     x2.reset(x2_raw);
100   }
101 
102   status_ = registry_.Register("IdentityN", IdentityNRegisterer);
103   ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
104   auto IdentityNGradModel = BuildGradModel(IdentityNModel, registry_);
105 
106   std::vector<AbstractTensorHandle*> outputs(2);
107   status_ =
108       RunModel(IdentityNGradModel, immediate_execution_ctx_.get(),
109                {x1.get(), x2.get()}, absl::MakeSpan(outputs), UseFunction());
110   ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
111   EXPECT_EQ(outputs[0], nullptr);
112   ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {1.0f}, /*dims*/ {},
113                                            /*abs_error*/ 0));
114   outputs[1]->Unref();
115 }
116 
117 #ifdef PLATFORM_GOOGLE
118 INSTANTIATE_TEST_SUITE_P(
119     UnifiedCAPI, CppGradients,
120     ::testing::Combine(::testing::Values("graphdef", "mlir"),
121                        /*tfrt*/ ::testing::Values(false),
122                        /*use_function*/ ::testing::Values(true, false)));
123 #else
124 INSTANTIATE_TEST_SUITE_P(
125     UnifiedCAPI, CppGradients,
126     ::testing::Combine(::testing::Values("graphdef", "mlir"),
127                        /*tfrt*/ ::testing::Values(false),
128                        /*use_function*/ ::testing::Values(true, false)));
129 #endif
130 }  // namespace
131 }  // namespace internal
132 }  // namespace gradients
133 }  // namespace tensorflow
134