1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Cross.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/WrapDimUtils.h>
7 #include <ATen/ExpandUtils.h>
8 #include <ATen/native/Resize.h>
9 
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/cross_native.h>
16 #include <ATen/ops/linalg_cross.h>
17 #include <ATen/ops/linalg_cross_native.h>
18 #endif
19 
20 namespace at::meta {
21 
TORCH_META_FUNC(linalg_cross)22 TORCH_META_FUNC(linalg_cross)
23 (const Tensor & input, const Tensor & other, int64_t dim) {
24   auto x_d = input.dim();
25   auto y_d = other.dim();
26   // This is to avoid things like
27   // linalg.cross(torch.randn(2, 3), torch.randn(5, 2, 3), dim=2)
28   TORCH_CHECK(x_d == y_d, "linalg.cross: inputs must have the same number of dimensions.");
29   TORCH_CHECK(input.size(dim) == 3 && other.size(dim) == 3, "linalg.cross: inputs dimension ", dim, " must have length 3. Got ", input.size(dim), " and ", other.size(dim));
30 
31   // Broadcast the batch dimension of input and other.
32   // Since the non-batch dimensions agree, this is the same as broadcast all the inputs
33   auto out_size = infer_size(input.sizes(), other.sizes());
34 
35   set_output_raw_strided(0, out_size, {}, input.options());
36 }
37 
38 } // namespace at::meta
39 namespace at::native {
40 
41 DEFINE_DISPATCH(cross_stub);
42 
_default_cross_dim(const std::optional<int64_t> & dimension,SymIntArrayRef sizes)43 static int64_t _default_cross_dim(const std::optional<int64_t> &dimension, SymIntArrayRef sizes) {
44   // If dimension is not given, it defaults to the first dimension found with the size 3.
45   // Note that this behaviour might be unexpected.
46   // _default_cross_dim is called internally inside the cross implementation to calculate
47   // the dim and finally cross delegates to the linalg_cross implementation with this dim
48   if(dimension.has_value()) {
49     return *dimension;
50   }
51 
52   for(auto i : c10::irange(sizes.size())) {
53     if(sizes[i] == 3) {
54       return i;
55     }
56   }
57   TORCH_CHECK(false, "no dimension of size 3 in input");
58 }
59 
cross(const Tensor & input,const Tensor & other,const std::optional<int64_t> dimension)60 Tensor cross(const Tensor & input, const Tensor & other, const std::optional<int64_t> dimension) {
61   if (!dimension) {
62     TORCH_WARN_ONCE(
63       "Using torch.cross without specifying the dim arg is deprecated.\n",
64       "Please either pass the dim explicitly or simply use torch.linalg.cross.\n",
65       "The default value of dim will change to agree with that of linalg.cross in a future release."
66     );
67   }
68   auto dim = _default_cross_dim(dimension, input.sym_sizes());
69   return at::linalg_cross(input, other, dim);
70 }
71 
cross_out(const Tensor & input,const Tensor & other,const std::optional<int64_t> dimension,Tensor & out)72 Tensor & cross_out(const Tensor & input, const Tensor & other, const std::optional<int64_t> dimension, Tensor & out) {
73   auto dim = _default_cross_dim(dimension, input.sym_sizes());
74   return at::linalg_cross_out(out, input, other, dim);
75 }
76 
77 
TORCH_IMPL_FUNC(linalg_cross_out)78 TORCH_IMPL_FUNC(linalg_cross_out)
79 (const Tensor & input, const Tensor & other, int64_t dim, const Tensor & out) {
80   dim = maybe_wrap_dim(dim, input.dim());
81   auto out_size = out.sizes();
82   Tensor input_broadcasted = input.expand(out_size);
83   Tensor other_broadcasted = other.expand(out_size);
84 
85   cross_stub(input.device().type(), out, input_broadcasted, other_broadcasted, dim);
86 }
87 
88 } // namespace at::native
89