Searched defs:tile_assignment (Results 1 – 8 of 8) sorted by relevance
75 def tile(cls, tile_assignment): argument103 def partial_tile(cls, tile_assignment): argument280 tile_assignment, argument321 def partial_tile(tensor, tile_assignment, use_sharding_op=False): argument
303 Array<int64> tile_assignment = sharding.tile_assignment(); in TransposeSharding() local502 Array<int64> tile_assignment(tile_dims); in ReshapeToTileDimension() local663 Array<int64> tile_assignment = in GatherEffectiveOutputSharding() local774 Array<int64> tile_assignment = in ScatterEffectiveIndexSharding() local817 Array<int64> tile_assignment = in ScatterEffectiveDataSharding() local865 Array<int64> tile_assignment = operand_sharding.tile_assignment(); in PassthroughOperandToGatherOutputOrScatterUpdate() local908 Array<int64> tile_assignment = in PassthroughGatherOutputOrScatterUpdateToOperand() local967 Array<int64> tile_assignment = relevant_output_sharding.tile_assignment(); in GatherParallelDataOperandSharding() local
559 const auto& tile_assignment = lhs->sharding().tile_assignment(); in InferConvolutionShardingFromOperands() local942 const auto& tile_assignment = operand->sharding().tile_assignment(); in InferShardingFromOperands() local1182 const Array<int64>& tile_assignment = user.sharding().tile_assignment(); in GetShardingFromUser() local1244 const auto& tile_assignment = user.sharding().tile_assignment(); in GetShardingFromUser() local1353 auto tile_assignment = user_sharding.tile_assignment(); in GetShardingFromUser() local
284 const Array<int64>& tile_assignment() const { return tile_assignment_; } in tile_assignment() function
114 const Array<int64>& tile_assignment, in Subgroup()600 Array<int64> tile_assignment( in FromProto() local
2989 Array<int64> tile_assignment({2, 2, 2, 2}); local
42 const TileAssignment& tile_assignment) { in Tile()
1832 xla::Array<int64> tile_assignment({2}); in TEST_F() local