Lines Matching refs:x_tensor
154 def _iter_tensor(x_tensor): argument
171 if _is_sparse_any_tensor(x_tensor):
182 x_nnz = x_tensor._nnz()
183 x_size = list(x_tensor.size())
184 if x_tensor.layout is torch.sparse_coo:
185 x_indices = x_tensor._indices().t()
186 x_values = x_tensor._values()
187 elif x_tensor.layout is torch.sparse_csr:
189 x_tensor.crow_indices(), x_tensor.col_indices()
191 x_values = x_tensor.values()
192 elif x_tensor.layout is torch.sparse_csc:
194 x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
196 x_values = x_tensor.values()
197 elif x_tensor.layout is torch.sparse_bsr:
198 x_block_values = x_tensor.values()
202 x_tensor.crow_indices(), x_tensor.col_indices()
205 .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
208 torch.where(torch.ones(x_blocksize, device=x_tensor.device))
215 elif x_tensor.layout is torch.sparse_bsc:
216 x_block_values = x_tensor.values()
220 x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True
223 .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1))
226 torch.where(torch.ones(x_blocksize, device=x_tensor.device))
244 elif x_tensor.layout == torch._mkldnn: # type: ignore[attr-defined]
245 for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
248 x_tensor_dense = x_tensor.to_dense()
252 x_tensor = x_tensor.data
253 for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
254 yield x_tensor, x_idx, d_idx