#include #import #import #import #import #import #import #import #import #include namespace at::native::metal { // Split the input tensor into two on channel dimension // TODO: [T87567124] Fully implement chunk in Metal shader static std::vector chunk(const Tensor& input, int64_t chunks, int64_t dim) { TORCH_CHECK(chunks == 2 && dim == 1); TORCH_CHECK(input.dim() == 4); TORCH_CHECK(input.size(0) == 1); int64_t dim_size = input.size(dim); int64_t split_size = (dim_size + chunks - 1) / chunks; int64_t num_splits = 1; if (split_size != 0) { num_splits = std::max((dim_size + split_size - 1) / split_size, 1); } std::vector splits(num_splits); int64_t last_split_size = split_size - (split_size * num_splits - dim_size); MPSImage* X = imageFromTensor(input); MetalCommandBuffer* commandBuffer = getCommandBuffer(input); auto outputSize1 = {input.size(0), split_size, input.size(2), input.size(3)}; auto outputSize2 = {input.size(0), last_split_size, input.size(2), input.size(3)}; MetalTensorImplStorage mt1(outputSize1); MetalTensorImplStorage mt2(outputSize2); mt1.texture()->allocateTemporaryStorage(outputSize1, commandBuffer); mt2.texture()->allocateTemporaryStorage(outputSize2, commandBuffer); MPSImage* Y1 = mt1.texture()->image(); MPSImage* Y2 = mt2.texture()->image(); id state = [[MetalContext sharedInstance] specializedPipelineState:"split_channels" Constants:@[ @(X.featureChannels), @(Y1.featureChannels), @(Y2.featureChannels)]]; id encoder = [commandBuffer.buffer computeCommandEncoder]; [encoder setComputePipelineState:state]; [encoder setTexture:[X texture] atIndex:0]; [encoder setTexture:[Y1 texture] atIndex:1]; [encoder setTexture:[Y2 texture] atIndex:2]; const auto& launchParams = mpscnn::spatialPointwiseKernelLaunchParams(state, X); [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; [encoder endEncoding]; auto output1 = makeTensor(std::move(mt1), input.options()); auto output2 = makeTensor(std::move(mt2), input.options()); return {output1, output2}; } TORCH_LIBRARY_IMPL(aten, Metal, m) { m.impl(TORCH_SELECTIVE_NAME("aten::chunk"), TORCH_FN(chunk)); } } // namespace at::native::metal