1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
17 #define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
18
19 #include <string>
20
21 #include "absl/base/casts.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
25
26 namespace tensorflow {
27 namespace tpu {
28
29 using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
30
31 // Returns the name of the optimization algorithm.
32 string GetOptimizationAlgorithmName(OptimizationAlgorithm alg);
33
34 // Returns a user-friendly name for the optimization algorithm.
35 string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg);
36
37 // Returns all supported optimization algorithms.
38 std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms();
39
40 enum class GradientAccumulationSupport {
41 // Accumulation cannot be used with this optimizer.
42 kNotSupported,
43
44 // Accumulation is allowed and changes optimizer behavior.
45 kSupported,
46 };
47
48 // Returns the number of optimization parameter vectors used by the optimization
49 // algorithm, excluding the weights themselves and assuming no gradient
50 // accumulation.
51 Status GetBaseAuxiliaryParameterCount(const OptimizationParameters ¶ms,
52 int *count);
53
54 // Returns whether (and how) an optimization algorithm supports gradient
55 // accumulation.
56 Status GetGradientAccumulationSupport(const OptimizationParameters ¶ms,
57 GradientAccumulationSupport *support);
58
59 // Returns the parameter specifications for the optimization algorithm (the main
60 // parameters first, followed by any auxiliary parameters such as Adagrad
61 // accumulators).
62 Status GetOptimizationAlgorithmStateVariables(
63 const OptimizationParameters ¶ms, bool use_gradient_accumulation,
64 std::vector<StateVariableSpecification> *state_variables);
65
66 // Maximum value of auxiliar_parameter_count for any optimization algorithm.
67 static constexpr int kMaxAuxiliaryParameterCount = 3;
68
69 // Fill value for gradient accumulators. This is a denormal so that it will be
70 // flushed to zero on the current TPU platforms and needs to continue to have
71 // the following properties in the future:
72 //
73 // 1. Does not have the same bit pattern as a zero and can be distinguished from
74 // it using integer operations.
75 // 2. Treated as zero by floating-point arithmetic operations (at least addition
76 // and subtraction).
77 // 3. Cannot be produced by any floating-point arithmetic operation, including
78 // those involving itself.
79 //
80 // It does not need to compare equal or not equal to zero in floating point. We
81 // need to use a non-zero value here because some optimization algorithms are
82 // not no-ops on zero gradients, so we need to distinguish an accumulated
83 // gradient of zero from one that has been cleared after its gradients have
84 // already been applied to the parameters and accumulators.
GradientAccumulatorInitialValue()85 inline float GradientAccumulatorInitialValue() {
86 return absl::bit_cast<float, uint32>(1);
87 }
88
89 // Generic shape function for per-optimization-algorithm load ops.
90 class LoadOpShapeFunction {
91 public:
92 // Computes resulting shape and does parameter checking.
93 Status operator()(shape_inference::InferenceContext *c) const;
94 };
95
96 // Generic shape function for per-optimization-algorithm retrieve ops.
97 class RetrieveOpShapeFunction {
98 public:
99 // Computes resulting shape and does parameter checking.
100 Status operator()(shape_inference::InferenceContext *c) const;
101 };
102
103 } // namespace tpu
104 } // namespace tensorflow
105
106 #endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
107