• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/mps/OperationUtils.h>
3#ifndef AT_PER_OPERATOR_HEADERS
4#include <ATen/Functions.h>
5#include <ATen/NativeFunctions.h>
6#else
7#include <ATen/ops/eye_native.h>
8#endif
9
10// Steps to add op for MPS backend:
11// 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key
12// 2. Define the function interface for the MPS backend similar to other
13//    backends depending on whether its structured or non-structured
14// 3. Add boiler-plate error checking code as expected for the Op
15// 4. The code structure roughly follows the pattern
16//    a) get the MPS stream handle to encode work onto
17//    b) get an instance of MPSGraphCache and create a key unique to the Graph
18//       needed for implementing this Op. Any shape, dataType or parameter
19//       passed to the MPSGraph during its construction will need to be included
20//       here.
21//    c) Create the graph using make_mps_graph() and add operations to the
22//       instance of MPSGraph. This is if the Cache->lookup() fails.
23//    d) Store the MPSGraphTensors for inputs and output which are needed at
24//       runtime.
25//    e) Use the CachedGraph instance's inputs and output to create Placeholders
26//       You will need to pass in Tensor to create MPSGraphTensorData objects.
27//    f) Using MPSGraphTensor and MPSGraphTensorData instances create a feeds
28//       dictionary.
29//    g) Then call runMPSGraph() with input params and return the result.
30//
31
32namespace at::native {
33
34Tensor& eye_out_mps(int64_t n, Tensor& result) {
35  // the default value of `m` equals to `n`
36  return eye_out_mps(n, n, result);
37}
38
39using namespace mps;
40
41Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
42  // This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts
43  TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
44  TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);
45
46  result.resize_({n, m});
47  result.zero_();
48
49  // Handle empty outputs
50  if (result.numel() == 0)
51    return result;
52
53  // Get MPS stream
54  MPSStream* stream = getCurrentMPSStream();
55
56  auto outputDataType = result.scalar_type();
57  // Derive from MPSCachedGraph
58  // This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph
59  // time and time again for the same operation The keys of this structure are based on the inputs and outputs needed
60  // for the operation here, we don't have any input tensors, just an output tensor.
61  // If the operator to be added is unary or binary, instead of creating a new CachedGraph struct yourself, please
62  // consider using `MPSUnaryCachedGraph` or `MPSBinaryCachedGraph` and their corresponding Grad versions in
63  // `OperationUtils.h`.
64  struct CachedGraph : public MPSCachedGraph {
65    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
66    MPSGraphTensor* outputTensor_ = nil;
67  };
68
69  @autoreleasepool {
70    // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
71    // etc match the earlier created MPSGraph
72    string key = "eye_out_mps:" + getTensorsStringKey({result});
73    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
74      MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f
75                                                          shape:getMPSShape(result)
76                                                       dataType:getMPSDataType(outputDataType)];
77
78      // Here we can call the MPSGraph API needed to execute the operation.
79      // The API details can be found here:
80      // https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
81      MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor numLower:0 numUpper:0 name:nil];
82
83      if ([outputTensor dataType] != getMPSDataType(outputDataType)) {
84        outputTensor = castMPSTensor(mpsGraph, outputTensor, outputDataType);
85      }
86      newCachedGraph->outputTensor_ = outputTensor;
87    });
88
89    // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
90    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
91
92    // Create dictionary of inputs/feeds and outputs/results
93    // In this case, there are no inputs, so the feeds are nil
94    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
95    auto results = dictionaryFromPlaceholders(outputPlaceholder);
96
97    // Run the graph
98    runMPSGraph(stream, cachedGraph->graph(), feeds, results);
99  }
100
101  return result;
102}
103
104} // namespace at::native
105