Home
last modified time | relevance | path

Searched refs:OpSharding (Results 1 – 25 of 43) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/xla/client/
Dsharding_builder.cc21 OpSharding Replicate() { in Replicate()
22 OpSharding result; in Replicate()
23 result.set_type(OpSharding::REPLICATED); in Replicate()
27 OpSharding Manual() { in Manual()
28 OpSharding result; in Manual()
29 result.set_type(OpSharding::MANUAL); in Manual()
33 OpSharding AssignDevice(int device) { in AssignDevice()
34 OpSharding result; in AssignDevice()
35 result.set_type(OpSharding::MAXIMAL); in AssignDevice()
41 OpSharding Tile(const Shape& tile_shape, in Tile()
[all …]
Dsharding_builder.h34 OpSharding Replicate();
37 OpSharding Manual();
40 OpSharding AssignDevice(int device);
48 OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment);
54 OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles);
57 OpSharding Tuple(const ShapeTree<OpSharding>& shardings);
/external/tensorflow/tensorflow/compiler/tf2xla/
Dsharding_util.cc37 void AssignOpMetadataToSharding(xla::OpSharding& sharding, in AssignOpMetadataToSharding()
40 if (sharding.type() == xla::OpSharding::TUPLE) { in AssignOpMetadataToSharding()
56 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice()
58 absl::optional<xla::OpSharding> explicit_sharding, in ParseShardingFromDevice()
74 return absl::optional<xla::OpSharding>(); in ParseShardingFromDevice()
84 return absl::optional<xla::OpSharding>(sharding); in ParseShardingFromDevice()
88 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice()
91 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice()
100 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice()
106 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice()
[all …]
Dsharding_util_test.cc28 [](absl::optional<xla::OpSharding> sharding) -> int64 { in TEST()
30 sharding.value().type() == xla::OpSharding::MAXIMAL) { in TEST()
60 : public ::testing::TestWithParam<xla::OpSharding> {};
77 auto check_metadata = [](const xla::OpSharding& sharding) { in TEST_P()
86 const std::function<xla::StatusOr<absl::optional<xla::OpSharding>>()>& in TEST_P()
93 if (sharding->type() == xla::OpSharding::TUPLE) { in TEST_P()
129 xla::OpSharding CreateTupleSharding() { in CreateTupleSharding()
130 xla::OpSharding sharding; in CreateTupleSharding()
131 sharding.set_type(xla::OpSharding::TUPLE); in CreateTupleSharding()
132 sharding.add_tuple_shardings()->set_type(xla::OpSharding::REPLICATED); in CreateTupleSharding()
[all …]
Dsharding_util.h36 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
38 absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt,
41 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
44 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
47 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
53 xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
Dxla_compiler.cc90 std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
93 [](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> { in ComputeArgAndRetvalShardings()
100 std::map<int, xla::OpSharding> arg_shardings; in ComputeArgAndRetvalShardings()
101 std::map<int, xla::OpSharding> retval_shardings; in ComputeArgAndRetvalShardings()
167 const std::map<int, xla::OpSharding>& arg_shardings, in BuildComputation()
168 const std::map<int, xla::OpSharding>& retval_shardings, in BuildComputation()
201 std::unordered_map<int, xla::OpSharding> retval_index_and_sharding; in BuildComputation()
225 absl::optional<xla::OpSharding> sharding = in BuildComputation()
226 it == retval_shardings.end() ? absl::optional<xla::OpSharding>() in BuildComputation()
327 ? absl::optional<xla::OpSharding>() in BuildComputation()
[all …]
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
Dxla_sharding.py47 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
57 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
69 proto=xla_data_pb2.OpSharding(
70 type=xla_data_pb2.OpSharding.MAXIMAL,
97 proto=xla_data_pb2.OpSharding(
98 type=xla_data_pb2.OpSharding.OTHER,
122 proto=xla_data_pb2.OpSharding(
123 type=xla_data_pb2.OpSharding.OTHER,
159 proto=xla_data_pb2.OpSharding(
160 type=xla_data_pb2.OpSharding.OTHER,
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dxla_sharding_util.cc146 const mlir::Location& location, const xla::OpSharding& input_sharding, in HandleTileShardedInputs()
203 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { in UnsupportedPartitionedShardingType()
204 return sharding != xla::OpSharding::REPLICATED && in UnsupportedPartitionedShardingType()
205 sharding != xla::OpSharding::OTHER; in UnsupportedPartitionedShardingType()
240 xla::OpSharding sharding; in ExtractInputsForLogicalDevices()
266 if (input_sharding_type == xla::OpSharding::REPLICATED) { in ExtractInputsForLogicalDevices()
272 assert(input_sharding_type == xla::OpSharding::OTHER); in ExtractInputsForLogicalDevices()
286 if (input_sharding_type == xla::OpSharding::OTHER) { in ExtractInputsForLogicalDevices()
300 } else if (input_sharding_type == xla::OpSharding::REPLICATED) { in ExtractInputsForLogicalDevices()
303 assert(input_sharding_type == xla::OpSharding::MAXIMAL); in ExtractInputsForLogicalDevices()
[all …]
Dxla_sharding_util.h49 mlir::SmallVector<xla::OpSharding, 4>* output_sharding_list);
55 const int core_id, llvm::ArrayRef<xla::OpSharding> output_sharding_config,
63 llvm::ArrayRef<xla::OpSharding> output_sharding_config,
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dspmd_manual_sharding_ops.cc42 xla::OpSharding sharding; in Compile()
51 if (sharding.type() == xla::OpSharding::OTHER) { in Compile()
73 xla::OpSharding manual; in Compile()
74 manual.set_type(xla::OpSharding::MANUAL); in Compile()
105 xla::OpSharding sharding; in Compile()
116 xla::OpSharding manual; in Compile()
117 manual.set_type(xla::OpSharding::MANUAL); in Compile()
/external/tensorflow/tensorflow/core/tpu/kernels/xla/
Dinfeed_op.cc52 absl::optional<xla::OpSharding> sharding) { in UpdateInfeedLayout()
53 if (sharding && sharding->type() == xla::OpSharding::OTHER) { in UpdateInfeedLayout()
131 absl::optional<xla::OpSharding> sharding; in Compile()
133 sharding = b->sharding()->type() == xla::OpSharding::TUPLE in Compile()
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding.cc474 const OpSharding& proto) { in FromProto()
477 if (proto.type() == OpSharding::TUPLE) { in FromProto()
482 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { in FromProto()
488 } else if (proto.type() == OpSharding::REPLICATED) { in FromProto()
490 } else if (proto.type() == OpSharding::MANUAL) { in FromProto()
496 TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL) in FromProto()
529 OpSharding HloSharding::ToProto() const { in ToProto()
530 OpSharding result; in ToProto()
537 result.set_type(OpSharding::TUPLE); in ToProto()
553 result.set_type(OpSharding::REPLICATED); in ToProto()
[all …]
Dhlo_sharding_test.cc94 OpSharding proto; in TEST_F()
95 proto.set_type(OpSharding::TUPLE); in TEST_F()
97 tiled->set_type(OpSharding::OTHER); in TEST_F()
105 replicated->set_type(OpSharding::REPLICATED); in TEST_F()
108 manual->set_type(OpSharding::MANUAL); in TEST_F()
169 OpSharding proto; in TEST_F()
170 proto.set_type(OpSharding::TUPLE); in TEST_F()
Dhlo_sharding.h107 static StatusOr<HloSharding> FromProto(const OpSharding& proto);
113 OpSharding ToProto() const;
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.cc605 const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) { in GetDimensionIndicesAndNumSplitsFromSharding()
792 const xla::OpSharding& sharding, int orig_arg_num, DataType dtype, in CreateOrGetSplitNodesForInputSharding()
1020 const xla::OpSharding& sharding, DataType dtype, in CreateConcatNodesForRetval()
1186 explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding) in NodeAndSharding()
1190 xla::OpSharding sharding;
1199 if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) { in ParseAndValidateSharding()
1209 if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) { in ParseAndValidateSharding()
1219 xla::OpSharding result_value = result->value().sharding; in ParseAndValidateSharding()
1224 xla::OpSharding sharding = node_and_sharding.sharding; in ParseAndValidateSharding()
1265 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseInputShardingFromAdjacentNode()
[all …]
Ddistributed_tpu_rewrite_pass.h312 std::vector<::xla::OpSharding>* arg_sharding,
314 std::vector<::xla::OpSharding>* retval_sharding,
358 const std::vector<::xla::OpSharding>& arg_sharding,
361 const std::vector<::xla::OpSharding>& retval_sharding,
435 const std::vector<::xla::OpSharding>& arg_shardings,
436 const std::vector<::xla::OpSharding>& retval_shardings,
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_support.cc227 if (sharding.type() == xla::OpSharding::MAXIMAL) { in AddVariableUpdatesToCores()
229 } else if (sharding.type() == xla::OpSharding::OTHER) { in AddVariableUpdatesToCores()
239 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED); in AddVariableUpdatesToCores()
248 if (sharding.type() == xla::OpSharding::MAXIMAL) { in AddVariableUpdatesToCores()
252 } else if (sharding.type() == xla::OpSharding::OTHER) { in AddVariableUpdatesToCores()
258 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED); in AddVariableUpdatesToCores()
285 if (retval.sharding().type() == xla::OpSharding::MAXIMAL) { in ComputeOutputShapesForEachCore()
288 } else if (retval.sharding().type() == xla::OpSharding::OTHER) { in ComputeOutputShapesForEachCore()
297 TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED) in ComputeOutputShapesForEachCore()
Dtpu_compile_op_common.cc100 if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) { in SetPerCoreArgShapes()
106 } else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) { in SetPerCoreArgShapes()
120 TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED) in SetPerCoreArgShapes()
155 if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) { in AssignReturnValueToCore()
160 } else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) { in AssignReturnValueToCore()
167 xla::OpSharding::REPLICATED) in AssignReturnValueToCore()
423 auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status { in AssignDevicesToArgsAndRetvals()
424 if (sharding.type() == xla::OpSharding::MAXIMAL) { in AssignDevicesToArgsAndRetvals()
429 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED || in AssignDevicesToArgsAndRetvals()
430 sharding.type() == xla::OpSharding::OTHER) in AssignDevicesToArgsAndRetvals()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dmlir_hlo_to_hlo.cc344 static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef( in CreateOpShardingFromStringRef()
346 xla::OpSharding sharding_proto; in CreateOpShardingFromStringRef()
354 static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute( in CreateOpShardingFromAttribute()
403 llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) { in AllOptionalShardingsAreSet()
405 [](const absl::optional<xla::OpSharding>& sharding) { in AllOptionalShardingsAreSet()
413 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings, in ExtractShardingsFromFunction()
414 llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) { in ExtractShardingsFromFunction()
416 absl::optional<xla::OpSharding>()); in ExtractShardingsFromFunction()
423 absl::optional<xla::OpSharding>()); in ExtractShardingsFromFunction()
486 llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
[all …]
/external/tensorflow/tensorflow/compiler/xla/python/
Dxla_compiler.cc563 py::enum_<OpSharding::Type>(m, "OpSharding_Type") in BuildXlaCompilerSubmodule()
564 .value("REPLICATED", OpSharding::REPLICATED) in BuildXlaCompilerSubmodule()
565 .value("MAXIMAL", OpSharding::MAXIMAL) in BuildXlaCompilerSubmodule()
566 .value("TUPLE", OpSharding::TUPLE) in BuildXlaCompilerSubmodule()
567 .value("OTHER", OpSharding::OTHER); in BuildXlaCompilerSubmodule()
Dtypes.h437 struct type_caster<xla::OpSharding> {
439 PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding"));
474 xla::OpSharding* sharding = value.add_tuple_shardings();
/external/tensorflow/tensorflow/compiler/xla/pjrt/
Dutils.cc32 const OpSharding& sharding) { in GetShardedShape()
33 if (sharding.type() == OpSharding::TUPLE) { in GetShardedShape()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/
Dresult-sharding.mlir9 // The following xla::OpSharding protos are used:
Dargument-sharding.mlir9 // The following xla::OpSharding protos are used:
/external/tensorflow/tensorflow/core/protobuf/tpu/
Dcompile_metadata.proto33 xla.OpSharding sharding = 4;
72 xla.OpSharding sharding = 1;

12