1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/experimental/libtf/tests/runtime_test.h"
16
17 namespace tf {
18 namespace libtf {
19 namespace runtime {
20
21 using ::tensorflow::testing::StatusIs;
22 using ::testing::HasSubstr;
23 using ::tf::libtf::impl::TaggedValueTensor;
24
25 constexpr char kSimpleModel[] =
26 "tensorflow/cc/experimental/libtf/tests/testdata/simple-model";
27
TEST_P(RuntimeTest,SimpleModelCallableFloatTest)28 TEST_P(RuntimeTest, SimpleModelCallableFloatTest) {
29 Runtime runtime = RuntimeTest::GetParam()();
30
31 // Import the module and grab the callable
32 const std::string module_path =
33 tensorflow::GetDataDependencyFilepath(kSimpleModel);
34
35 TF_ASSERT_OK_AND_ASSIGN(Object module,
36 runtime.Load(String(module_path.c_str())));
37 std::cout << "Module imported." << std::endl;
38
39 TF_ASSERT_OK_AND_ASSIGN(Callable fn,
40 module.Get<Callable>(String("test_float")));
41 TF_ASSERT_OK_AND_ASSIGN(
42 Tensor tensor, runtime.CreateHostTensor<float>({}, TF_FLOAT, {2.0f}));
43 TF_ASSERT_OK_AND_ASSIGN(Tensor result, fn.Call<Tensor>(Tensor(tensor)));
44
45 float out_val[1];
46 TF_ASSERT_OK(result.GetValue(absl::MakeSpan(out_val)));
47 EXPECT_EQ(out_val[0], 6.0);
48 }
49
TEST_P(RuntimeTest,SimpleModelCallableIntTest)50 TEST_P(RuntimeTest, SimpleModelCallableIntTest) {
51 Runtime runtime = RuntimeTest::GetParam()();
52
53 // Import the module and grab the callable
54 const std::string module_path =
55 tensorflow::GetDataDependencyFilepath(kSimpleModel);
56 TF_ASSERT_OK_AND_ASSIGN(Object module,
57 runtime.Load(String(module_path.c_str())));
58
59 TF_ASSERT_OK_AND_ASSIGN(Callable fn,
60 module.Get<Callable>(String("test_int")));
61
62 // Call the function
63 TF_ASSERT_OK_AND_ASSIGN(Tensor host_tensor,
64 runtime.CreateHostTensor<int>({}, TF_INT32, {2}));
65
66 TF_ASSERT_OK_AND_ASSIGN(Tensor tensor, fn.Call<Tensor>(Tensor(host_tensor)));
67
68 int out_val[1];
69 TF_ASSERT_OK(tensor.GetValue(absl::MakeSpan(out_val)));
70 EXPECT_EQ(out_val[0], 6);
71 }
72
TEST_P(RuntimeTest,SimpleModelCallableMultipleArgsTest)73 TEST_P(RuntimeTest, SimpleModelCallableMultipleArgsTest) {
74 Runtime runtime = RuntimeTest::GetParam()();
75
76 // Import the module and grab the callable
77 const std::string module_path =
78 tensorflow::GetDataDependencyFilepath(kSimpleModel);
79 TF_ASSERT_OK_AND_ASSIGN(Object module,
80 runtime.Load(String(module_path.c_str())));
81
82 TF_ASSERT_OK_AND_ASSIGN(Callable fn,
83 module.Get<Callable>(String("test_add")));
84
85 TF_ASSERT_OK_AND_ASSIGN(Tensor tensor1,
86 runtime.CreateHostTensor<float>({}, TF_FLOAT, {2.0f}))
87 TF_ASSERT_OK_AND_ASSIGN(Tensor tensor2,
88 runtime.CreateHostTensor<float>({}, TF_FLOAT, {3.0f}))
89
90 TF_ASSERT_OK_AND_ASSIGN(Tensor result_tensor,
91 fn.Call<Tensor>(tensor1, tensor2));
92 float out_val[1];
93 TF_ASSERT_OK(result_tensor.GetValue(absl::MakeSpan(out_val)));
94 EXPECT_EQ(out_val[0], 5.0f);
95 }
96
TEST_P(RuntimeTest,CreateHostTensorIncompatibleShape)97 TEST_P(RuntimeTest, CreateHostTensorIncompatibleShape) {
98 Runtime runtime = RuntimeTest::GetParam()();
99 EXPECT_THAT(runtime.CreateHostTensor<float>({2}, TF_FLOAT, {2.0f}),
100 StatusIs(tensorflow::error::INVALID_ARGUMENT,
101 HasSubstr("Mismatched shape and data size")));
102 }
103
TEST_P(RuntimeTest,CreateHostTensorNonFullyDefinedShapeRaises)104 TEST_P(RuntimeTest, CreateHostTensorNonFullyDefinedShapeRaises) {
105 Runtime runtime = RuntimeTest::GetParam()();
106 EXPECT_THAT(runtime.CreateHostTensor<float>({-1}, TF_FLOAT, {2.0f}),
107 StatusIs(tensorflow::error::INVALID_ARGUMENT,
108 HasSubstr("Shape must be fully-defined")));
109 }
110
TEST_P(RuntimeTest,CreateHostTensorIncompatibleDataType)111 TEST_P(RuntimeTest, CreateHostTensorIncompatibleDataType) {
112 Runtime runtime = RuntimeTest::GetParam()();
113 EXPECT_THAT(runtime.CreateHostTensor<float>({1}, TF_BOOL, {2.0f}),
114 StatusIs(tensorflow::error::INVALID_ARGUMENT,
115 HasSubstr("Invalid number of bytes in data buffer")));
116 }
117
TEST_P(RuntimeTest,TensorCopyInvalidSize)118 TEST_P(RuntimeTest, TensorCopyInvalidSize) {
119 Runtime runtime = RuntimeTest::GetParam()();
120 TF_ASSERT_OK_AND_ASSIGN(
121 Tensor tensor, runtime.CreateHostTensor<float>({1}, TF_FLOAT, {2.0f}))
122 float val[2];
123
124 EXPECT_THAT(tensor.GetValue(absl::MakeSpan(val)),
125 StatusIs(tensorflow::error::INVALID_ARGUMENT,
126 HasSubstr("Mismatched number of elements")));
127 }
128
129 } // namespace runtime
130 } // namespace libtf
131 } // namespace tf
132