• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/core/tensor.h"
18 
19 #include <cstdint>
20 #include <initializer_list>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "fcp/aggregation/core/datatype.h"
28 #include "fcp/aggregation/core/tensor.pb.h"
29 #include "fcp/aggregation/testing/test_data.h"
30 #include "fcp/aggregation/testing/testing.h"
31 #include "fcp/base/monitoring.h"
32 #include "fcp/testing/testing.h"
33 
34 namespace fcp {
35 namespace aggregation {
36 namespace {
37 
38 using testing::Eq;
39 
TEST(TensorTest,Create_Dense)40 TEST(TensorTest, Create_Dense) {
41   auto t = Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1, 2, 3}));
42   EXPECT_THAT(t, IsOk());
43   EXPECT_THAT(t->dtype(), Eq(DT_FLOAT));
44   EXPECT_THAT(t->shape(), Eq(TensorShape{3}));
45   EXPECT_TRUE(t->is_dense());
46   EXPECT_THAT(t->AsAggVector<float>().size(), Eq(3));
47 }
48 
TEST(TensorTest,Create_StringTensor)49 TEST(TensorTest, Create_StringTensor) {
50   auto t = Tensor::Create(DT_STRING, {2},
51                           CreateTestData<string_view>({"foo", "bar"}));
52   EXPECT_THAT(t, IsOk());
53   EXPECT_THAT(t->dtype(), Eq(DT_STRING));
54   EXPECT_THAT(t->shape(), Eq(TensorShape{2}));
55   EXPECT_TRUE(t->is_dense());
56   EXPECT_THAT(t->AsAggVector<string_view>().size(), Eq(2));
57 }
58 
TEST(TensorTest,Create_DataValidationError)59 TEST(TensorTest, Create_DataValidationError) {
60   auto t = Tensor::Create(DT_FLOAT, {}, CreateTestData<char>({'a', 'b', 'c'}));
61   EXPECT_THAT(t, IsCode(FAILED_PRECONDITION));
62 }
63 
TEST(TensorTest,Create_DataSizeError)64 TEST(TensorTest, Create_DataSizeError) {
65   auto t = Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({1, 2}));
66   EXPECT_THAT(t, IsCode(FAILED_PRECONDITION));
67 }
68 
69 struct FooBar {};
70 
TEST(TensorTest,AsAggVector_TypeCheckFailure)71 TEST(TensorTest, AsAggVector_TypeCheckFailure) {
72   auto t = Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({1}));
73   EXPECT_DEATH(t->AsAggVector<FooBar>(), "Incompatible tensor dtype()");
74   EXPECT_DEATH(t->AsAggVector<int>(), "Incompatible tensor dtype()");
75 }
76 
77 template <typename T>
ToProtoContent(std::initializer_list<T> values)78 std::string ToProtoContent(std::initializer_list<T> values) {
79   return std::string(reinterpret_cast<char*>(std::vector(values).data()),
80                      values.size() * sizeof(T));
81 }
82 
83 template <>
ToProtoContent(std::initializer_list<string_view> values)84 std::string ToProtoContent(std::initializer_list<string_view> values) {
85   // The following is the simplified version of serializing the string values
86   // that works only for short strings that are shorter than 128 characters, in
87   // which case string lengths can be encoded with one byte each.
88   std::string content(values.size(), '\0');
89   size_t index = 0;
90   // Write sizes of strings first.
91   for (string_view value : values) {
92     FCP_CHECK(value.size() < 128);
93     content[index++] = static_cast<char>(value.size());
94   }
95   // Append data of all strings.
96   for (string_view value : values) {
97     content.append(value.data(), value.size());
98   }
99   return content;
100 }
101 
TEST(TensorTest,ToProto_Numeric_Success)102 TEST(TensorTest, ToProto_Numeric_Success) {
103   std::initializer_list<int32_t> values{1, 2, 3, 4};
104   auto t = Tensor::Create(DT_INT32, {2, 2}, CreateTestData(values));
105   TensorProto expected_proto;
106   expected_proto.set_dtype(DT_INT32);
107   expected_proto.mutable_shape()->add_dim_sizes(2);
108   expected_proto.mutable_shape()->add_dim_sizes(2);
109   expected_proto.set_content(ToProtoContent(values));
110   EXPECT_THAT(t->ToProto(), EqualsProto(expected_proto));
111 }
112 
TEST(TensorTest,ToProto_String_Success)113 TEST(TensorTest, ToProto_String_Success) {
114   std::initializer_list<string_view> values{"abc",  "de",    "",
115                                             "fghi", "jklmn", "o"};
116   auto t = Tensor::Create(DT_STRING, {2, 3}, CreateTestData(values));
117   TensorProto expected_proto;
118   expected_proto.set_dtype(DT_STRING);
119   expected_proto.mutable_shape()->add_dim_sizes(2);
120   expected_proto.mutable_shape()->add_dim_sizes(3);
121   expected_proto.set_content(ToProtoContent(values));
122   EXPECT_THAT(t->ToProto(), EqualsProto(expected_proto));
123 }
124 
TEST(TensorTest,FromProto_Numeric_Success)125 TEST(TensorTest, FromProto_Numeric_Success) {
126   std::initializer_list<int32_t> values{5, 6, 7, 8, 9, 10};
127   TensorProto tensor_proto;
128   tensor_proto.set_dtype(DT_INT32);
129   tensor_proto.mutable_shape()->add_dim_sizes(2);
130   tensor_proto.mutable_shape()->add_dim_sizes(3);
131   tensor_proto.set_content(ToProtoContent(values));
132   auto t = Tensor::FromProto(tensor_proto);
133   EXPECT_THAT(t, IsOk());
134   EXPECT_THAT(*t, IsTensor({2, 3}, values));
135 }
136 
TEST(TensorTest,FromProto_String_Success)137 TEST(TensorTest, FromProto_String_Success) {
138   std::initializer_list<string_view> values{"aaaaaaaa", "b", "cccc", "ddddddd"};
139   TensorProto tensor_proto;
140   tensor_proto.set_dtype(DT_STRING);
141   tensor_proto.mutable_shape()->add_dim_sizes(2);
142   tensor_proto.mutable_shape()->add_dim_sizes(2);
143   tensor_proto.set_content(ToProtoContent(values));
144   auto t = Tensor::FromProto(tensor_proto);
145   EXPECT_THAT(t, IsOk());
146   EXPECT_THAT(*t, IsTensor({2, 2}, values));
147 }
148 
TEST(TensorTest,LargeStringValuesSerialization)149 TEST(TensorTest, LargeStringValuesSerialization) {
150   std::string s1(123456, 'a');
151   std::string s2(7890, 'b');
152   std::string s3(1357924, 'c');
153   auto t1 =
154       Tensor::Create(DT_STRING, {3}, CreateTestData<string_view>({s1, s2, s3}));
155   auto proto = t1->ToProto();
156   auto t2 = Tensor::FromProto(proto);
157   EXPECT_THAT(*t2, IsTensor<string_view>({3}, {s1, s2, s3}));
158 }
159 
TEST(TensorTest,FromProto_Mutable_Success)160 TEST(TensorTest, FromProto_Mutable_Success) {
161   std::initializer_list<int32_t> values{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
162   TensorProto tensor_proto;
163   tensor_proto.set_dtype(DT_INT32);
164   tensor_proto.mutable_shape()->add_dim_sizes(10);
165   tensor_proto.set_content(ToProtoContent(values));
166   // Store the data pointer to make sure that the tensor retains the same data.
167   void* data_ptr = tensor_proto.mutable_content()->data();
168   auto t = Tensor::FromProto(std::move(tensor_proto));
169   EXPECT_THAT(t, IsOk());
170   EXPECT_THAT(*t, IsTensor({10}, values));
171   EXPECT_EQ(data_ptr, t->data().data());
172 }
173 
TEST(TensorTest,FromProto_NegativeDimSize)174 TEST(TensorTest, FromProto_NegativeDimSize) {
175   TensorProto tensor_proto;
176   tensor_proto.set_dtype(DT_INT32);
177   tensor_proto.mutable_shape()->add_dim_sizes(-1);
178   tensor_proto.set_content(ToProtoContent<int32_t>({1}));
179   EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
180 }
181 
TEST(TensorTest,FromProto_InvalidStringContent)182 TEST(TensorTest, FromProto_InvalidStringContent) {
183   TensorProto tensor_proto;
184   tensor_proto.set_dtype(DT_STRING);
185   tensor_proto.mutable_shape()->add_dim_sizes(1);
186   tensor_proto.set_content("");
187   EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
188 
189   std::string content(1, '\5');
190   tensor_proto.set_content(content);
191   EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
192 
193   content.append("abc");
194   tensor_proto.set_content(content);
195   EXPECT_THAT(Tensor::FromProto(tensor_proto), IsCode(INVALID_ARGUMENT));
196 }
197 
198 }  // namespace
199 }  // namespace aggregation
200 }  // namespace fcp
201