Searched refs:arg_shardings (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | xla_compiler.cc | 100 std::map<int, xla::OpSharding> arg_shardings; in ComputeArgAndRetvalShardings() local 109 arg_shardings[index] = std::move(*sharding); in ComputeArgAndRetvalShardings() 119 return std::make_pair(std::move(arg_shardings), std::move(retval_shardings)); in ComputeArgAndRetvalShardings() 167 const std::map<int, xla::OpSharding>& arg_shardings, in BuildComputation() argument 287 auto it = arg_shardings.find(resource->arg_num()); in BuildComputation() 326 auto sharding = it == arg_shardings.end() in BuildComputation() 340 if (it != arg_shardings.end()) { in BuildComputation() 956 const std::map<int, xla::OpSharding>& arg_shardings, in BuildArguments() argument 1026 auto arg_sharding = arg_shardings.find((*input_to_args)[i]); in BuildArguments() 1028 if (arg_sharding != arg_shardings.end()) { in BuildArguments() [all …]
|
D | xla_compiler.h | 307 const std::map<int, xla::OpSharding>& arg_shardings,
|
/external/tensorflow/tensorflow/compiler/mlir/xla/ |
D | mlir_hlo_to_hlo.cc | 413 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings, in ExtractShardingsFromFunction() argument 415 arg_shardings->resize(function.getNumArguments(), in ExtractShardingsFromFunction() 420 (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); in ExtractShardingsFromFunction() 486 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings, 518 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings, 1404 llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings; in RunOnFunction() local 1431 ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings); in RunOnFunction() 1435 arg_shardings, ret_shardings, &computation))) { in RunOnFunction() 1481 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings, in SetEntryTupleShardings() argument 1483 if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) { in SetEntryTupleShardings() [all …]
|
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/ |
D | distributed_tpu_rewrite_pass.cc | 2793 const std::vector<xla::OpSharding>& arg_shardings, in CreateOrGetPerHostVariableCopy() argument 2819 if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) { in CreateOrGetPerHostVariableCopy() 2892 const std::vector<xla::OpSharding>& arg_shardings, in BuildExecuteNodes() argument 2992 arg_shardings.size() - params_info.NumGuaranteedConstants(); in BuildExecuteNodes() 2997 const auto& sharding = arg_shardings[i]; in BuildExecuteNodes() 3150 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) { in BuildExecuteNodes() 3166 const xla::OpSharding& sharding = arg_shardings[orig_arg_num]; in BuildExecuteNodes() 3183 arg_shardings[orig_arg_num].type() == in BuildExecuteNodes() 3245 arg_shardings, replicate_node, enable_xla_param_broadcast_, in BuildExecuteNodes() 3249 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) { in BuildExecuteNodes() [all …]
|
D | distributed_tpu_rewrite_pass.h | 435 const std::vector<::xla::OpSharding>& arg_shardings,
|