#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #endif namespace at::native { namespace { using namespace mps; static MetalShaderLibrary lib(R"RENORM_METAL( #include using namespace metal; template kernel void renorm(constant T* norm [[buffer(0)]], device T* factor [[buffer(1)]], constant float& maxnorm [[buffer(2)]], uint index [[thread_position_in_grid]]) { constexpr T eps = 1e-7; constexpr T one = 1; factor[index] = norm[index] > maxnorm ? maxnorm / (norm[index] + eps) : one; } #define REGISTER_RENORM_OP(DTYPE) \ template \ [[host_name("renorm_" #DTYPE)]] \ kernel void renorm(constant DTYPE* norm [[buffer(0)]], \ device DTYPE* factor [[buffer(1)]], \ constant float& maxnorm [[buffer(2)]], \ uint index [[thread_position_in_grid]]); REGISTER_RENORM_OP(float); REGISTER_RENORM_OP(half); )RENORM_METAL"); void renorm_out_mps(const Tensor& self, const Scalar& p, int64_t dim, const Scalar& maxnorm, const Tensor& out) { auto self_sizes = self.sizes(); dim = c10::maybe_wrap_dim(dim, self_sizes.size()); DimVector reduce_dims(self_sizes.size()); std::iota(reduce_dims.begin(), reduce_dims.end(), 0); reduce_dims.erase(reduce_dims.begin() + dim); Tensor norm = at::linalg_vector_norm(self, p.toDouble(), reduce_dims, /*keepdim=*/true); auto factor = at::empty(norm.sizes(), self.options()); id device = MPSDevice::getInstance()->device(); id normBuffer = getMTLBufferStorage(norm); id factorBuffer = getMTLBufferStorage(factor); string key = "renorm_" + scalarToMetalTypeString(self); MPSStream* mpsStream = getCurrentMPSStream(); id computeEncoder = mpsStream->commandEncoder(); id renormPSO = lib.getPipelineStateForFunc(key); dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { // this function call is a no-op if MPSProfiler is not enabled getMPSProfiler().beginProfileKernel(renormPSO, key, {norm}); [computeEncoder setComputePipelineState:renormPSO]; mtl_setBuffer(computeEncoder, norm, 0); mtl_setBuffer(computeEncoder, factor, 1); mtl_setBytes(computeEncoder, maxnorm.to(), 2); mtl_dispatch1DJob(computeEncoder, renormPSO, norm.numel()); getMPSProfiler().endProfileKernel(renormPSO); } }); at::mul_outf(self, factor, const_cast(out)); } } // namespace TORCH_IMPL_FUNC(renorm_out_mps) (const Tensor& self, const Scalar& p, int64_t dim, const Scalar& maxnorm, const Tensor& out) { renorm_out_mps(self, p, dim, maxnorm, out); } } // namespace at::native