1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17
18 #include "tensorflow/cc/saved_model/metrics.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/io/path.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace tensorflow {
24 namespace {
25
26 constexpr char kTestData[] = "cc/saved_model/testdata";
27
28 class BundleV2Test : public ::testing::Test {
29 protected:
BundleV2Test()30 BundleV2Test() {}
31
RestoreVarsAndVerify(SavedModelV2Bundle * bundle,std::vector<std::string> expected_names)32 void RestoreVarsAndVerify(SavedModelV2Bundle* bundle,
33 std::vector<std::string> expected_names) {
34 // Collect saved_node_id, full_name, checkpoint_key into a vector.
35 using RestoredVarType = std::tuple<int, std::string, std::string>;
36 std::vector<RestoredVarType> restored_vars;
37 TF_ASSERT_OK(bundle->VisitObjectsToRestore(
38 [&](int saved_node_id,
39 const TrackableObjectGraph::TrackableObject& trackable_object)
40 -> Status {
41 for (const auto& attr : trackable_object.attributes()) {
42 if (attr.name() == "VARIABLE_VALUE") {
43 restored_vars.emplace_back(saved_node_id, attr.full_name(),
44 attr.checkpoint_key());
45 }
46 }
47 return OkStatus();
48 }));
49
50 // Should be one of each var name restored.
51 for (const auto& expected_name : expected_names) {
52 EXPECT_EQ(1, std::count_if(restored_vars.begin(), restored_vars.end(),
53 [&](RestoredVarType t) {
54 return std::get<1>(t) == expected_name;
55 }));
56 }
57
58 for (const auto& restored_var : restored_vars) {
59 // Each restored var should match a SavedObjectGraph node with the same
60 // variable name.
61 const auto& saved_node =
62 bundle->saved_object_graph().nodes(std::get<0>(restored_var));
63 EXPECT_EQ(std::get<1>(restored_var), saved_node.variable().name());
64
65 // And should be able to load it from the tensor_bundle.
66 Tensor value;
67 TF_ASSERT_OK(
68 bundle->variable_reader()->Lookup(std::get<2>(restored_var), &value));
69 }
70 }
71 };
72
TEST_F(BundleV2Test,LoadsVarsAndArithmeticObjectGraph)73 TEST_F(BundleV2Test, LoadsVarsAndArithmeticObjectGraph) {
74 const string export_dir = io::JoinPath(
75 testing::TensorFlowSrcRoot(), kTestData, "VarsAndArithmeticObjectGraph");
76
77 SavedModelV2Bundle bundle;
78 TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
79
80 // Ensure that there are nodes in the trackable_object_graph.
81 EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0);
82
83 RestoreVarsAndVerify(&bundle, {"variable_x", "variable_y", "child_variable"});
84 }
85
TEST_F(BundleV2Test,LoadsCyclicModule)86 TEST_F(BundleV2Test, LoadsCyclicModule) {
87 const string export_dir =
88 io::JoinPath(testing::TensorFlowSrcRoot(), kTestData, "CyclicModule");
89
90 SavedModelV2Bundle bundle;
91 TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
92
93 // Ensure that there are nodes in the trackable_object_graph.
94 EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0);
95
96 RestoreVarsAndVerify(&bundle, {"MyVariable"});
97 }
98
TEST_F(BundleV2Test,UpdatesMetrics)99 TEST_F(BundleV2Test, UpdatesMetrics) {
100 const string kCCLoadBundleV2Label = "cc_load_bundle_v2";
101 const int read_count = metrics::SavedModelRead("2").value();
102 const int api_count =
103 metrics::SavedModelReadApi(kCCLoadBundleV2Label).value();
104 const string export_dir = io::JoinPath(
105 testing::TensorFlowSrcRoot(), kTestData, "VarsAndArithmeticObjectGraph");
106
107 SavedModelV2Bundle bundle;
108 TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
109
110 EXPECT_EQ(metrics::SavedModelRead("2").value(), read_count + 1);
111 EXPECT_EQ(metrics::SavedModelReadApi(kCCLoadBundleV2Label).value(),
112 api_count + 1);
113 }
114
115 } // namespace
116 } // namespace tensorflow
117