Home
last modified time | relevance | path

Searched refs:output_shardings (Results 1 – 1 of 1) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_rewrite_pass.cc253 auto output_shardings = in SetMetadataProtoRetvals() local
255 if (!output_shardings) in SetMetadataProtoRetvals()
259 if (output_shardings.size() != op.getNumResults()) in SetMetadataProtoRetvals()
262 op.getNumResults(), output_shardings.size())); in SetMetadataProtoRetvals()
265 for (auto output_sharding_and_idx : llvm::enumerate(output_shardings)) in SetMetadataProtoRetvals()
820 llvm::SmallVector<xla::OpSharding, 4> output_shardings; in Rewrite() local
822 num_cores_per_replica, cluster_func, &output_shardings); in Rewrite()
832 tpu_device_assignment.tpu_devices, output_shardings, in Rewrite()
842 cluster_func.getLoc(), output_shardings, cluster_to_core_index, in Rewrite()