#include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #endif #if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0)) @implementation FakeMPSGraphFFTDescriptor + (nullable instancetype)descriptor { // Redispatch the constructor to the actual implementation id desc = NSClassFromString(@"MPSGraphFFTDescriptor"); return (FakeMPSGraphFFTDescriptor*)[desc descriptor]; } - (nonnull id)copyWithZone:(nullable NSZone*)zone { return self; } @end #endif namespace at::native { namespace { MPSGraphFFTScalingMode normalization_to_ScalingMode(int64_t normalization) { switch (static_cast(normalization)) { case fft_norm_mode::none: return MPSGraphFFTScalingModeNone; case fft_norm_mode::by_n: return MPSGraphFFTScalingModeSize; case fft_norm_mode::by_root_n: return MPSGraphFFTScalingModeUnitary; default: break; } TORCH_CHECK(false, "Unsupported normalization type", normalization); } NSArray* IntArrayToNSArray(IntArrayRef arr) { auto rc = [NSMutableArray arrayWithCapacity:arr.size()]; for (const auto idx : c10::irange(arr.size())) { rc[idx] = [NSNumber numberWithInteger:arr[idx]]; } return rc; } } // anonymous namespace Tensor _fft_c2r_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { TORCH_CHECK(self.is_complex()); auto in_sizes = self.sizes(); DimVector out_sizes(in_sizes.begin(), in_sizes.end()); out_sizes[dim.back()] = last_dim_size; auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type()))); return _fft_c2r_mps_out(self, dim, normalization, last_dim_size, out); } Tensor _fft_r2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { TORCH_CHECK(self.is_floating_point()); auto input_sizes = self.sizes(); DimVector out_sizes(input_sizes.begin(), input_sizes.end()); auto last_dim = dim.back(); auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; if (onesided) { out_sizes[last_dim] = last_dim_halfsize; } auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type()))); return _fft_r2c_mps_out(self, dim, normalization, onesided, out); } Tensor _fft_c2c_mps(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { TORCH_CHECK(self.is_complex()); if (dim.empty()) { return self.clone(); } auto out = at::empty(self.sizes(), self.options()); return _fft_c2c_mps_out(self, dim, normalization, forward, out); } using namespace mps; // TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237 Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(onesided); @autoreleasepool { auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); auto descriptor = [MPSGraphFFTDescriptor descriptor]; descriptor.scalingMode = normalization_to_ScalingMode(normalization); MPSGraphTensor* outputTensor; if (onesided) { // Return only unique results: outputTensor = [mpsGraph realToHermiteanFFTWithTensor:inputTensor axes:IntArrayToNSArray(dim) descriptor:descriptor name:nil]; } else { // Return with Hermitean conjugate results: auto useDataType = (inputTensor.dataType == MPSDataTypeFloat16) ? MPSDataTypeComplexFloat16 : MPSDataTypeComplexFloat32; auto cTensor = [mpsGraph castTensor:inputTensor toType:useDataType name:nil]; outputTensor = [mpsGraph fastFourierTransformWithTensor:cTensor axes:IntArrayToNSArray(dim) descriptor:descriptor name:nil]; } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; }); auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); auto feeds = dictionaryFromPlaceholders(inputPlaceholder); runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); } return out; } Tensor& _fft_c2r_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size, Tensor& out) { TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(last_dim_size); @autoreleasepool { auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); auto descriptor = [MPSGraphFFTDescriptor descriptor]; descriptor.scalingMode = normalization_to_ScalingMode(normalization); descriptor.inverse = YES; descriptor.roundToOddHermitean = ((last_dim_size % 2) == 1) ? YES : NO; auto outputTensor = [mpsGraph HermiteanToRealFFTWithTensor:inputTensor axes:IntArrayToNSArray(dim) descriptor:descriptor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; }); auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); auto feeds = dictionaryFromPlaceholders(inputPlaceholder); runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); } return out; } Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) { TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+"); auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" + std::to_string(normalization) + ":" + std::to_string(forward); @autoreleasepool { auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); auto descriptor = [MPSGraphFFTDescriptor descriptor]; descriptor.scalingMode = normalization_to_ScalingMode(normalization); descriptor.inverse = !forward; auto outputTensor = [mpsGraph fastFourierTransformWithTensor:inputTensor axes:IntArrayToNSArray(dim) descriptor:descriptor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; }); auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); auto feeds = dictionaryFromPlaceholders(inputPlaceholder); runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); } return out; } } // namespace at::native