#include #include #include namespace torch::jit { void tupleIndex(Stack& stack) { int64_t index = pop(stack).toInt(); auto tuple = pop(stack).toTuple(); auto norm_index = normalizeIndex(index, static_cast(tuple->elements().size())); if (norm_index < 0 || norm_index >= static_cast(tuple->elements().size())) { throw std::out_of_range("Tuple list index out of range"); } stack.emplace_back(tuple->elements()[norm_index]); } void raiseException(Stack& stack) { // this kernel supports RaiseException with only one argument: the error // DEPRECATED from bytecode_version 8; // Please do not make any changes to this to support BC throw JITException(pop(stack).toStringRef()); } void raiseExceptionWithMessage(Stack& stack) { // this kernel supports RaiseException with only two arguments: the error and // the message Please make changes only to this kernel std::optional qualified_class_name = pop(stack).toOptional(); std::string message; pop(stack, message); throw JITException(message, qualified_class_name); } void is(Stack& stack) { IValue self, obj; pop(stack, self, obj); push(stack, self.is(obj)); } void unInitialized(Stack& stack) { push(stack, IValue::uninitialized()); } void isNot(Stack& stack) { IValue self, obj; pop(stack, self, obj); push(stack, !self.is(obj)); } void aten_format(Stack& stack) { size_t num_inputs = pop(stack).toInt(); format(stack, num_inputs); } void size(Stack& stack) { auto t = std::move(pop(stack)).toTensor(); pack(stack, t.sizes().vec()); } void sym_size(Stack& stack) { auto t = std::move(pop(stack)).toTensor(); pack(stack, t.sym_sizes().vec()); } void sym_size_int(Stack& stack) { auto dim = pop(stack).toInt(); auto t = pop(stack).toTensor(); push(stack, t.sym_sizes()[dim]); } void sym_stride_int(Stack& stack) { auto dim = pop(stack).toInt(); auto t = pop(stack).toTensor(); push(stack, t.sym_strides()[dim]); } void sym_numel(Stack& stack) { auto t = std::move(pop(stack)).toTensor(); push(stack, t.sym_numel()); } void sym_storage_offset(Stack& stack) { auto t = std::move(pop(stack)).toTensor(); push(stack, t.sym_storage_offset()); } void sym_stride(Stack& stack) { auto t = std::move(pop(stack)).toTensor(); pack(stack, t.sym_strides().vec()); } void device(Stack& stack) { push(stack, pop(stack).toTensor().device()); } void device_with_index(Stack& stack) { std::string type = pop(stack).toStringRef(); auto index = pop(stack).toInt(); std::string device_str = fmt::format("{}:{}", type, index); auto device = c10::Device(device_str); push(stack, device); } void dtype(Stack& stack) { at::Tensor a; pop(stack, a); push(stack, static_cast(a.scalar_type())); } void layout(Stack& stack) { push(stack, pop(stack).toTensor().layout()); } void toPrimDType(Stack& stack) { bool non_blocking = false; bool copy = false; pop(stack, non_blocking, copy); std::optional scalarType = pop(stack).toOptional(); std::optional device = std::nullopt; at::Tensor self = pop(stack).toTensor(); push(stack, to_dispatch(self, device, scalarType, non_blocking, copy)); } void dim(Stack& stack) { at::Tensor arg = pop(stack).toTensor(); push(stack, arg.dim()); } void _not(Stack& stack) { push(stack, !pop(stack).toBool()); } void boolTensor(Stack& stack) { at::Tensor a; pop(stack, a); push(stack, at::native::is_nonzero(a)); } void toList(Stack& stack) { int elem_ty_val = 0; int dim_val = 0; at::Tensor t; pop(stack, elem_ty_val); pop(stack, dim_val); pop(stack, t); // If the Tensor is not on the CPU, transfer it. if (!t.device().is_cpu()) { t = t.cpu(); } // Rebuild the output type using elem_ty_val and dim_val. Start // with the element type corresponding to elem_ty_val. at::TypePtr out_ty; if (elem_ty_val == 0) { out_ty = at::IntType::get(); } else if (elem_ty_val == 1) { out_ty = at::FloatType::get(); } else if (elem_ty_val == 2) { out_ty = at::BoolType::get(); } else if (elem_ty_val == 3) { out_ty = at::ComplexType::get(); } else { TORCH_CHECK( false, "Unsupported element type for tolist; only int, float, complex and bool are supported"); } // Check that type of the Tensor matches that of the annotation. // Make an exception for the case in which the annotated type is // float/complex and the Tensor data type is also float/complex; // the elements will be casted to double/c10::complex // later. TORCH_CHECK( (out_ty == at::FloatType::get() && t.is_floating_point()) || (out_ty == at::ComplexType::get() && t.is_complex()) || tryScalarTypeFromJitType(*out_ty) == t.scalar_type(), "Output annotation element type and runtime tensor element type must match for tolist(): ", *tryScalarTypeFromJitType(*out_ty), " vs ", t.scalar_type()); // Check that the dimension of the Tensor matches that of the // annotation. TORCH_CHECK( dim_val == t.dim(), "Output annotation list dimension and runtime tensor dimension must match for tolist()"); // Wrap out_ty in a ListType dim times. for (const auto i : c10::irange(dim_val)) { (void)i; // Suppress unused variable warning out_ty = at::ListType::create(out_ty); } int64_t dim = t.dim(); auto sizes = t.sizes(); auto strides = t.strides(); size_t element_size = t.element_size(); char* data = static_cast(t.data_ptr()); auto result = tensorToListRecursive( data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size); push(stack, std::move(result)); } void numToTensorScalar(Stack& stack) { at::Scalar s; pop(stack, s); push(stack, c10::scalar_to_tensor(s)); } void isCuda(Stack& stack) { at::Tensor a; pop(stack, a); push(stack, a.is_cuda()); } void numToTensorBool(Stack& stack) { bool b = false; pop(stack, b); push(stack, c10::scalar_to_tensor(b)); } void dictIndex(Stack& stack) { auto key = pop(stack); auto dict = pop(stack).toGenericDict(); auto value = dict.find(key); if (value == dict.end()) { AT_ERROR("KeyError: ", key); } push(stack, value->value()); } static const C10_UNUSED std::array op_reg = { mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex), mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor), mobile::prim_op_fn_register("aten::format", aten_format), mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar), mobile::prim_op_fn_register( "prim::RaiseException", raiseExceptionWithMessage), mobile::prim_op_fn_register("prim::device", device), mobile::prim_op_fn_register("prim::dtype", dtype), mobile::prim_op_fn_register("prim::layout", layout), mobile::prim_op_fn_register("aten::__not__", _not), mobile::prim_op_fn_register("aten::__is__", is), mobile::prim_op_fn_register("aten::__isnot__", isNot), mobile::prim_op_fn_register("aten::dim", dim), mobile::prim_op_fn_register("prim::Uninitialized", unInitialized), mobile::prim_op_fn_register("prim::is_cuda", isCuda), mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex), mobile::prim_op_fn_register("prim::unchecked_cast", noop), // TODO: (@pavithran) size is overloaded with int[] and Tensor // so this throws error expecting int not Tensor // mobile::prim_op_fn_register("aten::size", size) }; } // namespace torch::jit