• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <ATen/Context.h>
2 #include <torch/library.h>
3 
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/NativeFunctions.h>
6 #include <ATen/core/jit_type.h>
7 #include <c10/core/DefaultDtype.h>
8 #include <c10/util/irange.h>
9 #include <torch/csrc/api/include/torch/utils.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/jit/runtime/custom_operator.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13 #include <torch/csrc/jit/runtime/vararg_functions.h>
14 
15 #include <ATen/InitialTensorOptions.h>
16 #include <c10/core/ScalarType.h>
17 #include <torch/csrc/jit/frontend/error_report.h>
18 
19 #include <sstream>
20 
21 namespace torch::jit {
22 
23 namespace {
24 
aliasAnalysisFromSchema()25 c10::AliasAnalysisKind aliasAnalysisFromSchema() {
26   return c10::AliasAnalysisKind::FROM_SCHEMA;
27 }
28 
aliasAnalysisConservative()29 c10::AliasAnalysisKind aliasAnalysisConservative() {
30   return c10::AliasAnalysisKind::CONSERVATIVE;
31 }
32 
checkListInputType(const c10::TypePtr & elem_type,bool empty_list)33 void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) {
34   if (!elem_type->isSubtypeOf(*NumberType::get()) &&
35       !elem_type->isSubtypeOf(*BoolType::get())) {
36     std::stringstream error;
37     error << "Input must be of ints, floats, or bools, "
38           << "got " << elem_type->repr_str();
39     // special case empty list torch.tensor([])
40     if (elem_type->isSubtypeOf(*TensorType::get())) {
41       if (empty_list) {
42         error << "\nEmpty lists default to List[Tensor]. Add a variable "
43                  "annotation to the assignment to create an empty list "
44                  "of another type (torch.jit.annotate(List[T, []]) where T "
45                  "is the type of elements in the list for Python 2)";
46       }
47     }
48     throw std::runtime_error(error.str());
49   }
50 }
51 
castTensorTo(at::Tensor self,const IValue & dtype,const IValue & device)52 at::Tensor castTensorTo(
53     at::Tensor self,
54     const IValue& dtype,
55     const IValue& device) {
56   at::ScalarType scalar_type =
57       dtype.isNone() ? self.scalar_type() : dtype.toScalarType();
58   c10::Device dev = device.isNone() ? self.device() : device.toDevice();
59   if (scalar_type != self.scalar_type() || dev != self.device()) {
60     self = self.to(dev, scalar_type);
61   }
62   return self;
63 }
64 
compute_sizes(const IValue & seq)65 std::vector<int64_t> compute_sizes(const IValue& seq) {
66   std::vector<int64_t> sizes;
67   auto seq_recur = seq.toList();
68   while (true) {
69     sizes.push_back(seq_recur.size());
70     if (seq_recur.empty() || !seq_recur.get(0).isList()) {
71       break;
72     }
73     seq_recur = seq_recur.get(0).toList();
74   }
75   return sizes;
76 }
77 
checkSequenceSize(int64_t n,int64_t dim,int64_t seq_size)78 void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
79   if (seq_size != n) {
80     AT_ERROR(
81         "Expected sequence of length ",
82         n,
83         " at dim ",
84         dim,
85         " (got ",
86         seq_size,
87         ")");
88   }
89 }
90 
91 template <typename DTYPE>
storeLastDimension(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int elementSize,at::ArrayRef<IValue> obj)92 void storeLastDimension(
93     char* data,
94     const std::vector<int64_t>& sizes,
95     const c10::ArrayRef<int64_t>& strides,
96     int64_t dim,
97     int elementSize,
98     at::ArrayRef<IValue> obj) {
99   auto n = sizes[dim];
100   auto seq_size = obj.size();
101   checkSequenceSize(n, dim, seq_size);
102   for (const auto i : c10::irange(n)) {
103     *(DTYPE*)data = obj[i].to<DTYPE>();
104     data += strides[dim] * elementSize;
105   }
106 }
107 
storeLastDimensionFloat(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int elementSize,at::ArrayRef<IValue> obj)108 void storeLastDimensionFloat(
109     char* data,
110     const std::vector<int64_t>& sizes,
111     const c10::ArrayRef<int64_t>& strides,
112     int64_t dim,
113     int elementSize,
114     at::ArrayRef<IValue> obj) {
115   auto n = sizes[dim];
116   auto seq_size = obj.size();
117   checkSequenceSize(n, dim, seq_size);
118   for (const auto i : c10::irange(n)) {
119     *(float*)data = static_cast<float>(obj[i].to<double>());
120     data += strides[dim] * elementSize;
121   }
122 }
123 
storeLastDimensionHalf(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int elementSize,at::ArrayRef<IValue> obj)124 void storeLastDimensionHalf(
125     char* data,
126     const std::vector<int64_t>& sizes,
127     const c10::ArrayRef<int64_t>& strides,
128     int64_t dim,
129     int elementSize,
130     at::ArrayRef<IValue> obj) {
131   auto n = sizes[dim];
132   auto seq_size = obj.size();
133   checkSequenceSize(n, dim, seq_size);
134   for (const auto i : c10::irange(n)) {
135     *(at::Half*)data = at::convert<at::Half, double>(obj[i].to<double>());
136     data += strides[dim] * elementSize;
137   }
138 }
139 
140 // reference python implementation recursive_store in tensor_new.cpp
recursiveStore(char * data,const std::vector<int64_t> & sizes,const c10::ArrayRef<int64_t> & strides,int64_t dim,int tenElementSize,const IValue & obj)141 void recursiveStore(
142     char* data,
143     const std::vector<int64_t>& sizes,
144     const c10::ArrayRef<int64_t>& strides,
145     int64_t dim,
146     int tenElementSize,
147     const IValue& obj) {
148   auto ndim = sizes.size();
149   auto n = sizes[dim];
150   auto seq = obj.toListRef();
151   checkSequenceSize(n, dim, seq.size());
152   if (dim + 1 < static_cast<long>(ndim)) {
153     for (const auto i : c10::irange(n)) {
154       recursiveStore(data, sizes, strides, dim + 1, tenElementSize, seq[i]);
155       data += strides[dim] * tenElementSize;
156     }
157   } else {
158     if (obj.isIntList()) {
159       storeLastDimension<int64_t>(
160           data, sizes, strides, dim, tenElementSize, seq);
161     } else if (obj.isBoolList()) {
162       storeLastDimension<bool>(data, sizes, strides, dim, tenElementSize, seq);
163     } else if (obj.isDoubleList()) {
164       if (tenElementSize ==
165           static_cast<int>(elementSize(at::ScalarType::Double))) {
166         storeLastDimension<double>(
167             data, sizes, strides, dim, tenElementSize, seq);
168       } else if (
169           tenElementSize ==
170           static_cast<int>(elementSize(at::ScalarType::Float))) {
171         storeLastDimensionFloat(data, sizes, strides, dim, tenElementSize, seq);
172       } else if (
173           tenElementSize ==
174           static_cast<int>(elementSize(at::ScalarType::Half))) {
175         storeLastDimensionHalf(data, sizes, strides, dim, tenElementSize, seq);
176       } else {
177         TORCH_INTERNAL_ASSERT(false);
178       }
179     } else {
180       TORCH_INTERNAL_ASSERT(false);
181     }
182   }
183 }
184 
185 template <bool if_set_requires_grad>
createTensorFromList(Stack & stack)186 void createTensorFromList(Stack& stack) {
187   // torch.tensor has a fourth requires_grad arg but torch.as_tensor not, so
188   // we use the template arg to distinguish between these two cases
189   bool requires_grad = false;
190   IValue data;
191   IValue dtype;
192   IValue device;
193   if (if_set_requires_grad) {
194     pop(stack, data, dtype, device, requires_grad);
195   } else {
196     pop(stack, data, dtype, device);
197   }
198   auto elem_type = data.type();
199   while (elem_type->isSubtypeOf(AnyListType::get())) {
200     elem_type = elem_type->containedType(0);
201   }
202   auto sizes = compute_sizes(data);
203   checkListInputType(elem_type, sizes.size() == 1 && sizes[0] == 0);
204   at::ScalarType initial_scalar_type = scalarTypeFromJitType(*elem_type);
205   if (initial_scalar_type == at::ScalarType::Double) {
206     initial_scalar_type = typeMetaToScalarType(c10::get_default_dtype());
207   }
208 
209   auto tensor =
210       at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type));
211 
212   if (tensor.numel() != 0) {
213     recursiveStore(
214         (char*)tensor.data_ptr(),
215         sizes,
216         tensor.strides(),
217         0,
218         tensor.element_size(),
219         data);
220   }
221 
222   tensor = castTensorTo(tensor, dtype, device);
223   auto default_type = at::typeMetaToScalarType(at::get_default_dtype());
224 
225   if (dtype.isNone() && tensor.scalar_type() != default_type &&
226       tensor.numel() == 0) {
227     TORCH_WARN(
228         "Creating a tensor from an empty ",
229         elem_type->repr_str(),
230         "list will create a tensor of default floating point type  (currently ",
231         default_type,
232         ") in python but a tensor of type ",
233         elem_type->repr_str(),
234         " in torchscript.\n",
235         "Pass in a dtype argument to ensure consistent behavior");
236   }
237   if (if_set_requires_grad) {
238     tensor.set_requires_grad(requires_grad);
239   }
240   push(stack, std::move(tensor));
241 }
242 
243 RegisterOperators reg({
244     OperatorGenerator(
245         TORCH_SELECTIVE_SCHEMA(
246             "aten::split(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]"),
__anon0ef288b80202(Stack& stack) 247         [](Stack& stack) {
248           RECORD_FUNCTION("split_with_sizes", last(stack, 3));
249 
250           auto result = at::split_with_sizes(
251               (std::move(peek(stack, 0, 3))).toTensor(),
252               (std::move(peek(stack, 1, 3))).toDimVector(),
253               (std::move(peek(stack, 2, 3))).toInt());
254           drop(stack, 3);
255           pack(stack, std::move(result));
256         },
257         aliasAnalysisFromSchema()),
258 
259 #define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op)       \
260   OperatorGenerator(                                                            \
261       TORCH_SELECTIVE_SCHEMA(                                                   \
262           "aten::tensor." #operator_type "(" #operator_type                     \
263           " t, *, ScalarType? dtype=None, Device? device=None"                  \
264           ", bool requires_grad=False) -> Tensor"),                             \
265       [](Stack& stack) {                                                        \
266         c_type scalar_val;                                                      \
267         IValue dtype;                                                           \
268         IValue device;                                                          \
269         bool requires_grad;                                                     \
270         pop(stack, scalar_val, dtype, device, requires_grad);                   \
271         auto tensor = tensor_creation_op;                                       \
272         tensor = castTensorTo(tensor, dtype, device);                           \
273         tensor.set_requires_grad(requires_grad);                                \
274         push(stack, std::move(tensor));                                         \
275       },                                                                        \
276       aliasAnalysisFromSchema()),                                               \
277       OperatorGenerator(                                                        \
278           TORCH_SELECTIVE_SCHEMA(                                               \
279               "aten::as_tensor." #operator_type "(" #operator_type              \
280               " t, *, ScalarType? dtype=None, Device? device=None) -> Tensor"), \
281           [](Stack& stack) {                                                    \
282             c_type scalar_val;                                                  \
283             IValue dtype;                                                       \
284             IValue device;                                                      \
285             pop(stack, scalar_val, dtype, device);                              \
286             auto tensor = tensor_creation_op;                                   \
287             tensor = castTensorTo(tensor, dtype, device);                       \
288             push(stack, std::move(tensor));                                     \
289           },                                                                    \
290           aliasAnalysisFromSchema()),
291 
292     DEFINE_TORCH_TENSOR_OP(
293         bool,
294         bool,
295         at::empty({}, at::CPU(at::kBool).options()).fill_(scalar_val))
296         DEFINE_TORCH_TENSOR_OP(
297             float,
298             double,
299             at::native::scalar_tensor(
300                 scalar_val,
301                 typeMetaToScalarType(c10::get_default_dtype()),
302                 std::nullopt /* layout */,
303                 at::kCPU,
304                 std::nullopt /* pin_memory*/))
305             DEFINE_TORCH_TENSOR_OP(
306                 int,
307                 int64_t,
308                 at::scalar_to_tensor(scalar_val))
309                 DEFINE_TORCH_TENSOR_OP(
310                     complex,
311                     c10::complex<double>,
312                     at::native::scalar_tensor(
313                         scalar_val,
314                         typeMetaToScalarType(c10::get_default_complex_dtype()),
315                         std::nullopt /* layout */,
316                         at::kCPU,
317                         std::nullopt /* pin_memory */))
318 
319     // reference python implementation: internal_new_from_data in
320     // tensor_new.cpp
321     OperatorGenerator(
322         TORCH_SELECTIVE_SCHEMA("aten::_infer_size(int[] a, int[] b) -> int[]"),
__anon0ef288b80302(Stack& stack) 323         [](Stack& stack) {
324           auto a = pop(stack);
325           auto b = pop(stack);
326           push(stack, at::infer_size(a.toDimVector(), b.toDimVector()));
327         },
328         aliasAnalysisFromSchema()),
329     OperatorGenerator(
330         TORCH_SELECTIVE_SCHEMA(
331             "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor"),
__anon0ef288b80402(Stack& stack) 332         [](Stack& stack) {
333           at::Tensor weight;
334           at::Tensor input;
335           double max_norm = 0;
336           double norm_type = 0;
337           pop(stack, weight, input, max_norm, norm_type);
338 
339           // TODO: remove when script supports setting grad mode
340           torch::NoGradGuard no_grad;
341 
342           at::Tensor result =
343               at::embedding_renorm_(weight, input, max_norm, norm_type);
344           push(stack, std::move(result));
345         },
346         aliasAnalysisFromSchema()),
347     OperatorGenerator(
348         TORCH_SELECTIVE_SCHEMA(
349             "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor"),
350         createTensorFromList<true>,
351         aliasAnalysisFromSchema()),
352     OperatorGenerator(
353         TORCH_SELECTIVE_SCHEMA(
354             "aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(a|b)"),
__anon0ef288b80502(Stack& stack) 355         [](Stack& stack) {
356           auto device = pop(stack).toOptional<c10::Device>();
357           auto dtype = pop(stack).toOptional<at::ScalarType>();
358           at::Tensor data = pop(stack).toTensor();
359           at::ScalarType scalar_type =
360               dtype ? dtype.value() : data.scalar_type();
361           c10::Device dev = device ? device.value() : data.device();
362 
363           if (scalar_type != data.scalar_type() || dev != data.device()) {
364             data = data.to(
365                 dev, scalar_type, /*non_blocking=*/false, /*copy=*/false);
366           }
367           push(stack, std::move(data));
368         },
369         aliasAnalysisFromSchema()),
370     OperatorGenerator(
371         TORCH_SELECTIVE_SCHEMA(
372             "aten::as_tensor.list(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor"),
373         createTensorFromList<false>,
374         aliasAnalysisFromSchema()),
375     OperatorGenerator(
376         TORCH_SELECTIVE_SCHEMA(
377             "aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, "
378             "Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)"),
__anon0ef288b80602(Stack& stack) 379         [](Stack& stack) {},
380         aliasAnalysisFromSchema()),
381     OperatorGenerator(
382         TORCH_SELECTIVE_SCHEMA("aten::_get_tracing_state() -> bool"),
__anon0ef288b80702(Stack& stack) 383         [](Stack& stack) { push(stack, false); },
384         aliasAnalysisFromSchema()),
385     OperatorGenerator(
386         TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"),
__anon0ef288b80802(Stack& stack) 387         [](Stack& stack) { push(stack, true); },
388         aliasAnalysisFromSchema()),
389     OperatorGenerator(
390         TORCH_SELECTIVE_SCHEMA("aten::has_torch_function(...) -> bool"),
__anon0ef288b80902(Stack& stack) 391         [](Stack& stack) { push(stack, false); },
392         aliasAnalysisFromSchema()),
393     OperatorGenerator(
394         TORCH_SELECTIVE_SCHEMA(
395             "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b, Generator? generator=None) -> Tensor(a!)"),
__anon0ef288b80a02(Stack& stack) 396         [](Stack& stack) {
397           // TODO: remove when script supports setting grad mode
398           torch::NoGradGuard no_grad;
399 
400           at::Tensor tensor;
401           std::optional<at::Generator> generator =
402               pop(stack).toOptional<at::Generator>();
403 
404           double a = 0;
405           double b = 0;
406           pop(stack, tensor, a, b);
407           push(stack, tensor.uniform_(a, b, generator));
408         },
409         aliasAnalysisFromSchema()),
410     OperatorGenerator(
411         TORCH_SELECTIVE_SCHEMA(
412             "aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std, Generator? generator=None) -> Tensor(a!)"),
__anon0ef288b80b02(Stack& stack) 413         [](Stack& stack) {
414           // TODO: remove when script supports setting grad mode
415           torch::NoGradGuard no_grad;
416 
417           at::Tensor tensor;
418           double mean = 0;
419           double std = 0;
420           std::optional<at::Generator> generator =
421               pop(stack).toOptional<at::Generator>();
422 
423           pop(stack, tensor, mean, std);
424           push(stack, tensor.normal_(mean, std, generator));
425         },
426         aliasAnalysisFromSchema()),
427     OperatorGenerator(
428         TORCH_SELECTIVE_SCHEMA(
429             "aten::_no_grad_fill_(Tensor(a!) tensor, float val) -> Tensor(a!)"),
__anon0ef288b80c02(Stack& stack) 430         [](Stack& stack) {
431           // TODO: remove when script supports setting grad mode
432           torch::NoGradGuard no_grad;
433 
434           at::Tensor tensor;
435           double val = 0;
436           pop(stack, tensor, val);
437           push(stack, at::fill_(tensor, val));
438         },
439         aliasAnalysisFromSchema()),
440     OperatorGenerator(
441         TORCH_SELECTIVE_SCHEMA(
442             "aten::_no_grad_zero_(Tensor(a!) tensor) -> Tensor(a!)"),
__anon0ef288b80d02(Stack& stack) 443         [](Stack& stack) {
444           // TODO: remove when script supports setting grad mode
445           torch::NoGradGuard no_grad;
446 
447           at::Tensor tensor;
448           pop(stack, tensor);
449           push(stack, at::zero_(tensor));
450         },
451         aliasAnalysisFromSchema()),
452     Operator(
453         "aten::is_grad_enabled() -> bool",
__anon0ef288b80e02(Stack& stack) 454         [](Stack& stack) { push(stack, torch::GradMode::is_enabled()); },
455         aliasAnalysisConservative()),
456     Operator(
457         "aten::set_grad_enabled(bool val) -> ()",
__anon0ef288b80f02(Stack& stack) 458         [](Stack& stack) { torch::GradMode::set_enabled(pop(stack).toBool()); },
459         aliasAnalysisConservative()),
460     Operator(
461         "aten::_get_cpu_capability() -> str",
__anon0ef288b81002(Stack& stack) 462         [](Stack& stack) { push(stack, at::get_cpu_capability()); },
463         aliasAnalysisConservative()),
464 });
465 } // namespace
466 } // namespace torch::jit
467