• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/functorch/PlumbingHelper.h>
9 
10 namespace at::functorch {
11 
12 namespace{
13 std::tuple<Tensor, std::optional<int64_t>>
clone_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,std::optional<MemoryFormat> memory_format)14 clone_batch_rule(
15     const Tensor& self,
16     std::optional<int64_t> self_bdim,
17     std::optional<MemoryFormat> memory_format) {
18   // Memory format support is a little tricky because vmap is allowed to move
19   // around batch dimensions and some memory formats are rank-dependent.
20   // Another weird case is:
21   // - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
22   //   allow the user to clone a Tensor with 3 logical dimensions and 1 batch
23   //   dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
24   //   and N>1 batch dims?
25   TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
26       || memory_format == MemoryFormat::Contiguous,
27       "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
28       "memory_format torch.preserve_format or torch.contiguous_format (got ",
29       *memory_format, ")");
30 
31   if (memory_format == MemoryFormat::Contiguous) {
32     // There is an ambiguity here when the batch dims are not at the front of
33     // the tensor.
34     // >>> x = torch.randn(3, B0, 5)
35     // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
36     // >>> y[0].is_contiguous()
37     // ???
38     // Should we make the whole tensor contiguous, or should we
39     // make the non-batch dims contiguous? We've chosen the latter because
40     // philosophically vmap hides the batch dims and operates on a per-sample level.
41     auto self_ = moveBatchDimToFront(self, self_bdim);
42     auto result = at::clone(self_, memory_format);
43     return std::make_tuple(result, 0);
44   }
45 
46   TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
47   auto result = at::clone(self, memory_format);
48   return std::make_tuple(result, self_bdim);
49 }
50 
51 std::tuple<Tensor, std::optional<int64_t>>
view_as_complex_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)52 view_as_complex_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim) {
53   // guard against the user passing in a batch of scalar tensors with batch
54   // size equal to 2.
55   TORCH_CHECK(self.sym_sizes().size() > 1, "Input tensor must have one or more dimensions");
56 
57   auto self_ = moveBatchDimToFront(self, self_bdim);
58   auto result = at::view_as_complex(self_);
59   return std::make_tuple(result, 0);
60 }
61 
62 }
63 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)64 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
65 
66 #define UNARY_POINTWISE_ALL2(op, overload) \
67   POINTWISE_BOXED2(op ## _, overload); \
68   VMAP_SUPPORT2(op, overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
69 #define UNARY_POINTWISE_ALL(op) \
70   POINTWISE_BOXED(op ## _); \
71   VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
72 
73   UNARY_POINTWISE(view_as_real);
74   VMAP_SUPPORT(view_as_complex, view_as_complex_batch_rule);
75   VMAP_SUPPORT(clone, clone_batch_rule);
76 
77   UNARY_POINTWISE(_to_copy);
78   UNARY_POINTWISE(alias);
79   UNARY_POINTWISE_ALL(abs);
80   UNARY_POINTWISE_ALL(acos);
81   UNARY_POINTWISE_ALL(acosh);
82   UNARY_POINTWISE(angle);
83   UNARY_POINTWISE_ALL(asin);
84   UNARY_POINTWISE_ALL(asinh);
85   UNARY_POINTWISE_ALL(atan);
86   UNARY_POINTWISE_ALL(atanh);
87   UNARY_POINTWISE_ALL(bitwise_not);
88   UNARY_POINTWISE_ALL(ceil);
89   UNARY_POINTWISE_ALL(cos);
90   UNARY_POINTWISE_ALL(cosh);
91   UNARY_POINTWISE(_conj);
92   UNARY_POINTWISE_ALL(deg2rad);
93   UNARY_POINTWISE(detach);
94   UNARY_POINTWISE_ALL(digamma);
95   UNARY_POINTWISE_ALL(erf);
96   UNARY_POINTWISE_ALL(exp);
97   UNARY_POINTWISE_ALL(expm1);
98   UNARY_POINTWISE_ALL(floor);
99   UNARY_POINTWISE_ALL(frac);
100   UNARY_POINTWISE(isnan);
101   UNARY_POINTWISE(isinf);
102   UNARY_POINTWISE(isposinf);
103   UNARY_POINTWISE(isneginf);
104   UNARY_POINTWISE_ALL(lgamma);
105   UNARY_POINTWISE_ALL(log);
106   UNARY_POINTWISE_ALL(log10);
107   UNARY_POINTWISE_ALL(log1p);
108   UNARY_POINTWISE_ALL(log2);
109   UNARY_POINTWISE_ALL(logical_not);
110   UNARY_POINTWISE_ALL(logit);
111   UNARY_POINTWISE_ALL(mish);
112   UNARY_POINTWISE_ALL(mvlgamma);
113   UNARY_POINTWISE_ALL(nan_to_num);
114   UNARY_POINTWISE_ALL(neg);
115   UNARY_POINTWISE_ALL(rad2deg);
116   UNARY_POINTWISE_ALL(reciprocal);
117   UNARY_POINTWISE_ALL(round);
118   UNARY_POINTWISE_ALL2(round, decimals);
119   UNARY_POINTWISE_ALL(rsqrt);
120   UNARY_POINTWISE_ALL(sgn);
121   UNARY_POINTWISE_ALL(sign);
122   UNARY_POINTWISE(signbit);
123   UNARY_POINTWISE_ALL(sin);
124   UNARY_POINTWISE_ALL(sinc);
125   UNARY_POINTWISE_ALL(sinh);
126   UNARY_POINTWISE_ALL(sqrt);
127   UNARY_POINTWISE_ALL(tan);
128   UNARY_POINTWISE_ALL(threshold);
129   UNARY_POINTWISE_ALL(trunc);
130 
131   // special-related
132   UNARY_POINTWISE_ALL(i0);
133   UNARY_POINTWISE_ALL(erfc);
134   UNARY_POINTWISE_ALL(erfinv);
135   UNARY_POINTWISE_ALL(exp2);
136 
137   // torch.special.* functions
138   UNARY_POINTWISE(special_entr);
139   UNARY_POINTWISE(special_erfcx);
140   UNARY_POINTWISE(special_i0e);
141   UNARY_POINTWISE(special_i1);
142   UNARY_POINTWISE(special_i1e);
143   UNARY_POINTWISE(special_ndtri);
144   POINTWISE_BOXED(special_bessel_j0);
145   POINTWISE_BOXED(special_spherical_bessel_j0);
146   POINTWISE_BOXED(special_bessel_j1);
147   POINTWISE_BOXED(special_modified_bessel_i0);
148   POINTWISE_BOXED(special_modified_bessel_i1);
149   POINTWISE_BOXED(special_scaled_modified_bessel_k0);
150   POINTWISE_BOXED(special_modified_bessel_k0);
151   POINTWISE_BOXED(special_scaled_modified_bessel_k1);
152   POINTWISE_BOXED(special_modified_bessel_k1);
153   POINTWISE_BOXED(special_bessel_y0);
154   POINTWISE_BOXED(special_bessel_y1);
155 
156   // Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity)
157   UNARY_POINTWISE_ALL(elu);
158   UNARY_POINTWISE(hardshrink);
159   UNARY_POINTWISE_ALL(hardsigmoid);
160   UNARY_POINTWISE_ALL(hardtanh);
161   UNARY_POINTWISE_ALL(hardswish);
162   UNARY_POINTWISE_ALL(leaky_relu);
163   UNARY_POINTWISE_ALL(relu);
164   UNARY_POINTWISE_ALL(celu);
165   UNARY_POINTWISE(gelu);
166   UNARY_POINTWISE_ALL(sigmoid);
167   UNARY_POINTWISE_ALL(silu);
168   UNARY_POINTWISE(softplus);
169   UNARY_POINTWISE(softshrink);
170   UNARY_POINTWISE_ALL(tanh);
171 
172   POINTWISE_BOXED(fill_.Scalar);
173   POINTWISE_BOXED(zero_);
174 
175 #undef UNARY_POINTWISE
176 #undef UNARY_POINTWISE_ALL
177 
178 }
179 
180 #undef INVOKE
181 } // namespace at::functorch
182