Home
last modified time | relevance | path

Searched refs:retval_sharding (Results 1 – 3 of 3) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_sharding_identification_pass.cc285 if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) { in IdentifyXlaShardingForComputationOutputs() local
286 sharding_for_rets.push_back(retval_sharding.getValue()); in IdentifyXlaShardingForComputationOutputs()
288 builder->getStringAttr(retval_sharding.getValue())); in IdentifyXlaShardingForComputationOutputs()
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.h314 std::vector<::xla::OpSharding>* retval_sharding,
361 const std::vector<::xla::OpSharding>& retval_sharding,
Ddistributed_tpu_rewrite_pass.cc1920 std::vector<xla::OpSharding>* retval_sharding, in AssignArgsAndRetvalsToCores() argument
2117 retval_sharding->resize(retvals.size()); in AssignArgsAndRetvalsToCores()
2209 (*retval_sharding)[i] = node_and_sharding->sharding; in AssignArgsAndRetvalsToCores()
2216 absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) { in AssignArgsAndRetvalsToCores()
2226 arg_fast_mem, retval_sharding, arg_names); in AssignArgsAndRetvalsToCores()
2314 const std::vector<xla::OpSharding>& retval_sharding, in BuildCompileNode() argument
2333 !absl::c_any_of(retval_sharding, [](const xla::OpSharding& s) { in BuildCompileNode()
2409 const int num_retvals = retval_sharding.size(); in BuildCompileNode()
2411 *proto.add_retvals()->mutable_sharding() = retval_sharding[i]; in BuildCompileNode()
4261 std::vector<xla::OpSharding> retval_sharding; in RewriteTPUReplicateNode() local
[all …]