• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/reader.h"
17 
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/cc/saved_model/metrics.h"
20 #include "tensorflow/cc/saved_model/tag_constants.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/path.h"
26 #include "tensorflow/core/platform/resource_loader.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace {
31 
TestDataPbTxt()32 string TestDataPbTxt() {
33   return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
34                       "half_plus_two_pbtxt", "00000123");
35 }
36 
TestDataSharded()37 string TestDataSharded() {
38   return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
39                       "half_plus_two", "00000123");
40 }
41 
42 class ReaderTest : public ::testing::Test {
43  protected:
ReaderTest()44   ReaderTest() {}
45 
CheckMetaGraphDef(const MetaGraphDef & meta_graph_def)46   void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) {
47     const auto& tags = meta_graph_def.meta_info_def().tags();
48     EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) !=
49                 tags.end());
50     EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), "");
51     EXPECT_EQ(
52         meta_graph_def.signature_def().at("serving_default").method_name(),
53         "tensorflow/serving/predict");
54   }
55 };
56 
TEST_F(ReaderTest,TagMatch)57 TEST_F(ReaderTest, TagMatch) {
58   MetaGraphDef meta_graph_def;
59 
60   const string export_dir = GetDataDependencyFilepath(TestDataSharded());
61   TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
62                                               &meta_graph_def));
63   CheckMetaGraphDef(meta_graph_def);
64 }
65 
TEST_F(ReaderTest,NoTagMatch)66 TEST_F(ReaderTest, NoTagMatch) {
67   MetaGraphDef meta_graph_def;
68 
69   const string export_dir = GetDataDependencyFilepath(TestDataSharded());
70   Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
71                                              &meta_graph_def);
72   EXPECT_FALSE(st.ok());
73   EXPECT_TRUE(absl::StrContains(
74       st.error_message(),
75       "Could not find meta graph def matching supplied tags: { missing-tag }"))
76       << st.error_message();
77 }
78 
TEST_F(ReaderTest,NoTagMatchMultiple)79 TEST_F(ReaderTest, NoTagMatchMultiple) {
80   MetaGraphDef meta_graph_def;
81 
82   const string export_dir = GetDataDependencyFilepath(TestDataSharded());
83   Status st = ReadMetaGraphDefFromSavedModel(
84       export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
85   EXPECT_FALSE(st.ok());
86   EXPECT_TRUE(absl::StrContains(
87       st.error_message(),
88       "Could not find meta graph def matching supplied tags: "))
89       << st.error_message();
90 }
91 
TEST_F(ReaderTest,PbtxtFormat)92 TEST_F(ReaderTest, PbtxtFormat) {
93   MetaGraphDef meta_graph_def;
94 
95   const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
96   TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
97                                               &meta_graph_def));
98   CheckMetaGraphDef(meta_graph_def);
99 }
100 
TEST_F(ReaderTest,InvalidExportPath)101 TEST_F(ReaderTest, InvalidExportPath) {
102   MetaGraphDef meta_graph_def;
103 
104   const string export_dir = GetDataDependencyFilepath("missing-path");
105   Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
106                                              &meta_graph_def);
107   EXPECT_FALSE(st.ok());
108 }
109 
TEST_F(ReaderTest,ReadSavedModelDebugInfoIfPresent)110 TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) {
111   const string export_dir = GetDataDependencyFilepath(TestDataSharded());
112   std::unique_ptr<GraphDebugInfo> debug_info_proto;
113   TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto));
114 }
115 
TEST_F(ReaderTest,MetricsNotUpdatedFailedRead)116 TEST_F(ReaderTest, MetricsNotUpdatedFailedRead) {
117   MetaGraphDef meta_graph_def;
118   const int read_count_v1 = metrics::SavedModelRead("1").value();
119   const int read_count_v2 = metrics::SavedModelRead("2").value();
120 
121   const string export_dir = GetDataDependencyFilepath("missing-path");
122   Status st =
123       ReadMetaGraphDefFromSavedModel(export_dir, {"serve"}, &meta_graph_def);
124 
125   EXPECT_FALSE(st.ok());
126   EXPECT_EQ(metrics::SavedModelRead("1").value(), read_count_v1);
127   EXPECT_EQ(metrics::SavedModelRead("2").value(), read_count_v2);
128 }
129 
TEST_F(ReaderTest,MetricsUpdatedSuccessfulRead)130 TEST_F(ReaderTest, MetricsUpdatedSuccessfulRead) {
131   MetaGraphDef meta_graph_def;
132   const int read_count_v1 = metrics::SavedModelRead("1").value();
133 
134   const string export_dir = GetDataDependencyFilepath(TestDataSharded());
135   Status st =
136       ReadMetaGraphDefFromSavedModel(export_dir, {"serve"}, &meta_graph_def);
137   EXPECT_EQ(metrics::SavedModelRead("1").value(), read_count_v1 + 1);
138 }
139 
140 }  // namespace
141 }  // namespace tensorflow
142