1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/mps/OperationUtils.h> 3#ifndef AT_PER_OPERATOR_HEADERS 4#include <ATen/Functions.h> 5#include <ATen/NativeFunctions.h> 6#else 7#include <ATen/ops/eye_native.h> 8#endif 9 10// Steps to add op for MPS backend: 11// 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key 12// 2. Define the function interface for the MPS backend similar to other 13// backends depending on whether its structured or non-structured 14// 3. Add boiler-plate error checking code as expected for the Op 15// 4. The code structure roughly follows the pattern 16// a) get the MPS stream handle to encode work onto 17// b) get an instance of MPSGraphCache and create a key unique to the Graph 18// needed for implementing this Op. Any shape, dataType or parameter 19// passed to the MPSGraph during its construction will need to be included 20// here. 21// c) Create the graph using make_mps_graph() and add operations to the 22// instance of MPSGraph. This is if the Cache->lookup() fails. 23// d) Store the MPSGraphTensors for inputs and output which are needed at 24// runtime. 25// e) Use the CachedGraph instance's inputs and output to create Placeholders 26// You will need to pass in Tensor to create MPSGraphTensorData objects. 27// f) Using MPSGraphTensor and MPSGraphTensorData instances create a feeds 28// dictionary. 29// g) Then call runMPSGraph() with input params and return the result. 30// 31 32namespace at::native { 33 34Tensor& eye_out_mps(int64_t n, Tensor& result) { 35 // the default value of `m` equals to `n` 36 return eye_out_mps(n, n, result); 37} 38 39using namespace mps; 40 41Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) { 42 // This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts 43 TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); 44 TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); 45 46 result.resize_({n, m}); 47 result.zero_(); 48 49 // Handle empty outputs 50 if (result.numel() == 0) 51 return result; 52 53 // Get MPS stream 54 MPSStream* stream = getCurrentMPSStream(); 55 56 auto outputDataType = result.scalar_type(); 57 // Derive from MPSCachedGraph 58 // This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph 59 // time and time again for the same operation The keys of this structure are based on the inputs and outputs needed 60 // for the operation here, we don't have any input tensors, just an output tensor. 61 // If the operator to be added is unary or binary, instead of creating a new CachedGraph struct yourself, please 62 // consider using `MPSUnaryCachedGraph` or `MPSBinaryCachedGraph` and their corresponding Grad versions in 63 // `OperationUtils.h`. 64 struct CachedGraph : public MPSCachedGraph { 65 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 66 MPSGraphTensor* outputTensor_ = nil; 67 }; 68 69 @autoreleasepool { 70 // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types 71 // etc match the earlier created MPSGraph 72 string key = "eye_out_mps:" + getTensorsStringKey({result}); 73 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) { 74 MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f 75 shape:getMPSShape(result) 76 dataType:getMPSDataType(outputDataType)]; 77 78 // Here we can call the MPSGraph API needed to execute the operation. 79 // The API details can be found here: 80 // https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph 81 MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor numLower:0 numUpper:0 name:nil]; 82 83 if ([outputTensor dataType] != getMPSDataType(outputDataType)) { 84 outputTensor = castMPSTensor(mpsGraph, outputTensor, outputDataType); 85 } 86 newCachedGraph->outputTensor_ = outputTensor; 87 }); 88 89 // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation 90 Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); 91 92 // Create dictionary of inputs/feeds and outputs/results 93 // In this case, there are no inputs, so the feeds are nil 94 NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil; 95 auto results = dictionaryFromPlaceholders(outputPlaceholder); 96 97 // Run the graph 98 runMPSGraph(stream, cachedGraph->graph(), feeds, results); 99 } 100 101 return result; 102} 103 104} // namespace at::native 105