Searched refs:tile_assignment (Results 1 – 7 of 7) sorted by relevance
64 def tile(cls, tile_assignment): argument81 if not isinstance(tile_assignment, _np.ndarray):83 dims = list(tile_assignment.shape)84 flattened_devices = tile_assignment.reshape(-1, order='C')179 def tile(tensor, tile_assignment): argument180 Sharding.tile(tile_assignment).apply_to_tensor(tensor)
53 static HloSharding Tile(const Array<int64>& tile_assignment) { in Tile() argument54 return HloSharding(tile_assignment); in Tile()205 const Array<int64>& tile_assignment() const { return tile_assignment_; } in tile_assignment() function235 explicit HloSharding(const Array<int64>& tile_assignment) in HloSharding() argument239 tile_assignment_(tile_assignment) {} in HloSharding()
399 Array<int64> tile_assignment( in FromProto() local403 proto.tile_assignment_devices().end(), tile_assignment.begin()); in FromProto()404 return HloSharding(tile_assignment); in FromProto()
36 const TileAssignment& tile_assignment) { in Tile() argument40 for (int64 dim : tile_assignment.dimensions()) { in Tile()43 for (uint32 device : tile_assignment) { in Tile()
45 OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment);
101 tile_assignment = np.arange(np.prod(dims)).reshape(dims)102 return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
584 // None of the above; tile_shape and tile_assignment are both used.