• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/tensor_flag_utils.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace {
25 
26 using tensorflow::DataType;
27 using tensorflow::int32;
28 using tensorflow::int64;
29 using tensorflow::Tensor;
30 using tensorflow::TTypes;
31 using tensorflow::error::INVALID_ARGUMENT;
32 using tensorflow::tensor_flag_utils::FindConfigValueForKey;
33 using tensorflow::tensor_flag_utils::GetLinearBucket;
34 using tensorflow::tensor_flag_utils::GetPowerBucket;
35 using tensorflow::tensor_flag_utils::ValidateScalarQuantityShardingConfig;
36 using tensorflow::tensor_flag_utils::ValidateSparseMatrixShardingConfig;
37 
TEST(SparseUtilsTest,ValidateSparseMatrixShardingConfig)38 TEST(SparseUtilsTest, ValidateSparseMatrixShardingConfig) {
39   // Only a default is specified.
40   {
41     Tensor t(DataType::DT_FLOAT, {});
42     t.scalar<float>()() = 0.7;
43     EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
44   }
45   {
46     Tensor t(DataType::DT_FLOAT, {});
47     t.scalar<float>()() = 1.0;
48     EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
49   }
50 
51   // Misshapen.
52   {
53     Tensor t(DataType::DT_FLOAT, {1, 1});
54     int indx = 0;
55     for (const float v : {60.0}) {
56       t.flat<float>()(indx++) = v;
57     }
58     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
59   }
60   {
61     Tensor t(DataType::DT_FLOAT, {1, 2});
62     int indx = 0;
63     for (const float v : {
64              60.0,
65              50.0,
66          }) {
67       t.flat<float>()(indx++) = v;
68     }
69     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
70   }
71 
72   // Only one key is specified.
73   {
74     Tensor t(DataType::DT_FLOAT, {1, 3});
75     int indx = 0;
76     for (const float v : {30.0, 20.0, 1.0}) {
77       t.flat<float>()(indx++) = v;
78     }
79     EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
80   }
81 
82   // Two keys are specified.
83   {
84     Tensor t(DataType::DT_FLOAT, {2, 3});
85     int indx = 0;
86     for (const float v : {60.0, 50.0, 0.41, 30.0, 20.0, 0.7}) {
87       t.flat<float>()(indx++) = v;
88     }
89     EXPECT_TRUE(ValidateSparseMatrixShardingConfig(t).ok());
90   }
91 
92   // Out of range.
93   {
94     Tensor t(DataType::DT_FLOAT, {2, 3});
95     int indx = 0;
96     for (const float v : {60.0, 40.0, 0.41, 30.0, 20.0, 10.7}) {
97       t.flat<float>()(indx++) = v;
98     }
99     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
100   }
101   {
102     Tensor t(DataType::DT_FLOAT, {2, 3});
103     int indx = 0;
104     for (const float v : {60.0, 40.0, 0.41, 30.0, 20.0, -0.7}) {
105       t.flat<float>()(indx++) = v;
106     }
107     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
108   }
109   {
110     Tensor t(DataType::DT_FLOAT, {2, 3});
111     int indx = 0;
112     for (const float v : {60.0, -40.0, 0.41, 30.0, 20.0, 0.7}) {
113       t.flat<float>()(indx++) = v;
114     }
115     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
116   }
117   {
118     Tensor t(DataType::DT_FLOAT, {});
119     t.scalar<float>()() = -0.5;
120     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
121   }
122   {
123     Tensor t(DataType::DT_FLOAT, {});
124     t.scalar<float>()() = 0;
125     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
126   }
127   {
128     Tensor t(DataType::DT_FLOAT, {});
129     t.scalar<float>()() = 1.2;
130     EXPECT_EQ(INVALID_ARGUMENT, ValidateSparseMatrixShardingConfig(t).code());
131   }
132 }
133 
TEST(SparseUtilsTest,ValidateScalarQuantityShardingConfig)134 TEST(SparseUtilsTest, ValidateScalarQuantityShardingConfig) {
135   // Only a default is specified.
136   {
137     Tensor t(DataType::DT_FLOAT, {});
138     t.scalar<float>()() = 0.7;
139     EXPECT_TRUE(ValidateScalarQuantityShardingConfig(t).ok());
140   }
141   {
142     Tensor t(DataType::DT_FLOAT, {});
143     t.scalar<float>()() = 1.0;
144     EXPECT_TRUE(ValidateScalarQuantityShardingConfig(t).ok());
145   }
146   {
147     Tensor t(DataType::DT_FLOAT, {});
148     t.scalar<float>()() = 1.2;
149     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
150   }
151 
152   // Misshapen.
153   {
154     Tensor t(DataType::DT_FLOAT, {1, 1});
155     int indx = 0;
156     for (const float v : {60.0}) {
157       t.flat<float>()(indx++) = v;
158     }
159     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
160   }
161   {
162     Tensor t(DataType::DT_FLOAT, {1, 2});
163     int indx = 0;
164     for (const float v : {
165              60.0,
166              50.0,
167          }) {
168       t.flat<float>()(indx++) = v;
169     }
170     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
171   }
172 
173   // Two keys are specified.
174   {
175     Tensor t(DataType::DT_FLOAT, {1, 3});
176     int indx = 0;
177     for (const float v : {30.0, 20.0, 1.0}) {
178       t.flat<float>()(indx++) = v;
179     }
180     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
181   }
182 
183   // Only one key is specified.
184   {
185     Tensor t(DataType::DT_FLOAT, {2, 2});
186     int indx = 0;
187     for (const float v : {60.0, 0.41, 30.0, 0.7}) {
188       t.flat<float>()(indx++) = v;
189     }
190     EXPECT_TRUE(ValidateScalarQuantityShardingConfig(t).ok());
191   }
192 
193   // Out of range.
194   {
195     Tensor t(DataType::DT_FLOAT, {2, 2});
196     int indx = 0;
197     for (const float v : {60.0, 0.41, 30.0, 10.7}) {
198       t.flat<float>()(indx++) = v;
199     }
200     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
201   }
202   {
203     Tensor t(DataType::DT_FLOAT, {2, 2});
204     int indx = 0;
205     for (const float v : {60.0, 0.41, 30.0, -0.7}) {
206       t.flat<float>()(indx++) = v;
207     }
208     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
209   }
210   {
211     Tensor t(DataType::DT_FLOAT, {2, 2});
212     int indx = 0;
213     for (const float v : {-40.0, 0.41, 20.0, 0.7}) {
214       t.flat<float>()(indx++) = v;
215     }
216     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
217   }
218   {
219     Tensor t(DataType::DT_FLOAT, {});
220     t.scalar<float>()() = -0.5;
221     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
222   }
223   {
224     Tensor t(DataType::DT_FLOAT, {});
225     t.scalar<float>()() = 0;
226     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
227   }
228   {
229     Tensor t(DataType::DT_FLOAT, {});
230     t.scalar<float>()() = 1.2;
231     EXPECT_EQ(INVALID_ARGUMENT, ValidateScalarQuantityShardingConfig(t).code());
232   }
233 }
234 
TEST(SparseUtils,FindConfigValueForKey)235 TEST(SparseUtils, FindConfigValueForKey) {
236   {
237     float data[] = {60.0, 50.0, 0.41, 30.0, 20.0, 0.1, 0, 0, 0.7};
238     TTypes<float>::ConstMatrix config_mat(data, 3, 3);
239     auto val = FindConfigValueForKey<float, int32>(config_mat, {70, 40});
240     EXPECT_FLOAT_EQ(0.1, val);
241     val = FindConfigValueForKey<float, int32>(config_mat, {60, 50});
242     EXPECT_FLOAT_EQ(0.41, val);
243     val = FindConfigValueForKey<float, int32>(config_mat, {60, 60});
244     EXPECT_FLOAT_EQ(0.41, val);
245     val = FindConfigValueForKey<float, int32>(config_mat, {60, 40});
246     EXPECT_FLOAT_EQ(0.1, val);
247     val = FindConfigValueForKey<float, int32>(config_mat, {50, 60});
248     EXPECT_FLOAT_EQ(0.1, val);
249     val = FindConfigValueForKey<float, int32>(config_mat, {20, 30});
250     EXPECT_FLOAT_EQ(0.7, val);
251     val = FindConfigValueForKey<float, int32>(config_mat, {30, 10});
252     EXPECT_FLOAT_EQ(0.7, val);
253   }
254   {
255     float data[] = {0, 0, 0.7};
256     TTypes<float>::ConstMatrix config_mat(data, 1, 3);
257     auto val = FindConfigValueForKey<float, int64>(config_mat, {70, 40});
258     EXPECT_FLOAT_EQ(0.7, val);
259     val = FindConfigValueForKey<float, int64>(config_mat, {60, 50});
260     EXPECT_FLOAT_EQ(0.7, val);
261     val = FindConfigValueForKey<float, int64>(config_mat, {60, 60});
262     EXPECT_FLOAT_EQ(0.7, val);
263     val = FindConfigValueForKey<float, int64>(config_mat, {60, 40});
264     EXPECT_FLOAT_EQ(0.7, val);
265     val = FindConfigValueForKey<float, int64>(config_mat, {50, 60});
266     EXPECT_FLOAT_EQ(0.7, val);
267     val = FindConfigValueForKey<float, int64>(config_mat, {20, 30});
268     EXPECT_FLOAT_EQ(0.7, val);
269     val = FindConfigValueForKey<float, int64>(config_mat, {30, 10});
270     EXPECT_FLOAT_EQ(0.7, val);
271   }
272   {
273     float data[] = {60.0, 50.0, 0.41, 0, 0, 0.7};
274     TTypes<float>::ConstMatrix config_mat(data, 2, 3);
275     auto val = FindConfigValueForKey<float, int32>(config_mat, {70, 40});
276     EXPECT_FLOAT_EQ(0.7, val);
277     val = FindConfigValueForKey<float, int32>(config_mat, {60, 50});
278     EXPECT_FLOAT_EQ(0.41, val);
279     val = FindConfigValueForKey<float, int32>(config_mat, {60, 60});
280     EXPECT_FLOAT_EQ(0.41, val);
281     val = FindConfigValueForKey<float, int32>(config_mat, {60, 40});
282     EXPECT_FLOAT_EQ(0.7, val);
283     val = FindConfigValueForKey<float, int32>(config_mat, {50, 60});
284     EXPECT_FLOAT_EQ(0.7, val);
285     val = FindConfigValueForKey<float, int32>(config_mat, {20, 30});
286     EXPECT_FLOAT_EQ(0.7, val);
287     val = FindConfigValueForKey<float, int32>(config_mat, {30, 10});
288     EXPECT_FLOAT_EQ(0.7, val);
289   }
290   {
291     float data[] = {60.0, 0.41, 50.0, 0.14, 0, 0.7};
292     TTypes<float>::ConstMatrix config_mat(data, 3, 2);
293     auto val = FindConfigValueForKey<float, int32>(config_mat, 70);
294     EXPECT_FLOAT_EQ(0.41, val);
295     val = FindConfigValueForKey<float, int32>(config_mat, 60);
296     EXPECT_FLOAT_EQ(0.41, val);
297     val = FindConfigValueForKey<float, int32>(config_mat, 55);
298     EXPECT_FLOAT_EQ(0.14, val);
299     val = FindConfigValueForKey<float, int32>(config_mat, 50);
300     EXPECT_FLOAT_EQ(0.14, val);
301     val = FindConfigValueForKey<float, int32>(config_mat, 20);
302     EXPECT_FLOAT_EQ(0.7, val);
303     val = FindConfigValueForKey<float, int32>(config_mat, 30);
304     EXPECT_FLOAT_EQ(0.7, val);
305   }
306 }
307 
TEST(SparseUtils,GetLinearBucket)308 TEST(SparseUtils, GetLinearBucket) {
309   EXPECT_EQ(11, GetLinearBucket(11, 5));
310   EXPECT_EQ(11, GetLinearBucket(12, 5));
311   EXPECT_EQ(1, GetLinearBucket(int64{4}, int64{5}));
312 }
313 
TEST(SparseUtils,GetPowerBucket)314 TEST(SparseUtils, GetPowerBucket) {
315   EXPECT_EQ(6, GetPowerBucket(11, 5));
316   EXPECT_EQ(6, GetPowerBucket(12, 5));
317   EXPECT_EQ(1332, GetPowerBucket(1335, 11));
318   EXPECT_EQ(5, GetPowerBucket(int64{5}, int64{4}));
319   EXPECT_EQ(1, GetPowerBucket(int64{4}, int64{1}));
320 }
321 
322 }  // namespace
323