/external/tensorflow/tensorflow/compiler/xla/client/ |
D | sharding_builder.cc | 21 OpSharding Replicate() { in Replicate() 22 OpSharding result; in Replicate() 23 result.set_type(OpSharding::REPLICATED); in Replicate() 27 OpSharding Manual() { in Manual() 28 OpSharding result; in Manual() 29 result.set_type(OpSharding::MANUAL); in Manual() 33 OpSharding AssignDevice(int device) { in AssignDevice() 34 OpSharding result; in AssignDevice() 35 result.set_type(OpSharding::MAXIMAL); in AssignDevice() 41 OpSharding Tile(const Shape& tile_shape, in Tile() [all …]
|
D | sharding_builder.h | 34 OpSharding Replicate(); 37 OpSharding Manual(); 40 OpSharding AssignDevice(int device); 48 OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); 54 OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles); 57 OpSharding Tuple(const ShapeTree<OpSharding>& shardings);
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | sharding_util.cc | 37 void AssignOpMetadataToSharding(xla::OpSharding& sharding, in AssignOpMetadataToSharding() 40 if (sharding.type() == xla::OpSharding::TUPLE) { in AssignOpMetadataToSharding() 56 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice() 58 absl::optional<xla::OpSharding> explicit_sharding, in ParseShardingFromDevice() 74 return absl::optional<xla::OpSharding>(); in ParseShardingFromDevice() 84 return absl::optional<xla::OpSharding>(sharding); in ParseShardingFromDevice() 88 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice() 91 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice() 100 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice() 106 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice() [all …]
|
D | sharding_util_test.cc | 28 [](absl::optional<xla::OpSharding> sharding) -> int64 { in TEST() 30 sharding.value().type() == xla::OpSharding::MAXIMAL) { in TEST() 60 : public ::testing::TestWithParam<xla::OpSharding> {}; 77 auto check_metadata = [](const xla::OpSharding& sharding) { in TEST_P() 86 const std::function<xla::StatusOr<absl::optional<xla::OpSharding>>()>& in TEST_P() 93 if (sharding->type() == xla::OpSharding::TUPLE) { in TEST_P() 129 xla::OpSharding CreateTupleSharding() { in CreateTupleSharding() 130 xla::OpSharding sharding; in CreateTupleSharding() 131 sharding.set_type(xla::OpSharding::TUPLE); in CreateTupleSharding() 132 sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED); in CreateTupleSharding() [all …]
|
D | sharding_util.h | 36 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( 38 absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt, 41 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( 44 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( 47 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource( 53 xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
|
D | xla_compiler.cc | 90 std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>> 93 [](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> { in ComputeArgAndRetvalShardings() 100 std::map<int, xla::OpSharding> arg_shardings; in ComputeArgAndRetvalShardings() 101 std::map<int, xla::OpSharding> retval_shardings; in ComputeArgAndRetvalShardings() 167 const std::map<int, xla::OpSharding>& arg_shardings, in BuildComputation() 168 const std::map<int, xla::OpSharding>& retval_shardings, in BuildComputation() 201 std::unordered_map<int, xla::OpSharding> retval_index_and_sharding; in BuildComputation() 225 absl::optional<xla::OpSharding> sharding = in BuildComputation() 226 it == retval_shardings.end() ? absl::optional<xla::OpSharding>() in BuildComputation() 327 ? absl::optional<xla::OpSharding>() in BuildComputation() [all …]
|
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/ |
D | xla_sharding.py | 47 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) 57 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL)) 69 proto=xla_data_pb2.OpSharding( 70 type=xla_data_pb2.OpSharding.MAXIMAL, 97 proto=xla_data_pb2.OpSharding( 98 type=xla_data_pb2.OpSharding.OTHER, 122 proto=xla_data_pb2.OpSharding( 123 type=xla_data_pb2.OpSharding.OTHER, 159 proto=xla_data_pb2.OpSharding( 160 type=xla_data_pb2.OpSharding.OTHER, [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/ |
D | xla_sharding_util.cc | 146 const mlir::Location& location, const xla::OpSharding& input_sharding, in HandleTileShardedInputs() 203 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { in UnsupportedPartitionedShardingType() 204 return sharding != xla::OpSharding::REPLICATED && in UnsupportedPartitionedShardingType() 205 sharding != xla::OpSharding::OTHER; in UnsupportedPartitionedShardingType() 240 xla::OpSharding sharding; in ExtractInputsForLogicalDevices() 266 if (input_sharding_type == xla::OpSharding::REPLICATED) { in ExtractInputsForLogicalDevices() 272 assert(input_sharding_type == xla::OpSharding::OTHER); in ExtractInputsForLogicalDevices() 286 if (input_sharding_type == xla::OpSharding::OTHER) { in ExtractInputsForLogicalDevices() 300 } else if (input_sharding_type == xla::OpSharding::REPLICATED) { in ExtractInputsForLogicalDevices() 303 assert(input_sharding_type == xla::OpSharding::MAXIMAL); in ExtractInputsForLogicalDevices() [all …]
|
D | xla_sharding_util.h | 49 mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list); 55 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config, 63 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
|
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | spmd_manual_sharding_ops.cc | 42 xla::OpSharding sharding; in Compile() 51 if (sharding.type() == xla::OpSharding::OTHER) { in Compile() 73 xla::OpSharding manual; in Compile() 74 manual.set_type(xla::OpSharding::MANUAL); in Compile() 105 xla::OpSharding sharding; in Compile() 116 xla::OpSharding manual; in Compile() 117 manual.set_type(xla::OpSharding::MANUAL); in Compile()
|
/external/tensorflow/tensorflow/core/tpu/kernels/xla/ |
D | infeed_op.cc | 52 absl::optional<xla::OpSharding> sharding) { in UpdateInfeedLayout() 53 if (sharding && sharding->type() == xla::OpSharding::OTHER) { in UpdateInfeedLayout() 131 absl::optional<xla::OpSharding> sharding; in Compile() 133 sharding = b->sharding()->type() == xla::OpSharding::TUPLE in Compile()
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | hlo_sharding.cc | 474 const OpSharding& proto) { in FromProto() 477 if (proto.type() == OpSharding::TUPLE) { in FromProto() 482 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { in FromProto() 488 } else if (proto.type() == OpSharding::REPLICATED) { in FromProto() 490 } else if (proto.type() == OpSharding::MANUAL) { in FromProto() 496 TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL) in FromProto() 529 OpSharding HloSharding::ToProto() const { in ToProto() 530 OpSharding result; in ToProto() 537 result.set_type(OpSharding::TUPLE); in ToProto() 553 result.set_type(OpSharding::REPLICATED); in ToProto() [all …]
|
D | hlo_sharding_test.cc | 94 OpSharding proto; in TEST_F() 95 proto.set_type(OpSharding::TUPLE); in TEST_F() 97 tiled->set_type(OpSharding::OTHER); in TEST_F() 105 replicated->set_type(OpSharding::REPLICATED); in TEST_F() 108 manual->set_type(OpSharding::MANUAL); in TEST_F() 169 OpSharding proto; in TEST_F() 170 proto.set_type(OpSharding::TUPLE); in TEST_F()
|
D | hlo_sharding.h | 107 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 113 OpSharding ToProto() const;
|
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/ |
D | distributed_tpu_rewrite_pass.cc | 605 const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) { in GetDimensionIndicesAndNumSplitsFromSharding() 792 const xla::OpSharding& sharding, int orig_arg_num, DataType dtype, in CreateOrGetSplitNodesForInputSharding() 1020 const xla::OpSharding& sharding, DataType dtype, in CreateConcatNodesForRetval() 1186 explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding) in NodeAndSharding() 1190 xla::OpSharding sharding; 1199 if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) { in ParseAndValidateSharding() 1209 if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) { in ParseAndValidateSharding() 1219 xla::OpSharding result_value = result->value().sharding; in ParseAndValidateSharding() 1224 xla::OpSharding sharding = node_and_sharding.sharding; in ParseAndValidateSharding() 1265 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseInputShardingFromAdjacentNode() [all …]
|
D | distributed_tpu_rewrite_pass.h | 312 std::vector<::xla::OpSharding>* arg_sharding, 314 std::vector<::xla::OpSharding>* retval_sharding, 358 const std::vector<::xla::OpSharding>& arg_sharding, 361 const std::vector<::xla::OpSharding>& retval_sharding, 435 const std::vector<::xla::OpSharding>& arg_shardings, 436 const std::vector<::xla::OpSharding>& retval_shardings,
|
/external/tensorflow/tensorflow/core/tpu/kernels/ |
D | tpu_compile_op_support.cc | 227 if (sharding.type() == xla::OpSharding::MAXIMAL) { in AddVariableUpdatesToCores() 229 } else if (sharding.type() == xla::OpSharding::OTHER) { in AddVariableUpdatesToCores() 239 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED); in AddVariableUpdatesToCores() 248 if (sharding.type() == xla::OpSharding::MAXIMAL) { in AddVariableUpdatesToCores() 252 } else if (sharding.type() == xla::OpSharding::OTHER) { in AddVariableUpdatesToCores() 258 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED); in AddVariableUpdatesToCores() 285 if (retval.sharding().type() == xla::OpSharding::MAXIMAL) { in ComputeOutputShapesForEachCore() 288 } else if (retval.sharding().type() == xla::OpSharding::OTHER) { in ComputeOutputShapesForEachCore() 297 TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED) in ComputeOutputShapesForEachCore()
|
D | tpu_compile_op_common.cc | 100 if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) { in SetPerCoreArgShapes() 106 } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) { in SetPerCoreArgShapes() 120 TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED) in SetPerCoreArgShapes() 155 if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) { in AssignReturnValueToCore() 160 } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) { in AssignReturnValueToCore() 167 xla::OpSharding::REPLICATED) in AssignReturnValueToCore() 423 auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status { in AssignDevicesToArgsAndRetvals() 424 if (sharding.type() == xla::OpSharding::MAXIMAL) { in AssignDevicesToArgsAndRetvals() 429 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED || in AssignDevicesToArgsAndRetvals() 430 sharding.type() == xla::OpSharding::OTHER) in AssignDevicesToArgsAndRetvals()
|
/external/tensorflow/tensorflow/compiler/mlir/xla/ |
D | mlir_hlo_to_hlo.cc | 344 static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef( in CreateOpShardingFromStringRef() 346 xla::OpSharding sharding_proto; in CreateOpShardingFromStringRef() 354 static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute( in CreateOpShardingFromAttribute() 403 llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) { in AllOptionalShardingsAreSet() 405 [](const absl::optional<xla::OpSharding>& sharding) { in AllOptionalShardingsAreSet() 413 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings, in ExtractShardingsFromFunction() 414 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) { in ExtractShardingsFromFunction() 416 absl::optional<xla::OpSharding>()); in ExtractShardingsFromFunction() 423 absl::optional<xla::OpSharding>()); in ExtractShardingsFromFunction() 486 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings, [all …]
|
/external/tensorflow/tensorflow/compiler/xla/python/ |
D | xla_compiler.cc | 563 py::enum_<OpSharding::Type>(m, "OpSharding_Type") in BuildXlaCompilerSubmodule() 564 .value("REPLICATED", OpSharding::REPLICATED) in BuildXlaCompilerSubmodule() 565 .value("MAXIMAL", OpSharding::MAXIMAL) in BuildXlaCompilerSubmodule() 566 .value("TUPLE", OpSharding::TUPLE) in BuildXlaCompilerSubmodule() 567 .value("OTHER", OpSharding::OTHER); in BuildXlaCompilerSubmodule()
|
D | types.h | 437 struct type_caster<xla::OpSharding> { 439 PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding")); 474 xla::OpSharding* sharding = value.add_tuple_shardings();
|
/external/tensorflow/tensorflow/compiler/xla/pjrt/ |
D | utils.cc | 32 const OpSharding& sharding) { in GetShardedShape() 33 if (sharding.type() == OpSharding::TUPLE) { in GetShardedShape()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/ |
D | result-sharding.mlir | 9 // The following xla::OpSharding protos are used:
|
D | argument-sharding.mlir | 9 // The following xla::OpSharding protos are used:
|
/external/tensorflow/tensorflow/core/protobuf/tpu/ |
D | compile_metadata.proto | 33 xla.OpSharding sharding = 4; 72 xla.OpSharding sharding = 1;
|