1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/core/lib/math/math_util.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/status.h"
23
24 namespace tensorflow {
25 namespace {
26
27 // Validates that the batch_size_per_tensor_core and
28 // TableDescriptor.num_features fields have been populated correctly in the TPU
29 // embedding configuration.
ValidateBatchSizeAndFeatureCounts(const tpu::TPUEmbeddingConfiguration & config)30 Status ValidateBatchSizeAndFeatureCounts(
31 const tpu::TPUEmbeddingConfiguration& config) {
32 if (config.batch_size_per_tensor_core() <= 0) {
33 return errors::InvalidArgument(absl::StrFormat(
34 "Invalid batch_size_per_tensor_core: %d found in the TPU embedding "
35 "configuration. Valid values are >0.",
36 config.batch_size_per_tensor_core()));
37 }
38 for (const auto& table_config : config.table_descriptor()) {
39 if (table_config.num_features() <= 0) {
40 return errors::InvalidArgument(absl::StrFormat(
41 "Invalid num_features: %d found for table: %s in the TPU embedding "
42 "configuration. Valid values are >0.",
43 table_config.num_features(), table_config.name()));
44 }
45 } // table_config
46 return OkStatus();
47 }
48
49 // Validates that the batch_size_per_tensor_core and
50 // TableDescriptor.num_features fields are NOT populated in the TPU embedding
51 // configuration when the feature descriptor fields are filled in.
ValidateBatchSizeAndFeatureCountsAreEmpty(const tpu::TPUEmbeddingConfiguration & config)52 Status ValidateBatchSizeAndFeatureCountsAreEmpty(
53 const tpu::TPUEmbeddingConfiguration& config) {
54 if (config.batch_size_per_tensor_core() != 0) {
55 return errors::InvalidArgument(
56 "Invalid TPU embedding configuration. The batch_size_per_tensor_core "
57 "field must NOT be populated when the feature_descriptor fields are "
58 "filled in.");
59 }
60 for (const auto& table_config : config.table_descriptor()) {
61 if (table_config.num_features() != 0) {
62 return errors::InvalidArgument(absl::StrFormat(
63 "Invalid TPU embedding configuration. The "
64 "TableDescriptor.num_features field must NOT be populated when the "
65 "feature_descriptor fields are filled in, num_features is set to %d "
66 "for table %s.",
67 table_config.num_features(), table_config.name()));
68 }
69 } // table_config
70 return OkStatus();
71 }
72
73 // Validates that the feature_descriptor fields have been correctly filled in.
74 // All tables must have at least one input feature.
ValidateFeatureDescriptors(const tpu::TPUEmbeddingConfiguration & config)75 Status ValidateFeatureDescriptors(
76 const tpu::TPUEmbeddingConfiguration& config) {
77 const int table_count = config.table_descriptor_size();
78 std::vector<bool> tables_present(table_count, false);
79
80 for (const auto& feature_config : config.feature_descriptor()) {
81 const int table_id = feature_config.table_id();
82 const auto& input_shape = feature_config.input_shape();
83 if (table_id < 0 || table_id >= table_count) {
84 return errors::InvalidArgument(absl::StrFormat(
85 "Invalid table_id: %d found in feature_descriptor: %s, all table_ids "
86 "must be in the range[0, %d)",
87 table_id, feature_config.ShortDebugString(), table_count));
88 }
89 if (input_shape.empty()) {
90 return errors::InvalidArgument(absl::StrFormat(
91 "The input_shape field cannot be empty in feature_descriptor: %s",
92 feature_config.ShortDebugString()));
93 }
94 for (const int dim_size : input_shape) {
95 if (dim_size <= 0) {
96 return errors::InvalidArgument(absl::StrFormat(
97 "The input_shape dimension sizes must all be >0 in "
98 "feature_descriptor: %s, found dimension size set to %d",
99 feature_config.ShortDebugString(), dim_size));
100 }
101 }
102 tables_present[table_id] = true;
103 } // feature_config
104
105 for (int table_id = 0; table_id < table_count; ++table_id) {
106 if (!tables_present[table_id]) {
107 return errors::InvalidArgument(absl::StrFormat(
108 "No feature_descriptor fields found for table: %s (ID: %d) in "
109 "the TPU embedding configuration.",
110 config.table_descriptor(table_id).name(), table_id));
111 }
112 }
113 return OkStatus();
114 }
115
116 // Populates the feature_descriptor fields with default values when they have
117 // not been filled in by the user.
PopulateFeatureDescriptors(tpu::TPUEmbeddingConfiguration * config)118 void PopulateFeatureDescriptors(tpu::TPUEmbeddingConfiguration* config) {
119 for (int table_id = 0; table_id < config->table_descriptor_size();
120 ++table_id) {
121 tpu::TPUEmbeddingConfiguration::FeatureDescriptor* feature_descriptor =
122 config->add_feature_descriptor();
123 feature_descriptor->set_table_id(table_id);
124 feature_descriptor->add_input_shape(
125 config->batch_size_per_tensor_core() *
126 config->table_descriptor(table_id).num_features());
127 } // table_id
128 }
129
130 // Computes the input feature batch size based on the input feature shape. As
131 // we treat the last dimension as the reduction dimension, the batch size should
132 // be the product of all the axes except the last one.
ComputeInputFeatureBatchSizes(const tpu::TPUEmbeddingConfiguration & config)133 std::vector<int> ComputeInputFeatureBatchSizes(
134 const tpu::TPUEmbeddingConfiguration& config) {
135 std::vector<int32> input_feature_batch_sizes;
136 for (int i = 0; i < config.feature_descriptor_size(); ++i) {
137 const int32 batch_size =
138 absl::c_accumulate(config.feature_descriptor(i).input_shape(),
139 /*init=*/1, std::multiplies<>());
140 input_feature_batch_sizes.push_back(batch_size);
141 }
142 return input_feature_batch_sizes;
143 }
144
145 // Computes the TensorCore batch size as the GCD of all input feature batch
146 // sizes.
ComputeBatchSizePerTensorCore(absl::Span<const int> input_feature_batch_sizes)147 int ComputeBatchSizePerTensorCore(
148 absl::Span<const int> input_feature_batch_sizes) {
149 uint32_t batch_size = input_feature_batch_sizes[0];
150 for (const uint32_t input_feature_batch_size : input_feature_batch_sizes) {
151 batch_size =
152 tensorflow::MathUtil::GCD(batch_size, input_feature_batch_size);
153 }
154 return batch_size;
155 }
156
157 // Computes the TPU feature counts per user table as the sum of the TPU feature
158 // counts of the constituent input features. The TPU feature count for an input
159 // feature is the ratio of the batch size for that input feature to the batch
160 // size per TensorCore.
ComputeTpuFeatureCounts(const tpu::TPUEmbeddingConfiguration & config,absl::Span<const int> input_feature_batch_sizes,int batch_size_per_tensor_core)161 std::vector<int> ComputeTpuFeatureCounts(
162 const tpu::TPUEmbeddingConfiguration& config,
163 absl::Span<const int> input_feature_batch_sizes,
164 int batch_size_per_tensor_core) {
165 DCHECK_EQ(input_feature_batch_sizes.size(), config.feature_descriptor_size());
166 std::vector<int> tpu_feature_counts(config.table_descriptor_size(), 0);
167 for (int i = 0; i < config.feature_descriptor_size(); ++i) {
168 DCHECK_EQ(input_feature_batch_sizes[i] % batch_size_per_tensor_core, 0);
169 tpu_feature_counts[config.feature_descriptor(i).table_id()] +=
170 (input_feature_batch_sizes[i] / batch_size_per_tensor_core);
171 }
172 return tpu_feature_counts;
173 }
174
175 // Populates default values for batch_size_per_tensor_core and
176 // TableDescriptor.num_features when they have not been filled in by the user.
177 // The batch_size_per_tensor_core is computed as the GCD of the batch sizes of
178 // all input features.
PopulateBatchSizeAndFeatureCounts(tpu::TPUEmbeddingConfiguration * config)179 void PopulateBatchSizeAndFeatureCounts(tpu::TPUEmbeddingConfiguration* config) {
180 const std::vector<int> input_feature_batch_sizes =
181 ComputeInputFeatureBatchSizes(*config);
182 const int batch_size_per_tensor_core =
183 ComputeBatchSizePerTensorCore(input_feature_batch_sizes);
184 const std::vector<int> tpu_feature_counts = ComputeTpuFeatureCounts(
185 *config, input_feature_batch_sizes, batch_size_per_tensor_core);
186 config->set_batch_size_per_tensor_core(batch_size_per_tensor_core);
187 for (int table_id = 0; table_id < config->table_descriptor_size();
188 ++table_id) {
189 auto* table_config = config->mutable_table_descriptor(table_id);
190 table_config->set_num_features(tpu_feature_counts[table_id]);
191 } // table_id
192 }
193
194 } // namespace
195
PopulateMissingFieldsInTPUEmbeddingConfig(tpu::TPUEmbeddingConfiguration * config)196 Status PopulateMissingFieldsInTPUEmbeddingConfig(
197 tpu::TPUEmbeddingConfiguration* config) {
198 if (config->feature_descriptor_size() == 0) {
199 // If the feature_descriptor list is empty, validate that the batch size and
200 // feature counts have been set properly. then, populate the
201 // feature_descriptor with appropriate values.
202 TF_RETURN_IF_ERROR(ValidateBatchSizeAndFeatureCounts(*config));
203 PopulateFeatureDescriptors(config);
204 } else {
205 // If the feature_descriptor list is non-empty, validate that the batch size
206 // and feature counts have NOT been populated. Also, validate that the
207 // feature descriptors have been set properly. Then, populate the batch size
208 // and feature counts with appropriate values.
209 TF_RETURN_IF_ERROR(ValidateBatchSizeAndFeatureCountsAreEmpty(*config));
210 TF_RETURN_IF_ERROR(ValidateFeatureDescriptors(*config));
211 PopulateBatchSizeAndFeatureCounts(config);
212 }
213 return OkStatus();
214 }
215
216 } // namespace tensorflow
217