Lines Matching defs:group
12 def broadcast(tensor, src, group=group.WORLD): argument
32 def gather(tensor, dst=0, group=group.WORLD): argument
47 def scatter(tensors, src=0, group=group.WORLD): argument
67 def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): argument
88 def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): argument
107 def all_gather(tensor, group=group.WORLD): argument
122 def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): argument
158 def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): argument
179 group=group.WORLD, argument
205 def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): argument
228 def forward(ctx, src, group, tensor): argument
248 def forward(ctx, dst, group, tensor): argument
273 def forward(ctx, src, group, *tensors): argument
291 def forward(ctx, src, op, group, tensor): argument
305 def forward(ctx, op, group, tensor, *input_tensor_list): argument
320 def forward(ctx, group, tensor): argument
349 def forward(ctx, output_tensor, input_tensor, group): argument
376 def forward(ctx, group, out_tensor_list, *tensors): argument
411 def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): argument
443 def forward(ctx, op, group, tensor): argument