• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from model import get_custom_op_library_path
2
3import torch
4
5
6torch.ops.load_library(get_custom_op_library_path())
7
8
9@torch.library.impl_abstract("custom::nonzero")
10def nonzero_abstract(x):
11    n = x.dim()
12    ctx = torch.library.get_ctx()
13    nnz = ctx.create_unbacked_symint()
14    shape = [nnz, n]
15    return x.new_empty(shape, dtype=torch.long)
16