Home
last modified time | relevance | path

Searched defs:arg_sharding (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_sharding_identification_pass.cc168 auto arg_sharding = GetXlaShardingFromArg(arg); in IdentifyXlaShardingForComputationInputs() local
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_compiler.cc835 const absl::optional<xla::HloSharding>& arg_sharding, in XLAShapeForArgument()
1026 auto arg_sharding = arg_shardings.find((*input_to_args)[i]); in BuildArguments() local
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dcompile_mlir_util.cc113 absl::optional<xla::HloSharding> arg_sharding; in GetXlaInputShapes() local
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.cc1919 std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem, in AssignArgsAndRetvalsToCores()
2311 const std::vector<xla::OpSharding>& arg_sharding, in BuildCompileNode()
4258 std::vector<xla::OpSharding> arg_sharding; in RewriteTPUReplicateNode() local
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dmlir_hlo_to_hlo.cc1486 for (auto arg_sharding : llvm::enumerate(arg_shardings)) { in SetEntryTupleShardings() local