1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h> 5 #include <ATen/native/mkldnn/xpu/detail/Attr.h> 6 #include <ATen/native/mkldnn/xpu/detail/Utils.h> 7 8 namespace at::native::onednn{ 9 10 TORCH_API sycl::event matmul( 11 at::Tensor& result, 12 const at::Tensor& mat1, 13 const at::Tensor& mat2, 14 const at::Tensor& b_raw, 15 bool m2_trans, 16 Attr attr, 17 const std::vector<sycl::event>& deps = {}); 18 19 TORCH_API sycl::event convolution( 20 at::Tensor& dst, 21 const at::Tensor& src, 22 const at::Tensor& weight, 23 const at::Tensor& bia, 24 IntArrayRef padding_front_top_left, 25 IntArrayRef padding_back_bottom_right, 26 IntArrayRef stride, 27 IntArrayRef dilation, 28 int64_t groups, 29 Attr& attr, 30 const std::vector<sycl::event>& deps = {}); 31 32 TORCH_API sycl::event convolution_backward_weights( 33 at::Tensor& diff_weight, 34 at::Tensor& diff_bia, 35 const at::Tensor& diff_dst, 36 const at::Tensor& src, 37 IntArrayRef diff_weight_aten_size, 38 IntArrayRef padding_front_top_left, 39 IntArrayRef padding_back_bottom_right, 40 IntArrayRef stride, 41 IntArrayRef dilation, 42 int64_t groups, 43 const std::vector<sycl::event>& deps = {}); 44 45 TORCH_API sycl::event convolution_backward_data( 46 at::Tensor& diff_src, 47 const at::Tensor& diff_dst, 48 const at::Tensor& weight, 49 IntArrayRef padding_front_top_left, 50 IntArrayRef padding_back_bottom_right, 51 IntArrayRef stride, 52 IntArrayRef dilation, 53 int64_t groups, 54 bool bias_defined, 55 const std::vector<sycl::event>& deps = {}); 56 57 TORCH_API sycl::event deconvolution( 58 at::Tensor& dst, 59 const at::Tensor& src, 60 const at::Tensor& weight, 61 const at::Tensor& bia, 62 IntArrayRef stride, 63 IntArrayRef padding, 64 IntArrayRef dst_padding, 65 IntArrayRef dilation, 66 int64_t groups, 67 Attr& attr, 68 const std::vector<sycl::event>& deps = {}); 69 70 TORCH_API sycl::event deconvolution_backward_data( 71 at::Tensor& diff_src, 72 const at::Tensor& diff_dst, 73 const at::Tensor& weight, 74 IntArrayRef stride, 75 IntArrayRef padding, 76 IntArrayRef dilation, 77 int64_t groups, 78 bool bias_defined, 79 const std::vector<sycl::event>& deps = {}); 80 81 TORCH_API sycl::event deconvolution_backward_weights( 82 at::Tensor& diff_weight, 83 at::Tensor& diff_bia, 84 const at::Tensor& diff_dst, 85 const at::Tensor& src, 86 IntArrayRef stride, 87 IntArrayRef padding, 88 IntArrayRef dilation, 89 int64_t groups, 90 const std::vector<sycl::event>& deps = {}); 91 92 dnnl::memory::dims conv_dst_size( 93 int64_t ndim, 94 IntArrayRef src_tz, 95 IntArrayRef wgh_tz, 96 IntArrayRef padding_front_top_left, 97 IntArrayRef padding_back_bottom_right, 98 IntArrayRef stride, 99 IntArrayRef dilation); 100 101 dnnl::memory::dims deconv_dst_size( 102 IntArrayRef src_size, 103 IntArrayRef wgh_size, 104 IntArrayRef padding, 105 IntArrayRef stride, 106 IntArrayRef dilation, 107 IntArrayRef dst_padding, 108 int64_t groups); 109 110 } // namespace at::native::onednn 111