Home
last modified time | relevance | path

Searched refs:var_dtype (Results 1 – 11 of 11) sorted by relevance

/external/tensorflow/tensorflow/python/keras/optimizer_v2/
Dadam.py134 def _prepare_local(self, var_device, var_dtype, apply_state): argument
135 super(Adam, self)._prepare_local(var_device, var_dtype, apply_state)
137 local_step = math_ops.cast(self.iterations + 1, var_dtype)
138 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
139 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
142 lr = (apply_state[(var_device, var_dtype)]['lr_t'] *
144 apply_state[(var_device, var_dtype)].update(
148 self.epsilon, var_dtype),
167 var_device, var_dtype = var.device, var.dtype.base_dtype
168 coefficients = ((apply_state or {}).get((var_device, var_dtype))
[all …]
Dnadam.py91 var_dtype = var_list[0].dtype.base_dtype
96 dtype=var_dtype,
109 def _prepare_local(self, var_device, var_dtype, apply_state): argument
110 lr_t = array_ops.identity(self._get_hyper('learning_rate', var_dtype))
111 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
112 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
113 local_step = math_ops.cast(self.iterations + 1, var_dtype)
114 next_step = math_ops.cast(self.iterations + 2, var_dtype)
116 decay_base = math_ops.cast(0.96, var_dtype)
123 m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * m_t
[all …]
Dadamax.py113 def _prepare_local(self, var_device, var_dtype, apply_state): argument
114 super(Adamax, self)._prepare_local(var_device, var_dtype, apply_state)
116 local_step = math_ops.cast(self.iterations + 1, var_dtype)
117 beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
118 beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
120 lr_t = apply_state[(var_device, var_dtype)]['lr_t']
122 apply_state[(var_device, var_dtype)].update(
126 self.epsilon, var_dtype),
134 var_device, var_dtype = var.device, var.dtype.base_dtype
135 coefficients = ((apply_state or {}).get((var_device, var_dtype))
[all …]
Dgradient_descent.py127 def _prepare_local(self, var_device, var_dtype, apply_state): argument
128 super(SGD, self)._prepare_local(var_device, var_dtype, apply_state)
129 apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
130 self._get_hyper("momentum", var_dtype))
133 var_device, var_dtype = var.device, var.dtype.base_dtype
134 coefficients = ((apply_state or {}).get((var_device, var_dtype))
135 or self._fallback_apply_state(var_device, var_dtype))
160 var_device, var_dtype = var.device, var.dtype.base_dtype
161 coefficients = (kwargs.get("apply_state", {}).get((var_device, var_dtype))
162 or self._fallback_apply_state(var_device, var_dtype))
[all …]
Dftrl.py139 def _prepare_local(self, var_device, var_dtype, apply_state): argument
140 super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
141 apply_state[(var_device, var_dtype)].update(
144 self._get_hyper('learning_rate_power', var_dtype)),
146 self._get_hyper('l1_regularization_strength', var_dtype)),
148 self._get_hyper('l2_regularization_strength', var_dtype)),
149 beta=array_ops.identity(self._get_hyper('beta', var_dtype)),
151 self._l2_shrinkage_regularization_strength, var_dtype)))
154 var_device, var_dtype = var.device, var.dtype.base_dtype
155 coefficients = ((apply_state or {}).get((var_device, var_dtype))
[all …]
Dadadelta.py100 def _prepare_local(self, var_device, var_dtype, apply_state): argument
101 super(Adadelta, self)._prepare_local(var_device, var_dtype, apply_state)
102 apply_state[(var_device, var_dtype)].update(
105 self.epsilon, var_dtype),
106 rho=array_ops.identity(self._get_hyper('rho', var_dtype))))
118 var_device, var_dtype = var.device, var.dtype.base_dtype
119 coefficients = ((apply_state or {}).get((var_device, var_dtype))
120 or self._fallback_apply_state(var_device, var_dtype))
135 var_device, var_dtype = var.device, var.dtype.base_dtype
136 coefficients = ((apply_state or {}).get((var_device, var_dtype))
[all …]
Dadagrad.py86 def _prepare_local(self, var_device, var_dtype, apply_state): argument
87 super(Adagrad, self)._prepare_local(var_device, var_dtype, apply_state)
88 apply_state[(var_device, var_dtype)].update(
91 self.epsilon, var_dtype),
92 neg_lr_t=-apply_state[(var_device, var_dtype)]['lr_t'],
128 var_device, var_dtype = var.device, var.dtype.base_dtype
129 coefficients = ((apply_state or {}).get((var_device, var_dtype))
130 or self._fallback_apply_state(var_device, var_dtype))
142 var_device, var_dtype = var.device, var.dtype.base_dtype
143 coefficients = ((apply_state or {}).get((var_device, var_dtype))
[all …]
Drmsprop.py163 def _prepare_local(self, var_device, var_dtype, apply_state): argument
164 super(RMSprop, self)._prepare_local(var_device, var_dtype, apply_state)
166 rho = array_ops.identity(self._get_hyper("rho", var_dtype))
167 apply_state[(var_device, var_dtype)].update(
169 neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"],
171 self.epsilon, var_dtype),
173 momentum=array_ops.identity(self._get_hyper("momentum", var_dtype)),
177 var_device, var_dtype = var.device, var.dtype.base_dtype
178 coefficients = ((apply_state or {}).get((var_device, var_dtype))
179 or self._fallback_apply_state(var_device, var_dtype))
[all …]
Doptimizer_v2.py933 var_dtype = var.dtype.base_dtype
935 keys.add((var_device, var_dtype))
938 for var_device, var_dtype in keys:
939 apply_state[(var_device, var_dtype)] = {}
941 self._prepare_local(var_device, var_dtype, apply_state)
945 def _prepare_local(self, var_device, var_dtype, apply_state): argument
947 lr_t = array_ops.identity(self._decayed_lr(var_dtype))
948 apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
950 def _fallback_apply_state(self, var_device, var_dtype): argument
952 apply_state = {(var_device, var_dtype): {}}
[all …]
/external/tensorflow/tensorflow/python/keras/mixed_precision/
Dautocast_variable_test.py440 var_dtype = None
442 nonlocal var_dtype
443 var_dtype = x._cast_dtype
447 self.assertEqual(var_dtype, dtypes.float32)
/external/tensorflow/tensorflow/python/distribute/coordinator/
Dcluster_coordinator_test.py604 var_dtype = dtypes.float32
609 initial_value=0.0, dtype=var_dtype, name=var_name)
622 var._type_spec = tensor_spec.TensorSpec(var_shape, var_dtype, var_name)