Home
last modified time | relevance | path

Searched refs:arg_sharding (Results 1 – 8 of 8) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_sharding_identification_pass.cc168 auto arg_sharding = GetXlaShardingFromArg(arg); in IdentifyXlaShardingForComputationInputs() local
169 if (arg_sharding) { in IdentifyXlaShardingForComputationInputs()
170 sharding_for_args.push_back(arg_sharding.getValue()); in IdentifyXlaShardingForComputationInputs()
172 builder->getStringAttr(arg_sharding.getValue())); in IdentifyXlaShardingForComputationInputs()
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.h312 std::vector<::xla::OpSharding>* arg_sharding,
358 const std::vector<::xla::OpSharding>& arg_sharding,
Ddistributed_tpu_rewrite_pass.cc1919 std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem, in AssignArgsAndRetvalsToCores() argument
1970 arg_sharding->resize(args.size()); in AssignArgsAndRetvalsToCores()
2102 (*arg_sharding)[i] = node_and_sharding->sharding; in AssignArgsAndRetvalsToCores()
2212 (absl::c_any_of(*arg_sharding, in AssignArgsAndRetvalsToCores()
2225 /*allow_parameter_replication_for_spmd=*/false, arg_sharding, in AssignArgsAndRetvalsToCores()
2311 const std::vector<xla::OpSharding>& arg_sharding, in BuildCompileNode() argument
2329 !absl::c_any_of(arg_sharding, in BuildCompileNode()
2358 TF_RET_CHECK(num_args == arg_sharding.size()) in BuildCompileNode()
2359 << num_args << " != " << arg_sharding.size(); in BuildCompileNode()
2406 *arg->mutable_sharding() = arg_sharding[i]; in BuildCompileNode()
[all …]
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_common.cc242 TF_ASSIGN_OR_RETURN(auto arg_sharding, in BuildComputationArgumentDescriptions()
245 arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape)); in BuildComputationArgumentDescriptions()
263 TF_ASSIGN_OR_RETURN(auto arg_sharding, in GetShardingInfo()
270 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false, in GetShardingInfo()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dcompile_mlir_util.cc113 absl::optional<xla::HloSharding> arg_sharding; in GetXlaInputShapes() local
119 TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding)); in GetXlaInputShapes()
121 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false, in GetXlaInputShapes()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_compiler.h231 const absl::optional<xla::HloSharding>& arg_sharding,
Dxla_compiler.cc835 const absl::optional<xla::HloSharding>& arg_sharding, in XLAShapeForArgument() argument
853 arg_sharding, /*use_fast_memory=*/false, in XLAShapeForArgument()
882 arg_sharding, arg.fast_mem, options_.shape_representation_fn, in XLAShapeForArgument()
1026 auto arg_sharding = arg_shardings.find((*input_to_args)[i]); in BuildArguments() local
1028 if (arg_sharding != arg_shardings.end()) { in BuildArguments()
1030 xla::HloSharding::FromProto(arg_sharding->second)); in BuildArguments()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dmlir_hlo_to_hlo.cc1486 for (auto arg_sharding : llvm::enumerate(arg_shardings)) { in SetEntryTupleShardings() local
1487 auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value()); in SetEntryTupleShardings()
1494 shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]); in SetEntryTupleShardings()
1498 *sharding.add_tuple_shardings() = *arg_sharding.value(); in SetEntryTupleShardings()