• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#include <gtest/gtest.h>
2#include <torch/torch.h>
3#import <Foundation/Foundation.h>
4#import <Metal/Metal.h>
5
6// this sample custom kernel is taken from:
7// https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu
8static const char* CUSTOM_KERNEL = R"MPS_ADD_ARRAYS(
9#include <metal_stdlib>
10using namespace metal;
11kernel void add_arrays(device const float* inA,
12                       device const float* inB,
13                       device float* result,
14                       uint index [[thread_position_in_grid]])
15{
16    result[index] = inA[index] + inB[index];
17}
18)MPS_ADD_ARRAYS";
19
20static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
21  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
22}
23
24TEST(MPSObjCInterfaceTest, MPSCustomKernel) {
25  const unsigned int tensor_length = 100000UL;
26
27  // fail if mps isn't available
28  ASSERT_TRUE(torch::mps::is_available());
29
30  torch::Tensor cpu_input1 = torch::randn({tensor_length}, at::device(at::kCPU));
31  torch::Tensor cpu_input2 = torch::randn({tensor_length}, at::device(at::kCPU));
32  torch::Tensor cpu_output = cpu_input1 + cpu_input2;
33
34  torch::Tensor mps_input1 = cpu_input1.detach().to(at::kMPS);
35  torch::Tensor mps_input2 = cpu_input2.detach().to(at::kMPS);
36  torch::Tensor mps_output = torch::empty({tensor_length}, at::device(at::kMPS));
37
38  @autoreleasepool {
39    id<MTLDevice> device = MTLCreateSystemDefaultDevice();
40    NSError *error = nil;
41    size_t numThreads = mps_output.numel();
42    id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL]
43                                                              options: nil
44                                                                error: &error];
45    TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
46
47    id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"];
48    TORCH_CHECK(customFunction, "Failed to create function state object for the kernel");
49
50    id<MTLComputePipelineState> kernelPSO = [device newComputePipelineStateWithFunction: customFunction error: &error];
51    TORCH_CHECK(kernelPSO, error.localizedDescription.UTF8String);
52
53    // Get a reference of the MPSStream MTLCommandBuffer.
54    id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
55    TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
56
57    // Get a reference of the MPSStream dispatch_queue. This is used for CPU side synchronization while encoding.
58    dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
59    dispatch_sync(serialQueue, ^(){
60      // Start a compute pass.
61      id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
62      TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
63
64      // Encode the pipeline state object and its parameters.
65      [computeEncoder setComputePipelineState: kernelPSO];
66      [computeEncoder setBuffer: getMTLBufferStorage(mps_input1) offset:0 atIndex:0];
67      [computeEncoder setBuffer: getMTLBufferStorage(mps_input2) offset:0 atIndex:1];
68      [computeEncoder setBuffer: getMTLBufferStorage(mps_output) offset:0 atIndex:2];
69      MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
70
71      // Calculate a thread group size.
72      NSUInteger threadsPerGroupSize = std::min(kernelPSO.maxTotalThreadsPerThreadgroup, numThreads);
73      MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1);
74
75      // Encode the compute command.
76      [computeEncoder dispatchThreads: gridSize threadsPerThreadgroup: threadGroupSize];
77      [computeEncoder endEncoding];
78
79      torch::mps::commit();
80    });
81  }
82  // synchronize the MPS stream before reading back from MPS buffer
83  torch::mps::synchronize();
84
85  ASSERT_TRUE(at::allclose(cpu_output, mps_output.to(at::kCPU)));
86}
87