1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/sharding_util.h"
16
17 #include "absl/strings/match.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/util/device_name_utils.h"
21
22 namespace tensorflow {
23 namespace {
24 const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE";
25 const char kShardingAttribute[] = "_XlaSharding";
26 } // namespace
27
28 namespace {
CreateOpMetadata(const std::string & op_type,const std::string & op_name)29 xla::OpMetadata CreateOpMetadata(const std::string& op_type,
30 const std::string& op_name) {
31 xla::OpMetadata metadata;
32 metadata.set_op_type(op_type);
33 metadata.set_op_name(op_name);
34 return metadata;
35 }
36
AssignOpMetadataToSharding(xla::OpSharding & sharding,const string & op_type,const string & op_name)37 void AssignOpMetadataToSharding(xla::OpSharding& sharding,
38 const string& op_type, const string& op_name) {
39 auto metadata = CreateOpMetadata(op_type, op_name);
40 if (sharding.type() == xla::OpSharding::TUPLE) {
41 for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
42 *sharding_element.add_metadata() = metadata;
43 }
44 } else {
45 *sharding.add_metadata() = metadata;
46 }
47 }
48
CoreOutOfRangeError(int core,int num_cores_per_replica)49 Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
50 return errors::InvalidArgument(
51 "Invalid replicated core id: ", core,
52 "; num_cores_per_replica=", num_cores_per_replica);
53 }
54 } // namespace
55
ParseShardingFromDevice(const string & device_name,int num_cores_per_replica,absl::optional<xla::OpSharding> explicit_sharding,absl::optional<xla::OpMetadata> metadata)56 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
57 const string& device_name, int num_cores_per_replica,
58 absl::optional<xla::OpSharding> explicit_sharding,
59 absl::optional<xla::OpMetadata> metadata) {
60 if (device_name.empty()) {
61 return explicit_sharding;
62 }
63 DeviceNameUtils::ParsedName parsed_device;
64 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
65 return errors::InvalidArgument("Malformed assigned device '", device_name,
66 "'");
67 }
68
69 if (explicit_sharding.has_value()) {
70 return explicit_sharding;
71 } else if (!parsed_device.has_type || !parsed_device.has_id ||
72 !absl::StrContains(parsed_device.type,
73 kDeviceSuffixReplicatedCore)) {
74 return absl::optional<xla::OpSharding>();
75 } else {
76 const int core = parsed_device.id;
77 if (core < 0 || core >= num_cores_per_replica) {
78 return CoreOutOfRangeError(core, num_cores_per_replica);
79 }
80 auto sharding = xla::sharding_builder::AssignDevice(core);
81 if (metadata.has_value()) {
82 *sharding.add_metadata() = metadata.value();
83 }
84 return absl::optional<xla::OpSharding>(sharding);
85 }
86 }
87
ParseShardingFromDevice(const NodeDef & node_def,int num_cores_per_replica,bool add_metadata)88 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
89 const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) {
90 const string& device_name = node_def.device();
91 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
92 GetShardingFromNodeDef(node_def, add_metadata));
93 return ParseShardingFromDevice(
94 device_name, num_cores_per_replica, sharding,
95 add_metadata ? absl::optional<xla::OpMetadata>(
96 CreateOpMetadata(node_def.op(), node_def.name()))
97 : absl::nullopt);
98 }
99
ParseShardingFromDevice(const Node & node,int num_cores_per_replica,bool add_metadata)100 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
101 const Node& node, int num_cores_per_replica, bool add_metadata) {
102 string device_name = node.assigned_device_name();
103 if (device_name.empty()) {
104 device_name = node.requested_device();
105 }
106 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
107 GetShardingFromNodeDef(node.def(), add_metadata));
108 return ParseShardingFromDevice(
109 device_name, num_cores_per_replica, sharding,
110 add_metadata ? absl::optional<xla::OpMetadata>(
111 CreateOpMetadata(node.type_string(), node.name()))
112 : absl::nullopt);
113 }
114
ParseShardingFromEdgeSource(const Edge & edge,int num_cores_per_replica,bool add_metadata)115 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
116 const Edge& edge, int num_cores_per_replica, bool add_metadata) {
117 if (edge.src() == nullptr) {
118 return tensorflow::errors::InvalidArgument(
119 "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
120 }
121 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
122 ParseShardingFromDevice(
123 *edge.src(), num_cores_per_replica, add_metadata));
124 if (sharding.has_value() &&
125 sharding.value().type() == xla::OpSharding::TUPLE) {
126 if (edge.src_output() < 0 ||
127 edge.src_output() >= sharding.value().tuple_shardings_size()) {
128 return tensorflow::errors::InvalidArgument(
129 "Tuple index out of bound: edge=", edge.DebugString(),
130 " sharding=", sharding->DebugString());
131 }
132 absl::optional<xla::OpSharding> subsharding =
133 sharding.value().tuple_shardings(edge.src_output());
134 return subsharding;
135 }
136 return sharding;
137 }
138
SetShardingDeviceAssignmentFromNode(const Node & src,Node * dst)139 void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
140 string device_name = src.assigned_device_name();
141 if (device_name.empty()) {
142 device_name = src.requested_device();
143 }
144 dst->set_assigned_device_name(device_name);
145 if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) {
146 dst->AddAttr(kShardingAttribute, *attr);
147 }
148 }
149
GetShardingFromNodeDef(const NodeDef & node_def,bool add_metadata)150 xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
151 const NodeDef& node_def, bool add_metadata) {
152 if (!HasNodeAttr(node_def, kShardingAttribute)) {
153 return absl::optional<xla::OpSharding>();
154 }
155 string value;
156 xla::OpSharding sharding;
157 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
158 if (!sharding.ParseFromString(value)) {
159 return xla::InvalidArgument(
160 "Experimental _XlaSharding attribute was not a valid encoded "
161 "xla::OpSharding proto.");
162 }
163 if (add_metadata) {
164 AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
165 }
166 return absl::optional<xla::OpSharding>(sharding);
167 }
168 } // namespace tensorflow
169