/external/tensorflow/tensorflow/compiler/xla/client/ |
D | sharding_builder.cc | 21 OpSharding Replicate() { in Replicate() 22 OpSharding result; in Replicate() 23 result.set_type(OpSharding::REPLICATED); in Replicate() 27 OpSharding AssignDevice(int device) { in AssignDevice() 28 OpSharding result; in AssignDevice() 29 result.set_type(OpSharding::MAXIMAL); in AssignDevice() 35 OpSharding Tile(const Shape& tile_shape, in Tile() 37 OpSharding result; in Tile() 38 result.set_type(OpSharding::OTHER); in Tile() 49 OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { in Tile1D() [all …]
|
D | sharding_builder.h | 34 OpSharding Replicate(); 37 OpSharding AssignDevice(int device); 45 OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); 51 OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles); 54 OpSharding Tuple(const ShapeTree<OpSharding>& shardings);
|
D | xla_builder.h | 159 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } in SetSharding() 195 const absl::optional<OpSharding>& sharding() const { return sharding_; } in sharding() 739 absl::optional<OpSharding> sharding_; 1055 absl::optional<OpSharding> sharding) in XlaScopedShardingAssignment() 1067 void SetSharding(const absl::optional<OpSharding>& sharding) { in SetSharding() 1076 absl::optional<OpSharding> prev_sharding_;
|
D | xla_builder.cc | 1352 sharding()->type() == OpSharding::OTHER) { in Infeed() 1358 if (sharding() && sharding()->type() == OpSharding::REPLICATED) { in Infeed() 1373 OpSharding sharding = sharding_builder::AssignDevice(0); in Infeed() 1385 if (sharding() && sharding()->type() == OpSharding::TUPLE) { in Infeed() 1388 OpSharding infeed_instruction_sharding = *sharding(); in Infeed() 1426 sharding()->type() == OpSharding::OTHER) { in InfeedWithToken() 1432 if (sharding() && sharding()->type() == OpSharding::REPLICATED) { in InfeedWithToken() 1484 XlaScopedShardingAssignment scoped_sharding(this, OpSharding()); in Outfeed()
|
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/ |
D | xla_sharding.py | 47 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) 59 proto=xla_data_pb2.OpSharding( 60 type=xla_data_pb2.OpSharding.MAXIMAL, 87 proto=xla_data_pb2.OpSharding( 88 type=xla_data_pb2.OpSharding.OTHER, 119 proto=xla_data_pb2.OpSharding( 120 type=xla_data_pb2.OpSharding.OTHER, 137 proto = xla_data_pb2.OpSharding( 138 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) 155 proto = xla_data_pb2.OpSharding() [all …]
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | sharding_util.cc | 29 xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef( in GetShardingFromNodeDef() 32 return absl::optional<xla::OpSharding>(); in GetShardingFromNodeDef() 35 xla::OpSharding sharding; in GetShardingFromNodeDef() 42 return absl::optional<xla::OpSharding>(sharding); in GetShardingFromNodeDef() 52 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice() 54 absl::optional<xla::OpSharding> explicit_sharding) { in ParseShardingFromDevice() 69 return absl::optional<xla::OpSharding>(); in ParseShardingFromDevice() 75 return absl::optional<xla::OpSharding>( in ParseShardingFromDevice() 80 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( in ParseShardingFromDevice() 83 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, in ParseShardingFromDevice() [all …]
|
D | sharding_util.h | 36 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( 38 absl::optional<xla::OpSharding> explicit_sharding = absl::nullopt); 40 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice( 43 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
|
D | sharding_util_test.cc | 26 [](absl::optional<xla::OpSharding> sharding) -> int64 { in TEST() 28 sharding.value().type() == xla::OpSharding::MAXIMAL) { in TEST()
|
D | xla_compiler.cc | 83 std::pair<std::map<int, xla::OpSharding>, std::map<int, xla::OpSharding>>> 86 [](const Node* n) -> xla::StatusOr<absl::optional<xla::OpSharding>> { in ComputeArgAndRetvalShardings() 92 std::map<int, xla::OpSharding> arg_shardings; in ComputeArgAndRetvalShardings() 93 std::map<int, xla::OpSharding> retval_shardings; in ComputeArgAndRetvalShardings() 159 const std::map<int, xla::OpSharding>& arg_shardings, in BuildComputation() 160 const std::map<int, xla::OpSharding>& retval_shardings, in BuildComputation() 197 std::unordered_map<int, xla::OpSharding> retval_index_and_sharding; in BuildComputation() 223 ? absl::optional<xla::OpSharding>() in BuildComputation() 300 builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>() in BuildComputation() 365 const xla::OpSharding& sub_op_sharding = iter->second; in BuildComputation() [all …]
|
D | xla_compilation_device.cc | 99 absl::optional<xla::OpSharding> op_sharding = in Compute()
|
D | xla_compiler.h | 461 const std::map<int, xla::OpSharding>& arg_shardings,
|
D | tf2xla_util.cc | 502 absl::optional<xla::OpSharding> sharding, in SetNodeShardingFromNeighbors() 506 if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) { in SetNodeShardingFromNeighbors()
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | hlo_sharding.cc | 361 const OpSharding& proto) { in FromProto() 362 if (proto.type() == OpSharding::TUPLE) { in FromProto() 365 for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { in FromProto() 371 } else if (proto.type() == OpSharding::REPLICATED) { in FromProto() 377 TF_RET_CHECK(proto.type() != OpSharding::MAXIMAL) in FromProto() 407 OpSharding HloSharding::ToProto() const { in ToProto() 408 OpSharding result; in ToProto() 414 result.set_type(OpSharding::TUPLE); in ToProto() 425 result.set_type(OpSharding::REPLICATED); in ToProto() 427 result.set_type(OpSharding::MAXIMAL); in ToProto() [all …]
|
D | hlo_sharding.h | 81 static StatusOr<HloSharding> FromProto(const OpSharding& proto); 87 OpSharding ToProto() const;
|
D | hlo_sharding_test.cc | 134 OpSharding proto; in TEST_F() 135 proto.set_type(OpSharding::TUPLE); in TEST_F()
|
D | hlo_parser.cc | 279 bool ParseSharding(OpSharding* sharding); 281 bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); 690 optional<OpSharding> sharding; in ParseInstructionRhs() 1871 bool HloParserImpl::ParseSharding(OpSharding* sharding) { in ParseSharding() 1895 sharding->set_type(OpSharding::TUPLE); in ParseSharding() 1933 bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, in ParseSingleSharding() 2007 sharding->set_type(OpSharding::REPLICATED); in ParseSingleSharding() 2013 sharding->set_type(OpSharding::MAXIMAL); in ParseSingleSharding() 2026 sharding->set_type(OpSharding::OTHER); in ParseSingleSharding() 2089 optional<OpSharding> entry_sharding; in ParseDomain() [all …]
|
D | hlo.proto | 171 xla.OpSharding sharding = 40; 199 xla.OpSharding domain_entry_sharding = 54; 200 xla.OpSharding domain_exit_sharding = 55;
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | tpu_rewrite_pass.cc | 231 xla::OpSharding sharding; in SetMetadataProtoFromLaunchFuncOp() 232 sharding.set_type(xla::OpSharding::MAXIMAL); in SetMetadataProtoFromLaunchFuncOp() 242 xla::OpSharding sharding; in SetMetadataProtoFromLaunchFuncOp() 243 sharding.set_type(xla::OpSharding::MAXIMAL); in SetMetadataProtoFromLaunchFuncOp()
|
/external/tensorflow/tensorflow/core/protobuf/tpu/ |
D | compile_metadata.proto | 33 xla.OpSharding sharding = 4; 69 xla.OpSharding sharding = 1;
|
/external/tensorflow/tensorflow/compiler/xla/python/ |
D | types.h | 481 struct type_caster<xla::OpSharding> { 483 PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding")); 518 xla::OpSharding* sharding = value.add_tuple_shardings();
|
D | xla.cc | 1056 py::enum_<OpSharding::Type>(m, "OpSharding_Type") in PYBIND11_MODULE() 1057 .value("REPLICATED", OpSharding::REPLICATED) in PYBIND11_MODULE() 1058 .value("MAXIMAL", OpSharding::MAXIMAL) in PYBIND11_MODULE() 1059 .value("TUPLE", OpSharding::TUPLE) in PYBIND11_MODULE() 1060 .value("OTHER", OpSharding::OTHER); in PYBIND11_MODULE()
|
D | xla_client.py | 1839 class OpSharding(object): class
|
D | xla_client_test.py | 2041 sharding = xla_client.OpSharding()
|
/external/tensorflow/tensorflow/compiler/xla/ |
D | xla_data.proto | 594 message OpSharding { message 622 repeated OpSharding tuple_shardings = 5;
|