1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/reader.h"
17
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/cc/saved_model/metrics.h"
20 #include "tensorflow/cc/saved_model/tag_constants.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/path.h"
26 #include "tensorflow/core/platform/resource_loader.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace {
31
TestDataPbTxt()32 string TestDataPbTxt() {
33 return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
34 "half_plus_two_pbtxt", "00000123");
35 }
36
TestDataSharded()37 string TestDataSharded() {
38 return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
39 "half_plus_two", "00000123");
40 }
41
42 class ReaderTest : public ::testing::Test {
43 protected:
ReaderTest()44 ReaderTest() {}
45
CheckMetaGraphDef(const MetaGraphDef & meta_graph_def)46 void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) {
47 const auto& tags = meta_graph_def.meta_info_def().tags();
48 EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) !=
49 tags.end());
50 EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), "");
51 EXPECT_EQ(
52 meta_graph_def.signature_def().at("serving_default").method_name(),
53 "tensorflow/serving/predict");
54 }
55 };
56
TEST_F(ReaderTest,TagMatch)57 TEST_F(ReaderTest, TagMatch) {
58 MetaGraphDef meta_graph_def;
59
60 const string export_dir = GetDataDependencyFilepath(TestDataSharded());
61 TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
62 &meta_graph_def));
63 CheckMetaGraphDef(meta_graph_def);
64 }
65
TEST_F(ReaderTest,NoTagMatch)66 TEST_F(ReaderTest, NoTagMatch) {
67 MetaGraphDef meta_graph_def;
68
69 const string export_dir = GetDataDependencyFilepath(TestDataSharded());
70 Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
71 &meta_graph_def);
72 EXPECT_FALSE(st.ok());
73 EXPECT_TRUE(absl::StrContains(
74 st.error_message(),
75 "Could not find meta graph def matching supplied tags: { missing-tag }"))
76 << st.error_message();
77 }
78
TEST_F(ReaderTest,NoTagMatchMultiple)79 TEST_F(ReaderTest, NoTagMatchMultiple) {
80 MetaGraphDef meta_graph_def;
81
82 const string export_dir = GetDataDependencyFilepath(TestDataSharded());
83 Status st = ReadMetaGraphDefFromSavedModel(
84 export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
85 EXPECT_FALSE(st.ok());
86 EXPECT_TRUE(absl::StrContains(
87 st.error_message(),
88 "Could not find meta graph def matching supplied tags: "))
89 << st.error_message();
90 }
91
TEST_F(ReaderTest,PbtxtFormat)92 TEST_F(ReaderTest, PbtxtFormat) {
93 MetaGraphDef meta_graph_def;
94
95 const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
96 TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
97 &meta_graph_def));
98 CheckMetaGraphDef(meta_graph_def);
99 }
100
TEST_F(ReaderTest,InvalidExportPath)101 TEST_F(ReaderTest, InvalidExportPath) {
102 MetaGraphDef meta_graph_def;
103
104 const string export_dir = GetDataDependencyFilepath("missing-path");
105 Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
106 &meta_graph_def);
107 EXPECT_FALSE(st.ok());
108 }
109
TEST_F(ReaderTest,ReadSavedModelDebugInfoIfPresent)110 TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) {
111 const string export_dir = GetDataDependencyFilepath(TestDataSharded());
112 std::unique_ptr<GraphDebugInfo> debug_info_proto;
113 TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto));
114 }
115
TEST_F(ReaderTest,MetricsNotUpdatedFailedRead)116 TEST_F(ReaderTest, MetricsNotUpdatedFailedRead) {
117 MetaGraphDef meta_graph_def;
118 const int read_count_v1 = metrics::SavedModelRead("1").value();
119 const int read_count_v2 = metrics::SavedModelRead("2").value();
120
121 const string export_dir = GetDataDependencyFilepath("missing-path");
122 Status st =
123 ReadMetaGraphDefFromSavedModel(export_dir, {"serve"}, &meta_graph_def);
124
125 EXPECT_FALSE(st.ok());
126 EXPECT_EQ(metrics::SavedModelRead("1").value(), read_count_v1);
127 EXPECT_EQ(metrics::SavedModelRead("2").value(), read_count_v2);
128 }
129
TEST_F(ReaderTest,MetricsUpdatedSuccessfulRead)130 TEST_F(ReaderTest, MetricsUpdatedSuccessfulRead) {
131 MetaGraphDef meta_graph_def;
132 const int read_count_v1 = metrics::SavedModelRead("1").value();
133
134 const string export_dir = GetDataDependencyFilepath(TestDataSharded());
135 Status st =
136 ReadMetaGraphDefFromSavedModel(export_dir, {"serve"}, &meta_graph_def);
137 EXPECT_EQ(metrics::SavedModelRead("1").value(), read_count_v1 + 1);
138 }
139
140 } // namespace
141 } // namespace tensorflow
142