Home
last modified time | relevance | path

Searched refs:partitioned_input (Results 1 – 8 of 8) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_resource_partitioning.cc92 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>( in PartitionResourceReadsWrites() local
94 if (!partitioned_input || in PartitionResourceReadsWrites()
95 !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes())) in PartitionResourceReadsWrites()
100 partitioned_output_types.reserve(partitioned_input.N()); in PartitionResourceReadsWrites()
101 for (Type input_type : partitioned_input.inputs().getTypes()) in PartitionResourceReadsWrites()
105 partitioned_input.partition_dimAttr(), in PartitionResourceReadsWrites()
106 partitioned_input._XlaShardingAttr()); in PartitionResourceReadsWrites()
108 llvm::zip(partitioned_input.inputs(), partitioned_output.output())) in PartitionResourceReadsWrites()
119 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>( in PartitionResourceReadsWrites() local
121 if (!partitioned_input || !partitioned_input.output().hasOneUse() || in PartitionResourceReadsWrites()
[all …]
Dtpu_reorder_replicate_and_partitioned_inputs.cc54 auto partitioned_input = in ReorderReplicateAndPartitionedInputs() local
57 partitioned_input._XlaSharding(); in ReorderReplicateAndPartitionedInputs()
58 int64_t op_partition_dim = partitioned_input.partition_dim(); in ReorderReplicateAndPartitionedInputs()
61 return partitioned_input->emitOpError() in ReorderReplicateAndPartitionedInputs()
64 if (partitioned_input.getNumOperands() != num_cores_per_replica) in ReorderReplicateAndPartitionedInputs()
65 return partitioned_input->emitOpError() in ReorderReplicateAndPartitionedInputs()
67 << partitioned_input.getNumOperands(); in ReorderReplicateAndPartitionedInputs()
129 getFunction()->walk([](TF::TPUPartitionedInputOp partitioned_input) { in runOnFunction() argument
130 if (partitioned_input->use_empty()) partitioned_input->erase(); in runOnFunction()
Dtpu_sharding_identification_pass.cc63 if (auto partitioned_input = in GetXlaShardingFromOperand() local
66 return partitioned_input._XlaSharding(); in GetXlaShardingFromOperand()
198 if (auto partitioned_input = in GetXlaShardingFromResult() local
201 return partitioned_input._XlaSharding(); in GetXlaShardingFromResult()
Dtf_passes.td579 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", par…
580 …%computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla…
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dfft_handler.cc370 auto partitioned_input = in HandleFft() local
382 auto result = partitioned_input.hlo(); in HandleFft()
384 partitioned_input.hlo(), num_partitions_, hlo->sharding(), in HandleFft()
385 partitioned_input.state().collective_ops_creator, in HandleFft()
386 partitioned_input.state().next_channel_id, in HandleFft()
387 partitioned_input.state().partition_id, partitioned_input.state().b); in HandleFft()
397 partitioned_input.state().b); in HandleFft()
403 result, num_partitions_, partitioned_input.state().collective_ops_creator, in HandleFft()
404 partitioned_input.state().next_channel_id, partitioned_input.state().b); in HandleFft()
406 result = SliceValidData(result, partitioned_input.hlo()->shape(), &b_); in HandleFft()
[all …]
Dspmd_partitioner.cc1671 auto partitioned_input = GetPartitionedHlo(input).PadWithValue( in HandleSort() local
1693 shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()})); in HandleSort()
1834 auto partitioned_input = GetPartitionedHlo(input).PadWithValue( in HandleCustomCall() local
1848 hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()})); in HandleCustomCall()
2360 const auto& partitioned_input = in HandleDynamicUpdateSlice() local
2375 const auto& partitioned_shape = partitioned_input->shape(); in HandleDynamicUpdateSlice()
2441 partitioned_shape, partitioned_input, replicate_update, new_indices)); in HandleDynamicUpdateSlice()
2449 dus, partitioned_input)); in HandleDynamicUpdateSlice()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dxla_sharding_util.cc257 if (auto partitioned_input = in ExtractInputsForLogicalDevices() local
269 partitioned_input.getOperand(index_and_inputs.index())); in ExtractInputsForLogicalDevices()
273 if (partitioned_input.inputs().size() != num_cores_per_replica) in ExtractInputsForLogicalDevices()
274 return tiled_sharding_mismatched(partitioned_input.inputs().size()); in ExtractInputsForLogicalDevices()
280 partitioned_input.inputs()[i]); in ExtractInputsForLogicalDevices()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/
Dtpu_rewrite.mlir1224 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i…
1236 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @…
1262 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {_XlaSharding = "\08\03\1A\02\01\02\…
1274 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @…
1297 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i…
1299 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @…
1320 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i…
1322 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @…
1343 …%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i…
1344 …%computation = "tf_device.cluster_func"(%partitioned_input) {_tpu_replicate = "cluster0", func = @…