Home
last modified time | relevance | path

Searched refs:grad_fn (Results 1 – 17 of 17) sorted by relevance

/external/tensorflow/tensorflow/cc/gradients/
Dgrad_testutil.cc26 ops::GradFunc grad_fn; in CallGradFunction() local
28 op.node()->type_string(), &grad_fn)); in CallGradFunction()
29 TF_RETURN_IF_ERROR(grad_fn(scope, op, grad_inputs, grad_outputs)); in CallGradFunction()
/external/tensorflow/tensorflow/python/ops/
Dcustom_gradient.py330 result, grad_fn = f(*args)
379 grad_argspec = tf_inspect.getfullargspec(grad_fn)
398 input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
403 input_grads = grad_fn(*result_grads)
438 result, grad_fn = f(*args, **kwargs)
448 grad_argspec = tf_inspect.getfullargspec(grad_fn)
468 input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
473 input_grads = grad_fn(*result_grads)
Dgradients_util.py322 def _MaybeCompile(scope, op, func, grad_fn): argument
340 return grad_fn() # Exit early
356 return grad_fn()
598 grad_fn = None
608 grad_fn = ops.get_gradient_function(op)
634 grad_fn = func_call.python_grad_func
656 if (grad_fn or is_func_call) and has_out_grads:
662 (not grad_fn and is_func_call)
680 if grad_fn:
684 lambda: grad_fn(op, *out_grads))
Dgradient_checker_v2.py169 grad_fn = _prepare(lambda dy, *xs: grad_fn_unprep(*xs, dy=dy),
174 grad = _to_numpy(grad_fn(dy_data, *xs)[0])
190 grad = _to_numpy(grad_fn(dy_data, *xs)[0])
Dgradient_checker_v2_test.py224 def grad_fn(dy): function
229 return y, grad_fn
246 def grad_fn(dy): function
251 return y, grad_fn
/external/tensorflow/tensorflow/python/eager/
Dbackprop.py148 grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access
149 if grad_fn is None:
159 return grad_fn(mock_op, *out_grads)
161 return grad_fn(mock_op, *out_grads)
228 def grad_fn(*args, **kwds): function
256 return grad_fn
301 def grad_fn(*args, **kwds): function
305 return grad_fn
Dfunction_gradients_test.py429 grad_fn = backprop.implicit_grad(sum_gather)
430 gradient = grad_fn()
Dbackprop_test.py226 grad_fn = backprop.gradients_function(f)
227 self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
1603 def grad_fn(x): function
1606 grad_ops = grad_fn.get_concrete_function(
/external/tensorflow/tensorflow/cc/framework/
Dgradients.cc375 ops::GradFunc grad_fn; in IsPrimitiveOpWithNoGrad() local
376 Status s = registry_->Lookup(opname, &grad_fn); in IsPrimitiveOpWithNoGrad()
377 return s.ok() && (grad_fn == nullptr); in IsPrimitiveOpWithNoGrad()
384 ops::GradFunc grad_fn; in CallGradFunction() local
385 TF_RETURN_IF_ERROR(registry_->Lookup(op.node()->type_string(), &grad_fn)); in CallGradFunction()
386 TF_RETURN_IF_ERROR(grad_fn(scope_, op, grad_inputs, grad_outputs)); in CallGradFunction()
/external/tensorflow/tensorflow/python/distribute/
Dstrategy_test_lib.py155 grad_fn = backprop.implicit_grad(loss)
156 grad_fn = optimizer.get_filtered_grad_fn(grad_fn)
166 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
213 grad_fn = backprop.implicit_grad(loss)
223 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
Dcollective_all_reduce_strategy_test.py134 def grad_fn(x): function
151 g_v = d.extended.call_for_each_replica(grad_fn, args=[one])
Dparameter_server_strategy_test.py463 def grad_fn(x): function
480 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
/external/tensorflow/tensorflow/python/kernel_tests/
Dcholesky_op_test.py357 def _BenchmarkGrad(grad_fn, name, device): argument
366 grad = grad_fn(l, grad_matrix)
Dwhile_v2_test.py117 …def grad_fn(dy, variables=None): # pylint: disable=invalid-name, unused-argument, redefined-outer… function
120 return v * v * m, grad_fn
Dtensor_array_ops_test.py1015 grad_fn = backprop.gradients_function(func)
1016 v0_grad, state0_grad, var_grad = grad_fn(v0, state0, var, dy=grad_val)
/external/tensorflow/tensorflow/python/training/
Doptimizer.py49 def get_filtered_grad_fn(grad_fn): argument
60 return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None]
/external/tensorflow/tensorflow/compiler/tests/
Deager_test.py163 grad_fn = backprop.gradients_function(f)
164 self.assertAllEqual(2., grad_fn(1., dy=2.)[0])