1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_segment.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/allocator.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/kernels/ops_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/public/version.h"
29
30 namespace tensorflow {
31
32 class OpSegmentTest : public ::testing::Test {
33 protected:
34 DeviceBase device_;
35 std::vector<NodeDef> int32_nodedefs_;
36 std::vector<NodeDef> float_nodedefs_;
37
OpSegmentTest()38 OpSegmentTest() : device_(Env::Default()) {
39 for (int i = 0; i < 10; ++i) {
40 NodeDef def;
41 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
42 .Input("x", 0, DT_INT32)
43 .Input("y", 0, DT_INT32)
44 .Finalize(&def));
45 int32_nodedefs_.push_back(def);
46 TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
47 .Input("x", 0, DT_FLOAT)
48 .Input("y", 0, DT_FLOAT)
49 .Finalize(&def));
50 float_nodedefs_.push_back(def);
51 }
52 }
53
ValidateOpAndTypes(OpKernel * op,const NodeDef & expected,DataType dt)54 void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) {
55 ASSERT_NE(op, nullptr);
56 EXPECT_EQ(expected.DebugString(), op->def().DebugString());
57 EXPECT_EQ(2, op->num_inputs());
58 EXPECT_EQ(dt, op->input_type(0));
59 EXPECT_EQ(dt, op->input_type(1));
60 EXPECT_EQ(1, op->num_outputs());
61 EXPECT_EQ(dt, op->output_type(0));
62 }
63
GetFn(const NodeDef * ndef)64 OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) {
65 return [this, ndef](OpKernel** kernel) {
66 Status s;
67 auto created = CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(),
68 *ndef, TF_GRAPH_DEF_VERSION, &s);
69 if (s.ok()) {
70 *kernel = created.release();
71 }
72 return s;
73 };
74 }
75 };
76
TEST_F(OpSegmentTest,Basic)77 TEST_F(OpSegmentTest, Basic) {
78 OpSegment opseg;
79 OpKernel* op;
80
81 opseg.AddHold("A");
82 opseg.AddHold("B");
83 for (int i = 0; i < 10; ++i) {
84 // Register in session A.
85 auto* ndef = &float_nodedefs_[i];
86 TF_EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef)));
87 ValidateOpAndTypes(op, *ndef, DT_FLOAT);
88
89 // Register in session B.
90 ndef = &int32_nodedefs_[i];
91 TF_EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef)));
92 ValidateOpAndTypes(op, *ndef, DT_INT32);
93 }
94
95 auto reterr = [](OpKernel** kernel) {
96 return errors::Internal("Should not be called");
97 };
98 for (int i = 0; i < 10; ++i) {
99 // Lookup op in session A.
100 TF_EXPECT_OK(
101 opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr));
102 ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT);
103
104 // Lookup op in session B.
105 TF_EXPECT_OK(
106 opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr));
107 ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32);
108 }
109
110 opseg.RemoveHold("A");
111 opseg.RemoveHold("B");
112 }
113
TEST_F(OpSegmentTest,SessionNotFound)114 TEST_F(OpSegmentTest, SessionNotFound) {
115 OpSegment opseg;
116 OpKernel* op;
117 NodeDef def = float_nodedefs_[0];
118 Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
119 EXPECT_TRUE(errors::IsNotFound(s)) << s;
120 }
121
TEST_F(OpSegmentTest,CreateFailure)122 TEST_F(OpSegmentTest, CreateFailure) {
123 OpSegment opseg;
124 OpKernel* op;
125 NodeDef def = float_nodedefs_[0];
126 def.set_op("nonexistop");
127 opseg.AddHold("A");
128 Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
129 EXPECT_TRUE(errors::IsNotFound(s)) << s;
130 opseg.RemoveHold("A");
131 }
132
TEST_F(OpSegmentTest,AddRemoveHolds)133 TEST_F(OpSegmentTest, AddRemoveHolds) {
134 OpSegment opseg;
135 OpKernel* op;
136 const auto& ndef = int32_nodedefs_[0];
137
138 // No op.
139 opseg.RemoveHold("null");
140
141 // Thread1 register the op and wants to ensure it alive.
142 opseg.AddHold("foo");
143 TF_EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef)));
144
145 // Thread2 starts some execution needs "op" to be alive.
146 opseg.AddHold("foo");
147
148 // Thread1 clears session "foo". E.g., a master sends CleanupGraph
149 // before an execution finishes.
150 opseg.RemoveHold("foo");
151
152 // Thread2 should still be able to access "op".
153 ValidateOpAndTypes(op, ndef, DT_INT32);
154
155 // Thread2 then remove its hold on "foo".
156 opseg.RemoveHold("foo");
157 }
158
159 } // namespace tensorflow
160