1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/core/Tensor.h> 3#include <ATen/native/mps/OperationUtils.h> 4 5#ifndef AT_PER_OPERATOR_HEADERS 6#include <ATen/Functions.h> 7#include <ATen/NativeFunctions.h> 8#else 9#include <ATen/ops/add.h> 10#include <ATen/ops/lerp_native.h> 11#endif 12 13namespace at::native { 14TORCH_IMPL_FUNC(lerp_Tensor_mps)(const Tensor& self, const Tensor& end, const Tensor& weight, const Tensor& out) { 15 TORCH_CHECK(out.is_mps()); 16 std::array<TensorArg, 4> args{{{out, "out", 0}, {self, "self", 1}, {end, "end", 2}, {weight, "weight", 3}}}; 17 checkAllSameGPU(__func__, args); 18 using namespace mps; 19 struct CachedGraph : public MPSCachedGraph { 20 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 21 MPSGraphTensor* selfTensor_ = nil; 22 MPSGraphTensor* endTensor_ = nil; 23 MPSGraphTensor* weightTensor_ = nil; 24 MPSGraphTensor* outputTensor_ = nil; 25 }; 26 @autoreleasepool { 27 string key = "lerp_Tensor_mps" + getTensorsStringKey({self, end, weight}); 28 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto graph) { 29 auto selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); 30 auto endTensor = mpsGraphRankedPlaceHolder(mpsGraph, end); 31 auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); 32 auto distance = [mpsGraph subtractionWithPrimaryTensor:endTensor secondaryTensor:selfTensor name:nil]; 33 auto weighedDistance = [mpsGraph multiplicationWithPrimaryTensor:weightTensor secondaryTensor:distance name:nil]; 34 auto output = [mpsGraph additionWithPrimaryTensor:selfTensor secondaryTensor:weighedDistance name:nil]; 35 graph->selfTensor_ = selfTensor; 36 graph->endTensor_ = endTensor; 37 graph->weightTensor_ = weightTensor; 38 graph->outputTensor_ = output; 39 }); 40 auto selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); 41 auto endPlaceholder = Placeholder(cachedGraph->endTensor_, end); 42 auto weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); 43 auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); 44 auto feeds = dictionaryFromPlaceholders(selfPlaceholder, endPlaceholder, weightPlaceholder); 45 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); 46 } 47} 48 49} // namespace at::native 50