• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
5 #include <ATen/native/mkldnn/xpu/detail/Attr.h>
6 #include <ATen/native/mkldnn/xpu/detail/Utils.h>
7 
8 namespace at::native::onednn{
9 
10 TORCH_API sycl::event matmul(
11     at::Tensor& result,
12     const at::Tensor& mat1,
13     const at::Tensor& mat2,
14     const at::Tensor& b_raw,
15     bool m2_trans,
16     Attr attr,
17     const std::vector<sycl::event>& deps = {});
18 
19 TORCH_API sycl::event convolution(
20     at::Tensor& dst,
21     const at::Tensor& src,
22     const at::Tensor& weight,
23     const at::Tensor& bia,
24     IntArrayRef padding_front_top_left,
25     IntArrayRef padding_back_bottom_right,
26     IntArrayRef stride,
27     IntArrayRef dilation,
28     int64_t groups,
29     Attr& attr,
30     const std::vector<sycl::event>& deps = {});
31 
32 TORCH_API sycl::event convolution_backward_weights(
33     at::Tensor& diff_weight,
34     at::Tensor& diff_bia,
35     const at::Tensor& diff_dst,
36     const at::Tensor& src,
37     IntArrayRef diff_weight_aten_size,
38     IntArrayRef padding_front_top_left,
39     IntArrayRef padding_back_bottom_right,
40     IntArrayRef stride,
41     IntArrayRef dilation,
42     int64_t groups,
43     const std::vector<sycl::event>& deps = {});
44 
45 TORCH_API sycl::event convolution_backward_data(
46     at::Tensor& diff_src,
47     const at::Tensor& diff_dst,
48     const at::Tensor& weight,
49     IntArrayRef padding_front_top_left,
50     IntArrayRef padding_back_bottom_right,
51     IntArrayRef stride,
52     IntArrayRef dilation,
53     int64_t groups,
54     bool bias_defined,
55     const std::vector<sycl::event>& deps = {});
56 
57 TORCH_API sycl::event deconvolution(
58     at::Tensor& dst,
59     const at::Tensor& src,
60     const at::Tensor& weight,
61     const at::Tensor& bia,
62     IntArrayRef stride,
63     IntArrayRef padding,
64     IntArrayRef dst_padding,
65     IntArrayRef dilation,
66     int64_t groups,
67     Attr& attr,
68     const std::vector<sycl::event>& deps = {});
69 
70 TORCH_API sycl::event deconvolution_backward_data(
71     at::Tensor& diff_src,
72     const at::Tensor& diff_dst,
73     const at::Tensor& weight,
74     IntArrayRef stride,
75     IntArrayRef padding,
76     IntArrayRef dilation,
77     int64_t groups,
78     bool bias_defined,
79     const std::vector<sycl::event>& deps = {});
80 
81 TORCH_API sycl::event deconvolution_backward_weights(
82     at::Tensor& diff_weight,
83     at::Tensor& diff_bia,
84     const at::Tensor& diff_dst,
85     const at::Tensor& src,
86     IntArrayRef stride,
87     IntArrayRef padding,
88     IntArrayRef dilation,
89     int64_t groups,
90     const std::vector<sycl::event>& deps = {});
91 
92 dnnl::memory::dims conv_dst_size(
93     int64_t ndim,
94     IntArrayRef src_tz,
95     IntArrayRef wgh_tz,
96     IntArrayRef padding_front_top_left,
97     IntArrayRef padding_back_bottom_right,
98     IntArrayRef stride,
99     IntArrayRef dilation);
100 
101 dnnl::memory::dims deconv_dst_size(
102     IntArrayRef src_size,
103     IntArrayRef wgh_size,
104     IntArrayRef padding,
105     IntArrayRef stride,
106     IntArrayRef dilation,
107     IntArrayRef dst_padding,
108     int64_t groups);
109 
110 } // namespace at::native::onednn
111