Home
last modified time | relevance | path

Searched refs:OpSharding (Results 1 – 24 of 24) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/client/
Dsharding_builder.cc21 OpSharding Replicate() { in Replicate()
22 OpSharding result; in Replicate()
23 result.set_type(OpSharding::REPLICATED); in Replicate()
27 OpSharding AssignDevice(int device) { in AssignDevice()
28 OpSharding result; in AssignDevice()
29 result.set_type(OpSharding::MAXIMAL); in AssignDevice()
35 OpSharding Tile(const Shape& tile_shape, in Tile()
37 OpSharding result; in Tile()
38 result.set_type(OpSharding::OTHER); in Tile()
49 OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { in Tile1D()
[all …]
Dsharding_builder.h34 OpSharding Replicate();
37 OpSharding AssignDevice(int device);
45 OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment);
51 OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles);
54 OpSharding Tuple(const ShapeTree<OpSharding>& shardings);
Dxla_builder.h159 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } in SetSharding()
195 const absl::optional<OpSharding>& sharding() const { return sharding_; } in sharding()
739 absl::optional<OpSharding> sharding_;
1055 absl::optional<OpSharding> sharding) in XlaScopedShardingAssignment()
1067 void SetSharding(const absl::optional<OpSharding>& sharding) { in SetSharding()
1076 absl::optional<OpSharding> prev_sharding_;
Dxla_builder.cc1352 sharding()->type() == OpSharding::OTHER) { in Infeed()
1358 if (sharding() && sharding()->type() == OpSharding::REPLICATED) { in Infeed()
1373 OpSharding sharding = sharding_builder::AssignDevice(0); in Infeed()
1385 if (sharding() && sharding()->type() == OpSharding::TUPLE) { in Infeed()
1388 OpSharding infeed_instruction_sharding = *sharding(); in Infeed()
1426 sharding()->type() == OpSharding::OTHER) { in InfeedWithToken()
1432 if (sharding() && sharding()->type() == OpSharding::REPLICATED) { in InfeedWithToken()
1484 XlaScopedShardingAssignment scoped_sharding(this, OpSharding()); in Outfeed()
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
Dxla_sharding.py47 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
59 proto=xla_data_pb2.OpSharding(
60 type=xla_data_pb2.OpSharding.MAXIMAL,
87 proto=xla_data_pb2.OpSharding(
88 type=xla_data_pb2.OpSharding.OTHER,
119 proto=xla_data_pb2.OpSharding(
120 type=xla_data_pb2.OpSharding.OTHER,
137 proto = xla_data_pb2.OpSharding(
138 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
155 proto = xla_data_pb2.OpSharding()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/
Dsharding_util.cc29 xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef( in GetShardingFromNodeDef()
32 return absl::optional<xla::OpSharding>(); in GetShardingFromNodeDef()
35 xla::OpSharding sharding; in GetShardingFromNodeDef()
42 return absl::optional<xla::OpSharding>(sharding); in GetShardingFromNodeDef()
52 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice()
54 absl::optional<xla::OpSharding> explicit_sharding) { in ParseShardingFromDevice()
69 return absl::optional<xla::OpSharding>(); in ParseShardingFromDevice()
75 return absl::optional<xla::OpSharding>( in ParseShardingFromDevice()
80 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice()
83 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice()
[all …]
Dsharding_util.h36 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
38 absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt);
40 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
43 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
Dsharding_util_test.cc26 [](absl::optional<xla::OpSharding> sharding) -> int64 { in TEST()
28 sharding.value().type() == xla::OpSharding::MAXIMAL) { in TEST()
Dxla_compiler.cc83 std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>>
86 [](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> { in ComputeArgAndRetvalShardings()
92 std::map<int, xla::OpSharding> arg_shardings; in ComputeArgAndRetvalShardings()
93 std::map<int, xla::OpSharding> retval_shardings; in ComputeArgAndRetvalShardings()
159 const std::map<int, xla::OpSharding>& arg_shardings, in BuildComputation()
160 const std::map<int, xla::OpSharding>& retval_shardings, in BuildComputation()
197 std::unordered_map<int, xla::OpSharding> retval_index_and_sharding; in BuildComputation()
223 ? absl::optional<xla::OpSharding>() in BuildComputation()
300 builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>() in BuildComputation()
365 const xla::OpSharding& sub_op_sharding = iter->second; in BuildComputation()
[all …]
Dxla_compilation_device.cc99 absl::optional<xla::OpSharding> op_sharding = in Compute()
Dxla_compiler.h461 const std::map<int, xla::OpSharding>& arg_shardings,
Dtf2xla_util.cc502 absl::optional<xla::OpSharding> sharding, in SetNodeShardingFromNeighbors()
506 if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) { in SetNodeShardingFromNeighbors()
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding.cc361 const OpSharding& proto) { in FromProto()
362 if (proto.type() == OpSharding::TUPLE) { in FromProto()
365 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { in FromProto()
371 } else if (proto.type() == OpSharding::REPLICATED) { in FromProto()
377 TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL) in FromProto()
407 OpSharding HloSharding::ToProto() const { in ToProto()
408 OpSharding result; in ToProto()
414 result.set_type(OpSharding::TUPLE); in ToProto()
425 result.set_type(OpSharding::REPLICATED); in ToProto()
427 result.set_type(OpSharding::MAXIMAL); in ToProto()
[all …]
Dhlo_sharding.h81 static StatusOr<HloSharding> FromProto(const OpSharding& proto);
87 OpSharding ToProto() const;
Dhlo_sharding_test.cc134 OpSharding proto; in TEST_F()
135 proto.set_type(OpSharding::TUPLE); in TEST_F()
Dhlo_parser.cc279 bool ParseSharding(OpSharding* sharding);
281 bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
690 optional<OpSharding> sharding; in ParseInstructionRhs()
1871 bool HloParserImpl::ParseSharding(OpSharding* sharding) { in ParseSharding()
1895 sharding->set_type(OpSharding::TUPLE); in ParseSharding()
1933 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, in ParseSingleSharding()
2007 sharding->set_type(OpSharding::REPLICATED); in ParseSingleSharding()
2013 sharding->set_type(OpSharding::MAXIMAL); in ParseSingleSharding()
2026 sharding->set_type(OpSharding::OTHER); in ParseSingleSharding()
2089 optional<OpSharding> entry_sharding; in ParseDomain()
[all …]
Dhlo.proto171 xla.OpSharding sharding = 40;
199 xla.OpSharding domain_entry_sharding = 54;
200 xla.OpSharding domain_exit_sharding = 55;
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_rewrite_pass.cc231 xla::OpSharding sharding; in SetMetadataProtoFromLaunchFuncOp()
232 sharding.set_type(xla::OpSharding::MAXIMAL); in SetMetadataProtoFromLaunchFuncOp()
242 xla::OpSharding sharding; in SetMetadataProtoFromLaunchFuncOp()
243 sharding.set_type(xla::OpSharding::MAXIMAL); in SetMetadataProtoFromLaunchFuncOp()
/external/tensorflow/tensorflow/core/protobuf/tpu/
Dcompile_metadata.proto33 xla.OpSharding sharding = 4;
69 xla.OpSharding sharding = 1;
/external/tensorflow/tensorflow/compiler/xla/python/
Dtypes.h481 struct type_caster<xla::OpSharding> {
483 PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding"));
518 xla::OpSharding* sharding = value.add_tuple_shardings();
Dxla.cc1056 py::enum_<OpSharding::Type>(m, "OpSharding_Type") in PYBIND11_MODULE()
1057 .value("REPLICATED", OpSharding::REPLICATED) in PYBIND11_MODULE()
1058 .value("MAXIMAL", OpSharding::MAXIMAL) in PYBIND11_MODULE()
1059 .value("TUPLE", OpSharding::TUPLE) in PYBIND11_MODULE()
1060 .value("OTHER", OpSharding::OTHER); in PYBIND11_MODULE()
Dxla_client.py1839 class OpSharding(object): class
Dxla_client_test.py2041 sharding = xla_client.OpSharding()
/external/tensorflow/tensorflow/compiler/xla/
Dxla_data.proto594 message OpSharding { message
622 repeated OpSharding tuple_shardings = 5;