Searched refs:arg_sharding (Results 1 – 8 of 8) sorted by relevance
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | tpu_sharding_identification_pass.cc | 168 auto arg_sharding = GetXlaShardingFromArg(arg); in IdentifyXlaShardingForComputationInputs() local 169 if (arg_sharding) { in IdentifyXlaShardingForComputationInputs() 170 sharding_for_args.push_back(arg_sharding.getValue()); in IdentifyXlaShardingForComputationInputs() 172 builder->getStringAttr(arg_sharding.getValue())); in IdentifyXlaShardingForComputationInputs()
|
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/ |
D | distributed_tpu_rewrite_pass.h | 312 std::vector<::xla::OpSharding>* arg_sharding, 358 const std::vector<::xla::OpSharding>& arg_sharding,
|
D | distributed_tpu_rewrite_pass.cc | 1919 std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem, in AssignArgsAndRetvalsToCores() argument 1970 arg_sharding->resize(args.size()); in AssignArgsAndRetvalsToCores() 2102 (*arg_sharding)[i] = node_and_sharding->sharding; in AssignArgsAndRetvalsToCores() 2212 (absl::c_any_of(*arg_sharding, in AssignArgsAndRetvalsToCores() 2225 /*allow_parameter_replication_for_spmd=*/false, arg_sharding, in AssignArgsAndRetvalsToCores() 2311 const std::vector<xla::OpSharding>& arg_sharding, in BuildCompileNode() argument 2329 !absl::c_any_of(arg_sharding, in BuildCompileNode() 2358 TF_RET_CHECK(num_args == arg_sharding.size()) in BuildCompileNode() 2359 << num_args << " != " << arg_sharding.size(); in BuildCompileNode() 2406 *arg->mutable_sharding() = arg_sharding[i]; in BuildCompileNode() [all …]
|
/external/tensorflow/tensorflow/core/tpu/kernels/ |
D | tpu_compile_op_common.cc | 242 TF_ASSIGN_OR_RETURN(auto arg_sharding, in BuildComputationArgumentDescriptions() 245 arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape)); in BuildComputationArgumentDescriptions() 263 TF_ASSIGN_OR_RETURN(auto arg_sharding, in GetShardingInfo() 270 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false, in GetShardingInfo()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/ |
D | compile_mlir_util.cc | 113 absl::optional<xla::HloSharding> arg_sharding; in GetXlaInputShapes() local 119 TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding)); in GetXlaInputShapes() 121 RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false, in GetXlaInputShapes()
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | xla_compiler.h | 231 const absl::optional<xla::HloSharding>& arg_sharding,
|
D | xla_compiler.cc | 835 const absl::optional<xla::HloSharding>& arg_sharding, in XLAShapeForArgument() argument 853 arg_sharding, /*use_fast_memory=*/false, in XLAShapeForArgument() 882 arg_sharding, arg.fast_mem, options_.shape_representation_fn, in XLAShapeForArgument() 1026 auto arg_sharding = arg_shardings.find((*input_to_args)[i]); in BuildArguments() local 1028 if (arg_sharding != arg_shardings.end()) { in BuildArguments() 1030 xla::HloSharding::FromProto(arg_sharding->second)); in BuildArguments()
|
/external/tensorflow/tensorflow/compiler/mlir/xla/ |
D | mlir_hlo_to_hlo.cc | 1486 for (auto arg_sharding : llvm::enumerate(arg_shardings)) { in SetEntryTupleShardings() local 1487 auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value()); in SetEntryTupleShardings() 1494 shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]); in SetEntryTupleShardings() 1498 *sharding.add_tuple_shardings() = *arg_sharding.value(); in SetEntryTupleShardings()
|