1 #include <ATen/Context.h>
2 #include <torch/library.h>
3
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/core/jit_type.h>
7 #include <c10/core/DefaultDtype.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/api/include/torch/utils.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/jit/runtime/custom_operator.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13 #include <torch/csrc/jit/runtime/vararg_functions.h>
14
15 #include <ATen/InitialTensorOptions.h>
16 #include <c10/core/ScalarType.h>
17 #include <torch/csrc/jit/frontend/error_report.h>
18
19 #include <sstream>
20
21 namespace torch::jit {
22
23 namespace {
24
aliasAnalysisFromSchema()25 c10::AliasAnalysisKind aliasAnalysisFromSchema() {
26 return c10::AliasAnalysisKind::FROM_SCHEMA;
27 }
28
aliasAnalysisConservative()29 c10::AliasAnalysisKind aliasAnalysisConservative() {
30 return c10::AliasAnalysisKind::CONSERVATIVE;
31 }
32
checkListInputType(const c10::TypePtr & elem_type,bool empty_list)33 void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
34 if (!elem_type->isSubtypeOf(*NumberType::get()) &&
35 !elem_type->isSubtypeOf(*BoolType::get())) {
36 std::stringstream error;
37 error << "Input must be of ints, floats, or bools, "
38 << "got " << elem_type->repr_str();
39 // special case empty list torch.tensor([])
40 if (elem_type->isSubtypeOf(*TensorType::get())) {
41 if (empty_list) {
42 error << "\nEmpty lists default to List[Tensor]. Add a variable "
43 "annotation to the assignment to create an empty list "
44 "of another type (torch.jit.annotate(List[T, []]) where T "
45 "is the type of elements in the list for Python 2)";
46 }
47 }
48 throw std::runtime_error(error.str());
49 }
50 }
51
castTensorTo(at::Tensor self,const IValue & dtype,const IValue & device)52 at::Tensor castTensorTo(
53 at::Tensor self,
54 const IValue& dtype,
55 const IValue& device) {
56 at::ScalarType scalar_type =
57 dtype.isNone() ? self.scalar_type() : dtype.toScalarType();
58 c10::Device dev = device.isNone() ? self.device() : device.toDevice();
59 if (scalar_type != self.scalar_type() || dev != self.device()) {
60 self = self.to(dev, scalar_type);
61 }
62 return self;
63 }
64
compute_sizes(const IValue & seq)65 std::vector<int64_t> compute_sizes(const IValue& seq) {
66 std::vector<int64_t> sizes;
67 auto seq_recur = seq.toList();
68 while (true) {
69 sizes.push_back(seq_recur.size());
70 if (seq_recur.empty() || !seq_recur.get(0).isList()) {
71 break;
72 }
73 seq_recur = seq_recur.get(0).toList();
74 }
75 return sizes;
76 }
77
checkSequenceSize(int64_t n,int64_t dim,int64_t seq_size)78 void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
79 if (seq_size != n) {
80 AT_ERROR(
81 "Expected sequence of length ",
82 n,
83 " at dim ",
84 dim,
85 " (got ",
86 seq_size,
87 ")");
88 }
89 }
90
91 template <typename DTYPE>
storeLastDimension(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int elementSize,at::ArrayRef<IValue> obj)92 void storeLastDimension(
93 char* data,
94 const std::vector<int64_t>& sizes,
95 const c10::ArrayRef<int64_t>& strides,
96 int64_t dim,
97 int elementSize,
98 at::ArrayRef<IValue> obj) {
99 auto n = sizes[dim];
100 auto seq_size = obj.size();
101 checkSequenceSize(n, dim, seq_size);
102 for (const auto i : c10::irange(n)) {
103 *(DTYPE*)data = obj[i].to<DTYPE>();
104 data += strides[dim] * elementSize;
105 }
106 }
107
storeLastDimensionFloat(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int elementSize,at::ArrayRef<IValue> obj)108 void storeLastDimensionFloat(
109 char* data,
110 const std::vector<int64_t>& sizes,
111 const c10::ArrayRef<int64_t>& strides,
112 int64_t dim,
113 int elementSize,
114 at::ArrayRef<IValue> obj) {
115 auto n = sizes[dim];
116 auto seq_size = obj.size();
117 checkSequenceSize(n, dim, seq_size);
118 for (const auto i : c10::irange(n)) {
119 *(float*)data = static_cast<float>(obj[i].to<double>());
120 data += strides[dim] * elementSize;
121 }
122 }
123
storeLastDimensionHalf(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int elementSize,at::ArrayRef<IValue> obj)124 void storeLastDimensionHalf(
125 char* data,
126 const std::vector<int64_t>& sizes,
127 const c10::ArrayRef<int64_t>& strides,
128 int64_t dim,
129 int elementSize,
130 at::ArrayRef<IValue> obj) {
131 auto n = sizes[dim];
132 auto seq_size = obj.size();
133 checkSequenceSize(n, dim, seq_size);
134 for (const auto i : c10::irange(n)) {
135 *(at::Half*)data = at::convert<at::Half, double>(obj[i].to<double>());
136 data += strides[dim] * elementSize;
137 }
138 }
139
140 // reference python implementation recursive_store in tensor_new.cpp
recursiveStore(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int tenElementSize,const IValue & obj)141 void recursiveStore(
142 char* data,
143 const std::vector<int64_t>& sizes,
144 const c10::ArrayRef<int64_t>& strides,
145 int64_t dim,
146 int tenElementSize,
147 const IValue& obj) {
148 auto ndim = sizes.size();
149 auto n = sizes[dim];
150 auto seq = obj.toListRef();
151 checkSequenceSize(n, dim, seq.size());
152 if (dim + 1 < static_cast<long>(ndim)) {
153 for (const auto i : c10::irange(n)) {
154 recursiveStore(data, sizes, strides, dim + 1, tenElementSize, seq[i]);
155 data += strides[dim] * tenElementSize;
156 }
157 } else {
158 if (obj.isIntList()) {
159 storeLastDimension<int64_t>(
160 data, sizes, strides, dim, tenElementSize, seq);
161 } else if (obj.isBoolList()) {
162 storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
163 } else if (obj.isDoubleList()) {
164 if (tenElementSize ==
165 static_cast<int>(elementSize(at::ScalarType::Double))) {
166 storeLastDimension<double>(
167 data, sizes, strides, dim, tenElementSize, seq);
168 } else if (
169 tenElementSize ==
170 static_cast<int>(elementSize(at::ScalarType::Float))) {
171 storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
172 } else if (
173 tenElementSize ==
174 static_cast<int>(elementSize(at::ScalarType::Half))) {
175 storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
176 } else {
177 TORCH_INTERNAL_ASSERT(false);
178 }
179 } else {
180 TORCH_INTERNAL_ASSERT(false);
181 }
182 }
183 }
184
185 template <bool if_set_requires_grad>
createTensorFromList(Stack & stack)186 void createTensorFromList(Stack& stack) {
187 // torch.tensor has a fourth requires_grad arg but torch.as_tensor not, so
188 // we use the template arg to distinguish between these two cases
189 bool requires_grad = false;
190 IValue data;
191 IValue dtype;
192 IValue device;
193 if (if_set_requires_grad) {
194 pop(stack, data, dtype, device, requires_grad);
195 } else {
196 pop(stack, data, dtype, device);
197 }
198 auto elem_type = data.type();
199 while (elem_type->isSubtypeOf(AnyListType::get())) {
200 elem_type = elem_type->containedType(0);
201 }
202 auto sizes = compute_sizes(data);
203 checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
204 at::ScalarType initial_scalar_type = scalarTypeFromJitType(*elem_type);
205 if (initial_scalar_type == at::ScalarType::Double) {
206 initial_scalar_type = typeMetaToScalarType(c10::get_default_dtype());
207 }
208
209 auto tensor =
210 at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type));
211
212 if (tensor.numel() != 0) {
213 recursiveStore(
214 (char*)tensor.data_ptr(),
215 sizes,
216 tensor.strides(),
217 0,
218 tensor.element_size(),
219 data);
220 }
221
222 tensor = castTensorTo(tensor, dtype, device);
223 auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
224
225 if (dtype.isNone() && tensor.scalar_type() != default_type &&
226 tensor.numel() == 0) {
227 TORCH_WARN(
228 "Creating a tensor from an empty ",
229 elem_type->repr_str(),
230 "list will create a tensor of default floating point type (currently ",
231 default_type,
232 ") in python but a tensor of type ",
233 elem_type->repr_str(),
234 " in torchscript.\n",
235 "Pass in a dtype argument to ensure consistent behavior");
236 }
237 if (if_set_requires_grad) {
238 tensor.set_requires_grad(requires_grad);
239 }
240 push(stack, std::move(tensor));
241 }
242
243 RegisterOperators reg({
244 OperatorGenerator(
245 TORCH_SELECTIVE_SCHEMA(
246 "aten::split(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]"),
__anon0ef288b80202(Stack& stack) 247 [](Stack& stack) {
248 RECORD_FUNCTION("split_with_sizes", last(stack, 3));
249
250 auto result = at::split_with_sizes(
251 (std::move(peek(stack, 0, 3))).toTensor(),
252 (std::move(peek(stack, 1, 3))).toDimVector(),
253 (std::move(peek(stack, 2, 3))).toInt());
254 drop(stack, 3);
255 pack(stack, std::move(result));
256 },
257 aliasAnalysisFromSchema()),
258
259 #define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \
260 OperatorGenerator( \
261 TORCH_SELECTIVE_SCHEMA( \
262 "aten::tensor." #operator_type "(" #operator_type \
263 " t, *, ScalarType? dtype=None, Device? device=None" \
264 ", bool requires_grad=False) -> Tensor"), \
265 [](Stack& stack) { \
266 c_type scalar_val; \
267 IValue dtype; \
268 IValue device; \
269 bool requires_grad; \
270 pop(stack, scalar_val, dtype, device, requires_grad); \
271 auto tensor = tensor_creation_op; \
272 tensor = castTensorTo(tensor, dtype, device); \
273 tensor.set_requires_grad(requires_grad); \
274 push(stack, std::move(tensor)); \
275 }, \
276 aliasAnalysisFromSchema()), \
277 OperatorGenerator( \
278 TORCH_SELECTIVE_SCHEMA( \
279 "aten::as_tensor." #operator_type "(" #operator_type \
280 " t, *, ScalarType? dtype=None, Device? device=None) -> Tensor"), \
281 [](Stack& stack) { \
282 c_type scalar_val; \
283 IValue dtype; \
284 IValue device; \
285 pop(stack, scalar_val, dtype, device); \
286 auto tensor = tensor_creation_op; \
287 tensor = castTensorTo(tensor, dtype, device); \
288 push(stack, std::move(tensor)); \
289 }, \
290 aliasAnalysisFromSchema()),
291
292 DEFINE_TORCH_TENSOR_OP(
293 bool,
294 bool,
295 at::empty({}, at::CPU(at::kBool).options()).fill_(scalar_val))
296 DEFINE_TORCH_TENSOR_OP(
297 float,
298 double,
299 at::native::scalar_tensor(
300 scalar_val,
301 typeMetaToScalarType(c10::get_default_dtype()),
302 std::nullopt /* layout */,
303 at::kCPU,
304 std::nullopt /* pin_memory*/))
305 DEFINE_TORCH_TENSOR_OP(
306 int,
307 int64_t,
308 at::scalar_to_tensor(scalar_val))
309 DEFINE_TORCH_TENSOR_OP(
310 complex,
311 c10::complex<double>,
312 at::native::scalar_tensor(
313 scalar_val,
314 typeMetaToScalarType(c10::get_default_complex_dtype()),
315 std::nullopt /* layout */,
316 at::kCPU,
317 std::nullopt /* pin_memory */))
318
319 // reference python implementation: internal_new_from_data in
320 // tensor_new.cpp
321 OperatorGenerator(
322 TORCH_SELECTIVE_SCHEMA("aten::_infer_size(int[] a, int[] b) -> int[]"),
__anon0ef288b80302(Stack& stack) 323 [](Stack& stack) {
324 auto a = pop(stack);
325 auto b = pop(stack);
326 push(stack, at::infer_size(a.toDimVector(), b.toDimVector()));
327 },
328 aliasAnalysisFromSchema()),
329 OperatorGenerator(
330 TORCH_SELECTIVE_SCHEMA(
331 "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor"),
__anon0ef288b80402(Stack& stack) 332 [](Stack& stack) {
333 at::Tensor weight;
334 at::Tensor input;
335 double max_norm = 0;
336 double norm_type = 0;
337 pop(stack, weight, input, max_norm, norm_type);
338
339 // TODO: remove when script supports setting grad mode
340 torch::NoGradGuard no_grad;
341
342 at::Tensor result =
343 at::embedding_renorm_(weight, input, max_norm, norm_type);
344 push(stack, std::move(result));
345 },
346 aliasAnalysisFromSchema()),
347 OperatorGenerator(
348 TORCH_SELECTIVE_SCHEMA(
349 "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor"),
350 createTensorFromList<true>,
351 aliasAnalysisFromSchema()),
352 OperatorGenerator(
353 TORCH_SELECTIVE_SCHEMA(
354 "aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(a|b)"),
__anon0ef288b80502(Stack& stack) 355 [](Stack& stack) {
356 auto device = pop(stack).toOptional<c10::Device>();
357 auto dtype = pop(stack).toOptional<at::ScalarType>();
358 at::Tensor data = pop(stack).toTensor();
359 at::ScalarType scalar_type =
360 dtype ? dtype.value() : data.scalar_type();
361 c10::Device dev = device ? device.value() : data.device();
362
363 if (scalar_type != data.scalar_type() || dev != data.device()) {
364 data = data.to(
365 dev, scalar_type, /*non_blocking=*/false, /*copy=*/false);
366 }
367 push(stack, std::move(data));
368 },
369 aliasAnalysisFromSchema()),
370 OperatorGenerator(
371 TORCH_SELECTIVE_SCHEMA(
372 "aten::as_tensor.list(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor"),
373 createTensorFromList<false>,
374 aliasAnalysisFromSchema()),
375 OperatorGenerator(
376 TORCH_SELECTIVE_SCHEMA(
377 "aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, "
378 "Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)"),
__anon0ef288b80602(Stack& stack) 379 [](Stack& stack) {},
380 aliasAnalysisFromSchema()),
381 OperatorGenerator(
382 TORCH_SELECTIVE_SCHEMA("aten::_get_tracing_state() -> bool"),
__anon0ef288b80702(Stack& stack) 383 [](Stack& stack) { push(stack, false); },
384 aliasAnalysisFromSchema()),
385 OperatorGenerator(
386 TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"),
__anon0ef288b80802(Stack& stack) 387 [](Stack& stack) { push(stack, true); },
388 aliasAnalysisFromSchema()),
389 OperatorGenerator(
390 TORCH_SELECTIVE_SCHEMA("aten::has_torch_function(...) -> bool"),
__anon0ef288b80902(Stack& stack) 391 [](Stack& stack) { push(stack, false); },
392 aliasAnalysisFromSchema()),
393 OperatorGenerator(
394 TORCH_SELECTIVE_SCHEMA(
395 "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b, Generator? generator=None) -> Tensor(a!)"),
__anon0ef288b80a02(Stack& stack) 396 [](Stack& stack) {
397 // TODO: remove when script supports setting grad mode
398 torch::NoGradGuard no_grad;
399
400 at::Tensor tensor;
401 std::optional<at::Generator> generator =
402 pop(stack).toOptional<at::Generator>();
403
404 double a = 0;
405 double b = 0;
406 pop(stack, tensor, a, b);
407 push(stack, tensor.uniform_(a, b, generator));
408 },
409 aliasAnalysisFromSchema()),
410 OperatorGenerator(
411 TORCH_SELECTIVE_SCHEMA(
412 "aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std, Generator? generator=None) -> Tensor(a!)"),
__anon0ef288b80b02(Stack& stack) 413 [](Stack& stack) {
414 // TODO: remove when script supports setting grad mode
415 torch::NoGradGuard no_grad;
416
417 at::Tensor tensor;
418 double mean = 0;
419 double std = 0;
420 std::optional<at::Generator> generator =
421 pop(stack).toOptional<at::Generator>();
422
423 pop(stack, tensor, mean, std);
424 push(stack, tensor.normal_(mean, std, generator));
425 },
426 aliasAnalysisFromSchema()),
427 OperatorGenerator(
428 TORCH_SELECTIVE_SCHEMA(
429 "aten::_no_grad_fill_(Tensor(a!) tensor, float val) -> Tensor(a!)"),
__anon0ef288b80c02(Stack& stack) 430 [](Stack& stack) {
431 // TODO: remove when script supports setting grad mode
432 torch::NoGradGuard no_grad;
433
434 at::Tensor tensor;
435 double val = 0;
436 pop(stack, tensor, val);
437 push(stack, at::fill_(tensor, val));
438 },
439 aliasAnalysisFromSchema()),
440 OperatorGenerator(
441 TORCH_SELECTIVE_SCHEMA(
442 "aten::_no_grad_zero_(Tensor(a!) tensor) -> Tensor(a!)"),
__anon0ef288b80d02(Stack& stack) 443 [](Stack& stack) {
444 // TODO: remove when script supports setting grad mode
445 torch::NoGradGuard no_grad;
446
447 at::Tensor tensor;
448 pop(stack, tensor);
449 push(stack, at::zero_(tensor));
450 },
451 aliasAnalysisFromSchema()),
452 Operator(
453 "aten::is_grad_enabled() -> bool",
__anon0ef288b80e02(Stack& stack) 454 [](Stack& stack) { push(stack, torch::GradMode::is_enabled()); },
455 aliasAnalysisConservative()),
456 Operator(
457 "aten::set_grad_enabled(bool val) -> ()",
__anon0ef288b80f02(Stack& stack) 458 [](Stack& stack) { torch::GradMode::set_enabled(pop(stack).toBool()); },
459 aliasAnalysisConservative()),
460 Operator(
461 "aten::_get_cpu_capability() -> str",
__anon0ef288b81002(Stack& stack) 462 [](Stack& stack) { push(stack, at::get_cpu_capability()); },
463 aliasAnalysisConservative()),
464 });
465 } // namespace
466 } // namespace torch::jit
467