• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/xla_data.pb.h"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/strings/stringprintf.h"
25 
26 namespace tensorflow {
27 namespace tpu {
28 
GetOptimizationAlgorithmName(OptimizationAlgorithm alg)29 string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
30   switch (alg) {
31     case OptimizationAlgorithm::kAdagrad:
32       return "Adagrad";
33     case OptimizationAlgorithm::kBoundedAdagrad:
34       return "BoundedAdagrad";
35     case OptimizationAlgorithm::kStochasticGradientDescent:
36       return "StochasticGradientDescent";
37     case OptimizationAlgorithm::kFtrl:
38       return "FTRL";
39     case OptimizationAlgorithm::kAdam:
40       return "ADAM";
41     case OptimizationAlgorithm::kMomentum:
42       return "Momentum";
43     case OptimizationAlgorithm::kRmsProp:
44       return "RMSProp";
45     case OptimizationAlgorithm::kCenteredRmsProp:
46       return "CenteredRMSProp";
47     case OptimizationAlgorithm::kMdlAdagradLight:
48       return "MDLAdagradLight";
49     case OptimizationAlgorithm::kAdadelta:
50       return "Adadelta";
51     case OptimizationAlgorithm::kProximalAdagrad:
52       return "ProximalAdagrad";
53     case OptimizationAlgorithm::kOnlineYogi:
54       return "OnlineYogi";
55     case OptimizationAlgorithm::kProximalYogi:
56       return "ProximalYogi";
57     case OptimizationAlgorithm::kFrequencyEstimator:
58       return "FrequencyEstimator";
59     case OptimizationAlgorithm::kUserDefinedProgram:
60       return "UserDefinedProgram";
61     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
62       return "*** Not set ***";
63   }
64   return "*** Not set ***";
65 }
66 
GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg)67 string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
68   switch (alg) {
69     case OptimizationAlgorithm::kAdagrad:
70       return "Adagrad";
71     case OptimizationAlgorithm::kBoundedAdagrad:
72       return "Bounded Adagrad";
73     case OptimizationAlgorithm::kStochasticGradientDescent:
74       return "stochastic gradient descent";
75     case OptimizationAlgorithm::kFtrl:
76       return "FTRL";
77     case OptimizationAlgorithm::kAdam:
78       return "ADAM";
79     case OptimizationAlgorithm::kMomentum:
80       return "Momentum";
81     case OptimizationAlgorithm::kRmsProp:
82       return "RMSProp";
83     case OptimizationAlgorithm::kCenteredRmsProp:
84       return "centered RMSProp";
85     case OptimizationAlgorithm::kMdlAdagradLight:
86       return "MDL Adagrad Light";
87     case OptimizationAlgorithm::kAdadelta:
88       return "Adadelta";
89     case OptimizationAlgorithm::kProximalAdagrad:
90       return "proximal Adagrad";
91     case OptimizationAlgorithm::kOnlineYogi:
92       return "online Yogi";
93     case OptimizationAlgorithm::kProximalYogi:
94       return "proximal Yogi";
95     case OptimizationAlgorithm::kFrequencyEstimator:
96       return "frequency estimator";
97     case OptimizationAlgorithm::kUserDefinedProgram:
98       return "UserDefinedProgram";
99     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
100       return "unknown (not specified)";
101   }
102   return "unknown (not specified)";
103 }
104 
105 // Returns the number of optimization parameter vectors used by the optimization
106 // algorithm, excluding the weights themselves and assuming no gradient
107 // accumulation.
GetBaseAuxiliaryParameterCount(const OptimizationParameters & params,int * count)108 Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
109                                       int* count) {
110   switch (params.parameters_case()) {
111     case OptimizationAlgorithm::kAdagrad:
112       *count = 1;
113       return Status::OK();
114     case OptimizationAlgorithm::kBoundedAdagrad:
115       *count = 1;
116       return Status::OK();
117     case OptimizationAlgorithm::kStochasticGradientDescent:
118       *count = 0;
119       return Status::OK();
120     case OptimizationAlgorithm::kFtrl:
121       *count = 2;
122       return Status::OK();
123     case OptimizationAlgorithm::kAdam:
124       *count = 2;
125       return Status::OK();
126     case OptimizationAlgorithm::kMomentum:
127       *count = 1;
128       return Status::OK();
129     case OptimizationAlgorithm::kRmsProp:
130       *count = 2;
131       return Status::OK();
132     case OptimizationAlgorithm::kCenteredRmsProp:
133       *count = 3;
134       return Status::OK();
135     case OptimizationAlgorithm::kMdlAdagradLight:
136       *count = 3;
137       return Status::OK();
138     case OptimizationAlgorithm::kAdadelta:
139       *count = 2;
140       return Status::OK();
141     case OptimizationAlgorithm::kProximalAdagrad:
142       *count = 1;
143       return Status::OK();
144     case OptimizationAlgorithm::kOnlineYogi:
145       *count = 2;
146       return Status::OK();
147     case OptimizationAlgorithm::kProximalYogi:
148       *count = 2;
149       return Status::OK();
150     case OptimizationAlgorithm::kFrequencyEstimator:
151       *count = 1;
152       return Status::OK();
153     case OptimizationAlgorithm::kUserDefinedProgram: {
154       const xla::ProgramShapeProto& program_shape =
155           params.user_defined_program().program().host_program_shape();
156 
157       const int num_inputs = program_shape.parameters_size();
158       const int num_outputs = program_shape.result().tuple_shapes_size();
159 
160       if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
161                                (num_inputs != num_outputs + 2))) {
162         return errors::InvalidArgument(
163             "User-defined TPU embedding optimizer program must have at least "
164             "two inputs and the number of outputs must be 1 or 2 less than the "
165             "number of inputs. Received ",
166             num_inputs, " input(s) and ", num_outputs, "output(s).");
167       }
168 
169       *count = num_outputs - 1;
170 
171       return Status::OK();
172     }
173     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
174       return errors::InvalidArgument("No optimization algorithm specified");
175   }
176   return errors::InvalidArgument("No optimization algorithm specified");
177 }
178 
GetGradientAccumulationSupport(const OptimizationParameters & params,GradientAccumulationSupport * support)179 Status GetGradientAccumulationSupport(const OptimizationParameters& params,
180                                       GradientAccumulationSupport* support) {
181   int auxiliary_parameter_count;
182   TF_RETURN_IF_ERROR(
183       GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
184   *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
185                  ? GradientAccumulationSupport::kSupported
186                  : GradientAccumulationSupport::kNotSupported;
187   return Status::OK();
188 }
189 
190 namespace {
191 // Make a normal state variable specification. Please refer to
192 // //tensorflow/core/protobuf/tpu/optimization_parameters.proto
193 // (StateVariableSpecification message) for instructions on how to set the
194 // padding_initial_value field.
MakeStandardStateVariableSpecification(const string & name,double padding_initial_value)195 StateVariableSpecification MakeStandardStateVariableSpecification(
196     const string& name, double padding_initial_value) {
197   StateVariableSpecification result;
198   result.set_name(name);
199   result.mutable_user_defined()->set_padding_initial_value(
200       padding_initial_value);
201   return result;
202 }
203 }  // namespace
204 
GetOptimizationAlgorithmStateVariables(const OptimizationParameters & params,bool use_gradient_accumulation,std::vector<StateVariableSpecification> * state_variables)205 Status GetOptimizationAlgorithmStateVariables(
206     const OptimizationParameters& params, bool use_gradient_accumulation,
207     std::vector<StateVariableSpecification>* state_variables) {
208   // The order of the returned parameters needs to match the offsets used by
209   // the algorithm implementations in test_util.cc and
210   // address_handler_program_creator.cc.
211   // The first parameter set is always the weights themselves.
212   auto add_state_variable = [&](const std::string& name, float value) {
213     state_variables->push_back(
214         MakeStandardStateVariableSpecification(name, value));
215   };
216   switch (params.parameters_case()) {
217     case OptimizationAlgorithm::kAdagrad: {
218       add_state_variable("parameters", 0.0);
219       add_state_variable("accumulators", 0.1);
220       break;
221     }
222     case OptimizationAlgorithm::kBoundedAdagrad: {
223       add_state_variable("parameters", 0.0);
224       add_state_variable("accumulators", 0.1);
225       break;
226     }
227     case OptimizationAlgorithm::kStochasticGradientDescent: {
228       add_state_variable("parameters", 0.0);
229       break;
230     }
231     case OptimizationAlgorithm::kFtrl: {
232       add_state_variable("parameters", 0.0);
233       add_state_variable("accumulators", 0.1);
234       add_state_variable("linears", 0.0);
235       break;
236     }
237     case OptimizationAlgorithm::kAdam: {
238       add_state_variable("parameters", 0.0);
239       add_state_variable("momenta", 0.0);
240       add_state_variable("velocities", 0.0);
241       break;
242     }
243     case OptimizationAlgorithm::kMomentum: {
244       add_state_variable("parameters", 0.0);
245       add_state_variable("momenta", 0.0);
246       break;
247     }
248     case OptimizationAlgorithm::kRmsProp: {
249       add_state_variable("parameters", 0.0);
250       add_state_variable("ms", 1.0);
251       add_state_variable("mom", 0.0);
252       break;
253     }
254     case OptimizationAlgorithm::kCenteredRmsProp: {
255       add_state_variable("parameters", 0.0);
256       add_state_variable("ms", 1.0);
257       add_state_variable("mom", 0.0);
258       add_state_variable("mg", 0.0);
259       break;
260     }
261     case OptimizationAlgorithm::kMdlAdagradLight: {
262       add_state_variable("parameters", 0.0);
263       add_state_variable("accumulators", 0.1);
264       add_state_variable("weights", 0.0);
265       add_state_variable("benefits", 0.0);
266       break;
267     }
268     case OptimizationAlgorithm::kAdadelta: {
269       add_state_variable("parameters", 0.0);
270       add_state_variable("accumulators", 0.0);
271       add_state_variable("updates", 0.0);
272       break;
273     }
274     case OptimizationAlgorithm::kProximalAdagrad: {
275       add_state_variable("parameters", 0.0);
276       add_state_variable("accumulators", 0.1);
277       break;
278     }
279     case OptimizationAlgorithm::kOnlineYogi: {
280       add_state_variable("parameters", 0.0);
281       add_state_variable("vs", 0.1);
282       add_state_variable("linears", 0.0);
283       break;
284     }
285     case OptimizationAlgorithm::kProximalYogi: {
286       add_state_variable("parameters", 0.0);
287       add_state_variable("v", 0.1);
288       add_state_variable("m", 0.0);
289       break;
290     }
291     case OptimizationAlgorithm::kFrequencyEstimator: {
292       add_state_variable("parameters", 0.0);
293       add_state_variable("last_hit_step", 0);
294       break;
295     }
296     case OptimizationAlgorithm::kUserDefinedProgram: {
297       add_state_variable("parameters",
298                          params.user_defined_program().padding_values(0));
299       int num_slots = -1;
300       TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
301       if (num_slots + 1 !=
302           params.user_defined_program().padding_values_size()) {
303         return errors::InvalidArgument(
304             "Number of slots does not agree with the number of padding values "
305             "specified.");
306       }
307       for (int i = 0; i < num_slots; ++i) {
308         add_state_variable(absl::StrCat("Slot_", i),
309                            params.user_defined_program().padding_values(i + 1));
310       }
311       break;
312     }
313     case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
314       return errors::InvalidArgument("No optimization algorithm specified");
315     }
316   }
317   // This needs to be last so that the save/restore ops do not need to know
318   // about gradient accumulation.
319   if (use_gradient_accumulation) {
320     StateVariableSpecification gradient_acc;
321     gradient_acc.set_name("gradient_accumulators");
322     gradient_acc.mutable_fill_with_constant()->set_initial_value(
323         GradientAccumulatorInitialValue());
324     state_variables->push_back(std::move(gradient_acc));
325   }
326   if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
327     return errors::InvalidArgument(
328         "Optimization algorithm",
329         GetOptimizationAlgorithmName(params.parameters_case()),
330         "does not support gradient accumulation because it "
331         "already has too many other accumulators");
332   }
333   return Status::OK();
334 }  // namespace tpu
335 
GetOptimizationAlgorithms()336 std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
337   return {
338       OptimizationAlgorithm::kAdagrad,
339       OptimizationAlgorithm::kBoundedAdagrad,
340       OptimizationAlgorithm::kStochasticGradientDescent,
341       OptimizationAlgorithm::kFtrl,
342       OptimizationAlgorithm::kAdam,
343       OptimizationAlgorithm::kMomentum,
344       OptimizationAlgorithm::kRmsProp,
345       OptimizationAlgorithm::kCenteredRmsProp,
346       OptimizationAlgorithm::kMdlAdagradLight,
347       OptimizationAlgorithm::kAdadelta,
348       OptimizationAlgorithm::kProximalAdagrad,
349       OptimizationAlgorithm::kOnlineYogi,
350       OptimizationAlgorithm::kProximalYogi,
351       OptimizationAlgorithm::kFrequencyEstimator,
352       OptimizationAlgorithm::kUserDefinedProgram,
353   };
354 }
355 
operator ()(shape_inference::InferenceContext * c) const356 Status LoadOpShapeFunction::operator()(
357     shape_inference::InferenceContext* c) const {
358   int table_id;
359   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
360   string table_name;
361   TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
362   // Exactly one must be non-default.
363   if ((table_id >= 0) == (!table_name.empty())) {
364     return errors::InvalidArgument(
365         "exactly one of table_id or table_name must be non-default");
366   }
367   int num_shards;
368   TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
369   int shard_id;
370   TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
371 
372   // Verify shapes have rank 2 and are compatible when they are
373   // required to be valid.
374   shape_inference::ShapeHandle parameter_shape;
375   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &parameter_shape));
376   for (int j = 1; j < c->num_inputs(); ++j) {
377     shape_inference::ShapeHandle accumulator_j_shape;
378     TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
379     shape_inference::ShapeHandle merged;
380     TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
381   }
382 
383   return Status::OK();
384 }
385 
operator ()(shape_inference::InferenceContext * c) const386 Status RetrieveOpShapeFunction::operator()(
387     shape_inference::InferenceContext* c) const {
388   int table_id;
389   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
390   string table_name;
391   TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
392   // Exactly one must be non-default.
393   if ((table_id >= 0) == (!table_name.empty())) {
394     return errors::InvalidArgument(
395         "exactly one of table_id or table_name must be non-default");
396   }
397   int num_shards;
398   TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
399   int shard_id;
400   TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
401   for (int j = 0; j < c->num_outputs(); ++j) {
402     c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
403                          2, c->UnknownDim())));
404   }
405   return Status::OK();
406 }
407 
408 }  // namespace tpu
409 }  // namespace tensorflow
410