1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/node_matchers.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/control_flow_ops.h"
22 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
23 #include "tensorflow/cc/ops/math_ops.h"
24
25 namespace tensorflow {
26 namespace testing {
27 namespace {
28
29 using ::testing::_;
30
31 using testing::matchers::AssignedDevice;
32 using testing::matchers::Attr;
33 using testing::matchers::ConstantValue;
34 using testing::matchers::CtrlDeps;
35 using testing::matchers::Inputs;
36 using testing::matchers::Name;
37 using testing::matchers::NodeWith;
38 using testing::matchers::Op;
39 using testing::matchers::Out;
40
41 template <typename M, typename T>
Explain(const T & t,const M & m)42 string Explain(const T& t, const M& m) {
43 ::testing::StringMatchResultListener listener;
44 EXPECT_THAT(t, ::testing::Not(m)); // For the error message.
45 EXPECT_FALSE(m.MatchAndExplain(t, &listener));
46 return listener.str();
47 }
48
TEST(NodeMatchers,CheckAgainstConstant)49 TEST(NodeMatchers, CheckAgainstConstant) {
50 Scope root = Scope::NewRootScope().ExitOnError();
51 Output placeholder =
52 ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
53
54 EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder")));
55 EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder")));
56 EXPECT_THAT(placeholder.node(),
57 NodeWith(Op("Placeholder"), Name("placeholder")));
58 EXPECT_THAT(placeholder.node(),
59 NodeWith(Name("placeholder"), Op("Placeholder")));
60 EXPECT_THAT(placeholder.node(), NodeWith(Inputs()));
61 EXPECT_THAT(placeholder.node(),
62 NodeWith(Op("Placeholder"), Name("placeholder"), Inputs()));
63
64 EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))),
65 "\nexpected op Add but found Placeholder");
66 EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))),
67 "\nexpected name add but found placeholder");
68 EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(Out(NodeWith())))),
69 "\nexpected 1 inputs but node has 0");
70 }
71
TEST(NodeMatchers,CheckAgainstBinary)72 TEST(NodeMatchers, CheckAgainstBinary) {
73 Scope root = Scope::NewRootScope().ExitOnError();
74
75 Output placeholder_a =
76 ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
77 Output placeholder_b =
78 ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
79 Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b);
80
81 EXPECT_THAT(add.node(),
82 NodeWith(Op("Add"), Name("add"),
83 Inputs(Out(NodeWith(Name("placeholder_a"))),
84 Out(NodeWith(Name("placeholder_b"))))));
85
86 EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())),
87 "\nexpected 0 inputs but node has 2");
88 EXPECT_EQ(
89 Explain(add.node(), NodeWith(Inputs(Out(NodeWith(Name("blah"))), _))),
90 "\ninput 0 does not match expected:\nname: blah, \nsource does not match "
91 "expected name: blah\n\t\nexpected name blah but found placeholder_a");
92 EXPECT_EQ(
93 Explain(add.node(), NodeWith(Inputs(_, Out(NodeWith(Name("blah")))))),
94 "\ninput 1 does not match expected:\nname: blah, \nsource does not match "
95 "expected name: blah\n\t\nexpected name blah but found placeholder_b");
96 }
97
TEST(NodeMatchers,CheckControlDependence)98 TEST(NodeMatchers, CheckControlDependence) {
99 Scope root = Scope::NewRootScope().ExitOnError();
100
101 Output placeholder_a =
102 ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
103 Output placeholder_b =
104 ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
105 Output placeholder_c =
106 ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT);
107 Output placeholder_d =
108 ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT);
109
110 root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node());
111 root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node());
112
113 EXPECT_THAT(placeholder_c.node(),
114 NodeWith(Name("placeholder_c"),
115 CtrlDeps(NodeWith(Name("placeholder_a")),
116 NodeWith(Name("placeholder_b")))));
117 EXPECT_THAT(placeholder_d.node(),
118 NodeWith(Name("placeholder_d"), CtrlDeps()));
119
120 EXPECT_EQ(
121 Explain(placeholder_c.node(), NodeWith(CtrlDeps())),
122 "ctrl_deps, which has 2 elements, does not match expected: is empty");
123 EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))),
124 "ctrl_deps does not match expected: has 1 element and that element "
125 "is any node");
126 }
127
TEST(NodeMatchers,ConstValue)128 TEST(NodeMatchers, ConstValue) {
129 Scope root = Scope::NewRootScope().ExitOnError();
130 Output placeholder =
131 ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
132 Output const_0d = ops::Const(root.WithOpName("const_0d"), 42);
133
134 Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}});
135
136 EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42)));
137 EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d")));
138
139 EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}})));
140
141 EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))),
142 "\nexpected op Const but found Placeholder");
143 EXPECT_EQ(
144 Explain(const_0d.node(), NodeWith(ConstantValue(43))),
145 "\nmismatch in constant tensor at index 0 expected = 43 actual = 42");
146 EXPECT_EQ(
147 Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))),
148 "\nwas looking for tensor with 4 elements, found tensor with 1 elements");
149 EXPECT_EQ(
150 Explain(const_2d.node(), NodeWith(ConstantValue(42))),
151 "\nwas looking for tensor with 1 elements, found tensor with 4 elements");
152 }
153
TEST(NodeMatchers,AssignedDevice)154 TEST(NodeMatchers, AssignedDevice) {
155 Scope root = Scope::NewRootScope().ExitOnError();
156
157 Output placeholder_a =
158 ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
159 Output placeholder_b =
160 ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
161
162 Output assigned_add =
163 ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b);
164 assigned_add.node()->set_assigned_device_name(
165 "/job:localhost/replica:0/task:0/device:CPU:0");
166
167 Output unassigned_add =
168 ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b);
169
170 EXPECT_THAT(
171 assigned_add.node(),
172 NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0")));
173 EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice("")));
174
175 EXPECT_EQ(Explain(unassigned_add.node(),
176 NodeWith(AssignedDevice(
177 "/job:localhost/replica:0/task:0/device:CPU:0"))),
178 "\nexpected assigned_device "
179 "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\"");
180 }
181
TEST(NodeMatchers,OutputIndices)182 TEST(NodeMatchers, OutputIndices) {
183 Scope root = Scope::NewRootScope().ExitOnError();
184 Output pred = ops::Placeholder(root.WithOpName("pred"), DT_BOOL);
185
186 Output data = ops::Placeholder(root.WithOpName("data"), DT_FLOAT);
187 ops::Switch sw(root.WithOpName("switch"), data, pred);
188 Output add = ops::Add(root.WithOpName("add"), sw.output_true,
189 ops::Placeholder(root.WithOpName("addend"), DT_FLOAT));
190
191 EXPECT_THAT(add.node(), NodeWith(Inputs(Out(1, NodeWith(Op("Switch"))), _)));
192 EXPECT_EQ(
193 Explain(add.node(), NodeWith(Inputs(Out(0, NodeWith(Op("Switch"))), _))),
194 "\ninput 0 does not match expected:\nop: Switch, \nexpected output slot "
195 "to be 0 but found 1");
196 }
197
TEST(NodeMatchers,Attrs)198 TEST(NodeMatchers, Attrs) {
199 Scope root = Scope::NewRootScope().ExitOnError();
200 Output enter = ops::internal::Enter(
201 root.WithOpName("enter"),
202 ops::Placeholder(root.WithOpName("data"), DT_FLOAT), "frame_name",
203 ops::internal::Enter::Attrs{}.IsConstant(true));
204 EXPECT_THAT(enter.node(), NodeWith(Attr("is_constant", true)));
205 EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("is_constant", false))),
206 "attribute named is_constant does not match value; expected: "
207 "\"false\", found: \"true\"");
208 EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("missing_attr", false))),
209 "did not find attribute named \"missing_attr\" in node");
210 }
211
212 } // namespace
213 } // namespace testing
214 } // namespace tensorflow
215