• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_segment.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/allocator.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/kernels/ops_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/core/public/version.h"
29 
30 namespace tensorflow {
31 
32 class OpSegmentTest : public ::testing::Test {
33  protected:
34   DeviceBase device_;
35   std::vector<NodeDef> int32_nodedefs_;
36   std::vector<NodeDef> float_nodedefs_;
37 
OpSegmentTest()38   OpSegmentTest() : device_(Env::Default()) {
39     for (int i = 0; i < 10; ++i) {
40       NodeDef def;
41       TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
42                       .Input("x", 0, DT_INT32)
43                       .Input("y", 0, DT_INT32)
44                       .Finalize(&def));
45       int32_nodedefs_.push_back(def);
46       TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
47                       .Input("x", 0, DT_FLOAT)
48                       .Input("y", 0, DT_FLOAT)
49                       .Finalize(&def));
50       float_nodedefs_.push_back(def);
51     }
52   }
53 
ValidateOpAndTypes(OpKernel * op,const NodeDef & expected,DataType dt)54   void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) {
55     ASSERT_NE(op, nullptr);
56     EXPECT_EQ(expected.DebugString(), op->def().DebugString());
57     EXPECT_EQ(2, op->num_inputs());
58     EXPECT_EQ(dt, op->input_type(0));
59     EXPECT_EQ(dt, op->input_type(1));
60     EXPECT_EQ(1, op->num_outputs());
61     EXPECT_EQ(dt, op->output_type(0));
62   }
63 
GetFn(const NodeDef * ndef)64   OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) {
65     return [this, ndef](OpKernel** kernel) {
66       Status s;
67       auto created = CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(),
68                                     *ndef, TF_GRAPH_DEF_VERSION, &s);
69       if (s.ok()) {
70         *kernel = created.release();
71       }
72       return s;
73     };
74   }
75 };
76 
TEST_F(OpSegmentTest,Basic)77 TEST_F(OpSegmentTest, Basic) {
78   OpSegment opseg;
79   OpKernel* op;
80 
81   opseg.AddHold("A");
82   opseg.AddHold("B");
83   for (int i = 0; i < 10; ++i) {
84     // Register in session A.
85     auto* ndef = &float_nodedefs_[i];
86     TF_EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef)));
87     ValidateOpAndTypes(op, *ndef, DT_FLOAT);
88 
89     // Register in session B.
90     ndef = &int32_nodedefs_[i];
91     TF_EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef)));
92     ValidateOpAndTypes(op, *ndef, DT_INT32);
93   }
94 
95   auto reterr = [](OpKernel** kernel) {
96     return errors::Internal("Should not be called");
97   };
98   for (int i = 0; i < 10; ++i) {
99     // Lookup op in session A.
100     TF_EXPECT_OK(
101         opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr));
102     ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT);
103 
104     // Lookup op in session B.
105     TF_EXPECT_OK(
106         opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr));
107     ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32);
108   }
109 
110   opseg.RemoveHold("A");
111   opseg.RemoveHold("B");
112 }
113 
TEST_F(OpSegmentTest,SessionNotFound)114 TEST_F(OpSegmentTest, SessionNotFound) {
115   OpSegment opseg;
116   OpKernel* op;
117   NodeDef def = float_nodedefs_[0];
118   Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
119   EXPECT_TRUE(errors::IsNotFound(s)) << s;
120 }
121 
TEST_F(OpSegmentTest,CreateFailure)122 TEST_F(OpSegmentTest, CreateFailure) {
123   OpSegment opseg;
124   OpKernel* op;
125   NodeDef def = float_nodedefs_[0];
126   def.set_op("nonexistop");
127   opseg.AddHold("A");
128   Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
129   EXPECT_TRUE(errors::IsNotFound(s)) << s;
130   opseg.RemoveHold("A");
131 }
132 
TEST_F(OpSegmentTest,AddRemoveHolds)133 TEST_F(OpSegmentTest, AddRemoveHolds) {
134   OpSegment opseg;
135   OpKernel* op;
136   const auto& ndef = int32_nodedefs_[0];
137 
138   // No op.
139   opseg.RemoveHold("null");
140 
141   // Thread1 register the op and wants to ensure it alive.
142   opseg.AddHold("foo");
143   TF_EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef)));
144 
145   // Thread2 starts some execution needs "op" to be alive.
146   opseg.AddHold("foo");
147 
148   // Thread1 clears session "foo".  E.g., a master sends CleanupGraph
149   // before an execution finishes.
150   opseg.RemoveHold("foo");
151 
152   // Thread2 should still be able to access "op".
153   ValidateOpAndTypes(op, ndef, DT_INT32);
154 
155   // Thread2 then remove its hold on "foo".
156   opseg.RemoveHold("foo");
157 }
158 
159 }  // namespace tensorflow
160