1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9
10 namespace at::functorch {
11
12 namespace{
13 std::tuple<Tensor, std::optional<int64_t>>
clone_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,std::optional<MemoryFormat> memory_format)14 clone_batch_rule(
15 const Tensor& self,
16 std::optional<int64_t> self_bdim,
17 std::optional<MemoryFormat> memory_format) {
18 // Memory format support is a little tricky because vmap is allowed to move
19 // around batch dimensions and some memory formats are rank-dependent.
20 // Another weird case is:
21 // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
22 // allow the user to clone a Tensor with 3 logical dimensions and 1 batch
23 // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
24 // and N>1 batch dims?
25 TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
26 || memory_format == MemoryFormat::Contiguous,
27 "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
28 "memory_format torch.preserve_format or torch.contiguous_format (got ",
29 *memory_format, ")");
30
31 if (memory_format == MemoryFormat::Contiguous) {
32 // There is an ambiguity here when the batch dims are not at the front of
33 // the tensor.
34 // >>> x = torch.randn(3, B0, 5)
35 // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
36 // >>> y[0].is_contiguous()
37 // ???
38 // Should we make the whole tensor contiguous, or should we
39 // make the non-batch dims contiguous? We've chosen the latter because
40 // philosophically vmap hides the batch dims and operates on a per-sample level.
41 auto self_ = moveBatchDimToFront(self, self_bdim);
42 auto result = at::clone(self_, memory_format);
43 return std::make_tuple(result, 0);
44 }
45
46 TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
47 auto result = at::clone(self, memory_format);
48 return std::make_tuple(result, self_bdim);
49 }
50
51 std::tuple<Tensor, std::optional<int64_t>>
view_as_complex_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)52 view_as_complex_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim) {
53 // guard against the user passing in a batch of scalar tensors with batch
54 // size equal to 2.
55 TORCH_CHECK(self.sym_sizes().size() > 1, "Input tensor must have one or more dimensions");
56
57 auto self_ = moveBatchDimToFront(self, self_bdim);
58 auto result = at::view_as_complex(self_);
59 return std::make_tuple(result, 0);
60 }
61
62 }
63
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)64 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
65
66 #define UNARY_POINTWISE_ALL2(op, overload) \
67 POINTWISE_BOXED2(op ## _, overload); \
68 VMAP_SUPPORT2(op, overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
69 #define UNARY_POINTWISE_ALL(op) \
70 POINTWISE_BOXED(op ## _); \
71 VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
72
73 UNARY_POINTWISE(view_as_real);
74 VMAP_SUPPORT(view_as_complex, view_as_complex_batch_rule);
75 VMAP_SUPPORT(clone, clone_batch_rule);
76
77 UNARY_POINTWISE(_to_copy);
78 UNARY_POINTWISE(alias);
79 UNARY_POINTWISE_ALL(abs);
80 UNARY_POINTWISE_ALL(acos);
81 UNARY_POINTWISE_ALL(acosh);
82 UNARY_POINTWISE(angle);
83 UNARY_POINTWISE_ALL(asin);
84 UNARY_POINTWISE_ALL(asinh);
85 UNARY_POINTWISE_ALL(atan);
86 UNARY_POINTWISE_ALL(atanh);
87 UNARY_POINTWISE_ALL(bitwise_not);
88 UNARY_POINTWISE_ALL(ceil);
89 UNARY_POINTWISE_ALL(cos);
90 UNARY_POINTWISE_ALL(cosh);
91 UNARY_POINTWISE(_conj);
92 UNARY_POINTWISE_ALL(deg2rad);
93 UNARY_POINTWISE(detach);
94 UNARY_POINTWISE_ALL(digamma);
95 UNARY_POINTWISE_ALL(erf);
96 UNARY_POINTWISE_ALL(exp);
97 UNARY_POINTWISE_ALL(expm1);
98 UNARY_POINTWISE_ALL(floor);
99 UNARY_POINTWISE_ALL(frac);
100 UNARY_POINTWISE(isnan);
101 UNARY_POINTWISE(isinf);
102 UNARY_POINTWISE(isposinf);
103 UNARY_POINTWISE(isneginf);
104 UNARY_POINTWISE_ALL(lgamma);
105 UNARY_POINTWISE_ALL(log);
106 UNARY_POINTWISE_ALL(log10);
107 UNARY_POINTWISE_ALL(log1p);
108 UNARY_POINTWISE_ALL(log2);
109 UNARY_POINTWISE_ALL(logical_not);
110 UNARY_POINTWISE_ALL(logit);
111 UNARY_POINTWISE_ALL(mish);
112 UNARY_POINTWISE_ALL(mvlgamma);
113 UNARY_POINTWISE_ALL(nan_to_num);
114 UNARY_POINTWISE_ALL(neg);
115 UNARY_POINTWISE_ALL(rad2deg);
116 UNARY_POINTWISE_ALL(reciprocal);
117 UNARY_POINTWISE_ALL(round);
118 UNARY_POINTWISE_ALL2(round, decimals);
119 UNARY_POINTWISE_ALL(rsqrt);
120 UNARY_POINTWISE_ALL(sgn);
121 UNARY_POINTWISE_ALL(sign);
122 UNARY_POINTWISE(signbit);
123 UNARY_POINTWISE_ALL(sin);
124 UNARY_POINTWISE_ALL(sinc);
125 UNARY_POINTWISE_ALL(sinh);
126 UNARY_POINTWISE_ALL(sqrt);
127 UNARY_POINTWISE_ALL(tan);
128 UNARY_POINTWISE_ALL(threshold);
129 UNARY_POINTWISE_ALL(trunc);
130
131 // special-related
132 UNARY_POINTWISE_ALL(i0);
133 UNARY_POINTWISE_ALL(erfc);
134 UNARY_POINTWISE_ALL(erfinv);
135 UNARY_POINTWISE_ALL(exp2);
136
137 // torch.special.* functions
138 UNARY_POINTWISE(special_entr);
139 UNARY_POINTWISE(special_erfcx);
140 UNARY_POINTWISE(special_i0e);
141 UNARY_POINTWISE(special_i1);
142 UNARY_POINTWISE(special_i1e);
143 UNARY_POINTWISE(special_ndtri);
144 POINTWISE_BOXED(special_bessel_j0);
145 POINTWISE_BOXED(special_spherical_bessel_j0);
146 POINTWISE_BOXED(special_bessel_j1);
147 POINTWISE_BOXED(special_modified_bessel_i0);
148 POINTWISE_BOXED(special_modified_bessel_i1);
149 POINTWISE_BOXED(special_scaled_modified_bessel_k0);
150 POINTWISE_BOXED(special_modified_bessel_k0);
151 POINTWISE_BOXED(special_scaled_modified_bessel_k1);
152 POINTWISE_BOXED(special_modified_bessel_k1);
153 POINTWISE_BOXED(special_bessel_y0);
154 POINTWISE_BOXED(special_bessel_y1);
155
156 // Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity)
157 UNARY_POINTWISE_ALL(elu);
158 UNARY_POINTWISE(hardshrink);
159 UNARY_POINTWISE_ALL(hardsigmoid);
160 UNARY_POINTWISE_ALL(hardtanh);
161 UNARY_POINTWISE_ALL(hardswish);
162 UNARY_POINTWISE_ALL(leaky_relu);
163 UNARY_POINTWISE_ALL(relu);
164 UNARY_POINTWISE_ALL(celu);
165 UNARY_POINTWISE(gelu);
166 UNARY_POINTWISE_ALL(sigmoid);
167 UNARY_POINTWISE_ALL(silu);
168 UNARY_POINTWISE(softplus);
169 UNARY_POINTWISE(softshrink);
170 UNARY_POINTWISE_ALL(tanh);
171
172 POINTWISE_BOXED(fill_.Scalar);
173 POINTWISE_BOXED(zero_);
174
175 #undef UNARY_POINTWISE
176 #undef UNARY_POINTWISE_ALL
177
178 }
179
180 #undef INVOKE
181 } // namespace at::functorch
182