• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/validate.h"
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/graph_def_util.h"
22 #include "tensorflow/core/framework/op_def_builder.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_def_builder.h"
25 #include "tensorflow/core/graph/subgraph.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/test.h"
31 
32 namespace tensorflow {
33 namespace {
34 
35 REGISTER_OP("FloatInput").Output("o: float");
36 REGISTER_OP("Int32Input").Output("o: int32");
37 
TEST(ValidateGraphDefTest,TestValidGraph)38 TEST(ValidateGraphDefTest, TestValidGraph) {
39   const string graph_def_str =
40       "node { name: 'A' op: 'FloatInput' }"
41       "node { name: 'B' op: 'FloatInput' }"
42       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
43       " input: ['A', 'B'] }";
44   GraphDef graph_def;
45   auto parser = protobuf::TextFormat::Parser();
46   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
47   TF_ASSERT_OK(graph::ValidateGraphDef(graph_def, *OpRegistry::Global()));
48 }
49 
TEST(ValidateGraphDefTest,GraphWithUnspecifiedDefaultAttr)50 TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) {
51   const string graph_def_str =
52       "node { name: 'A' op: 'FloatInput' }"
53       "node { name: 'B' op: 'Int32Input' }"
54       "node { "
55       "       name: 'C' op: 'Sum' "
56       "       attr { key: 'T' value { type: DT_FLOAT } }"
57       "       input: ['A', 'B'] "
58       "}";
59   GraphDef graph_def;
60   auto parser = protobuf::TextFormat::Parser();
61   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
62   Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
63   EXPECT_FALSE(s.ok());
64   EXPECT_TRUE(absl::StrContains(s.ToString(), "NodeDef missing attr"));
65 
66   // Add the defaults.
67   TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0));
68 
69   // Validation should succeed.
70   TF_ASSERT_OK(graph::ValidateGraphDef(graph_def, *OpRegistry::Global()));
71 }
72 
TEST(ValidateGraphDefTest,GraphWithUnspecifiedRequiredAttr)73 TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) {
74   // "DstT" attribute is missing.
75   const string graph_def_str =
76       "node { name: 'A' op: 'FloatInput' }"
77       "node { "
78       "       name: 'B' op: 'Cast' "
79       "       attr { key: 'SrcT' value { type: DT_FLOAT } }"
80       "       input: ['A'] "
81       "}";
82   GraphDef graph_def;
83   auto parser = protobuf::TextFormat::Parser();
84   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
85   Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
86   EXPECT_FALSE(s.ok());
87   EXPECT_TRUE(absl::StrContains(s.ToString(), "NodeDef missing attr"));
88 
89   // Add the defaults.
90   TF_ASSERT_OK(AddDefaultAttrsToGraphDef(&graph_def, *OpRegistry::Global(), 0));
91 
92   // Validation should still fail.
93   s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global());
94   EXPECT_FALSE(s.ok());
95   EXPECT_TRUE(absl::StrContains(s.ToString(), "NodeDef missing attr"));
96 }
97 
TEST(ValidateGraphDefAgainstOpListTest,GraphWithOpOnlyInOpList)98 TEST(ValidateGraphDefAgainstOpListTest, GraphWithOpOnlyInOpList) {
99   OpRegistrationData op_reg_data;
100   TF_ASSERT_OK(OpDefBuilder("UniqueSnowflake").Finalize(&op_reg_data));
101   OpList op_list;
102   *op_list.add_op() = op_reg_data.op_def;
103   const string graph_def_str = "node { name: 'A' op: 'UniqueSnowflake' }";
104   GraphDef graph_def;
105   auto parser = protobuf::TextFormat::Parser();
106   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
107   TF_ASSERT_OK(graph::ValidateGraphDefAgainstOpList(graph_def, op_list));
108 }
109 
TEST(ValidateGraphDefAgainstOpListTest,GraphWithGlobalOpNotInOpList)110 TEST(ValidateGraphDefAgainstOpListTest, GraphWithGlobalOpNotInOpList) {
111   OpRegistrationData op_reg_data;
112   TF_ASSERT_OK(OpDefBuilder("NotAnywhere").Finalize(&op_reg_data));
113   OpList op_list;
114   *op_list.add_op() = op_reg_data.op_def;
115   const string graph_def_str = "node { name: 'A' op: 'FloatInput' }";
116   GraphDef graph_def;
117   auto parser = protobuf::TextFormat::Parser();
118   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
119   ASSERT_FALSE(graph::ValidateGraphDefAgainstOpList(graph_def, op_list).ok());
120 }
121 
122 REGISTER_OP("HasDocs").Doc("This is in the summary.");
123 
TEST(GetOpListForValidationTest,ShouldStripDocs)124 TEST(GetOpListForValidationTest, ShouldStripDocs) {
125   bool found_float = false;
126   bool found_int32 = false;
127   bool found_has_docs = false;
128   OpList op_list;
129   graph::GetOpListForValidation(&op_list);
130   for (const OpDef& op_def : op_list.op()) {
131     if (op_def.name() == "FloatInput") {
132       EXPECT_FALSE(found_float);
133       found_float = true;
134     }
135     if (op_def.name() == "Int32Input") {
136       EXPECT_FALSE(found_int32);
137       found_int32 = true;
138     }
139     if (op_def.name() == "HasDocs") {
140       EXPECT_FALSE(found_has_docs);
141       found_has_docs = true;
142       EXPECT_TRUE(op_def.summary().empty());
143     }
144   }
145   EXPECT_TRUE(found_float);
146   EXPECT_TRUE(found_int32);
147   EXPECT_TRUE(found_has_docs);
148 }
149 
TEST(VerifyNoDuplicateNodeNames,NoDuplicateNodeNames)150 TEST(VerifyNoDuplicateNodeNames, NoDuplicateNodeNames) {
151   const string graph_def_str =
152       "node { name: 'A' op: 'FloatInput' }"
153       "node { name: 'B' op: 'Int32Input' }"
154       "node { "
155       "       name: 'C' op: 'Sum' "
156       "       attr { key: 'T' value { type: DT_FLOAT } }"
157       "       input: ['A', 'B'] "
158       "}";
159   GraphDef graph_def;
160   auto parser = protobuf::TextFormat::Parser();
161   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
162   TF_ASSERT_OK(graph::VerifyNoDuplicateNodeNames(graph_def));
163 }
164 
TEST(VerifyNoDuplicateNodeNames,DuplicateNodeNames)165 TEST(VerifyNoDuplicateNodeNames, DuplicateNodeNames) {
166   const string graph_def_str =
167       "node { name: 'A' op: 'FloatInput' }"
168       "node { name: 'A' op: 'Int32Input' }"
169       "node { "
170       "       name: 'C' op: 'Sum' "
171       "       attr { key: 'T' value { type: DT_FLOAT } }"
172       "       input: ['A', 'A'] "
173       "}";
174   GraphDef graph_def;
175   auto parser = protobuf::TextFormat::Parser();
176   CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str;
177   EXPECT_EQ(graph::VerifyNoDuplicateNodeNames(graph_def).code(),
178             tensorflow::error::ALREADY_EXISTS);
179 }
180 
181 }  // namespace
182 }  // namespace tensorflow
183