Home
last modified time | relevance | path

Searched refs:compute_dtype (Results 1 – 25 of 89) sorted by relevance

1234

/external/pytorch/torch/_inductor/codegen/
Dcpp_micro_gemm.py73 compute_dtype, argument
82 self.compute_dtype = compute_dtype
88 assert self.compute_dtype == torch.int32
97 "compute_dtype": self.compute_dtype,
101 "compute_t": DTYPE_TO_CPP[self.compute_dtype],
182 compute_dtype: torch.dtype
209 compute_dtype=None, argument
214 if compute_dtype is None:
215 compute_dtype = output_dtype
223 compute_dtype,
[all …]
Dcpp_gemm_template.py580 output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
591 compute_dtype=compute_dtype,
978 output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
989 compute_dtype=compute_dtype,
/external/tensorflow/tensorflow/python/keras/mixed_precision/
Dpolicy.py263 def compute_dtype(self): member in Policy
508 policy.compute_dtype != policy.variable_dtype)
511 if (policy is not None and policy.compute_dtype is not None and
512 not dtypes.as_dtype(policy.compute_dtype).is_floating):
Dlayer_test.py107 self.assertEqual(layer.compute_dtype, dtype)
197 self.assertEqual(layer.compute_dtype, dtypes.float64)
/external/executorch/kernels/optimized/blas/
DCPUBlas.cpp101 using acc_type = utils::compute_dtype<float>; in gemm()
142 using acc_type = utils::compute_dtype<float>; in gemm()
165 using acc_type = utils::compute_dtype<Half>; in gemm()
188 using acc_type = utils::compute_dtype<BFloat16>; in gemm()
DCPUBlas.h129 using acc_type = utils::compute_dtype<T>; in gemm()
/external/tensorflow/tensorflow/python/keras/layers/
Dembeddings.py192 if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
195 out = math_ops.cast(out, self._dtype_policy.compute_dtype)
/external/tensorflow/tensorflow/python/keras/engine/
Dbase_layer_v1.py426 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype
1778 if self._dtype_policy.compute_dtype:
1780 self._dtype_policy.compute_dtype)
1796 return self._dtype_policy.compute_dtype
1810 compute_dtype = self._compute_dtype
1811 if (self._autocast and compute_dtype and
1812 dtypes.as_dtype(compute_dtype).is_floating):
1818 x.dtype.base_dtype.name != compute_dtype):
1819 return math_ops.cast(x, compute_dtype)
1823 return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name)
[all …]
Dbase_layer.py637 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype
2303 if self._dtype_policy.compute_dtype:
2305 self._dtype_policy.compute_dtype)
2318 def compute_dtype(self): member in Layer
2337 return self._dtype_policy.compute_dtype
2342 return self._dtype_policy.compute_dtype
2631 if input_list and self._dtype_policy.compute_dtype is None:
/external/executorch/kernels/optimized/utils/
Dmath_utils.h42 using compute_dtype = typename ComputeDTypeTraits<T>::type;
/external/pytorch/torch/masked/
D_ops.py1602 compute_dtype = dtype
1603 if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
1604 compute_dtype = torch.float32
1624 total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
1627 …x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-unde…
1633 corresponding_real_dtype(compute_dtype)
1634 if compute_dtype.is_complex
1635 else compute_dtype
/external/pytorch/torch/_prims_common/
Dwrappers.py129 compute_dtype, result_dtype = utils.elementwise_dtypes(
135 x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
/external/executorch/kernels/optimized/cpu/
Dmoments_utils.h26 using acc_t = executorch::utils::compute_dtype<T>;
/external/tensorflow/tensorflow/tools/api/golden/v2/
Dtensorflow.metrics.-recall-at-precision.pbtxt17 name: "compute_dtype"
Dtensorflow.metrics.-one-hot-io-u.pbtxt18 name: "compute_dtype"
Dtensorflow.metrics.-specificity-at-sensitivity.pbtxt17 name: "compute_dtype"
Dtensorflow.metrics.-k-l-divergence.pbtxt19 name: "compute_dtype"
Dtensorflow.metrics.-log-cosh-error.pbtxt19 name: "compute_dtype"
Dtensorflow.metrics.-cosine-similarity.pbtxt19 name: "compute_dtype"
Dtensorflow.metrics.-io-u.pbtxt17 name: "compute_dtype"
Dtensorflow.metrics.-binary-accuracy.pbtxt19 name: "compute_dtype"
Dtensorflow.metrics.-squared-hinge.pbtxt19 name: "compute_dtype"
Dtensorflow.metrics.-accuracy.pbtxt19 name: "compute_dtype"
/external/tensorflow/tensorflow/tools/api/golden/v1/
Dtensorflow.layers.-dense.pbtxt18 name: "compute_dtype"
Dtensorflow.layers.-max-pooling2-d.pbtxt19 name: "compute_dtype"

1234