1from model import get_custom_op_library_path 2 3import torch 4 5 6torch.ops.load_library(get_custom_op_library_path()) 7 8 9@torch.library.impl_abstract("custom::nonzero") 10def nonzero_abstract(x): 11 n = x.dim() 12 ctx = torch.library.get_ctx() 13 nnz = ctx.create_unbacked_symint() 14 shape = [nnz, n] 15 return x.new_empty(shape, dtype=torch.long) 16