Searched refs:partitioned_input (Results 1 – 8 of 8) sorted by relevance
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | tpu_resource_partitioning.cc | 92 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>( in PartitionResourceReadsWrites() local 94 if (!partitioned_input || in PartitionResourceReadsWrites() 95 !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes())) in PartitionResourceReadsWrites() 100 partitioned_output_types.reserve(partitioned_input.N()); in PartitionResourceReadsWrites() 101 for (Type input_type : partitioned_input.inputs().getTypes()) in PartitionResourceReadsWrites() 105 partitioned_input.partition_dimAttr(), in PartitionResourceReadsWrites() 106 partitioned_input._XlaShardingAttr()); in PartitionResourceReadsWrites() 108 llvm::zip(partitioned_input.inputs(), partitioned_output.output())) in PartitionResourceReadsWrites() 119 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>( in PartitionResourceReadsWrites() local 121 if (!partitioned_input || !partitioned_input.output().hasOneUse() || in PartitionResourceReadsWrites() [all …]
|
D | tpu_reorder_replicate_and_partitioned_inputs.cc | 54 auto partitioned_input = in ReorderReplicateAndPartitionedInputs() local 57 partitioned_input._XlaSharding(); in ReorderReplicateAndPartitionedInputs() 58 int64_t op_partition_dim = partitioned_input.partition_dim(); in ReorderReplicateAndPartitionedInputs() 61 return partitioned_input->emitOpError() in ReorderReplicateAndPartitionedInputs() 64 if (partitioned_input.getNumOperands() != num_cores_per_replica) in ReorderReplicateAndPartitionedInputs() 65 return partitioned_input->emitOpError() in ReorderReplicateAndPartitionedInputs() 67 << partitioned_input.getNumOperands(); in ReorderReplicateAndPartitionedInputs() 129 getFunction()->walk([](TF::TPUPartitionedInputOp partitioned_input) { in runOnFunction() argument 130 if (partitioned_input->use_empty()) partitioned_input->erase(); in runOnFunction()
|
D | tpu_sharding_identification_pass.cc | 63 if (auto partitioned_input = in GetXlaShardingFromOperand() local 66 return partitioned_input._XlaSharding(); in GetXlaShardingFromOperand() 198 if (auto partitioned_input = in GetXlaShardingFromResult() local 201 return partitioned_input._XlaSharding(); in GetXlaShardingFromResult()
|
D | tf_passes.td | 579 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", par… 580 …%computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla…
|
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | fft_handler.cc | 370 auto partitioned_input = in HandleFft() local 382 auto result = partitioned_input.hlo(); in HandleFft() 384 partitioned_input.hlo(), num_partitions_, hlo->sharding(), in HandleFft() 385 partitioned_input.state().collective_ops_creator, in HandleFft() 386 partitioned_input.state().next_channel_id, in HandleFft() 387 partitioned_input.state().partition_id, partitioned_input.state().b); in HandleFft() 397 partitioned_input.state().b); in HandleFft() 403 result, num_partitions_, partitioned_input.state().collective_ops_creator, in HandleFft() 404 partitioned_input.state().next_channel_id, partitioned_input.state().b); in HandleFft() 406 result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_); in HandleFft() [all …]
|
D | spmd_partitioner.cc | 1671 auto partitioned_input = GetPartitionedHlo(input).PadWithValue( in HandleSort() local 1693 shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()})); in HandleSort() 1834 auto partitioned_input = GetPartitionedHlo(input).PadWithValue( in HandleCustomCall() local 1848 hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()})); in HandleCustomCall() 2360 const auto& partitioned_input = in HandleDynamicUpdateSlice() local 2375 const auto& partitioned_shape = partitioned_input->shape(); in HandleDynamicUpdateSlice() 2441 partitioned_shape, partitioned_input, replicate_update, new_indices)); in HandleDynamicUpdateSlice() 2449 dus, partitioned_input)); in HandleDynamicUpdateSlice()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/ |
D | xla_sharding_util.cc | 257 if (auto partitioned_input = in ExtractInputsForLogicalDevices() local 269 partitioned_input.getOperand(index_and_inputs.index())); in ExtractInputsForLogicalDevices() 273 if (partitioned_input.inputs().size() != num_cores_per_replica) in ExtractInputsForLogicalDevices() 274 return tiled_sharding_mismatched(partitioned_input.inputs().size()); in ExtractInputsForLogicalDevices() 280 partitioned_input.inputs()[i]); in ExtractInputsForLogicalDevices()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/ |
D | tpu_rewrite.mlir | 1224 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i… 1236 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @… 1262 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {_XlaSharding = "\08\03\1A\02\01\02\… 1274 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @… 1297 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i… 1299 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @… 1320 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i… 1322 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @… 1343 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i… 1344 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @…
|