1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
17
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/xla_data.pb.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/strings/stringprintf.h"
25
26 namespace tensorflow {
27 namespace tpu {
28
GetOptimizationAlgorithmName(OptimizationAlgorithm alg)29 string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
30 switch (alg) {
31 case OptimizationAlgorithm::kAdagrad:
32 return "Adagrad";
33 case OptimizationAlgorithm::kBoundedAdagrad:
34 return "BoundedAdagrad";
35 case OptimizationAlgorithm::kStochasticGradientDescent:
36 return "StochasticGradientDescent";
37 case OptimizationAlgorithm::kFtrl:
38 return "FTRL";
39 case OptimizationAlgorithm::kAdam:
40 return "ADAM";
41 case OptimizationAlgorithm::kMomentum:
42 return "Momentum";
43 case OptimizationAlgorithm::kRmsProp:
44 return "RMSProp";
45 case OptimizationAlgorithm::kCenteredRmsProp:
46 return "CenteredRMSProp";
47 case OptimizationAlgorithm::kMdlAdagradLight:
48 return "MDLAdagradLight";
49 case OptimizationAlgorithm::kAdadelta:
50 return "Adadelta";
51 case OptimizationAlgorithm::kProximalAdagrad:
52 return "ProximalAdagrad";
53 case OptimizationAlgorithm::kOnlineYogi:
54 return "OnlineYogi";
55 case OptimizationAlgorithm::kProximalYogi:
56 return "ProximalYogi";
57 case OptimizationAlgorithm::kFrequencyEstimator:
58 return "FrequencyEstimator";
59 case OptimizationAlgorithm::kUserDefinedProgram:
60 return "UserDefinedProgram";
61 case OptimizationAlgorithm::PARAMETERS_NOT_SET:
62 return "*** Not set ***";
63 }
64 return "*** Not set ***";
65 }
66
GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg)67 string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
68 switch (alg) {
69 case OptimizationAlgorithm::kAdagrad:
70 return "Adagrad";
71 case OptimizationAlgorithm::kBoundedAdagrad:
72 return "Bounded Adagrad";
73 case OptimizationAlgorithm::kStochasticGradientDescent:
74 return "stochastic gradient descent";
75 case OptimizationAlgorithm::kFtrl:
76 return "FTRL";
77 case OptimizationAlgorithm::kAdam:
78 return "ADAM";
79 case OptimizationAlgorithm::kMomentum:
80 return "Momentum";
81 case OptimizationAlgorithm::kRmsProp:
82 return "RMSProp";
83 case OptimizationAlgorithm::kCenteredRmsProp:
84 return "centered RMSProp";
85 case OptimizationAlgorithm::kMdlAdagradLight:
86 return "MDL Adagrad Light";
87 case OptimizationAlgorithm::kAdadelta:
88 return "Adadelta";
89 case OptimizationAlgorithm::kProximalAdagrad:
90 return "proximal Adagrad";
91 case OptimizationAlgorithm::kOnlineYogi:
92 return "online Yogi";
93 case OptimizationAlgorithm::kProximalYogi:
94 return "proximal Yogi";
95 case OptimizationAlgorithm::kFrequencyEstimator:
96 return "frequency estimator";
97 case OptimizationAlgorithm::kUserDefinedProgram:
98 return "UserDefinedProgram";
99 case OptimizationAlgorithm::PARAMETERS_NOT_SET:
100 return "unknown (not specified)";
101 }
102 return "unknown (not specified)";
103 }
104
105 // Returns the number of optimization parameter vectors used by the optimization
106 // algorithm, excluding the weights themselves and assuming no gradient
107 // accumulation.
GetBaseAuxiliaryParameterCount(const OptimizationParameters & params,int * count)108 Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
109 int* count) {
110 switch (params.parameters_case()) {
111 case OptimizationAlgorithm::kAdagrad:
112 *count = 1;
113 return Status::OK();
114 case OptimizationAlgorithm::kBoundedAdagrad:
115 *count = 1;
116 return Status::OK();
117 case OptimizationAlgorithm::kStochasticGradientDescent:
118 *count = 0;
119 return Status::OK();
120 case OptimizationAlgorithm::kFtrl:
121 *count = 2;
122 return Status::OK();
123 case OptimizationAlgorithm::kAdam:
124 *count = 2;
125 return Status::OK();
126 case OptimizationAlgorithm::kMomentum:
127 *count = 1;
128 return Status::OK();
129 case OptimizationAlgorithm::kRmsProp:
130 *count = 2;
131 return Status::OK();
132 case OptimizationAlgorithm::kCenteredRmsProp:
133 *count = 3;
134 return Status::OK();
135 case OptimizationAlgorithm::kMdlAdagradLight:
136 *count = 3;
137 return Status::OK();
138 case OptimizationAlgorithm::kAdadelta:
139 *count = 2;
140 return Status::OK();
141 case OptimizationAlgorithm::kProximalAdagrad:
142 *count = 1;
143 return Status::OK();
144 case OptimizationAlgorithm::kOnlineYogi:
145 *count = 2;
146 return Status::OK();
147 case OptimizationAlgorithm::kProximalYogi:
148 *count = 2;
149 return Status::OK();
150 case OptimizationAlgorithm::kFrequencyEstimator:
151 *count = 1;
152 return Status::OK();
153 case OptimizationAlgorithm::kUserDefinedProgram: {
154 const xla::ProgramShapeProto& program_shape =
155 params.user_defined_program().program().host_program_shape();
156
157 const int num_inputs = program_shape.parameters_size();
158 const int num_outputs = program_shape.result().tuple_shapes_size();
159
160 if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
161 (num_inputs != num_outputs + 2))) {
162 return errors::InvalidArgument(
163 "User-defined TPU embedding optimizer program must have at least "
164 "two inputs and the number of outputs must be 1 or 2 less than the "
165 "number of inputs. Received ",
166 num_inputs, " input(s) and ", num_outputs, "output(s).");
167 }
168
169 *count = num_outputs - 1;
170
171 return Status::OK();
172 }
173 case OptimizationAlgorithm::PARAMETERS_NOT_SET:
174 return errors::InvalidArgument("No optimization algorithm specified");
175 }
176 return errors::InvalidArgument("No optimization algorithm specified");
177 }
178
GetGradientAccumulationSupport(const OptimizationParameters & params,GradientAccumulationSupport * support)179 Status GetGradientAccumulationSupport(const OptimizationParameters& params,
180 GradientAccumulationSupport* support) {
181 int auxiliary_parameter_count;
182 TF_RETURN_IF_ERROR(
183 GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
184 *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
185 ? GradientAccumulationSupport::kSupported
186 : GradientAccumulationSupport::kNotSupported;
187 return Status::OK();
188 }
189
190 namespace {
191 // Make a normal state variable specification. Please refer to
192 // //tensorflow/core/protobuf/tpu/optimization_parameters.proto
193 // (StateVariableSpecification message) for instructions on how to set the
194 // padding_initial_value field.
MakeStandardStateVariableSpecification(const string & name,double padding_initial_value)195 StateVariableSpecification MakeStandardStateVariableSpecification(
196 const string& name, double padding_initial_value) {
197 StateVariableSpecification result;
198 result.set_name(name);
199 result.mutable_user_defined()->set_padding_initial_value(
200 padding_initial_value);
201 return result;
202 }
203 } // namespace
204
GetOptimizationAlgorithmStateVariables(const OptimizationParameters & params,bool use_gradient_accumulation,std::vector<StateVariableSpecification> * state_variables)205 Status GetOptimizationAlgorithmStateVariables(
206 const OptimizationParameters& params, bool use_gradient_accumulation,
207 std::vector<StateVariableSpecification>* state_variables) {
208 // The order of the returned parameters needs to match the offsets used by
209 // the algorithm implementations in test_util.cc and
210 // address_handler_program_creator.cc.
211 // The first parameter set is always the weights themselves.
212 auto add_state_variable = [&](const std::string& name, float value) {
213 state_variables->push_back(
214 MakeStandardStateVariableSpecification(name, value));
215 };
216 switch (params.parameters_case()) {
217 case OptimizationAlgorithm::kAdagrad: {
218 add_state_variable("parameters", 0.0);
219 add_state_variable("accumulators", 0.1);
220 break;
221 }
222 case OptimizationAlgorithm::kBoundedAdagrad: {
223 add_state_variable("parameters", 0.0);
224 add_state_variable("accumulators", 0.1);
225 break;
226 }
227 case OptimizationAlgorithm::kStochasticGradientDescent: {
228 add_state_variable("parameters", 0.0);
229 break;
230 }
231 case OptimizationAlgorithm::kFtrl: {
232 add_state_variable("parameters", 0.0);
233 add_state_variable("accumulators", 0.1);
234 add_state_variable("linears", 0.0);
235 break;
236 }
237 case OptimizationAlgorithm::kAdam: {
238 add_state_variable("parameters", 0.0);
239 add_state_variable("momenta", 0.0);
240 add_state_variable("velocities", 0.0);
241 break;
242 }
243 case OptimizationAlgorithm::kMomentum: {
244 add_state_variable("parameters", 0.0);
245 add_state_variable("momenta", 0.0);
246 break;
247 }
248 case OptimizationAlgorithm::kRmsProp: {
249 add_state_variable("parameters", 0.0);
250 add_state_variable("ms", 1.0);
251 add_state_variable("mom", 0.0);
252 break;
253 }
254 case OptimizationAlgorithm::kCenteredRmsProp: {
255 add_state_variable("parameters", 0.0);
256 add_state_variable("ms", 1.0);
257 add_state_variable("mom", 0.0);
258 add_state_variable("mg", 0.0);
259 break;
260 }
261 case OptimizationAlgorithm::kMdlAdagradLight: {
262 add_state_variable("parameters", 0.0);
263 add_state_variable("accumulators", 0.1);
264 add_state_variable("weights", 0.0);
265 add_state_variable("benefits", 0.0);
266 break;
267 }
268 case OptimizationAlgorithm::kAdadelta: {
269 add_state_variable("parameters", 0.0);
270 add_state_variable("accumulators", 0.0);
271 add_state_variable("updates", 0.0);
272 break;
273 }
274 case OptimizationAlgorithm::kProximalAdagrad: {
275 add_state_variable("parameters", 0.0);
276 add_state_variable("accumulators", 0.1);
277 break;
278 }
279 case OptimizationAlgorithm::kOnlineYogi: {
280 add_state_variable("parameters", 0.0);
281 add_state_variable("vs", 0.1);
282 add_state_variable("linears", 0.0);
283 break;
284 }
285 case OptimizationAlgorithm::kProximalYogi: {
286 add_state_variable("parameters", 0.0);
287 add_state_variable("v", 0.1);
288 add_state_variable("m", 0.0);
289 break;
290 }
291 case OptimizationAlgorithm::kFrequencyEstimator: {
292 add_state_variable("parameters", 0.0);
293 add_state_variable("last_hit_step", 0);
294 break;
295 }
296 case OptimizationAlgorithm::kUserDefinedProgram: {
297 add_state_variable("parameters",
298 params.user_defined_program().padding_values(0));
299 int num_slots = -1;
300 TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
301 if (num_slots + 1 !=
302 params.user_defined_program().padding_values_size()) {
303 return errors::InvalidArgument(
304 "Number of slots does not agree with the number of padding values "
305 "specified.");
306 }
307 for (int i = 0; i < num_slots; ++i) {
308 add_state_variable(absl::StrCat("Slot_", i),
309 params.user_defined_program().padding_values(i + 1));
310 }
311 break;
312 }
313 case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
314 return errors::InvalidArgument("No optimization algorithm specified");
315 }
316 }
317 // This needs to be last so that the save/restore ops do not need to know
318 // about gradient accumulation.
319 if (use_gradient_accumulation) {
320 StateVariableSpecification gradient_acc;
321 gradient_acc.set_name("gradient_accumulators");
322 gradient_acc.mutable_fill_with_constant()->set_initial_value(
323 GradientAccumulatorInitialValue());
324 state_variables->push_back(std::move(gradient_acc));
325 }
326 if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
327 return errors::InvalidArgument(
328 "Optimization algorithm",
329 GetOptimizationAlgorithmName(params.parameters_case()),
330 "does not support gradient accumulation because it "
331 "already has too many other accumulators");
332 }
333 return Status::OK();
334 } // namespace tpu
335
GetOptimizationAlgorithms()336 std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
337 return {
338 OptimizationAlgorithm::kAdagrad,
339 OptimizationAlgorithm::kBoundedAdagrad,
340 OptimizationAlgorithm::kStochasticGradientDescent,
341 OptimizationAlgorithm::kFtrl,
342 OptimizationAlgorithm::kAdam,
343 OptimizationAlgorithm::kMomentum,
344 OptimizationAlgorithm::kRmsProp,
345 OptimizationAlgorithm::kCenteredRmsProp,
346 OptimizationAlgorithm::kMdlAdagradLight,
347 OptimizationAlgorithm::kAdadelta,
348 OptimizationAlgorithm::kProximalAdagrad,
349 OptimizationAlgorithm::kOnlineYogi,
350 OptimizationAlgorithm::kProximalYogi,
351 OptimizationAlgorithm::kFrequencyEstimator,
352 OptimizationAlgorithm::kUserDefinedProgram,
353 };
354 }
355
operator ()(shape_inference::InferenceContext * c) const356 Status LoadOpShapeFunction::operator()(
357 shape_inference::InferenceContext* c) const {
358 int table_id;
359 TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
360 string table_name;
361 TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
362 // Exactly one must be non-default.
363 if ((table_id >= 0) == (!table_name.empty())) {
364 return errors::InvalidArgument(
365 "exactly one of table_id or table_name must be non-default");
366 }
367 int num_shards;
368 TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
369 int shard_id;
370 TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
371
372 // Verify shapes have rank 2 and are compatible when they are
373 // required to be valid.
374 shape_inference::ShapeHandle parameter_shape;
375 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, ¶meter_shape));
376 for (int j = 1; j < c->num_inputs(); ++j) {
377 shape_inference::ShapeHandle accumulator_j_shape;
378 TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
379 shape_inference::ShapeHandle merged;
380 TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
381 }
382
383 return Status::OK();
384 }
385
operator ()(shape_inference::InferenceContext * c) const386 Status RetrieveOpShapeFunction::operator()(
387 shape_inference::InferenceContext* c) const {
388 int table_id;
389 TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
390 string table_name;
391 TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
392 // Exactly one must be non-default.
393 if ((table_id >= 0) == (!table_name.empty())) {
394 return errors::InvalidArgument(
395 "exactly one of table_id or table_name must be non-default");
396 }
397 int num_shards;
398 TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
399 int shard_id;
400 TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
401 for (int j = 0; j < c->num_outputs(); ++j) {
402 c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
403 2, c->UnknownDim())));
404 }
405 return Status::OK();
406 }
407
408 } // namespace tpu
409 } // namespace tensorflow
410