Lines Matching refs:ScaledTensor
218 class ScaledTensor(torch.Tensor): class
249 return ScaledTensor(
259 return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant)
1300 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1301 sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7))
1305 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1306 sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
1310 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6))
1314 sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6))
1319 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
1320 sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
1324 sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3))
1328 sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5))
1693 x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2)