• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/native/mps/OperationUtils.h>
4
5#ifndef AT_PER_OPERATOR_HEADERS
6#include <ATen/Functions.h>
7#include <ATen/NativeFunctions.h>
8#else
9#include <ATen/ops/gather_native.h>
10#include <ATen/ops/scatter_add_native.h>
11#include <ATen/ops/scatter_native.h>
12#endif
13
14namespace at::native {
15
16TORCH_IMPL_FUNC(gather_out_mps)
17(const Tensor& self_arg, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& output) {
18  using namespace mps;
19
20  if (self_arg.numel() == 0 || index.numel() == 0) {
21    return;
22  }
23  auto self = self_arg.dim() == 0 ? self_arg.view({1}) : self_arg;
24  dim = at::maybe_wrap_dim(dim, self.dim());
25
26  TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet")
27  TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type");
28  TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor");
29  TORCH_CHECK(!self.is_complex(), "gather(): Yet not supported for complex");
30
31  struct CachedGraph : public MPSCachedGraph {
32    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
33    MPSGraphTensor* inputTensor_ = nil;
34    MPSGraphTensor* indexTensor_ = nil;
35    MPSGraphTensor* outputTensor_ = nil;
36  };
37
38  @autoreleasepool {
39    MPSShape* input_shape = getMPSShape(self);
40    MPSShape* index_shape = getMPSShape(index);
41    uint32_t num_input_dims = [input_shape count];
42    uint32_t num_index_dims = [index_shape count];
43    TORCH_CHECK(num_input_dims == num_index_dims, "Input and index must have same rank")
44
45    // Determine if we need to slice into the input tensor
46    bool needSlice = false;
47
48    for (const auto i : c10::irange(num_input_dims)) {
49      TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
50                  "Index dim must not exceed input dim except at gathering axis")
51      if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
52        needSlice = true;
53    }
54    auto input_type = getMPSDataType(self);
55    auto output_type = getMPSDataType(output);
56    if (input_type == MPSDataTypeUInt8) {
57      input_type = MPSDataTypeInt8;
58    }
59    if (output_type == MPSDataTypeUInt8) {
60      output_type = MPSDataTypeInt8;
61    }
62    string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim);
63    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
64      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, getMPSShape(self));
65      MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
66
67      MPSGraphTensor* getInput = inputTensor;
68
69      // Slice into the input tensor IF NEEDED
70      if (needSlice) {
71        NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
72        NSMutableArray<NSNumber*>* ends = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
73        NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
74
75        for (const auto i : c10::irange(num_input_dims)) {
76          // All strides are 1
77          strides[i] = @1;
78          // All starts are 0
79          starts[i] = @0;
80          ends[i] = (i != dim) ? index_shape[i] : input_shape[i];
81        }
82
83        getInput = [mpsGraph sliceTensor:inputTensor starts:starts ends:ends strides:strides name:nil];
84      }
85
86      MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor
87                                                      toType:MPSDataTypeInt32
88                                                        name:(NSString* _Nonnull)nil];
89
90      C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wobjc-method-access")
91      MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis:(NSInteger)dim
92                                             withUpdatesTensor:getInput
93                                                 indicesTensor:castIndexTensor
94                                                          name:nil];
95      C10_DIAGNOSTIC_POP()
96      newCachedGraph->inputTensor_ = inputTensor;
97      newCachedGraph->indexTensor_ = indexTensor;
98      newCachedGraph->outputTensor_ = outputTensor;
99    });
100
101    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type);
102    Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index, index_shape);
103    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nullptr, false, output_type);
104
105    auto feeds = dictionaryFromPlaceholders(selfPlaceholder, indexPlaceholder);
106    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
107  }
108}
109
110static void scatter_mps_general(const Tensor& self_arg,
111                                int64_t dim,
112                                const Tensor& index,
113                                const Tensor& src,
114                                const Tensor& output,
115                                string func_name,
116                                const c10::string_view reduce) {
117  using namespace mps;
118
119  if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
120    return;
121  }
122  auto self = self_arg.dim() == 0 ? self_arg.view({1}) : self_arg;
123  dim = at::maybe_wrap_dim(dim, self.dim());
124
125  TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(),
126              "scatter(): self, src and output must have the same scalar type");
127  TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor");
128  TORCH_CHECK(!self.is_complex(), "scatter(): Yet not supported for complex");
129
130  struct CachedGraph : public MPSCachedGraph {
131    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
132    MPSGraphTensor* inputTensor_ = nil;
133    MPSGraphTensor* indexTensor_ = nil;
134    MPSGraphTensor* srcTensor_ = nil;
135    MPSGraphTensor* outputTensor_ = nil;
136  };
137
138  @autoreleasepool {
139    MPSShape* input_shape = getMPSShape(self);
140    MPSShape* index_shape = getMPSShape(index);
141    MPSShape* src_shape = getMPSShape(src);
142    uint32_t num_input_dims = [input_shape count];
143    uint32_t num_index_dims = [index_shape count];
144    uint32_t num_src_dims = [src_shape count];
145
146    TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims,
147                "Input, index and src must have same rank")
148
149    // Do we need to slice into the src tensor?
150    bool needSlice = false;
151    bool inputNeedSlice = false;
152    bool needsCast = false;
153
154    for (const auto i : c10::irange(num_input_dims)) {
155      TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue],
156                  "Index dim must not exceed input dim except at gathering axis")
157      TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue],
158                  "Index dim must not exceed input dim except at gathering axis")
159      if ([index_shape[i] intValue] < [src_shape[i] intValue])
160        needSlice = true;
161      if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
162        inputNeedSlice = true;
163    }
164    TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS")
165
166    MPSDataType src_type = getMPSDataType(src);
167    if (reduce != "set" || src_type == MPSDataTypeUInt8 || src_type == MPSDataTypeBool) {
168      src_type = isFloatingType(src.scalar_type()) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
169      needsCast = true;
170    }
171
172    string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" +
173        std::string(reduce);
174    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
175      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
176      MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
177      MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src);
178
179      MPSGraphTensor* outputTensor = nil;
180      MPSGraphTensor* castSrcTensor = srcTensor;
181      MPSGraphTensor* castInputTensor = inputTensor;
182
183      if (needsCast) {
184        castSrcTensor = [mpsGraph castTensor:srcTensor toType:src_type name:@"cast"];
185        castInputTensor = [mpsGraph castTensor:inputTensor toType:src_type name:@"cast"];
186      }
187      MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor toType:MPSDataTypeInt32 name:@"cast"];
188
189      MPSGraphTensor* slicedSrc = castSrcTensor;
190      MPSGraphTensor* slicedInput = castInputTensor;
191
192      // Use in case input needs to be smaller to get scatter
193      NSMutableArray<NSNumber*>* scatterInputShape = [NSMutableArray arrayWithArray:input_shape];
194
195      // Slice into the src or input tensors IF NEEDED
196      if (needSlice || inputNeedSlice) {
197        NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
198        NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
199        NSMutableArray<NSNumber*>* ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
200
201        for (const auto i : c10::irange(num_input_dims)) {
202          strides[i] = @1;
203          starts[i] = @0;
204          ends_src[i] = index_shape[i];
205          scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i];
206        }
207        if (needSlice) {
208          slicedSrc = [mpsGraph sliceTensor:castSrcTensor starts:starts ends:ends_src strides:strides name:nil];
209        }
210        if (inputNeedSlice) {
211          slicedInput = [mpsGraph sliceTensor:castInputTensor
212                                       starts:starts
213                                         ends:scatterInputShape
214                                      strides:strides
215                                         name:nil];
216        }
217      }
218      MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet;
219
220      if (reduce == "sum" || reduce == "add")
221        scatter_mode = MPSGraphScatterModeAdd;
222      else if (reduce == "prod" || reduce == "multiply")
223        scatter_mode = MPSGraphScatterModeMul;
224      else if (reduce == "amax")
225        scatter_mode = MPSGraphScatterModeMax;
226      else if (reduce == "amin")
227        scatter_mode = MPSGraphScatterModeMin;
228
229      // Scatter this into the input with set mode
230      C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wobjc-method-access")
231      MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis:(NSInteger)dim
232                                                  withDataTensor:slicedInput
233                                                   updatesTensor:slicedSrc
234                                                   indicesTensor:castIndexTensor
235                                                            mode:scatter_mode
236                                                            name:nil];
237      C10_DIAGNOSTIC_POP()
238      if (inputNeedSlice) {
239        // Make an array of scatter indices tensors
240        NSMutableArray<MPSGraphTensor*>* indicesTensors =
241            [NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims];
242
243        // 1. Concatenate the coord tensors
244        // 2. Flatten the values
245        // 3. Scatter into input with add mode
246
247        std::vector<int> shape_data(num_input_dims);
248
249        for (const auto i : c10::irange(num_input_dims)) {
250          shape_data[i] = {[scatterInputShape[i] intValue]};
251        }
252
253        MPSGraphTensor* scatterInputShapeTensor =
254            [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)]
255                                 shape:@[ [NSNumber numberWithUnsignedInt:num_input_dims] ]
256                              dataType:MPSDataTypeInt32];
257
258        for (const auto i : c10::irange(num_input_dims)) {
259          MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i dataType:MPSDataTypeInt32];
260          MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor:axisTensor
261                                                                           withShapeTensor:scatterInputShapeTensor
262                                                                                      name:nil];
263          scatter_currentIndexTensor = [mpsGraph reshapeTensor:scatter_currentIndexTensor
264                                                     withShape:@[ @-1, @1 ]
265                                                          name:nil];
266          indicesTensors[i] = scatter_currentIndexTensor;
267        }
268
269        MPSGraphTensor* scatter_fullIndexTensor = [mpsGraph concatTensors:indicesTensors
270                                                                dimension:(NSInteger)1
271                                                                     name:nil];
272
273        MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor withShape:@[ @-1 ] name:nil];
274
275        outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor
276                                           updatesTensor:flatValuesTensor
277                                           indicesTensor:scatter_fullIndexTensor
278                                         batchDimensions:0
279                                                    mode:MPSGraphScatterModeSet
280                                                    name:nil];
281      } else {
282        outputTensor = scatterTensor;
283      }
284      newCachedGraph->inputTensor_ = inputTensor;
285      newCachedGraph->srcTensor_ = srcTensor;
286      newCachedGraph->indexTensor_ = indexTensor;
287      newCachedGraph->outputTensor_ =
288          needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor;
289    });
290
291    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape);
292    Placeholder srcPlaceholder = Placeholder(cachedGraph->srcTensor_, src, src_shape);
293    Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index, index_shape);
294    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
295
296    auto feeds = dictionaryFromPlaceholders(selfPlaceholder, srcPlaceholder, indexPlaceholder);
297    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
298  }
299}
300
301TORCH_IMPL_FUNC(scatter_src_out_mps)
302(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
303  scatter_mps_general(self, dim, index, src, output, "scatter_src_out_mps", "set");
304}
305
306TORCH_IMPL_FUNC(scatter_value_out_mps)
307(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const Tensor& output) {
308  Tensor src =
309      at::empty(index.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, self.suggest_memory_format());
310  src.fill_(value);
311  scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_out_mps", "set");
312}
313
314TORCH_IMPL_FUNC(scatter_reduce_out_mps)
315(const Tensor& self,
316 int64_t dim,
317 const Tensor& index,
318 const Tensor& src,
319 const c10::string_view reduce,
320 const Tensor& output) {
321  scatter_mps_general(self, dim, index, src, output, "scatter_reduce_out_mps", reduce);
322}
323
324TORCH_IMPL_FUNC(scatter_value_reduce_out_mps)
325(const Tensor& self,
326 int64_t dim,
327 const Tensor& index,
328 const Scalar& value,
329 const c10::string_view reduce,
330 const Tensor& output) {
331  Tensor src =
332      at::empty(index.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, self.suggest_memory_format());
333  src.fill_(value);
334  scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_reduce_out_mps", reduce);
335}
336
337TORCH_IMPL_FUNC(scatter_add_mps_out)
338(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) {
339  scatter_mps_general(self, dim, index, src, output, "scatter_add_mps_out", "add");
340}
341
342} // namespace at::native
343