• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/node_matchers.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/control_flow_ops.h"
22 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
23 #include "tensorflow/cc/ops/math_ops.h"
24 
25 namespace tensorflow {
26 namespace testing {
27 namespace {
28 
29 using ::testing::_;
30 
31 using testing::matchers::AssignedDevice;
32 using testing::matchers::Attr;
33 using testing::matchers::ConstantValue;
34 using testing::matchers::CtrlDeps;
35 using testing::matchers::Inputs;
36 using testing::matchers::Name;
37 using testing::matchers::NodeWith;
38 using testing::matchers::Op;
39 using testing::matchers::Out;
40 
41 template <typename M, typename T>
Explain(const T & t,const M & m)42 string Explain(const T& t, const M& m) {
43   ::testing::StringMatchResultListener listener;
44   EXPECT_THAT(t, ::testing::Not(m));  // For the error message.
45   EXPECT_FALSE(m.MatchAndExplain(t, &listener));
46   return listener.str();
47 }
48 
TEST(NodeMatchers,CheckAgainstConstant)49 TEST(NodeMatchers, CheckAgainstConstant) {
50   Scope root = Scope::NewRootScope().ExitOnError();
51   Output placeholder =
52       ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
53 
54   EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder")));
55   EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder")));
56   EXPECT_THAT(placeholder.node(),
57               NodeWith(Op("Placeholder"), Name("placeholder")));
58   EXPECT_THAT(placeholder.node(),
59               NodeWith(Name("placeholder"), Op("Placeholder")));
60   EXPECT_THAT(placeholder.node(), NodeWith(Inputs()));
61   EXPECT_THAT(placeholder.node(),
62               NodeWith(Op("Placeholder"), Name("placeholder"), Inputs()));
63 
64   EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))),
65             "\nexpected op Add but found Placeholder");
66   EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))),
67             "\nexpected name add but found placeholder");
68   EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(Out(NodeWith())))),
69             "\nexpected 1 inputs but node has 0");
70 }
71 
TEST(NodeMatchers,CheckAgainstBinary)72 TEST(NodeMatchers, CheckAgainstBinary) {
73   Scope root = Scope::NewRootScope().ExitOnError();
74 
75   Output placeholder_a =
76       ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
77   Output placeholder_b =
78       ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
79   Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b);
80 
81   EXPECT_THAT(add.node(),
82               NodeWith(Op("Add"), Name("add"),
83                        Inputs(Out(NodeWith(Name("placeholder_a"))),
84                               Out(NodeWith(Name("placeholder_b"))))));
85 
86   EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())),
87             "\nexpected 0 inputs but node has 2");
88   EXPECT_EQ(
89       Explain(add.node(), NodeWith(Inputs(Out(NodeWith(Name("blah"))), _))),
90       "\ninput 0 does not match expected:\nname: blah, \nsource does not match "
91       "expected name: blah\n\t\nexpected name blah but found placeholder_a");
92   EXPECT_EQ(
93       Explain(add.node(), NodeWith(Inputs(_, Out(NodeWith(Name("blah")))))),
94       "\ninput 1 does not match expected:\nname: blah, \nsource does not match "
95       "expected name: blah\n\t\nexpected name blah but found placeholder_b");
96 }
97 
TEST(NodeMatchers,CheckControlDependence)98 TEST(NodeMatchers, CheckControlDependence) {
99   Scope root = Scope::NewRootScope().ExitOnError();
100 
101   Output placeholder_a =
102       ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
103   Output placeholder_b =
104       ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
105   Output placeholder_c =
106       ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT);
107   Output placeholder_d =
108       ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT);
109 
110   root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node());
111   root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node());
112 
113   EXPECT_THAT(placeholder_c.node(),
114               NodeWith(Name("placeholder_c"),
115                        CtrlDeps(NodeWith(Name("placeholder_a")),
116                                 NodeWith(Name("placeholder_b")))));
117   EXPECT_THAT(placeholder_d.node(),
118               NodeWith(Name("placeholder_d"), CtrlDeps()));
119 
120   EXPECT_EQ(
121       Explain(placeholder_c.node(), NodeWith(CtrlDeps())),
122       "ctrl_deps, which has 2 elements, does not match expected: is empty");
123   EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))),
124             "ctrl_deps does not match expected: has 1 element and that element "
125             "is any node");
126 }
127 
TEST(NodeMatchers,ConstValue)128 TEST(NodeMatchers, ConstValue) {
129   Scope root = Scope::NewRootScope().ExitOnError();
130   Output placeholder =
131       ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
132   Output const_0d = ops::Const(root.WithOpName("const_0d"), 42);
133 
134   Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}});
135 
136   EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42)));
137   EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d")));
138 
139   EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}})));
140 
141   EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))),
142             "\nexpected op Const but found Placeholder");
143   EXPECT_EQ(
144       Explain(const_0d.node(), NodeWith(ConstantValue(43))),
145       "\nmismatch in constant tensor at index 0 expected = 43 actual = 42");
146   EXPECT_EQ(
147       Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))),
148       "\nwas looking for tensor with 4 elements, found tensor with 1 elements");
149   EXPECT_EQ(
150       Explain(const_2d.node(), NodeWith(ConstantValue(42))),
151       "\nwas looking for tensor with 1 elements, found tensor with 4 elements");
152 }
153 
TEST(NodeMatchers,AssignedDevice)154 TEST(NodeMatchers, AssignedDevice) {
155   Scope root = Scope::NewRootScope().ExitOnError();
156 
157   Output placeholder_a =
158       ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
159   Output placeholder_b =
160       ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
161 
162   Output assigned_add =
163       ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b);
164   assigned_add.node()->set_assigned_device_name(
165       "/job:localhost/replica:0/task:0/device:CPU:0");
166 
167   Output unassigned_add =
168       ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b);
169 
170   EXPECT_THAT(
171       assigned_add.node(),
172       NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0")));
173   EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice("")));
174 
175   EXPECT_EQ(Explain(unassigned_add.node(),
176                     NodeWith(AssignedDevice(
177                         "/job:localhost/replica:0/task:0/device:CPU:0"))),
178             "\nexpected assigned_device "
179             "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\"");
180 }
181 
TEST(NodeMatchers,OutputIndices)182 TEST(NodeMatchers, OutputIndices) {
183   Scope root = Scope::NewRootScope().ExitOnError();
184   Output pred = ops::Placeholder(root.WithOpName("pred"), DT_BOOL);
185 
186   Output data = ops::Placeholder(root.WithOpName("data"), DT_FLOAT);
187   ops::Switch sw(root.WithOpName("switch"), data, pred);
188   Output add = ops::Add(root.WithOpName("add"), sw.output_true,
189                         ops::Placeholder(root.WithOpName("addend"), DT_FLOAT));
190 
191   EXPECT_THAT(add.node(), NodeWith(Inputs(Out(1, NodeWith(Op("Switch"))), _)));
192   EXPECT_EQ(
193       Explain(add.node(), NodeWith(Inputs(Out(0, NodeWith(Op("Switch"))), _))),
194       "\ninput 0 does not match expected:\nop: Switch, \nexpected output slot "
195       "to be 0 but found 1");
196 }
197 
TEST(NodeMatchers,Attrs)198 TEST(NodeMatchers, Attrs) {
199   Scope root = Scope::NewRootScope().ExitOnError();
200   Output enter = ops::internal::Enter(
201       root.WithOpName("enter"),
202       ops::Placeholder(root.WithOpName("data"), DT_FLOAT), "frame_name",
203       ops::internal::Enter::Attrs{}.IsConstant(true));
204   EXPECT_THAT(enter.node(), NodeWith(Attr("is_constant", true)));
205   EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("is_constant", false))),
206             "attribute named is_constant does not match value; expected: "
207             "\"false\", found: \"true\"");
208   EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("missing_attr", false))),
209             "did not find attribute named \"missing_attr\" in node");
210 }
211 
212 }  // namespace
213 }  // namespace testing
214 }  // namespace tensorflow
215