Searched refs:triton_kernel (Results 1 – 4 of 4) sorted by relevance
/external/pytorch/torch/_library/ |
D | triton.py | 162 def capture_triton(triton_kernel: Callable, /) -> Callable: 227 if not isinstance(triton_kernel, (JITFunction, Autotuner)): 232 return triton_kernel 233 return TraceableTritonKernelWrapper(triton_kernel, None, None)
|
/external/pytorch/torch/_inductor/ |
D | wrapper_benchmark.py | 84 triton_kernel = get_triton_kernel(kernel_mod) 90 for arg_name in triton_kernel.fn.arg_names 94 num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) 127 len(triton_kernel.launchers) == 1 129 launcher = triton_kernel.launchers[0]
|
/external/pytorch/aten/src/ATen/native/sparse/ |
D | SparseBlasImpl.cpp | 114 …const auto triton_kernel = triton_schema.value().typed<Tensor&(const Tensor&, const Tensor&, Tenso… in _compressed_row_strided_mm_out() local 115 if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) { in _compressed_row_strided_mm_out() 116 return triton_kernel.call(compressed, strided, result); in _compressed_row_strided_mm_out() 243 …const auto triton_kernel = triton_schema.value().typed<Tensor&(const Tensor&, const Tensor&, const… in _compressed_row_strided_addmm_out() local 244 if (triton_kernel.hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) { in _compressed_row_strided_addmm_out() 246 return triton_kernel.call(self, mat1, mat2, beta, alpha, result); in _compressed_row_strided_addmm_out()
|
/external/pytorch/torch/_inductor/codegen/ |
D | triton_combo_kernel.py | 457 def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: 458 sub_kernel = triton_kernel
|