• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/experimental/libtf/module.h"
16 
17 #include <string>
18 
19 #include "tensorflow/cc/experimental/libtf/runtime/core/core.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/resource_loader.h"
22 #include "tensorflow/core/platform/status_matchers.h"
23 #include "tensorflow/core/platform/statusor.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
27 
28 namespace tf {
29 namespace libtf {
30 namespace impl {
31 
32 using ::tensorflow::libexport::TFPackage;
33 using ::tensorflow::testing::StatusIs;
34 using ::tf::libtf::runtime::Runtime;
35 
TEST(ModuleTest,TestStubbedFunctions)36 TEST(ModuleTest, TestStubbedFunctions) {
37   Runtime runtime = runtime::core::Runtime();
38   TFPackage tf_package;
39   tensorflow::StatusOr<Handle> result = BuildProgram(runtime, tf_package);
40   ASSERT_FALSE(result.status().ok());
41 }
42 
TEST(ModuleTest,TestBuildObjectsDataStructures)43 TEST(ModuleTest, TestBuildObjectsDataStructures) {
44   const std::string path = tensorflow::GetDataDependencyFilepath(
45       "tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model");
46   TF_ASSERT_OK_AND_ASSIGN(TFPackage tf_package, TFPackage::Load(path));
47 
48   TF_ASSERT_OK_AND_ASSIGN(std::vector<Handle> objects,
49                           BuildObjects(tf_package));
50   EXPECT_EQ(objects.size(), 7);
51   // The first node of data-structure-model is a dictionary.
52   TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node,
53                           Cast<tf::libtf::Dictionary>(objects.front()));
54 
55   // The next three nodes of data-structure-model are lists.
56   for (unsigned int i = 1; i < 4; i++) {
57     TF_ASSERT_OK_AND_ASSIGN(tf::libtf::List node,
58                             Cast<tf::libtf::List>(objects.at(i)));
59   }
60   // The last three nodes of data-structure-model are dictionaries.
61   for (unsigned int i = 4; i < 7; i++) {
62     TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node,
63                             Cast<tf::libtf::Dictionary>(objects.at(i)));
64   }
65 }
66 
TEST(ModuleTest,TestBuildEmptyList)67 TEST(ModuleTest, TestBuildEmptyList) {
68   tensorflow::SavedObject saved_object_proto;
69   const std::string pb_txt = R"pb(
70     user_object {
71       identifier: "trackable_list_wrapper"
72       version { producer: 1 min_consumer: 1 }
73     }
74   )pb";
75 
76   ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
77       pb_txt, &saved_object_proto));
78   TF_ASSERT_OK_AND_ASSIGN(Handle result,
79                           BuildSavedUserObject(saved_object_proto));
80   EXPECT_EQ(Cast<tf::libtf::List>(result)->size(), 0);
81 }
82 
TEST(ModuleTest,TestBuildEmptyDict)83 TEST(ModuleTest, TestBuildEmptyDict) {
84   tensorflow::SavedObject saved_object_proto;
85   const std::string pb_txt = R"pb(
86     user_object {
87       identifier: "trackable_dict_wrapper"
88       version { producer: 1 min_consumer: 1 }
89     }
90   )pb";
91 
92   ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
93       pb_txt, &saved_object_proto));
94 
95   TF_ASSERT_OK_AND_ASSIGN(Handle result,
96                           BuildSavedUserObject(saved_object_proto));
97   EXPECT_EQ(Cast<tf::libtf::Dictionary>(result)->size(), 0);
98 }
99 
TEST(ModuleTest,TestBuildSignatureMap)100 TEST(ModuleTest, TestBuildSignatureMap) {
101   tensorflow::SavedObject saved_object_proto;
102   const std::string pb_txt = R"pb(
103     user_object {
104       identifier: "signature_map"
105       version { producer: 1 min_consumer: 1 }
106     }
107   )pb";
108 
109   ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
110       pb_txt, &saved_object_proto));
111   TF_ASSERT_OK_AND_ASSIGN(Handle result,
112                           BuildSavedUserObject(saved_object_proto));
113   EXPECT_EQ(Cast<tf::libtf::Dictionary>(result)->size(), 0);
114 }
115 
TEST(ModuleTest,TestUnimplementedUserObject)116 TEST(ModuleTest, TestUnimplementedUserObject) {
117   tensorflow::SavedObject saved_object_proto;
118   const std::string pb_txt = R"pb(
119     user_object {
120       identifier: "foo"
121       version { producer: 1 min_consumer: 1 }
122     }
123   )pb";
124 
125   ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
126       pb_txt, &saved_object_proto));
127 
128   EXPECT_THAT(
129       BuildSavedUserObject(saved_object_proto),
130       StatusIs(tensorflow::error::UNIMPLEMENTED, ::testing::HasSubstr("foo")));
131 }
132 
133 }  // namespace impl
134 }  // namespace libtf
135 }  // namespace tf
136