1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/tensor_flag_utils.h"
17
18 #include <vector>
19
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace {
25
26 using tensorflow::DataType;
27 using tensorflow::int32;
28 using tensorflow::int64;
29 using tensorflow::Tensor;
30 using tensorflow::TTypes;
31 using tensorflow::error::INVALID_ARGUMENT;
32 using tensorflow::tensor_flag_utils::FindConfigValueForKey;
33 using tensorflow::tensor_flag_utils::GetLinearBucket;
34 using tensorflow::tensor_flag_utils::GetPowerBucket;
35 using tensorflow::tensor_flag_utils::ValidateScalarQuantityShardingConfig;
36 using tensorflow::tensor_flag_utils::ValidateSparseMatrixShardingConfig;
37
TEST(SparseUtilsTest,ValidateSparseMatrixShardingConfig)38 TEST(SparseUtilsTest, ValidateSparseMatrixShardingConfig) {
39 // Only a default is specified.
40 {
41 Tensor t(DataType::DT_FLOAT, {});
42 t.scalar<float>()() = 0.7;
43 EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
44 }
45 {
46 Tensor t(DataType::DT_FLOAT, {});
47 t.scalar<float>()() = 1.0;
48 EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
49 }
50
51 // Misshapen.
52 {
53 Tensor t(DataType::DT_FLOAT, {1, 1});
54 int indx = 0;
55 for (const float v : {60.0}) {
56 t.flat<float>()(indx++) = v;
57 }
58 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
59 }
60 {
61 Tensor t(DataType::DT_FLOAT, {1, 2});
62 int indx = 0;
63 for (const float v : {
64 60.0,
65 50.0,
66 }) {
67 t.flat<float>()(indx++) = v;
68 }
69 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
70 }
71
72 // Only one key is specified.
73 {
74 Tensor t(DataType::DT_FLOAT, {1, 3});
75 int indx = 0;
76 for (const float v : {30.0, 20.0, 1.0}) {
77 t.flat<float>()(indx++) = v;
78 }
79 EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
80 }
81
82 // Two keys are specified.
83 {
84 Tensor t(DataType::DT_FLOAT, {2, 3});
85 int indx = 0;
86 for (const float v : {60.0, 50.0, 0.41, 30.0, 20.0, 0.7}) {
87 t.flat<float>()(indx++) = v;
88 }
89 EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
90 }
91
92 // Out of range.
93 {
94 Tensor t(DataType::DT_FLOAT, {2, 3});
95 int indx = 0;
96 for (const float v : {60.0, 40.0, 0.41, 30.0, 20.0, 10.7}) {
97 t.flat<float>()(indx++) = v;
98 }
99 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
100 }
101 {
102 Tensor t(DataType::DT_FLOAT, {2, 3});
103 int indx = 0;
104 for (const float v : {60.0, 40.0, 0.41, 30.0, 20.0, -0.7}) {
105 t.flat<float>()(indx++) = v;
106 }
107 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
108 }
109 {
110 Tensor t(DataType::DT_FLOAT, {2, 3});
111 int indx = 0;
112 for (const float v : {60.0, -40.0, 0.41, 30.0, 20.0, 0.7}) {
113 t.flat<float>()(indx++) = v;
114 }
115 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
116 }
117 {
118 Tensor t(DataType::DT_FLOAT, {});
119 t.scalar<float>()() = -0.5;
120 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
121 }
122 {
123 Tensor t(DataType::DT_FLOAT, {});
124 t.scalar<float>()() = 0;
125 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
126 }
127 {
128 Tensor t(DataType::DT_FLOAT, {});
129 t.scalar<float>()() = 1.2;
130 EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
131 }
132 }
133
TEST(SparseUtilsTest,ValidateScalarQuantityShardingConfig)134 TEST(SparseUtilsTest, ValidateScalarQuantityShardingConfig) {
135 // Only a default is specified.
136 {
137 Tensor t(DataType::DT_FLOAT, {});
138 t.scalar<float>()() = 0.7;
139 EXPECT_TRUE(ValidateScalarQuantityShardingConfig(t).ok());
140 }
141 {
142 Tensor t(DataType::DT_FLOAT, {});
143 t.scalar<float>()() = 1.0;
144 EXPECT_TRUE(ValidateScalarQuantityShardingConfig(t).ok());
145 }
146 {
147 Tensor t(DataType::DT_FLOAT, {});
148 t.scalar<float>()() = 1.2;
149 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
150 }
151
152 // Misshapen.
153 {
154 Tensor t(DataType::DT_FLOAT, {1, 1});
155 int indx = 0;
156 for (const float v : {60.0}) {
157 t.flat<float>()(indx++) = v;
158 }
159 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
160 }
161 {
162 Tensor t(DataType::DT_FLOAT, {1, 2});
163 int indx = 0;
164 for (const float v : {
165 60.0,
166 50.0,
167 }) {
168 t.flat<float>()(indx++) = v;
169 }
170 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
171 }
172
173 // Two keys are specified.
174 {
175 Tensor t(DataType::DT_FLOAT, {1, 3});
176 int indx = 0;
177 for (const float v : {30.0, 20.0, 1.0}) {
178 t.flat<float>()(indx++) = v;
179 }
180 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
181 }
182
183 // Only one key is specified.
184 {
185 Tensor t(DataType::DT_FLOAT, {2, 2});
186 int indx = 0;
187 for (const float v : {60.0, 0.41, 30.0, 0.7}) {
188 t.flat<float>()(indx++) = v;
189 }
190 EXPECT_TRUE(ValidateScalarQuantityShardingConfig(t).ok());
191 }
192
193 // Out of range.
194 {
195 Tensor t(DataType::DT_FLOAT, {2, 2});
196 int indx = 0;
197 for (const float v : {60.0, 0.41, 30.0, 10.7}) {
198 t.flat<float>()(indx++) = v;
199 }
200 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
201 }
202 {
203 Tensor t(DataType::DT_FLOAT, {2, 2});
204 int indx = 0;
205 for (const float v : {60.0, 0.41, 30.0, -0.7}) {
206 t.flat<float>()(indx++) = v;
207 }
208 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
209 }
210 {
211 Tensor t(DataType::DT_FLOAT, {2, 2});
212 int indx = 0;
213 for (const float v : {-40.0, 0.41, 20.0, 0.7}) {
214 t.flat<float>()(indx++) = v;
215 }
216 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
217 }
218 {
219 Tensor t(DataType::DT_FLOAT, {});
220 t.scalar<float>()() = -0.5;
221 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
222 }
223 {
224 Tensor t(DataType::DT_FLOAT, {});
225 t.scalar<float>()() = 0;
226 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
227 }
228 {
229 Tensor t(DataType::DT_FLOAT, {});
230 t.scalar<float>()() = 1.2;
231 EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
232 }
233 }
234
TEST(SparseUtils,FindConfigValueForKey)235 TEST(SparseUtils, FindConfigValueForKey) {
236 {
237 float data[] = {60.0, 50.0, 0.41, 30.0, 20.0, 0.1, 0, 0, 0.7};
238 TTypes<float>::ConstMatrix config_mat(data, 3, 3);
239 auto val = FindConfigValueForKey<float, int32>(config_mat, {70, 40});
240 EXPECT_FLOAT_EQ(0.1, val);
241 val = FindConfigValueForKey<float, int32>(config_mat, {60, 50});
242 EXPECT_FLOAT_EQ(0.41, val);
243 val = FindConfigValueForKey<float, int32>(config_mat, {60, 60});
244 EXPECT_FLOAT_EQ(0.41, val);
245 val = FindConfigValueForKey<float, int32>(config_mat, {60, 40});
246 EXPECT_FLOAT_EQ(0.1, val);
247 val = FindConfigValueForKey<float, int32>(config_mat, {50, 60});
248 EXPECT_FLOAT_EQ(0.1, val);
249 val = FindConfigValueForKey<float, int32>(config_mat, {20, 30});
250 EXPECT_FLOAT_EQ(0.7, val);
251 val = FindConfigValueForKey<float, int32>(config_mat, {30, 10});
252 EXPECT_FLOAT_EQ(0.7, val);
253 }
254 {
255 float data[] = {0, 0, 0.7};
256 TTypes<float>::ConstMatrix config_mat(data, 1, 3);
257 auto val = FindConfigValueForKey<float, int64>(config_mat, {70, 40});
258 EXPECT_FLOAT_EQ(0.7, val);
259 val = FindConfigValueForKey<float, int64>(config_mat, {60, 50});
260 EXPECT_FLOAT_EQ(0.7, val);
261 val = FindConfigValueForKey<float, int64>(config_mat, {60, 60});
262 EXPECT_FLOAT_EQ(0.7, val);
263 val = FindConfigValueForKey<float, int64>(config_mat, {60, 40});
264 EXPECT_FLOAT_EQ(0.7, val);
265 val = FindConfigValueForKey<float, int64>(config_mat, {50, 60});
266 EXPECT_FLOAT_EQ(0.7, val);
267 val = FindConfigValueForKey<float, int64>(config_mat, {20, 30});
268 EXPECT_FLOAT_EQ(0.7, val);
269 val = FindConfigValueForKey<float, int64>(config_mat, {30, 10});
270 EXPECT_FLOAT_EQ(0.7, val);
271 }
272 {
273 float data[] = {60.0, 50.0, 0.41, 0, 0, 0.7};
274 TTypes<float>::ConstMatrix config_mat(data, 2, 3);
275 auto val = FindConfigValueForKey<float, int32>(config_mat, {70, 40});
276 EXPECT_FLOAT_EQ(0.7, val);
277 val = FindConfigValueForKey<float, int32>(config_mat, {60, 50});
278 EXPECT_FLOAT_EQ(0.41, val);
279 val = FindConfigValueForKey<float, int32>(config_mat, {60, 60});
280 EXPECT_FLOAT_EQ(0.41, val);
281 val = FindConfigValueForKey<float, int32>(config_mat, {60, 40});
282 EXPECT_FLOAT_EQ(0.7, val);
283 val = FindConfigValueForKey<float, int32>(config_mat, {50, 60});
284 EXPECT_FLOAT_EQ(0.7, val);
285 val = FindConfigValueForKey<float, int32>(config_mat, {20, 30});
286 EXPECT_FLOAT_EQ(0.7, val);
287 val = FindConfigValueForKey<float, int32>(config_mat, {30, 10});
288 EXPECT_FLOAT_EQ(0.7, val);
289 }
290 {
291 float data[] = {60.0, 0.41, 50.0, 0.14, 0, 0.7};
292 TTypes<float>::ConstMatrix config_mat(data, 3, 2);
293 auto val = FindConfigValueForKey<float, int32>(config_mat, 70);
294 EXPECT_FLOAT_EQ(0.41, val);
295 val = FindConfigValueForKey<float, int32>(config_mat, 60);
296 EXPECT_FLOAT_EQ(0.41, val);
297 val = FindConfigValueForKey<float, int32>(config_mat, 55);
298 EXPECT_FLOAT_EQ(0.14, val);
299 val = FindConfigValueForKey<float, int32>(config_mat, 50);
300 EXPECT_FLOAT_EQ(0.14, val);
301 val = FindConfigValueForKey<float, int32>(config_mat, 20);
302 EXPECT_FLOAT_EQ(0.7, val);
303 val = FindConfigValueForKey<float, int32>(config_mat, 30);
304 EXPECT_FLOAT_EQ(0.7, val);
305 }
306 }
307
TEST(SparseUtils,GetLinearBucket)308 TEST(SparseUtils, GetLinearBucket) {
309 EXPECT_EQ(11, GetLinearBucket(11, 5));
310 EXPECT_EQ(11, GetLinearBucket(12, 5));
311 EXPECT_EQ(1, GetLinearBucket(int64{4}, int64{5}));
312 }
313
TEST(SparseUtils,GetPowerBucket)314 TEST(SparseUtils, GetPowerBucket) {
315 EXPECT_EQ(6, GetPowerBucket(11, 5));
316 EXPECT_EQ(6, GetPowerBucket(12, 5));
317 EXPECT_EQ(1332, GetPowerBucket(1335, 11));
318 EXPECT_EQ(5, GetPowerBucket(int64{5}, int64{4}));
319 EXPECT_EQ(1, GetPowerBucket(int64{4}, int64{1}));
320 }
321
322 } // namespace
323