1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/experimental/libtf/module.h"
16
17 #include <string>
18
19 #include "tensorflow/cc/experimental/libtf/runtime/core/core.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/resource_loader.h"
22 #include "tensorflow/core/platform/status_matchers.h"
23 #include "tensorflow/core/platform/statusor.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
27
28 namespace tf {
29 namespace libtf {
30 namespace impl {
31
32 using ::tensorflow::libexport::TFPackage;
33 using ::tensorflow::testing::StatusIs;
34 using ::tf::libtf::runtime::Runtime;
35
TEST(ModuleTest,TestStubbedFunctions)36 TEST(ModuleTest, TestStubbedFunctions) {
37 Runtime runtime = runtime::core::Runtime();
38 TFPackage tf_package;
39 tensorflow::StatusOr<Handle> result = BuildProgram(runtime, tf_package);
40 ASSERT_FALSE(result.status().ok());
41 }
42
TEST(ModuleTest,TestBuildObjectsDataStructures)43 TEST(ModuleTest, TestBuildObjectsDataStructures) {
44 const std::string path = tensorflow::GetDataDependencyFilepath(
45 "tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model");
46 TF_ASSERT_OK_AND_ASSIGN(TFPackage tf_package, TFPackage::Load(path));
47
48 TF_ASSERT_OK_AND_ASSIGN(std::vector<Handle> objects,
49 BuildObjects(tf_package));
50 EXPECT_EQ(objects.size(), 7);
51 // The first node of data-structure-model is a dictionary.
52 TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node,
53 Cast<tf::libtf::Dictionary>(objects.front()));
54
55 // The next three nodes of data-structure-model are lists.
56 for (unsigned int i = 1; i < 4; i++) {
57 TF_ASSERT_OK_AND_ASSIGN(tf::libtf::List node,
58 Cast<tf::libtf::List>(objects.at(i)));
59 }
60 // The last three nodes of data-structure-model are dictionaries.
61 for (unsigned int i = 4; i < 7; i++) {
62 TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node,
63 Cast<tf::libtf::Dictionary>(objects.at(i)));
64 }
65 }
66
TEST(ModuleTest,TestBuildEmptyList)67 TEST(ModuleTest, TestBuildEmptyList) {
68 tensorflow::SavedObject saved_object_proto;
69 const std::string pb_txt = R"pb(
70 user_object {
71 identifier: "trackable_list_wrapper"
72 version { producer: 1 min_consumer: 1 }
73 }
74 )pb";
75
76 ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
77 pb_txt, &saved_object_proto));
78 TF_ASSERT_OK_AND_ASSIGN(Handle result,
79 BuildSavedUserObject(saved_object_proto));
80 EXPECT_EQ(Cast<tf::libtf::List>(result)->size(), 0);
81 }
82
TEST(ModuleTest,TestBuildEmptyDict)83 TEST(ModuleTest, TestBuildEmptyDict) {
84 tensorflow::SavedObject saved_object_proto;
85 const std::string pb_txt = R"pb(
86 user_object {
87 identifier: "trackable_dict_wrapper"
88 version { producer: 1 min_consumer: 1 }
89 }
90 )pb";
91
92 ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
93 pb_txt, &saved_object_proto));
94
95 TF_ASSERT_OK_AND_ASSIGN(Handle result,
96 BuildSavedUserObject(saved_object_proto));
97 EXPECT_EQ(Cast<tf::libtf::Dictionary>(result)->size(), 0);
98 }
99
TEST(ModuleTest,TestBuildSignatureMap)100 TEST(ModuleTest, TestBuildSignatureMap) {
101 tensorflow::SavedObject saved_object_proto;
102 const std::string pb_txt = R"pb(
103 user_object {
104 identifier: "signature_map"
105 version { producer: 1 min_consumer: 1 }
106 }
107 )pb";
108
109 ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
110 pb_txt, &saved_object_proto));
111 TF_ASSERT_OK_AND_ASSIGN(Handle result,
112 BuildSavedUserObject(saved_object_proto));
113 EXPECT_EQ(Cast<tf::libtf::Dictionary>(result)->size(), 0);
114 }
115
TEST(ModuleTest,TestUnimplementedUserObject)116 TEST(ModuleTest, TestUnimplementedUserObject) {
117 tensorflow::SavedObject saved_object_proto;
118 const std::string pb_txt = R"pb(
119 user_object {
120 identifier: "foo"
121 version { producer: 1 min_consumer: 1 }
122 }
123 )pb";
124
125 ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
126 pb_txt, &saved_object_proto));
127
128 EXPECT_THAT(
129 BuildSavedUserObject(saved_object_proto),
130 StatusIs(tensorflow::error::UNIMPLEMENTED, ::testing::HasSubstr("foo")));
131 }
132
133 } // namespace impl
134 } // namespace libtf
135 } // namespace tf
136