• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/core/Tensor.h>
3#include <ATen/native/mps/OperationUtils.h>
4
5#ifndef AT_PER_OPERATOR_HEADERS
6#include <ATen/Functions.h>
7#include <ATen/NativeFunctions.h>
8#else
9#include <ATen/ops/add.h>
10#include <ATen/ops/lerp_native.h>
11#endif
12
13namespace at::native {
14TORCH_IMPL_FUNC(lerp_Tensor_mps)(const Tensor& self, const Tensor& end, const Tensor& weight, const Tensor& out) {
15  TORCH_CHECK(out.is_mps());
16  std::array<TensorArg, 4> args{{{out, "out", 0}, {self, "self", 1}, {end, "end", 2}, {weight, "weight", 3}}};
17  checkAllSameGPU(__func__, args);
18  using namespace mps;
19  struct CachedGraph : public MPSCachedGraph {
20    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
21    MPSGraphTensor* selfTensor_ = nil;
22    MPSGraphTensor* endTensor_ = nil;
23    MPSGraphTensor* weightTensor_ = nil;
24    MPSGraphTensor* outputTensor_ = nil;
25  };
26  @autoreleasepool {
27    string key = "lerp_Tensor_mps" + getTensorsStringKey({self, end, weight});
28    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto graph) {
29      auto selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
30      auto endTensor = mpsGraphRankedPlaceHolder(mpsGraph, end);
31      auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
32      auto distance = [mpsGraph subtractionWithPrimaryTensor:endTensor secondaryTensor:selfTensor name:nil];
33      auto weighedDistance = [mpsGraph multiplicationWithPrimaryTensor:weightTensor secondaryTensor:distance name:nil];
34      auto output = [mpsGraph additionWithPrimaryTensor:selfTensor secondaryTensor:weighedDistance name:nil];
35      graph->selfTensor_ = selfTensor;
36      graph->endTensor_ = endTensor;
37      graph->weightTensor_ = weightTensor;
38      graph->outputTensor_ = output;
39    });
40    auto selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
41    auto endPlaceholder = Placeholder(cachedGraph->endTensor_, end);
42    auto weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight);
43    auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
44    auto feeds = dictionaryFromPlaceholders(selfPlaceholder, endPlaceholder, weightPlaceholder);
45    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
46  }
47}
48
49} // namespace at::native
50