Searched refs:retval_sharding (Results 1 – 3 of 3) sorted by relevance
285 if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) { in IdentifyXlaShardingForComputationOutputs() local286 sharding_for_rets.push_back(retval_sharding.getValue()); in IdentifyXlaShardingForComputationOutputs()288 builder->getStringAttr(retval_sharding.getValue())); in IdentifyXlaShardingForComputationOutputs()
314 std::vector<::xla::OpSharding>* retval_sharding,361 const std::vector<::xla::OpSharding>& retval_sharding,
1920 std::vector<xla::OpSharding>* retval_sharding, in AssignArgsAndRetvalsToCores() argument2117 retval_sharding->resize(retvals.size()); in AssignArgsAndRetvalsToCores()2209 (*retval_sharding)[i] = node_and_sharding->sharding; in AssignArgsAndRetvalsToCores()2216 absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) { in AssignArgsAndRetvalsToCores()2226 arg_fast_mem, retval_sharding, arg_names); in AssignArgsAndRetvalsToCores()2314 const std::vector<xla::OpSharding>& retval_sharding, in BuildCompileNode() argument2333 !absl::c_any_of(retval_sharding, [](const xla::OpSharding& s) { in BuildCompileNode()2409 const int num_retvals = retval_sharding.size(); in BuildCompileNode()2411 *proto.add_retvals()->mutable_sharding() = retval_sharding[i]; in BuildCompileNode()4261 std::vector<xla::OpSharding> retval_sharding; in RewriteTPUReplicateNode() local[all …]