Searched refs:output_shards (Results 1 – 4 of 4) sorted by relevance
337 std::vector<Tensor> output_shards(num_batches * num_returns); in batchedTensorForLoopFallback() local371 output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor(); in batchedTensorForLoopFallback()378 auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches); in batchedTensorForLoopFallback()467 std::vector<Tensor> output_shards(num_components * num_returns); in batchedNestedTensorForLoopFallback() local494 output_shards[num_components * return_idx + component_idx] = returns[return_idx].toTensor(); in batchedNestedTensorForLoopFallback()502 auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_components); in batchedNestedTensorForLoopFallback()
321 std::vector<Tensor> output_shards(num_batches * num_returns); in batchedTensorForLoopFallback() local351 output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor(); in batchedTensorForLoopFallback()358 auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches); in batchedTensorForLoopFallback()
400 std::vector<int32> output_shards = output_layout->num_shards(); in ExpandOp() local401 if (operand_shards.size() != output_shards.size()) { in ExpandOp()411 if (output_shards[dim_index] > static_multiples[dim_index]) in ExpandOp()417 if (static_multiples[dim_index] % output_shards[dim_index] != 0) in ExpandOp()432 output_shards[dim_index]); in ExpandOp()
330 output_shards = [output.chunk(group.size()) for output in outputs]335 mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank])