• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/cc/experimental/libtf/tests/runtime_test.h"
16 
17 namespace tf {
18 namespace libtf {
19 namespace runtime {
20 
21 using ::tensorflow::testing::StatusIs;
22 using ::testing::HasSubstr;
23 using ::tf::libtf::impl::TaggedValueTensor;
24 
25 constexpr char kSimpleModel[] =
26     "tensorflow/cc/experimental/libtf/tests/testdata/simple-model";
27 
TEST_P(RuntimeTest,SimpleModelCallableFloatTest)28 TEST_P(RuntimeTest, SimpleModelCallableFloatTest) {
29   Runtime runtime = RuntimeTest::GetParam()();
30 
31   // Import the module and grab the callable
32   const std::string module_path =
33       tensorflow::GetDataDependencyFilepath(kSimpleModel);
34 
35   TF_ASSERT_OK_AND_ASSIGN(Object module,
36                           runtime.Load(String(module_path.c_str())));
37   std::cout << "Module imported." << std::endl;
38 
39   TF_ASSERT_OK_AND_ASSIGN(Callable fn,
40                           module.Get<Callable>(String("test_float")));
41   TF_ASSERT_OK_AND_ASSIGN(
42       Tensor tensor, runtime.CreateHostTensor<float>({}, TF_FLOAT, {2.0f}));
43   TF_ASSERT_OK_AND_ASSIGN(Tensor result, fn.Call<Tensor>(Tensor(tensor)));
44 
45   float out_val[1];
46   TF_ASSERT_OK(result.GetValue(absl::MakeSpan(out_val)));
47   EXPECT_EQ(out_val[0], 6.0);
48 }
49 
TEST_P(RuntimeTest,SimpleModelCallableIntTest)50 TEST_P(RuntimeTest, SimpleModelCallableIntTest) {
51   Runtime runtime = RuntimeTest::GetParam()();
52 
53   // Import the module and grab the callable
54   const std::string module_path =
55       tensorflow::GetDataDependencyFilepath(kSimpleModel);
56   TF_ASSERT_OK_AND_ASSIGN(Object module,
57                           runtime.Load(String(module_path.c_str())));
58 
59   TF_ASSERT_OK_AND_ASSIGN(Callable fn,
60                           module.Get<Callable>(String("test_int")));
61 
62   // Call the function
63   TF_ASSERT_OK_AND_ASSIGN(Tensor host_tensor,
64                           runtime.CreateHostTensor<int>({}, TF_INT32, {2}));
65 
66   TF_ASSERT_OK_AND_ASSIGN(Tensor tensor, fn.Call<Tensor>(Tensor(host_tensor)));
67 
68   int out_val[1];
69   TF_ASSERT_OK(tensor.GetValue(absl::MakeSpan(out_val)));
70   EXPECT_EQ(out_val[0], 6);
71 }
72 
TEST_P(RuntimeTest,SimpleModelCallableMultipleArgsTest)73 TEST_P(RuntimeTest, SimpleModelCallableMultipleArgsTest) {
74   Runtime runtime = RuntimeTest::GetParam()();
75 
76   // Import the module and grab the callable
77   const std::string module_path =
78       tensorflow::GetDataDependencyFilepath(kSimpleModel);
79   TF_ASSERT_OK_AND_ASSIGN(Object module,
80                           runtime.Load(String(module_path.c_str())));
81 
82   TF_ASSERT_OK_AND_ASSIGN(Callable fn,
83                           module.Get<Callable>(String("test_add")));
84 
85   TF_ASSERT_OK_AND_ASSIGN(Tensor tensor1,
86                           runtime.CreateHostTensor<float>({}, TF_FLOAT, {2.0f}))
87   TF_ASSERT_OK_AND_ASSIGN(Tensor tensor2,
88                           runtime.CreateHostTensor<float>({}, TF_FLOAT, {3.0f}))
89 
90   TF_ASSERT_OK_AND_ASSIGN(Tensor result_tensor,
91                           fn.Call<Tensor>(tensor1, tensor2));
92   float out_val[1];
93   TF_ASSERT_OK(result_tensor.GetValue(absl::MakeSpan(out_val)));
94   EXPECT_EQ(out_val[0], 5.0f);
95 }
96 
TEST_P(RuntimeTest,CreateHostTensorIncompatibleShape)97 TEST_P(RuntimeTest, CreateHostTensorIncompatibleShape) {
98   Runtime runtime = RuntimeTest::GetParam()();
99   EXPECT_THAT(runtime.CreateHostTensor<float>({2}, TF_FLOAT, {2.0f}),
100               StatusIs(tensorflow::error::INVALID_ARGUMENT,
101                        HasSubstr("Mismatched shape and data size")));
102 }
103 
TEST_P(RuntimeTest,CreateHostTensorNonFullyDefinedShapeRaises)104 TEST_P(RuntimeTest, CreateHostTensorNonFullyDefinedShapeRaises) {
105   Runtime runtime = RuntimeTest::GetParam()();
106   EXPECT_THAT(runtime.CreateHostTensor<float>({-1}, TF_FLOAT, {2.0f}),
107               StatusIs(tensorflow::error::INVALID_ARGUMENT,
108                        HasSubstr("Shape must be fully-defined")));
109 }
110 
TEST_P(RuntimeTest,CreateHostTensorIncompatibleDataType)111 TEST_P(RuntimeTest, CreateHostTensorIncompatibleDataType) {
112   Runtime runtime = RuntimeTest::GetParam()();
113   EXPECT_THAT(runtime.CreateHostTensor<float>({1}, TF_BOOL, {2.0f}),
114               StatusIs(tensorflow::error::INVALID_ARGUMENT,
115                        HasSubstr("Invalid number of bytes in data buffer")));
116 }
117 
TEST_P(RuntimeTest,TensorCopyInvalidSize)118 TEST_P(RuntimeTest, TensorCopyInvalidSize) {
119   Runtime runtime = RuntimeTest::GetParam()();
120   TF_ASSERT_OK_AND_ASSIGN(
121       Tensor tensor, runtime.CreateHostTensor<float>({1}, TF_FLOAT, {2.0f}))
122   float val[2];
123 
124   EXPECT_THAT(tensor.GetValue(absl::MakeSpan(val)),
125               StatusIs(tensorflow::error::INVALID_ARGUMENT,
126                        HasSubstr("Mismatched number of elements")));
127 }
128 
129 }  // namespace runtime
130 }  // namespace libtf
131 }  // namespace tf
132