#include #import #import #import #import #import #import #import #include #include #include namespace at::native::metal { API_AVAILABLE(ios(11.3), macos(10.13)) static inline MPSNNReduceUnary* kernelForReducedDim(int dim) { id device = [MetalContext sharedInstance].device; if (dim == 3) { return [[MPSNNReduceRowMean alloc] initWithDevice:device]; } else if (dim == 2) { return [[MPSNNReduceColumnMean alloc] initWithDevice:device]; } else if (dim == 1) { return [[MPSNNReduceFeatureChannelsMean alloc] initWithDevice:device]; } return nil; } static Tensor wrapper_mean_dim( const Tensor& input, OptionalIntArrayRef opt_dims, bool keepdim, std::optional dtype) { if (@available(iOS 11.3, *)) { MPSImage* X = imageFromTensor(input); auto imageSize = input.sizes().vec(); TORCH_CHECK(imageSize.size() == 4); // TODO: [T87340633] Support reducing the batch dimension TORCH_CHECK(imageSize[0] == 1); auto mask = make_dim_mask(opt_dims, input.dim()); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); MPSImage* Y = nil; if (opt_dims.has_value()) { auto dims = opt_dims.value(); for (int dim : dims) { imageSize[dim] = 1; MPSNNReduceUnary* kernel = kernelForReducedDim(dim); if (kernel) { Y = createTemporaryImage(commandBuffer, imageSize); [kernel encodeToCommandBuffer:commandBuffer.buffer sourceImage:X destinationImage:Y]; X = Y; } } } MetalTensorImplStorage mt{imageSize}; mt.texture()->setCommandBuffer(commandBuffer); mt.texture()->setImage(Y); auto shape = DimVector(input.sizes()); for (int dim = shape.size() - 1; dim >= 0; dim--) { if (mask[dim]) { if (keepdim) { shape[dim] = 1; } else { shape.erase(shape.begin() + dim); } } } auto output = makeTensor(std::move(mt), input.options()).view(shape); return output; } else { // TODO: [T87350528] Fallback to shader kernels for 10.0 users TORCH_CHECK( false, "MPSNNReduceUnary is only available on iOS 11.3 and above"); } } TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::mean.dim"), TORCH_FN(wrapper_mean_dim)); }; } // namespace at::native::metal