Home
last modified time | relevance | path

Searched refs:arg_shardings (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_compiler.cc100 std::map<int, xla::OpSharding> arg_shardings; in ComputeArgAndRetvalShardings() local
109 arg_shardings[index] = std::move(*sharding); in ComputeArgAndRetvalShardings()
119 return std::make_pair(std::move(arg_shardings), std::move(retval_shardings)); in ComputeArgAndRetvalShardings()
167 const std::map<int, xla::OpSharding>& arg_shardings, in BuildComputation() argument
287 auto it = arg_shardings.find(resource->arg_num()); in BuildComputation()
326 auto sharding = it == arg_shardings.end() in BuildComputation()
340 if (it != arg_shardings.end()) { in BuildComputation()
956 const std::map<int, xla::OpSharding>& arg_shardings, in BuildArguments() argument
1026 auto arg_sharding = arg_shardings.find((*input_to_args)[i]); in BuildArguments()
1028 if (arg_sharding != arg_shardings.end()) { in BuildArguments()
[all …]
Dxla_compiler.h307 const std::map<int, xla::OpSharding>& arg_shardings,
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dmlir_hlo_to_hlo.cc413 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings, in ExtractShardingsFromFunction() argument
415 arg_shardings->resize(function.getNumArguments(), in ExtractShardingsFromFunction()
420 (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); in ExtractShardingsFromFunction()
486 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
518 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
1404 llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings; in RunOnFunction() local
1431 ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings); in RunOnFunction()
1435 arg_shardings, ret_shardings, &computation))) { in RunOnFunction()
1481 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings, in SetEntryTupleShardings() argument
1483 if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) { in SetEntryTupleShardings()
[all …]
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.cc2793 const std::vector<xla::OpSharding>& arg_shardings, in CreateOrGetPerHostVariableCopy() argument
2819 if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) { in CreateOrGetPerHostVariableCopy()
2892 const std::vector<xla::OpSharding>& arg_shardings, in BuildExecuteNodes() argument
2992 arg_shardings.size() - params_info.NumGuaranteedConstants(); in BuildExecuteNodes()
2997 const auto& sharding = arg_shardings[i]; in BuildExecuteNodes()
3150 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) { in BuildExecuteNodes()
3166 const xla::OpSharding& sharding = arg_shardings[orig_arg_num]; in BuildExecuteNodes()
3183 arg_shardings[orig_arg_num].type() == in BuildExecuteNodes()
3245 arg_shardings, replicate_node, enable_xla_param_broadcast_, in BuildExecuteNodes()
3249 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) { in BuildExecuteNodes()
[all …]
Ddistributed_tpu_rewrite_pass.h435 const std::vector<::xla::OpSharding>& arg_shardings,