• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17 
18 #include "tensorflow/cc/saved_model/metrics.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 #include "tensorflow/core/lib/io/path.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 constexpr char kTestData[] = "cc/saved_model/testdata";
27 
28 class BundleV2Test : public ::testing::Test {
29  protected:
BundleV2Test()30   BundleV2Test() {}
31 
RestoreVarsAndVerify(SavedModelV2Bundle * bundle,std::vector<std::string> expected_names)32   void RestoreVarsAndVerify(SavedModelV2Bundle* bundle,
33                             std::vector<std::string> expected_names) {
34     // Collect saved_node_id, full_name, checkpoint_key into a vector.
35     using RestoredVarType = std::tuple<int, std::string, std::string>;
36     std::vector<RestoredVarType> restored_vars;
37     TF_ASSERT_OK(bundle->VisitObjectsToRestore(
38         [&](int saved_node_id,
39             const TrackableObjectGraph::TrackableObject& trackable_object)
40             -> Status {
41           for (const auto& attr : trackable_object.attributes()) {
42             if (attr.name() == "VARIABLE_VALUE") {
43               restored_vars.emplace_back(saved_node_id, attr.full_name(),
44                                          attr.checkpoint_key());
45             }
46           }
47           return OkStatus();
48         }));
49 
50     // Should be one of each var name restored.
51     for (const auto& expected_name : expected_names) {
52       EXPECT_EQ(1, std::count_if(restored_vars.begin(), restored_vars.end(),
53                                  [&](RestoredVarType t) {
54                                    return std::get<1>(t) == expected_name;
55                                  }));
56     }
57 
58     for (const auto& restored_var : restored_vars) {
59       // Each restored var should match a SavedObjectGraph node with the same
60       // variable name.
61       const auto& saved_node =
62           bundle->saved_object_graph().nodes(std::get<0>(restored_var));
63       EXPECT_EQ(std::get<1>(restored_var), saved_node.variable().name());
64 
65       // And should be able to load it from the tensor_bundle.
66       Tensor value;
67       TF_ASSERT_OK(
68           bundle->variable_reader()->Lookup(std::get<2>(restored_var), &value));
69     }
70   }
71 };
72 
TEST_F(BundleV2Test,LoadsVarsAndArithmeticObjectGraph)73 TEST_F(BundleV2Test, LoadsVarsAndArithmeticObjectGraph) {
74   const string export_dir = io::JoinPath(
75       testing::TensorFlowSrcRoot(), kTestData, "VarsAndArithmeticObjectGraph");
76 
77   SavedModelV2Bundle bundle;
78   TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
79 
80   // Ensure that there are nodes in the trackable_object_graph.
81   EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0);
82 
83   RestoreVarsAndVerify(&bundle, {"variable_x", "variable_y", "child_variable"});
84 }
85 
TEST_F(BundleV2Test,LoadsCyclicModule)86 TEST_F(BundleV2Test, LoadsCyclicModule) {
87   const string export_dir =
88       io::JoinPath(testing::TensorFlowSrcRoot(), kTestData, "CyclicModule");
89 
90   SavedModelV2Bundle bundle;
91   TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
92 
93   // Ensure that there are nodes in the trackable_object_graph.
94   EXPECT_GT(bundle.trackable_object_graph().nodes_size(), 0);
95 
96   RestoreVarsAndVerify(&bundle, {"MyVariable"});
97 }
98 
TEST_F(BundleV2Test,UpdatesMetrics)99 TEST_F(BundleV2Test, UpdatesMetrics) {
100   const string kCCLoadBundleV2Label = "cc_load_bundle_v2";
101   const int read_count = metrics::SavedModelRead("2").value();
102   const int api_count =
103       metrics::SavedModelReadApi(kCCLoadBundleV2Label).value();
104   const string export_dir = io::JoinPath(
105       testing::TensorFlowSrcRoot(), kTestData, "VarsAndArithmeticObjectGraph");
106 
107   SavedModelV2Bundle bundle;
108   TF_ASSERT_OK(SavedModelV2Bundle::Load(export_dir, &bundle));
109 
110   EXPECT_EQ(metrics::SavedModelRead("2").value(), read_count + 1);
111   EXPECT_EQ(metrics::SavedModelReadApi(kCCLoadBundleV2Label).value(),
112             api_count + 1);
113 }
114 
115 }  // namespace
116 }  // namespace tensorflow
117