• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/core/lib/math/math_util.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/status.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 // Validates that the batch_size_per_tensor_core and
28 // TableDescriptor.num_features fields have been populated correctly in the TPU
29 // embedding configuration.
ValidateBatchSizeAndFeatureCounts(const tpu::TPUEmbeddingConfiguration & config)30 Status ValidateBatchSizeAndFeatureCounts(
31     const tpu::TPUEmbeddingConfiguration& config) {
32   if (config.batch_size_per_tensor_core() <= 0) {
33     return errors::InvalidArgument(absl::StrFormat(
34         "Invalid batch_size_per_tensor_core: %d found in the TPU embedding "
35         "configuration. Valid values are >0.",
36         config.batch_size_per_tensor_core()));
37   }
38   for (const auto& table_config : config.table_descriptor()) {
39     if (table_config.num_features() <= 0) {
40       return errors::InvalidArgument(absl::StrFormat(
41           "Invalid num_features: %d found for table: %s in the TPU embedding "
42           "configuration. Valid values are >0.",
43           table_config.num_features(), table_config.name()));
44     }
45   }  // table_config
46   return OkStatus();
47 }
48 
49 // Validates that the batch_size_per_tensor_core and
50 // TableDescriptor.num_features fields are NOT populated in the TPU embedding
51 // configuration when the feature descriptor fields are filled in.
ValidateBatchSizeAndFeatureCountsAreEmpty(const tpu::TPUEmbeddingConfiguration & config)52 Status ValidateBatchSizeAndFeatureCountsAreEmpty(
53     const tpu::TPUEmbeddingConfiguration& config) {
54   if (config.batch_size_per_tensor_core() != 0) {
55     return errors::InvalidArgument(
56         "Invalid TPU embedding configuration. The batch_size_per_tensor_core "
57         "field must NOT be populated when the feature_descriptor fields are "
58         "filled in.");
59   }
60   for (const auto& table_config : config.table_descriptor()) {
61     if (table_config.num_features() != 0) {
62       return errors::InvalidArgument(absl::StrFormat(
63           "Invalid TPU embedding configuration. The "
64           "TableDescriptor.num_features field must NOT be populated when the "
65           "feature_descriptor fields are filled in, num_features is set to %d "
66           "for table %s.",
67           table_config.num_features(), table_config.name()));
68     }
69   }  // table_config
70   return OkStatus();
71 }
72 
73 // Validates that the feature_descriptor fields have been correctly filled in.
74 // All tables must have at least one input feature.
ValidateFeatureDescriptors(const tpu::TPUEmbeddingConfiguration & config)75 Status ValidateFeatureDescriptors(
76     const tpu::TPUEmbeddingConfiguration& config) {
77   const int table_count = config.table_descriptor_size();
78   std::vector<bool> tables_present(table_count, false);
79 
80   for (const auto& feature_config : config.feature_descriptor()) {
81     const int table_id = feature_config.table_id();
82     const auto& input_shape = feature_config.input_shape();
83     if (table_id < 0 || table_id >= table_count) {
84       return errors::InvalidArgument(absl::StrFormat(
85           "Invalid table_id: %d found in feature_descriptor: %s, all table_ids "
86           "must be in the range[0, %d)",
87           table_id, feature_config.ShortDebugString(), table_count));
88     }
89     if (input_shape.empty()) {
90       return errors::InvalidArgument(absl::StrFormat(
91           "The input_shape field cannot be empty in feature_descriptor: %s",
92           feature_config.ShortDebugString()));
93     }
94     for (const int dim_size : input_shape) {
95       if (dim_size <= 0) {
96         return errors::InvalidArgument(absl::StrFormat(
97             "The input_shape dimension sizes must all be >0 in "
98             "feature_descriptor: %s, found dimension size set to %d",
99             feature_config.ShortDebugString(), dim_size));
100       }
101     }
102     tables_present[table_id] = true;
103   }  // feature_config
104 
105   for (int table_id = 0; table_id < table_count; ++table_id) {
106     if (!tables_present[table_id]) {
107       return errors::InvalidArgument(absl::StrFormat(
108           "No feature_descriptor fields found for table: %s (ID: %d) in "
109           "the TPU embedding configuration.",
110           config.table_descriptor(table_id).name(), table_id));
111     }
112   }
113   return OkStatus();
114 }
115 
116 // Populates the feature_descriptor fields with default values when they have
117 // not been filled in by the user.
PopulateFeatureDescriptors(tpu::TPUEmbeddingConfiguration * config)118 void PopulateFeatureDescriptors(tpu::TPUEmbeddingConfiguration* config) {
119   for (int table_id = 0; table_id < config->table_descriptor_size();
120        ++table_id) {
121     tpu::TPUEmbeddingConfiguration::FeatureDescriptor* feature_descriptor =
122         config->add_feature_descriptor();
123     feature_descriptor->set_table_id(table_id);
124     feature_descriptor->add_input_shape(
125         config->batch_size_per_tensor_core() *
126         config->table_descriptor(table_id).num_features());
127   }  // table_id
128 }
129 
130 // Computes the input feature batch size based on the input feature shape. As
131 // we treat the last dimension as the reduction dimension, the batch size should
132 // be the product of all the axes except the last one.
ComputeInputFeatureBatchSizes(const tpu::TPUEmbeddingConfiguration & config)133 std::vector<int> ComputeInputFeatureBatchSizes(
134     const tpu::TPUEmbeddingConfiguration& config) {
135   std::vector<int32> input_feature_batch_sizes;
136   for (int i = 0; i < config.feature_descriptor_size(); ++i) {
137     const int32 batch_size =
138         absl::c_accumulate(config.feature_descriptor(i).input_shape(),
139                            /*init=*/1, std::multiplies<>());
140     input_feature_batch_sizes.push_back(batch_size);
141   }
142   return input_feature_batch_sizes;
143 }
144 
145 // Computes the TensorCore batch size as the GCD of all input feature batch
146 // sizes.
ComputeBatchSizePerTensorCore(absl::Span<const int> input_feature_batch_sizes)147 int ComputeBatchSizePerTensorCore(
148     absl::Span<const int> input_feature_batch_sizes) {
149   uint32_t batch_size = input_feature_batch_sizes[0];
150   for (const uint32_t input_feature_batch_size : input_feature_batch_sizes) {
151     batch_size =
152         tensorflow::MathUtil::GCD(batch_size, input_feature_batch_size);
153   }
154   return batch_size;
155 }
156 
157 // Computes the TPU feature counts per user table as the sum of the TPU feature
158 // counts of the constituent input features. The TPU feature count for an input
159 // feature is the ratio of the batch size for that input feature to the batch
160 // size per TensorCore.
ComputeTpuFeatureCounts(const tpu::TPUEmbeddingConfiguration & config,absl::Span<const int> input_feature_batch_sizes,int batch_size_per_tensor_core)161 std::vector<int> ComputeTpuFeatureCounts(
162     const tpu::TPUEmbeddingConfiguration& config,
163     absl::Span<const int> input_feature_batch_sizes,
164     int batch_size_per_tensor_core) {
165   DCHECK_EQ(input_feature_batch_sizes.size(), config.feature_descriptor_size());
166   std::vector<int> tpu_feature_counts(config.table_descriptor_size(), 0);
167   for (int i = 0; i < config.feature_descriptor_size(); ++i) {
168     DCHECK_EQ(input_feature_batch_sizes[i] % batch_size_per_tensor_core, 0);
169     tpu_feature_counts[config.feature_descriptor(i).table_id()] +=
170         (input_feature_batch_sizes[i] / batch_size_per_tensor_core);
171   }
172   return tpu_feature_counts;
173 }
174 
175 // Populates default values for batch_size_per_tensor_core and
176 // TableDescriptor.num_features when they have not been filled in by the user.
177 // The batch_size_per_tensor_core is computed as the GCD of the batch sizes of
178 // all input features.
PopulateBatchSizeAndFeatureCounts(tpu::TPUEmbeddingConfiguration * config)179 void PopulateBatchSizeAndFeatureCounts(tpu::TPUEmbeddingConfiguration* config) {
180   const std::vector<int> input_feature_batch_sizes =
181       ComputeInputFeatureBatchSizes(*config);
182   const int batch_size_per_tensor_core =
183       ComputeBatchSizePerTensorCore(input_feature_batch_sizes);
184   const std::vector<int> tpu_feature_counts = ComputeTpuFeatureCounts(
185       *config, input_feature_batch_sizes, batch_size_per_tensor_core);
186   config->set_batch_size_per_tensor_core(batch_size_per_tensor_core);
187   for (int table_id = 0; table_id < config->table_descriptor_size();
188        ++table_id) {
189     auto* table_config = config->mutable_table_descriptor(table_id);
190     table_config->set_num_features(tpu_feature_counts[table_id]);
191   }  // table_id
192 }
193 
194 }  // namespace
195 
PopulateMissingFieldsInTPUEmbeddingConfig(tpu::TPUEmbeddingConfiguration * config)196 Status PopulateMissingFieldsInTPUEmbeddingConfig(
197     tpu::TPUEmbeddingConfiguration* config) {
198   if (config->feature_descriptor_size() == 0) {
199     // If the feature_descriptor list is empty, validate that the batch size and
200     // feature counts have been set properly. then, populate the
201     // feature_descriptor with appropriate values.
202     TF_RETURN_IF_ERROR(ValidateBatchSizeAndFeatureCounts(*config));
203     PopulateFeatureDescriptors(config);
204   } else {
205     // If the feature_descriptor list is non-empty, validate that the batch size
206     // and feature counts have NOT been populated. Also, validate that the
207     // feature descriptors have been set properly. Then, populate the batch size
208     // and feature counts with appropriate values.
209     TF_RETURN_IF_ERROR(ValidateBatchSizeAndFeatureCountsAreEmpty(*config));
210     TF_RETURN_IF_ERROR(ValidateFeatureDescriptors(*config));
211     PopulateBatchSizeAndFeatureCounts(config);
212   }
213   return OkStatus();
214 }
215 
216 }  // namespace tensorflow
217