1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include <string>
19
20 #include "tensorflow/core/framework/fake_input.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_base.h"
24 #include "tensorflow/core/framework/resource_handle.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/dense_update_functor.h"
27 #include "tensorflow/core/kernels/ops_testutil.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace tensorflow {
31 namespace {
32
33 class MockResource : public ResourceBase {
34 public:
MockResource(bool * alive,int payload)35 MockResource(bool* alive, int payload) : alive_(alive), payload_(payload) {
36 if (alive_ != nullptr) {
37 *alive_ = true;
38 }
39 }
~MockResource()40 ~MockResource() override {
41 if (alive_ != nullptr) {
42 *alive_ = false;
43 }
44 }
DebugString() const45 string DebugString() const override { return ""; }
46 bool* alive_;
47 int payload_;
48 };
49
50 class MockHandleCreationOpKernel : public OpKernel {
51 public:
MockHandleCreationOpKernel(OpKernelConstruction * ctx)52 explicit MockHandleCreationOpKernel(OpKernelConstruction* ctx)
53 : OpKernel(ctx) {}
54
Compute(OpKernelContext * ctx)55 void Compute(OpKernelContext* ctx) override {
56 bool* alive = reinterpret_cast<bool*>(ctx->input(0).scalar<int64>()());
57 int payload = ctx->input(1).scalar<int>()();
58 AllocatorAttributes attr;
59 Tensor handle_tensor;
60 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
61 &handle_tensor, attr));
62 handle_tensor.scalar<ResourceHandle>()() =
63 ResourceHandle::MakeRefCountingHandle(new MockResource(alive, payload),
64 ctx->device()->name(), {},
65 ctx->stack_trace());
66 ctx->set_output(0, handle_tensor);
67 }
68 };
69
70 REGISTER_OP("MockHandleCreationOp")
71 .Input("alive: int64")
72 .Input("payload: int32")
73 .Output("output: resource");
74
75 REGISTER_KERNEL_BUILDER(Name("MockHandleCreationOp").Device(DEVICE_CPU),
76 MockHandleCreationOpKernel);
77
78 class MockHandleCreationOpTest : public OpsTestBase {
79 protected:
MakeOp()80 void MakeOp() {
81 TF_ASSERT_OK(
82 NodeDefBuilder("mock_handle_creation_op", "MockHandleCreationOp")
83 .Input(FakeInput(DT_INT64))
84 .Input(FakeInput(DT_INT32))
85 .Finalize(node_def()));
86 TF_ASSERT_OK(InitOp());
87 }
88 };
89
TEST_F(MockHandleCreationOpTest,RefCounting)90 TEST_F(MockHandleCreationOpTest, RefCounting) {
91 MakeOp();
92 bool alive = false;
93 int payload = -123;
94
95 // Feed and run
96 AddInputFromArray<int64>(TensorShape({}), {reinterpret_cast<int64>(&alive)});
97 AddInputFromArray<int32>(TensorShape({}), {payload});
98 TF_ASSERT_OK(RunOpKernel());
99 EXPECT_TRUE(alive);
100
101 // Check the output.
102 Tensor* output = GetOutput(0);
103 ResourceHandle& output_handle = output->scalar<ResourceHandle>()();
104 ResourceBase* base = output_handle.resource().get();
105 EXPECT_TRUE(base);
106 EXPECT_EQ(base->RefCount(), 1);
107 MockResource* mock = output_handle.GetResource<MockResource>().ValueOrDie();
108 EXPECT_TRUE(mock);
109 EXPECT_EQ(mock->payload_, payload);
110 EXPECT_EQ(base->RefCount(), 1);
111
112 // context_->outputs_ holds the last ref to MockResource
113 context_.reset();
114 EXPECT_FALSE(alive);
115 // For some reason if we don't call context_.reset(), it will trigger a
116 // segfault (only in -c fastbuild) when it's called by ~OpsTestBase().
117 }
118
119 using CPUDevice = Eigen::ThreadPoolDevice;
120
121 template <typename T>
122 class MockCopyOpKernel : public OpKernel {
123 public:
MockCopyOpKernel(OpKernelConstruction * ctx)124 explicit MockCopyOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
125
Compute(OpKernelContext * ctx)126 void Compute(OpKernelContext* ctx) override {
127 const Tensor& input_tensor = ctx->input(0);
128 AllocatorAttributes attr;
129 Tensor output_tensor;
130 OP_REQUIRES_OK(
131 ctx, ctx->allocate_temp(input_tensor.dtype(), input_tensor.shape(),
132 &output_tensor, attr));
133 // copy_functor will properly call copy constructors on the elements
134 functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor;
135 copy_functor(ctx->eigen_device<CPUDevice>(), output_tensor.flat<T>(),
136 input_tensor.flat<T>());
137 ctx->set_output(0, output_tensor);
138 }
139 };
140
141 REGISTER_OP("MockCopyOp").Attr("T: type").Input("input: T").Output("output: T");
142
143 REGISTER_KERNEL_BUILDER(
144 Name("MockCopyOp").Device(DEVICE_CPU).TypeConstraint<ResourceHandle>("T"),
145 MockCopyOpKernel<ResourceHandle>);
146
147 class MockCopyOpTest : public OpsTestBase {
148 protected:
MakeOp()149 void MakeOp() {
150 TF_ASSERT_OK(NodeDefBuilder("mock_copy_op", "MockCopyOp")
151 .Input(FakeInput(DT_RESOURCE))
152 .Finalize(node_def()));
153 TF_ASSERT_OK(InitOp());
154 }
155 };
156
is_equal_handles(const ResourceHandle & a,const ResourceHandle & b)157 bool is_equal_handles(const ResourceHandle& a, const ResourceHandle& b) {
158 return a.resource() == b.resource() && a.name() == b.name() &&
159 a.maybe_type_name() == b.maybe_type_name() &&
160 a.hash_code() == b.hash_code() && a.device() == b.device() &&
161 a.container() == b.container();
162 }
163
TEST_F(MockCopyOpTest,RefCounting)164 TEST_F(MockCopyOpTest, RefCounting) {
165 MakeOp();
166 int payload = -123;
167
168 // Feed and run
169 AddInputFromArray<ResourceHandle>(
170 TensorShape({}),
171 {ResourceHandle::MakeRefCountingHandle(new MockResource(nullptr, payload),
172 device_->name(), {}, {})});
173 const Tensor* input = inputs_[0].tensor;
174 EXPECT_EQ(input->scalar<ResourceHandle>()().resource()->RefCount(), 1);
175 TF_ASSERT_OK(RunOpKernel());
176
177 // Check the output.
178 Tensor* output = GetOutput(0);
179 test::ExpectTensorEqual<ResourceHandle>(*output, *input, is_equal_handles);
180 EXPECT_EQ(input->scalar<ResourceHandle>()().resource()->RefCount(), 2);
181 }
182
183 } // namespace
184 } // namespace tensorflow
185