• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/fake_input.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_base.h"
24 #include "tensorflow/core/framework/resource_handle.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/dense_update_functor.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 class MockResource : public ResourceBase {
34  public:
MockResource(bool * alive,int payload)35   MockResource(bool* alive, int payload) : alive_(alive), payload_(payload) {
36     if (alive_ != nullptr) {
37       *alive_ = true;
38     }
39   }
~MockResource()40   ~MockResource() override {
41     if (alive_ != nullptr) {
42       *alive_ = false;
43     }
44   }
DebugString() const45   string DebugString() const override { return ""; }
46   bool* alive_;
47   int payload_;
48 };
49 
50 class MockHandleCreationOpKernel : public OpKernel {
51  public:
MockHandleCreationOpKernel(OpKernelConstruction * ctx)52   explicit MockHandleCreationOpKernel(OpKernelConstruction* ctx)
53       : OpKernel(ctx) {}
54 
Compute(OpKernelContext * ctx)55   void Compute(OpKernelContext* ctx) override {
56     bool* alive = reinterpret_cast<bool*>(ctx->input(0).scalar<int64>()());
57     int payload = ctx->input(1).scalar<int>()();
58     AllocatorAttributes attr;
59     Tensor handle_tensor;
60     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
61                                            &handle_tensor, attr));
62     handle_tensor.scalar<ResourceHandle>()() =
63         ResourceHandle::MakeRefCountingHandle(new MockResource(alive, payload),
64                                               ctx->device()->name(), {},
65                                               ctx->stack_trace());
66     ctx->set_output(0, handle_tensor);
67   }
68 };
69 
70 REGISTER_OP("MockHandleCreationOp")
71     .Input("alive: int64")
72     .Input("payload: int32")
73     .Output("output: resource");
74 
75 REGISTER_KERNEL_BUILDER(Name("MockHandleCreationOp").Device(DEVICE_CPU),
76                         MockHandleCreationOpKernel);
77 
78 class MockHandleCreationOpTest : public OpsTestBase {
79  protected:
MakeOp()80   void MakeOp() {
81     TF_ASSERT_OK(
82         NodeDefBuilder("mock_handle_creation_op", "MockHandleCreationOp")
83             .Input(FakeInput(DT_INT64))
84             .Input(FakeInput(DT_INT32))
85             .Finalize(node_def()));
86     TF_ASSERT_OK(InitOp());
87   }
88 };
89 
TEST_F(MockHandleCreationOpTest,RefCounting)90 TEST_F(MockHandleCreationOpTest, RefCounting) {
91   MakeOp();
92   bool alive = false;
93   int payload = -123;
94 
95   // Feed and run
96   AddInputFromArray<int64>(TensorShape({}), {reinterpret_cast<int64>(&alive)});
97   AddInputFromArray<int32>(TensorShape({}), {payload});
98   TF_ASSERT_OK(RunOpKernel());
99   EXPECT_TRUE(alive);
100 
101   // Check the output.
102   Tensor* output = GetOutput(0);
103   ResourceHandle& output_handle = output->scalar<ResourceHandle>()();
104   ResourceBase* base = output_handle.resource().get();
105   EXPECT_TRUE(base);
106   EXPECT_EQ(base->RefCount(), 1);
107   MockResource* mock = output_handle.GetResource<MockResource>().ValueOrDie();
108   EXPECT_TRUE(mock);
109   EXPECT_EQ(mock->payload_, payload);
110   EXPECT_EQ(base->RefCount(), 1);
111 
112   // context_->outputs_ holds the last ref to MockResource
113   context_.reset();
114   EXPECT_FALSE(alive);
115   // For some reason if we don't call context_.reset(), it will trigger a
116   // segfault (only in -c fastbuild) when it's called by ~OpsTestBase().
117 }
118 
119 using CPUDevice = Eigen::ThreadPoolDevice;
120 
121 template <typename T>
122 class MockCopyOpKernel : public OpKernel {
123  public:
MockCopyOpKernel(OpKernelConstruction * ctx)124   explicit MockCopyOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
125 
Compute(OpKernelContext * ctx)126   void Compute(OpKernelContext* ctx) override {
127     const Tensor& input_tensor = ctx->input(0);
128     AllocatorAttributes attr;
129     Tensor output_tensor;
130     OP_REQUIRES_OK(
131         ctx, ctx->allocate_temp(input_tensor.dtype(), input_tensor.shape(),
132                                 &output_tensor, attr));
133     // copy_functor will properly call copy constructors on the elements
134     functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor;
135     copy_functor(ctx->eigen_device<CPUDevice>(), output_tensor.flat<T>(),
136                  input_tensor.flat<T>());
137     ctx->set_output(0, output_tensor);
138   }
139 };
140 
141 REGISTER_OP("MockCopyOp").Attr("T: type").Input("input: T").Output("output: T");
142 
143 REGISTER_KERNEL_BUILDER(
144     Name("MockCopyOp").Device(DEVICE_CPU).TypeConstraint<ResourceHandle>("T"),
145     MockCopyOpKernel<ResourceHandle>);
146 
147 class MockCopyOpTest : public OpsTestBase {
148  protected:
MakeOp()149   void MakeOp() {
150     TF_ASSERT_OK(NodeDefBuilder("mock_copy_op", "MockCopyOp")
151                      .Input(FakeInput(DT_RESOURCE))
152                      .Finalize(node_def()));
153     TF_ASSERT_OK(InitOp());
154   }
155 };
156 
is_equal_handles(const ResourceHandle & a,const ResourceHandle & b)157 bool is_equal_handles(const ResourceHandle& a, const ResourceHandle& b) {
158   return a.resource() == b.resource() && a.name() == b.name() &&
159          a.maybe_type_name() == b.maybe_type_name() &&
160          a.hash_code() == b.hash_code() && a.device() == b.device() &&
161          a.container() == b.container();
162 }
163 
TEST_F(MockCopyOpTest,RefCounting)164 TEST_F(MockCopyOpTest, RefCounting) {
165   MakeOp();
166   int payload = -123;
167 
168   // Feed and run
169   AddInputFromArray<ResourceHandle>(
170       TensorShape({}),
171       {ResourceHandle::MakeRefCountingHandle(new MockResource(nullptr, payload),
172                                              device_->name(), {}, {})});
173   const Tensor* input = inputs_[0].tensor;
174   EXPECT_EQ(input->scalar<ResourceHandle>()().resource()->RefCount(), 1);
175   TF_ASSERT_OK(RunOpKernel());
176 
177   // Check the output.
178   Tensor* output = GetOutput(0);
179   test::ExpectTensorEqual<ResourceHandle>(*output, *input, is_equal_handles);
180   EXPECT_EQ(input->scalar<ResourceHandle>()().resource()->RefCount(), 2);
181 }
182 
183 }  // namespace
184 }  // namespace tensorflow
185