#include #include #include namespace at::native { static TensorOptions verify_empty_parameters( const at::Tensor& self, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory, std::optional optional_memory_format) { TensorOptions options_ = TensorOptions() .dtype(dtype) .layout(layout) .device(device) .pinned_memory(pin_memory) .memory_format(optional_memory_format); TensorOptions options = self.options().merge_in(options_); auto memory_format = options_.memory_format_opt().value_or(MemoryFormat::Preserve); TORCH_CHECK( memory_format == MemoryFormat::Preserve || memory_format == MemoryFormat::Contiguous, "empty_like_nested only supports memory format Preserve or Contiguous, but got ", memory_format, " instead."); TORCH_CHECK( !(options.layout() != kStrided && optional_memory_format.has_value()), "memory format option is only supported by strided tensors"); return options; } Tensor empty_like_nested( const Tensor& self, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory, std::optional optional_memory_format) { auto options = verify_empty_parameters( self, dtype, layout, device, pin_memory, optional_memory_format); auto self_nt = get_nested_tensor_impl(self); auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Preserve); if (memory_format == MemoryFormat::Contiguous) { auto nested_size = self_nt->get_nested_sizes().clone(); int64_t buffer_size = get_numel_from_nested_size_tensor(nested_size); Tensor new_buffer = at::empty({buffer_size}, options); auto tensor = wrap_buffer(new_buffer, nested_size); return tensor; } // The fall through path must be Preserve TORCH_CHECK( memory_format == MemoryFormat::Preserve, "memory format option is only supported by strided tensors"); // Since we clone sizes, strides, and offsets it should be safe to use // get_unsafe_storage_as_tensor for the call to empty_like. Tensor new_buffer = at::empty_like(self_nt->get_unsafe_storage_as_tensor(), options); auto nested_size = self_nt->get_nested_sizes().clone(); auto nested_strides = self_nt->get_nested_strides().clone(); auto offsets = self_nt->get_storage_offsets().clone(); auto tensor = wrap_buffer(new_buffer, nested_size, nested_strides, offsets); return tensor; } // Take a Device that may not have device_index set (i.e., having it as -1 // representing the current device) and return the corresponding Device // according to the actual device at the time of this function call. No-op // if the device_index is set. static inline Device ensure_has_index(Device device) { if (device.is_cpu() || device.has_index()) { return device; } const c10::impl::DeviceGuardImplInterface* impl = c10::impl::getDeviceGuardImpl(device.type()); return impl->getDevice(); } Tensor _to_copy_nested( const Tensor& self, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory, bool non_blocking, std::optional optional_memory_format) { TORCH_CHECK( !layout.has_value() || self.layout() == layout.value(), "to(options) doesn't support converting to a different layout, " "but got self.layout being ", self.layout(), " and options.layout set as ", layout.value()); auto options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( pin_memory); if (options.has_device()) { options = options.device(ensure_has_index(options.device())); } // memory_format is handled separately due to MemoryFormat::Preserve logic options = self.options().merge_in(options).memory_format(std::nullopt); auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() && (options.layout() == c10::kStrided)); Tensor r; r = at::empty_like(self, dtype, layout, device, pin_out, memory_format); get_nested_tensor_impl(r)->get_buffer().copy_( get_nested_tensor_impl(self)->get_buffer(), non_blocking); return r; } Tensor& copy_nested_(Tensor& self, const Tensor& src, bool non_blocking) { const auto* nt_self = get_nested_tensor_impl(self); const auto* nt_src = get_nested_tensor_impl(src); TORCH_CHECK( at::equal( nt_self->get_nested_sizes(), nt_src->get_nested_sizes()), "copy_ only supports tensors that are the same size for Nested implementations"); nt_self->get_buffer().copy_(nt_src->get_buffer(), non_blocking); return self; } Tensor clone_nested( const Tensor& self, std::optional optional_memory_format) { auto memory_format = optional_memory_format.value_or(c10::MemoryFormat::Preserve); auto self_ptr = get_nested_tensor_impl(self); if (memory_format == c10::MemoryFormat::Preserve || (memory_format == c10::MemoryFormat::Contiguous && self.is_contiguous())) { const Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor(), sizemat = self_ptr->get_nested_sizes(), stridemat = self_ptr->get_nested_strides(); const auto& offsets = self_ptr->get_storage_offsets(); // TODO: The size and the stride do not necessarily need to be cloned, // but it is more conservative. // This is something we could revisit once we land a more // efficient implementation of nested_sizes_ and nested_strides. return wrap_buffer(buffer.clone(), sizemat.clone(), stridemat.clone(), offsets.clone()); } // actually, memory format is contiguous and self is noncontiguous else if (memory_format == c10::MemoryFormat::Contiguous) { const Tensor& self_buffer = self_ptr->get_unsafe_storage_as_tensor(), sizemat = self_ptr->get_nested_sizes(); Tensor output_buffer = at::empty(self.numel(), self_buffer.options()); Tensor output = wrap_buffer(output_buffer, sizemat); std::vector self_unbind = self.unbind(), output_unbind = output.unbind(); for (const int64_t i: c10::irange(self_ptr->size(0))) { output_unbind[i].copy_(self_unbind[i]); } return output; } else { TORCH_CHECK( false, "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ", memory_format); } } std::vector NestedTensor_unbind( const at::Tensor& self, int64_t dim) { TORCH_CHECK( dim == 0, "NestedTensor can only be unbound along dimension 0 ", "got dimension ", dim, " instead."); auto self_ptr = get_nested_tensor_impl(self); int64_t ntensors = self_ptr->size(0); std::vector result_tensors(ntensors); if (ntensors == 0) { return result_tensors; } // This returns a differentiable view of self as a regular tensor auto buffer = self.values(); std::vector sizes = NestedTensor_get_sizes(self_ptr), strides = NestedTensor_get_strides(self_ptr); int64_t *offsets_ptr = self_ptr->get_storage_offsets().data_ptr(); for (const int64_t i: c10::irange(ntensors)){ result_tensors[i] = buffer.as_strided(sizes[i], strides[i], offsets_ptr[i]); } return result_tensors; } // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor narrow_nested_symint(const at::Tensor& self, int64_t dim, SymInt start, SymInt length) { TORCH_CHECK(dim == 0, "narrow(): only dim=0 supported for nested tensors, but got: ", dim); TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative"); auto cur_size = self.sym_size(dim); TORCH_CHECK_INDEX( ((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__), "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"); if (start < 0) { start = start + cur_size; } TORCH_SYM_CHECK(start.sym_le(cur_size - length), "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); auto *nt_impl = get_nested_tensor_impl(self); TORCH_CHECK(self.is_contiguous(), "narrow(): only contiguous nested tensors supported"); auto buffer = nt_impl->get_unsafe_storage_as_tensor(); auto nested_sizes = nt_impl->get_nested_sizes(); auto nested_strides = nt_impl->get_nested_strides(); auto storage_offsets = nt_impl->get_storage_offsets(); auto storage_offsets_ptr = storage_offsets.data_ptr(); auto start_int = start.guard_int(__FILE__, __LINE__); auto length_int = length.guard_int(__FILE__, __LINE__); auto buffer_offset = storage_offsets_ptr[start_int]; nested_sizes = nested_sizes.narrow(0, start_int, length_int); nested_strides = nested_strides.narrow(0, start_int, length_int); storage_offsets = storage_offsets.narrow(0, start_int, length_int); return at::detail::make_tensor( c10::TensorImpl::VIEW, buffer.narrow(0, buffer_offset, buffer.numel() - buffer_offset), nested_sizes, nested_strides, storage_offsets); } Tensor alias_nested(const Tensor& self) { auto* nt_impl = get_nested_tensor_impl(self); auto buffer = nt_impl->get_unsafe_storage_as_tensor(); const auto& nested_sizes = nt_impl->get_nested_sizes(); const auto& nested_strides = nt_impl->get_nested_strides(); const auto& storage_offsets = nt_impl->get_storage_offsets(); return at::detail::make_tensor( c10::TensorImpl::VIEW, std::move(buffer), nested_sizes, nested_strides, storage_offsets); } } // namespace at::native