#pragma once #include #include #include #include namespace at::native::onednn{ TORCH_API sycl::event matmul( at::Tensor& result, const at::Tensor& mat1, const at::Tensor& mat2, const at::Tensor& b_raw, bool m2_trans, Attr attr, const std::vector& deps = {}); TORCH_API sycl::event convolution( at::Tensor& dst, const at::Tensor& src, const at::Tensor& weight, const at::Tensor& bia, IntArrayRef padding_front_top_left, IntArrayRef padding_back_bottom_right, IntArrayRef stride, IntArrayRef dilation, int64_t groups, Attr& attr, const std::vector& deps = {}); TORCH_API sycl::event convolution_backward_weights( at::Tensor& diff_weight, at::Tensor& diff_bia, const at::Tensor& diff_dst, const at::Tensor& src, IntArrayRef diff_weight_aten_size, IntArrayRef padding_front_top_left, IntArrayRef padding_back_bottom_right, IntArrayRef stride, IntArrayRef dilation, int64_t groups, const std::vector& deps = {}); TORCH_API sycl::event convolution_backward_data( at::Tensor& diff_src, const at::Tensor& diff_dst, const at::Tensor& weight, IntArrayRef padding_front_top_left, IntArrayRef padding_back_bottom_right, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined, const std::vector& deps = {}); TORCH_API sycl::event deconvolution( at::Tensor& dst, const at::Tensor& src, const at::Tensor& weight, const at::Tensor& bia, IntArrayRef stride, IntArrayRef padding, IntArrayRef dst_padding, IntArrayRef dilation, int64_t groups, Attr& attr, const std::vector& deps = {}); TORCH_API sycl::event deconvolution_backward_data( at::Tensor& diff_src, const at::Tensor& diff_dst, const at::Tensor& weight, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups, bool bias_defined, const std::vector& deps = {}); TORCH_API sycl::event deconvolution_backward_weights( at::Tensor& diff_weight, at::Tensor& diff_bia, const at::Tensor& diff_dst, const at::Tensor& src, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups, const std::vector& deps = {}); dnnl::memory::dims conv_dst_size( int64_t ndim, IntArrayRef src_tz, IntArrayRef wgh_tz, IntArrayRef padding_front_top_left, IntArrayRef padding_back_bottom_right, IntArrayRef stride, IntArrayRef dilation); dnnl::memory::dims deconv_dst_size( IntArrayRef src_size, IntArrayRef wgh_size, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, IntArrayRef dst_padding, int64_t groups); } // namespace at::native::onednn