1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tensor-view.h"
18
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21
22 namespace libtextclassifier3 {
23 namespace {
24
TEST(TensorViewTest,TestSize)25 TEST(TensorViewTest, TestSize) {
26 std::vector<float> data{0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
27 const TensorView<float> tensor(data.data(), {3, 1, 2});
28 EXPECT_TRUE(tensor.is_valid());
29 EXPECT_EQ(tensor.shape(), (std::vector<int>{3, 1, 2}));
30 EXPECT_EQ(tensor.data(), data.data());
31 EXPECT_EQ(tensor.size(), 6);
32 EXPECT_EQ(tensor.dims(), 3);
33 EXPECT_EQ(tensor.dim(0), 3);
34 EXPECT_EQ(tensor.dim(1), 1);
35 EXPECT_EQ(tensor.dim(2), 2);
36 std::vector<float> output_data(6);
37 EXPECT_TRUE(tensor.copy_to(output_data.data(), output_data.size()));
38 EXPECT_EQ(data, output_data);
39
40 // Should not copy when the output is small.
41 std::vector<float> small_output_data{-1, -1, -1};
42 EXPECT_FALSE(
43 tensor.copy_to(small_output_data.data(), small_output_data.size()));
44 // The output buffer should not be changed.
45 EXPECT_EQ(small_output_data, (std::vector<float>{-1, -1, -1}));
46
47 const TensorView<float> invalid_tensor = TensorView<float>::Invalid();
48 EXPECT_FALSE(invalid_tensor.is_valid());
49 }
50
51 } // namespace
52 } // namespace libtextclassifier3
53