• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
17 
18 #include <string>
19 
20 #include <gmock/gmock.h>
21 #include <gtest/gtest.h>
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/casts.h"
27 #include "tensorflow/core/platform/protobuf.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/platform/status_matchers.h"
30 #include "tensorflow/core/protobuf/error_codes.pb.h"
31 #include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h"
32 
33 namespace tensorflow {
34 namespace {
35 
ParseTextProto(absl::string_view text_proto,protobuf::Message * parsed_proto)36 Status ParseTextProto(absl::string_view text_proto,
37                       protobuf::Message* parsed_proto) {
38   protobuf::TextFormat::Parser parser;
39   // Attempt to parse as text.
40   protobuf::io::ArrayInputStream input_stream(text_proto.data(),
41                                               text_proto.size());
42   if (parser.Parse(&input_stream, parsed_proto)) {
43     return OkStatus();
44   }
45   parsed_proto->Clear();
46   return errors::InvalidArgument("Could not parse text proto: ", text_proto);
47 }
48 
TEST(TPUEmbeddingConfigurationProtoRewriteTest,FillFeatureDescriptor)49 TEST(TPUEmbeddingConfigurationProtoRewriteTest, FillFeatureDescriptor) {
50   const std::string config_str = R"pb(
51     table_descriptor {
52       name: "T0"
53       vocabulary_size: 35324928
54       dimension: 128
55       num_features: 3
56       optimization_parameters {
57         adagrad {}
58         learning_rate { constant: 0.1 }
59       }
60     }
61     table_descriptor {
62       name: "T1"
63       vocabulary_size: 3122176
64       dimension: 128
65       num_features: 2
66       optimization_parameters {
67         adagrad {}
68         learning_rate { constant: 0.1 }
69       }
70     }
71     mode: TRAINING
72     batch_size_per_tensor_core: 256
73     num_hosts: 16
74     num_tensor_cores: 128
75     pipeline_execution_with_tensor_core: true
76   )pb";
77   tpu::TPUEmbeddingConfiguration tpu_embedding_config;
78   TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
79   TF_ASSERT_OK(
80       PopulateMissingFieldsInTPUEmbeddingConfig(&tpu_embedding_config));
81 
82   EXPECT_EQ(tpu_embedding_config.feature_descriptor_size(), 2);
83   const auto& feature_0 = tpu_embedding_config.feature_descriptor(0);
84   EXPECT_EQ(feature_0.table_id(), 0);
85   EXPECT_THAT(feature_0.input_shape(), ::testing::ElementsAre(256 * 3));
86   const auto& feature_1 = tpu_embedding_config.feature_descriptor(1);
87   EXPECT_EQ(feature_1.table_id(), 1);
88   EXPECT_THAT(feature_1.input_shape(), ::testing::ElementsAre(256 * 2));
89 }
90 
TEST(TPUEmbeddingConfigurationProtoRewriteTest,FillBatchSizeAndNumFeatures)91 TEST(TPUEmbeddingConfigurationProtoRewriteTest, FillBatchSizeAndNumFeatures) {
92   const std::string config_str = R"pb(
93     table_descriptor {
94       name: "T0"
95       vocabulary_size: 35324928
96       dimension: 128
97       optimization_parameters {
98         adagrad {}
99         learning_rate { constant: 0.1 }
100       }
101     }
102     table_descriptor {
103       name: "T1"
104       vocabulary_size: 3122176
105       dimension: 128
106       optimization_parameters {
107         adagrad {}
108         learning_rate { constant: 0.1 }
109       }
110     }
111     feature_descriptor {
112       name: "F0"
113       table_id: 0
114       input_shape: [ 100, 5 ]
115     }
116     feature_descriptor {
117       name: "F1"
118       table_id: 1
119       input_shape: [ 200, 5, 20 ]
120     }
121     feature_descriptor {
122       name: "F2"
123       table_id: 0
124       input_shape: [ 50 ]
125     }
126     feature_descriptor {
127       name: "F3"
128       table_id: 0
129       input_shape: [ 100, 2, 3 ]
130     }
131     mode: TRAINING
132     num_hosts: 16
133     num_tensor_cores: 128
134     pipeline_execution_with_tensor_core: true
135   )pb";
136   tpu::TPUEmbeddingConfiguration tpu_embedding_config;
137   TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
138   TF_ASSERT_OK(
139       PopulateMissingFieldsInTPUEmbeddingConfig(&tpu_embedding_config));
140 
141   EXPECT_EQ(tpu_embedding_config.batch_size_per_tensor_core(), 50);
142   const auto& table_0 = tpu_embedding_config.table_descriptor(0);
143   EXPECT_EQ(table_0.num_features(), 23);
144   const auto& table_1 = tpu_embedding_config.table_descriptor(1);
145   EXPECT_EQ(table_1.num_features(), 400);
146 }
147 
TEST(TPUEmbeddingConfigurationProtoRewriteTest,InvalidBatchSizeOrNumFeatures)148 TEST(TPUEmbeddingConfigurationProtoRewriteTest, InvalidBatchSizeOrNumFeatures) {
149   const std::string config_str = R"pb(
150     table_descriptor {
151       name: "T0"
152       vocabulary_size: 35324928
153       dimension: 128
154       num_features: 3
155       optimization_parameters {
156         adagrad {}
157         learning_rate { constant: 0.1 }
158       }
159     }
160     feature_descriptor {
161       table_id: 0
162       input_shape: [ 768 ]
163     }
164     mode: TRAINING
165     batch_size_per_tensor_core: 256
166     num_hosts: 16
167     num_tensor_cores: 128
168     pipeline_execution_with_tensor_core: true
169   )pb";
170   tpu::TPUEmbeddingConfiguration tpu_embedding_config;
171   TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
172   {
173     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
174     invalid_config.clear_feature_descriptor();
175     invalid_config.clear_batch_size_per_tensor_core();
176     EXPECT_THAT(
177         PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
178         tensorflow::testing::StatusIs(
179             tensorflow::error::INVALID_ARGUMENT,
180             ::testing::HasSubstr("Invalid batch_size_per_tensor_core")));
181   }
182   {
183     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
184     invalid_config.clear_feature_descriptor();
185     invalid_config.mutable_table_descriptor(0)->clear_num_features();
186     EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
187                 tensorflow::testing::StatusIs(
188                     tensorflow::error::INVALID_ARGUMENT,
189                     ::testing::HasSubstr("Invalid num_features")));
190   }
191   {
192     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
193     EXPECT_THAT(
194         PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
195         tensorflow::testing::StatusIs(
196             tensorflow::error::INVALID_ARGUMENT,
197             ::testing::HasSubstr(
198                 "The batch_size_per_tensor_core field must NOT be populated")));
199   }
200   {
201     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
202     invalid_config.clear_batch_size_per_tensor_core();
203     EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
204                 tensorflow::testing::StatusIs(
205                     tensorflow::error::INVALID_ARGUMENT,
206                     ::testing::HasSubstr("The TableDescriptor.num_features "
207                                          "field must NOT be populated")));
208   }
209 }
210 
TEST(TPUEmbeddingConfigurationProtoRewriteTest,InvalidFeatureDescriptor)211 TEST(TPUEmbeddingConfigurationProtoRewriteTest, InvalidFeatureDescriptor) {
212   const std::string config_str = R"pb(
213     table_descriptor {
214       name: "T0"
215       vocabulary_size: 35324928
216       dimension: 128
217       optimization_parameters {
218         adagrad {}
219         learning_rate { constant: 0.1 }
220       }
221     }
222     table_descriptor {
223       name: "T1"
224       vocabulary_size: 3122176
225       dimension: 128
226       optimization_parameters {
227         adagrad {}
228         learning_rate { constant: 0.1 }
229       }
230     }
231     feature_descriptor {
232       name: "F1"
233       table_id: 0
234       input_shape: [ 768 ]
235     }
236     feature_descriptor {
237       name: "F2"
238       table_id: 1
239       input_shape: [ 512 ]
240     }
241     mode: TRAINING
242     num_hosts: 16
243     num_tensor_cores: 128
244     pipeline_execution_with_tensor_core: true
245   )pb";
246   tpu::TPUEmbeddingConfiguration tpu_embedding_config;
247   TF_ASSERT_OK(ParseTextProto(config_str, &tpu_embedding_config));
248   {
249     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
250     invalid_config.mutable_feature_descriptor(0)->set_table_id(2);
251     EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
252                 tensorflow::testing::StatusIs(
253                     tensorflow::error::INVALID_ARGUMENT,
254                     ::testing::HasSubstr("Invalid table_id")));
255   }
256   {
257     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
258     invalid_config.mutable_feature_descriptor(0)->clear_input_shape();
259     EXPECT_THAT(
260         PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
261         tensorflow::testing::StatusIs(
262             tensorflow::error::INVALID_ARGUMENT,
263             ::testing::HasSubstr("The input_shape field cannot be empty")));
264   }
265   {
266     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
267     invalid_config.mutable_feature_descriptor(0)->set_input_shape(0, -5);
268     EXPECT_THAT(
269         PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
270         tensorflow::testing::StatusIs(
271             tensorflow::error::INVALID_ARGUMENT,
272             ::testing::HasSubstr("The input_shape dimension sizes must all")));
273   }
274   {
275     tpu::TPUEmbeddingConfiguration invalid_config = tpu_embedding_config;
276     invalid_config.mutable_feature_descriptor(1)->set_table_id(0);
277     EXPECT_THAT(PopulateMissingFieldsInTPUEmbeddingConfig(&invalid_config),
278                 tensorflow::testing::StatusIs(
279                     tensorflow::error::INVALID_ARGUMENT,
280                     ::testing::HasSubstr(
281                         "No feature_descriptor fields found for table: T1")));
282   }
283 }
284 
285 }  // namespace
286 }  // namespace tensorflow
287