1// Copyright © 2022 Apple Inc. 2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 3#include <ATen/native/mps/OperationUtils.h> 4 5#ifndef AT_PER_OPERATOR_HEADERS 6#include <ATen/Functions.h> 7#include <ATen/NativeFunctions.h> 8#else 9#include <ATen/ops/gather_native.h> 10#include <ATen/ops/scatter_add_native.h> 11#include <ATen/ops/scatter_native.h> 12#endif 13 14namespace at::native { 15 16TORCH_IMPL_FUNC(gather_out_mps) 17(const Tensor& self_arg, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& output) { 18 using namespace mps; 19 20 if (self_arg.numel() == 0 || index.numel() == 0) { 21 return; 22 } 23 auto self = self_arg.dim() == 0 ? self_arg.view({1}) : self_arg; 24 dim = at::maybe_wrap_dim(dim, self.dim()); 25 26 TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet") 27 TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type"); 28 TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor"); 29 TORCH_CHECK(!self.is_complex(), "gather(): Yet not supported for complex"); 30 31 struct CachedGraph : public MPSCachedGraph { 32 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 33 MPSGraphTensor* inputTensor_ = nil; 34 MPSGraphTensor* indexTensor_ = nil; 35 MPSGraphTensor* outputTensor_ = nil; 36 }; 37 38 @autoreleasepool { 39 MPSShape* input_shape = getMPSShape(self); 40 MPSShape* index_shape = getMPSShape(index); 41 uint32_t num_input_dims = [input_shape count]; 42 uint32_t num_index_dims = [index_shape count]; 43 TORCH_CHECK(num_input_dims == num_index_dims, "Input and index must have same rank") 44 45 // Determine if we need to slice into the input tensor 46 bool needSlice = false; 47 48 for (const auto i : c10::irange(num_input_dims)) { 49 TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], 50 "Index dim must not exceed input dim except at gathering axis") 51 if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) 52 needSlice = true; 53 } 54 auto input_type = getMPSDataType(self); 55 auto output_type = getMPSDataType(output); 56 if (input_type == MPSDataTypeUInt8) { 57 input_type = MPSDataTypeInt8; 58 } 59 if (output_type == MPSDataTypeUInt8) { 60 output_type = MPSDataTypeInt8; 61 } 62 string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim); 63 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 64 MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, getMPSShape(self)); 65 MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); 66 67 MPSGraphTensor* getInput = inputTensor; 68 69 // Slice into the input tensor IF NEEDED 70 if (needSlice) { 71 NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; 72 NSMutableArray<NSNumber*>* ends = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; 73 NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; 74 75 for (const auto i : c10::irange(num_input_dims)) { 76 // All strides are 1 77 strides[i] = @1; 78 // All starts are 0 79 starts[i] = @0; 80 ends[i] = (i != dim) ? index_shape[i] : input_shape[i]; 81 } 82 83 getInput = [mpsGraph sliceTensor:inputTensor starts:starts ends:ends strides:strides name:nil]; 84 } 85 86 MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor 87 toType:MPSDataTypeInt32 88 name:(NSString* _Nonnull)nil]; 89 90 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wobjc-method-access") 91 MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis:(NSInteger)dim 92 withUpdatesTensor:getInput 93 indicesTensor:castIndexTensor 94 name:nil]; 95 C10_DIAGNOSTIC_POP() 96 newCachedGraph->inputTensor_ = inputTensor; 97 newCachedGraph->indexTensor_ = indexTensor; 98 newCachedGraph->outputTensor_ = outputTensor; 99 }); 100 101 Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type); 102 Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index, index_shape); 103 Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nullptr, false, output_type); 104 105 auto feeds = dictionaryFromPlaceholders(selfPlaceholder, indexPlaceholder); 106 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); 107 } 108} 109 110static void scatter_mps_general(const Tensor& self_arg, 111 int64_t dim, 112 const Tensor& index, 113 const Tensor& src, 114 const Tensor& output, 115 string func_name, 116 const c10::string_view reduce) { 117 using namespace mps; 118 119 if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) { 120 return; 121 } 122 auto self = self_arg.dim() == 0 ? self_arg.view({1}) : self_arg; 123 dim = at::maybe_wrap_dim(dim, self.dim()); 124 125 TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(), 126 "scatter(): self, src and output must have the same scalar type"); 127 TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor"); 128 TORCH_CHECK(!self.is_complex(), "scatter(): Yet not supported for complex"); 129 130 struct CachedGraph : public MPSCachedGraph { 131 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 132 MPSGraphTensor* inputTensor_ = nil; 133 MPSGraphTensor* indexTensor_ = nil; 134 MPSGraphTensor* srcTensor_ = nil; 135 MPSGraphTensor* outputTensor_ = nil; 136 }; 137 138 @autoreleasepool { 139 MPSShape* input_shape = getMPSShape(self); 140 MPSShape* index_shape = getMPSShape(index); 141 MPSShape* src_shape = getMPSShape(src); 142 uint32_t num_input_dims = [input_shape count]; 143 uint32_t num_index_dims = [index_shape count]; 144 uint32_t num_src_dims = [src_shape count]; 145 146 TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims, 147 "Input, index and src must have same rank") 148 149 // Do we need to slice into the src tensor? 150 bool needSlice = false; 151 bool inputNeedSlice = false; 152 bool needsCast = false; 153 154 for (const auto i : c10::irange(num_input_dims)) { 155 TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], 156 "Index dim must not exceed input dim except at gathering axis") 157 TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue], 158 "Index dim must not exceed input dim except at gathering axis") 159 if ([index_shape[i] intValue] < [src_shape[i] intValue]) 160 needSlice = true; 161 if (i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) 162 inputNeedSlice = true; 163 } 164 TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS") 165 166 MPSDataType src_type = getMPSDataType(src); 167 if (reduce != "set" || src_type == MPSDataTypeUInt8 || src_type == MPSDataTypeBool) { 168 src_type = isFloatingType(src.scalar_type()) ? MPSDataTypeFloat32 : MPSDataTypeInt32; 169 needsCast = true; 170 } 171 172 string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + 173 std::string(reduce); 174 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 175 MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); 176 MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); 177 MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src); 178 179 MPSGraphTensor* outputTensor = nil; 180 MPSGraphTensor* castSrcTensor = srcTensor; 181 MPSGraphTensor* castInputTensor = inputTensor; 182 183 if (needsCast) { 184 castSrcTensor = [mpsGraph castTensor:srcTensor toType:src_type name:@"cast"]; 185 castInputTensor = [mpsGraph castTensor:inputTensor toType:src_type name:@"cast"]; 186 } 187 MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor toType:MPSDataTypeInt32 name:@"cast"]; 188 189 MPSGraphTensor* slicedSrc = castSrcTensor; 190 MPSGraphTensor* slicedInput = castInputTensor; 191 192 // Use in case input needs to be smaller to get scatter 193 NSMutableArray<NSNumber*>* scatterInputShape = [NSMutableArray arrayWithArray:input_shape]; 194 195 // Slice into the src or input tensors IF NEEDED 196 if (needSlice || inputNeedSlice) { 197 NSMutableArray<NSNumber*>* starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; 198 NSMutableArray<NSNumber*>* strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; 199 NSMutableArray<NSNumber*>* ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; 200 201 for (const auto i : c10::irange(num_input_dims)) { 202 strides[i] = @1; 203 starts[i] = @0; 204 ends_src[i] = index_shape[i]; 205 scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i]; 206 } 207 if (needSlice) { 208 slicedSrc = [mpsGraph sliceTensor:castSrcTensor starts:starts ends:ends_src strides:strides name:nil]; 209 } 210 if (inputNeedSlice) { 211 slicedInput = [mpsGraph sliceTensor:castInputTensor 212 starts:starts 213 ends:scatterInputShape 214 strides:strides 215 name:nil]; 216 } 217 } 218 MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet; 219 220 if (reduce == "sum" || reduce == "add") 221 scatter_mode = MPSGraphScatterModeAdd; 222 else if (reduce == "prod" || reduce == "multiply") 223 scatter_mode = MPSGraphScatterModeMul; 224 else if (reduce == "amax") 225 scatter_mode = MPSGraphScatterModeMax; 226 else if (reduce == "amin") 227 scatter_mode = MPSGraphScatterModeMin; 228 229 // Scatter this into the input with set mode 230 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wobjc-method-access") 231 MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis:(NSInteger)dim 232 withDataTensor:slicedInput 233 updatesTensor:slicedSrc 234 indicesTensor:castIndexTensor 235 mode:scatter_mode 236 name:nil]; 237 C10_DIAGNOSTIC_POP() 238 if (inputNeedSlice) { 239 // Make an array of scatter indices tensors 240 NSMutableArray<MPSGraphTensor*>* indicesTensors = 241 [NSMutableArray<MPSGraphTensor*> arrayWithCapacity:num_input_dims]; 242 243 // 1. Concatenate the coord tensors 244 // 2. Flatten the values 245 // 3. Scatter into input with add mode 246 247 std::vector<int> shape_data(num_input_dims); 248 249 for (const auto i : c10::irange(num_input_dims)) { 250 shape_data[i] = {[scatterInputShape[i] intValue]}; 251 } 252 253 MPSGraphTensor* scatterInputShapeTensor = 254 [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)] 255 shape:@[ [NSNumber numberWithUnsignedInt:num_input_dims] ] 256 dataType:MPSDataTypeInt32]; 257 258 for (const auto i : c10::irange(num_input_dims)) { 259 MPSGraphTensor* axisTensor = [mpsGraph constantWithScalar:i dataType:MPSDataTypeInt32]; 260 MPSGraphTensor* scatter_currentIndexTensor = [mpsGraph coordinateAlongAxisTensor:axisTensor 261 withShapeTensor:scatterInputShapeTensor 262 name:nil]; 263 scatter_currentIndexTensor = [mpsGraph reshapeTensor:scatter_currentIndexTensor 264 withShape:@[ @-1, @1 ] 265 name:nil]; 266 indicesTensors[i] = scatter_currentIndexTensor; 267 } 268 269 MPSGraphTensor* scatter_fullIndexTensor = [mpsGraph concatTensors:indicesTensors 270 dimension:(NSInteger)1 271 name:nil]; 272 273 MPSGraphTensor* flatValuesTensor = [mpsGraph reshapeTensor:scatterTensor withShape:@[ @-1 ] name:nil]; 274 275 outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor 276 updatesTensor:flatValuesTensor 277 indicesTensor:scatter_fullIndexTensor 278 batchDimensions:0 279 mode:MPSGraphScatterModeSet 280 name:nil]; 281 } else { 282 outputTensor = scatterTensor; 283 } 284 newCachedGraph->inputTensor_ = inputTensor; 285 newCachedGraph->srcTensor_ = srcTensor; 286 newCachedGraph->indexTensor_ = indexTensor; 287 newCachedGraph->outputTensor_ = 288 needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor; 289 }); 290 291 Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape); 292 Placeholder srcPlaceholder = Placeholder(cachedGraph->srcTensor_, src, src_shape); 293 Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index, index_shape); 294 Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); 295 296 auto feeds = dictionaryFromPlaceholders(selfPlaceholder, srcPlaceholder, indexPlaceholder); 297 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); 298 } 299} 300 301TORCH_IMPL_FUNC(scatter_src_out_mps) 302(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) { 303 scatter_mps_general(self, dim, index, src, output, "scatter_src_out_mps", "set"); 304} 305 306TORCH_IMPL_FUNC(scatter_value_out_mps) 307(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value, const Tensor& output) { 308 Tensor src = 309 at::empty(index.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, self.suggest_memory_format()); 310 src.fill_(value); 311 scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_out_mps", "set"); 312} 313 314TORCH_IMPL_FUNC(scatter_reduce_out_mps) 315(const Tensor& self, 316 int64_t dim, 317 const Tensor& index, 318 const Tensor& src, 319 const c10::string_view reduce, 320 const Tensor& output) { 321 scatter_mps_general(self, dim, index, src, output, "scatter_reduce_out_mps", reduce); 322} 323 324TORCH_IMPL_FUNC(scatter_value_reduce_out_mps) 325(const Tensor& self, 326 int64_t dim, 327 const Tensor& index, 328 const Scalar& value, 329 const c10::string_view reduce, 330 const Tensor& output) { 331 Tensor src = 332 at::empty(index.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, self.suggest_memory_format()); 333 src.fill_(value); 334 scatter_mps_general(self, dim, index, const_cast<Tensor&>(src), output, "scatter_value_reduce_out_mps", reduce); 335} 336 337TORCH_IMPL_FUNC(scatter_add_mps_out) 338(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, const Tensor& output) { 339 scatter_mps_general(self, dim, index, src, output, "scatter_add_mps_out", "add"); 340} 341 342} // namespace at::native 343