1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
17
18 #include <string>
19
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/casts.h"
27 #include "tensorflow/core/platform/protobuf.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/platform/status_matchers.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h"
32
33 namespace tensorflow {
34 namespace {
35
ParseTextProto(absl::string_view text_proto,protobuf::Message * parsed_proto)36 Status ParseTextProto(absl::string_view text_proto,
37 protobuf::Message* parsed_proto) {
38 protobuf::TextFormat::Parser parser;
39 // Attempt to parse as text.
40 protobuf::io::ArrayInputStream input_stream(text_proto.data(),
41 text_proto.size());
42 if (parser.Parse(&input_stream, parsed_proto)) {
43 return OkStatus();
44 }
45 parsed_proto->Clear();
46 return errors::InvalidArgument("Could not parse text proto: ", text_proto);
47 }
48
TEST(TPUEmbeddingConfigurationProtoRewriteTest,FillFeatureDescriptor)49 TEST(TPUEmbeddingConfigurationProtoRewriteTest, FillFeatureDescriptor) {
50 const std::string config_str = R"pb(
51 table_descriptor {
52 name: "T0"
53 vocabulary_size: 35324928
54 dimension: 128
55 num_features: 3
56 optimization_parameters {
57 adagrad {}
58 learning_rate { constant: 0.1 }
59 }
60 }
61 table_descriptor {
62 name: "T1"
63 vocabulary_size: 3122176
64 dimension: 128
65 num_features: 2
66 optimization_parameters {
67 adagrad {}
68 learning_rate { constant: 0.1 }
69 }
70 }
71 mode: TRAINING
72 batch_size_per_tensor_core: 256
73 num_hosts: 16
74 num_tensor_cores: 128
75 pipeline_execution_with_tensor_core: true
76 )pb";
77 tpu::TPUEmbeddingConfiguration tpu_embedding_config;
78 TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
79 TF_ASSERT_OK(
80 PopulateMissingFieldsInTPUEmbeddingConfig(&tpu_embedding_config));
81
82 EXPECT_EQ(tpu_embedding_config.feature_descriptor_size(), 2);
83 const auto& feature_0 = tpu_embedding_config.feature_descriptor(0);
84 EXPECT_EQ(feature_0.table_id(), 0);
85 EXPECT_THAT(feature_0.input_shape(), ::testing::ElementsAre(256 * 3));
86 const auto& feature_1 = tpu_embedding_config.feature_descriptor(1);
87 EXPECT_EQ(feature_1.table_id(), 1);
88 EXPECT_THAT(feature_1.input_shape(), ::testing::ElementsAre(256 * 2));
89 }
90
TEST(TPUEmbeddingConfigurationProtoRewriteTest,FillBatchSizeAndNumFeatures)91 TEST(TPUEmbeddingConfigurationProtoRewriteTest, FillBatchSizeAndNumFeatures) {
92 const std::string config_str = R"pb(
93 table_descriptor {
94 name: "T0"
95 vocabulary_size: 35324928
96 dimension: 128
97 optimization_parameters {
98 adagrad {}
99 learning_rate { constant: 0.1 }
100 }
101 }
102 table_descriptor {
103 name: "T1"
104 vocabulary_size: 3122176
105 dimension: 128
106 optimization_parameters {
107 adagrad {}
108 learning_rate { constant: 0.1 }
109 }
110 }
111 feature_descriptor {
112 name: "F0"
113 table_id: 0
114 input_shape: [ 100, 5 ]
115 }
116 feature_descriptor {
117 name: "F1"
118 table_id: 1
119 input_shape: [ 200, 5, 20 ]
120 }
121 feature_descriptor {
122 name: "F2"
123 table_id: 0
124 input_shape: [ 50 ]
125 }
126 feature_descriptor {
127 name: "F3"
128 table_id: 0
129 input_shape: [ 100, 2, 3 ]
130 }
131 mode: TRAINING
132 num_hosts: 16
133 num_tensor_cores: 128
134 pipeline_execution_with_tensor_core: true
135 )pb";
136 tpu::TPUEmbeddingConfiguration tpu_embedding_config;
137 TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
138 TF_ASSERT_OK(
139 PopulateMissingFieldsInTPUEmbeddingConfig(&tpu_embedding_config));
140
141 EXPECT_EQ(tpu_embedding_config.batch_size_per_tensor_core(), 50);
142 const auto& table_0 = tpu_embedding_config.table_descriptor(0);
143 EXPECT_EQ(table_0.num_features(), 23);
144 const auto& table_1 = tpu_embedding_config.table_descriptor(1);
145 EXPECT_EQ(table_1.num_features(), 400);
146 }
147
TEST(TPUEmbeddingConfigurationProtoRewriteTest,InvalidBatchSizeOrNumFeatures)148 TEST(TPUEmbeddingConfigurationProtoRewriteTest, InvalidBatchSizeOrNumFeatures) {
149 const std::string config_str = R"pb(
150 table_descriptor {
151 name: "T0"
152 vocabulary_size: 35324928
153 dimension: 128
154 num_features: 3
155 optimization_parameters {
156 adagrad {}
157 learning_rate { constant: 0.1 }
158 }
159 }
160 feature_descriptor {
161 table_id: 0
162 input_shape: [ 768 ]
163 }
164 mode: TRAINING
165 batch_size_per_tensor_core: 256
166 num_hosts: 16
167 num_tensor_cores: 128
168 pipeline_execution_with_tensor_core: true
169 )pb";
170 tpu::TPUEmbeddingConfiguration tpu_embedding_config;
171 TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
172 {
173 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
174 invalid_config.clear_feature_descriptor();
175 invalid_config.clear_batch_size_per_tensor_core();
176 EXPECT_THAT(
177 PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
178 tensorflow::testing::StatusIs(
179 tensorflow::error::INVALID_ARGUMENT,
180 ::testing::HasSubstr("Invalid batch_size_per_tensor_core")));
181 }
182 {
183 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
184 invalid_config.clear_feature_descriptor();
185 invalid_config.mutable_table_descriptor(0)->clear_num_features();
186 EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
187 tensorflow::testing::StatusIs(
188 tensorflow::error::INVALID_ARGUMENT,
189 ::testing::HasSubstr("Invalid num_features")));
190 }
191 {
192 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
193 EXPECT_THAT(
194 PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
195 tensorflow::testing::StatusIs(
196 tensorflow::error::INVALID_ARGUMENT,
197 ::testing::HasSubstr(
198 "The batch_size_per_tensor_core field must NOT be populated")));
199 }
200 {
201 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
202 invalid_config.clear_batch_size_per_tensor_core();
203 EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
204 tensorflow::testing::StatusIs(
205 tensorflow::error::INVALID_ARGUMENT,
206 ::testing::HasSubstr("The TableDescriptor.num_features "
207 "field must NOT be populated")));
208 }
209 }
210
TEST(TPUEmbeddingConfigurationProtoRewriteTest,InvalidFeatureDescriptor)211 TEST(TPUEmbeddingConfigurationProtoRewriteTest, InvalidFeatureDescriptor) {
212 const std::string config_str = R"pb(
213 table_descriptor {
214 name: "T0"
215 vocabulary_size: 35324928
216 dimension: 128
217 optimization_parameters {
218 adagrad {}
219 learning_rate { constant: 0.1 }
220 }
221 }
222 table_descriptor {
223 name: "T1"
224 vocabulary_size: 3122176
225 dimension: 128
226 optimization_parameters {
227 adagrad {}
228 learning_rate { constant: 0.1 }
229 }
230 }
231 feature_descriptor {
232 name: "F1"
233 table_id: 0
234 input_shape: [ 768 ]
235 }
236 feature_descriptor {
237 name: "F2"
238 table_id: 1
239 input_shape: [ 512 ]
240 }
241 mode: TRAINING
242 num_hosts: 16
243 num_tensor_cores: 128
244 pipeline_execution_with_tensor_core: true
245 )pb";
246 tpu::TPUEmbeddingConfiguration tpu_embedding_config;
247 TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
248 {
249 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
250 invalid_config.mutable_feature_descriptor(0)->set_table_id(2);
251 EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
252 tensorflow::testing::StatusIs(
253 tensorflow::error::INVALID_ARGUMENT,
254 ::testing::HasSubstr("Invalid table_id")));
255 }
256 {
257 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
258 invalid_config.mutable_feature_descriptor(0)->clear_input_shape();
259 EXPECT_THAT(
260 PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
261 tensorflow::testing::StatusIs(
262 tensorflow::error::INVALID_ARGUMENT,
263 ::testing::HasSubstr("The input_shape field cannot be empty")));
264 }
265 {
266 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
267 invalid_config.mutable_feature_descriptor(0)->set_input_shape(0, -5);
268 EXPECT_THAT(
269 PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
270 tensorflow::testing::StatusIs(
271 tensorflow::error::INVALID_ARGUMENT,
272 ::testing::HasSubstr("The input_shape dimension sizes must all")));
273 }
274 {
275 tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
276 invalid_config.mutable_feature_descriptor(1)->set_table_id(0);
277 EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
278 tensorflow::testing::StatusIs(
279 tensorflow::error::INVALID_ARGUMENT,
280 ::testing::HasSubstr(
281 "No feature_descriptor fields found for table: T1")));
282 }
283 }
284
285 } // namespace
286 } // namespace tensorflow
287