1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/utils.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/client/executable_build_options.h"
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29
30 namespace {
GetShardedShape(const Shape & shape,const OpSharding & sharding)31 StatusOr<Shape> GetShardedShape(const Shape& shape,
32 const OpSharding& sharding) {
33 if (sharding.type() == OpSharding::TUPLE) {
34 if (!shape.IsTuple()) {
35 return InvalidArgument(
36 "Got tuple OpSharding (%s) for non-tuple shape (%s)",
37 sharding.DebugString(), shape.ToString());
38 }
39 if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
40 return InvalidArgument(
41 "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
42 " (OpSharding: %s, shape: %s)",
43 sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
44 sharding.DebugString(), shape.ToString());
45 }
46 std::vector<Shape> sharded_subshapes;
47 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
48 TF_ASSIGN_OR_RETURN(
49 Shape sharded_subshape,
50 GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
51 sharded_subshapes.emplace_back(std::move(sharded_subshape));
52 }
53 return ShapeUtil::MakeTupleShape(sharded_subshapes);
54 }
55 TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
56 HloSharding::FromProto(sharding));
57 return hlo_sharding.TileShape(shape);
58 }
59
GetShardedShape(const HloInstructionProto & instr)60 StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
61 const Shape unsharded_shape(instr.shape());
62 Shape sharded_shape;
63 if (instr.has_sharding()) {
64 TF_ASSIGN_OR_RETURN(sharded_shape,
65 GetShardedShape(unsharded_shape, instr.sharding()));
66 } else {
67 sharded_shape = unsharded_shape;
68 }
69 LayoutUtil::ClearLayout(&sharded_shape);
70 return sharded_shape;
71 }
72
73 // Returns sharded (argument shapes, result shape) without layouts.
GetShardedProgramShapes(const XlaComputation & computation,const ProgramShape & program_shape)74 StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
75 const XlaComputation& computation, const ProgramShape& program_shape) {
76 std::vector<Shape> arg_shapes;
77 arg_shapes.resize(program_shape.parameters_size());
78 Shape result_shape;
79 for (const HloComputationProto& comp : computation.proto().computations()) {
80 if (comp.id() != computation.proto().entry_computation_id()) {
81 continue;
82 }
83 for (const HloInstructionProto& instr : comp.instructions()) {
84 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
85 if (instr.parameter_number() >= program_shape.parameters_size()) {
86 return InvalidArgument(
87 "Got invalid parameter number %d, expected %d parameters",
88 instr.parameter_number(), program_shape.parameters_size());
89 }
90 TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
91 GetShardedShape(instr));
92 }
93 if (instr.id() == comp.root_id()) {
94 if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
95 return InvalidArgument("Found multiple root instructions");
96 }
97 TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
98 }
99 }
100 }
101 for (int i = 0; i < arg_shapes.size(); ++i) {
102 if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
103 return InvalidArgument("Couldn't find parameter %d", i);
104 }
105 }
106 if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
107 return InvalidArgument("Couldn't find root instruction");
108 }
109 return std::make_pair(arg_shapes, result_shape);
110 }
111 } // namespace
112
ParseDeviceAssignmentCompileOptions(bool compile_portable_executable,ExecutableBuildOptions * build_options,std::function<StatusOr<DeviceAssignment> (int,int)> GetDefaultDeviceAssignmentFunction,int * num_replicas,int * num_partitions,std::shared_ptr<DeviceAssignment> * device_assignment)113 Status ParseDeviceAssignmentCompileOptions(
114 bool compile_portable_executable, ExecutableBuildOptions* build_options,
115 std::function<StatusOr<DeviceAssignment>(int, int)>
116 GetDefaultDeviceAssignmentFunction,
117 int* num_replicas, int* num_partitions,
118 std::shared_ptr<DeviceAssignment>* device_assignment) {
119 if (compile_portable_executable) {
120 if (build_options->has_device_assignment()) {
121 return InvalidArgument(
122 "CompileOptions requests portable executable but "
123 "ExecutableBuildOptions includes a device assignment");
124 }
125 *num_replicas = 1;
126 *num_partitions = 1;
127 } else {
128 if (!build_options->has_device_assignment()) {
129 VLOG(2) << "Compile using default device_assignment.";
130 TF_ASSIGN_OR_RETURN(
131 DeviceAssignment device_assignment,
132 GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
133 build_options->num_partitions()));
134 build_options->set_device_assignment(device_assignment);
135 }
136 VLOG(2) << "Compile device_assignment:\n"
137 << build_options->device_assignment().ToString();
138 *num_replicas = build_options->device_assignment().replica_count();
139 *num_partitions = build_options->device_assignment().computation_count();
140 *device_assignment =
141 std::make_shared<DeviceAssignment>(build_options->device_assignment());
142 }
143 return Status::OK();
144 }
145
DetermineArgumentLayoutsFromCompileOptions(const XlaComputation & computation,std::function<StatusOr<Shape> (Shape)> choose_compact_layout_for_shape_function,absl::optional<std::vector<Shape>> & argument_layouts,ExecutableBuildOptions * build_options,std::vector<const Shape * > * argument_layout_pointers)146 Status DetermineArgumentLayoutsFromCompileOptions(
147 const XlaComputation& computation,
148 std::function<StatusOr<Shape>(Shape)>
149 choose_compact_layout_for_shape_function,
150 absl::optional<std::vector<Shape>>& argument_layouts,
151 ExecutableBuildOptions* build_options,
152 std::vector<const Shape*>* argument_layout_pointers) {
153 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
154 computation.GetProgramShape());
155 if (!argument_layouts) {
156 argument_layouts.emplace(program_shape.parameters());
157 for (Shape& shape : *argument_layouts) {
158 LayoutUtil::ClearLayout(&shape);
159 }
160 } else if (argument_layouts->size() != program_shape.parameters_size()) {
161 return InvalidArgument(
162 "CompileOptions specify %d argument layouts, but computation has %d "
163 "arguments",
164 argument_layouts->size(), program_shape.parameters_size());
165 }
166 argument_layout_pointers->reserve(argument_layouts->size());
167
168 // Assign a default layout based on `sharded_shape` to any array subshapes in
169 // `dst_shape` that are missing layouts.
170 auto assign_layouts = [&choose_compact_layout_for_shape_function](
171 const Shape& sharded_shape, Shape* dst_shape) {
172 return ShapeUtil::ForEachMutableSubshapeWithStatus(
173 dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
174 if (subshape->IsArray() && !subshape->has_layout()) {
175 CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
176 const Shape& sharded_subshape =
177 ShapeUtil::GetSubshape(sharded_shape, idx);
178 LayoutUtil::SetToDefaultLayout(subshape);
179 TF_ASSIGN_OR_RETURN(
180 Shape layout,
181 choose_compact_layout_for_shape_function(sharded_subshape));
182 *subshape->mutable_layout() = layout.layout();
183 }
184 return Status::OK();
185 });
186 };
187 TF_ASSIGN_OR_RETURN(auto sharded_shapes,
188 GetShardedProgramShapes(computation, program_shape));
189
190 CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
191 for (int i = 0; i < argument_layouts->size(); ++i) {
192 Shape* layout = &(*argument_layouts)[i];
193 argument_layout_pointers->push_back(layout);
194 TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
195 }
196
197 Shape result_layout;
198 if (build_options->result_layout()) {
199 result_layout = *build_options->result_layout();
200 } else {
201 result_layout = program_shape.result();
202 LayoutUtil::ClearLayout(&result_layout);
203 }
204 TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
205 build_options->set_result_layout(result_layout);
206 return Status::OK();
207 }
208
GetParametersThatMustBeDonated(const HloModule & module,bool tuple_inputs)209 StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
210 const HloModule& module, bool tuple_inputs) {
211 HloComputation* computation = module.entry_computation();
212 int number_of_parameters = [&]() -> int {
213 if (tuple_inputs) {
214 CHECK_EQ(computation->num_parameters(), 1);
215 const Shape& input_tuple_shape =
216 computation->parameter_instruction(0)->shape();
217 CHECK(input_tuple_shape.IsTuple());
218 return input_tuple_shape.tuple_shapes_size();
219 } else {
220 return computation->num_parameters();
221 }
222 }();
223 // If any buffer in a parameter is aliased we will donate the entire input
224 // parameter.
225 absl::flat_hash_set<int> parameters_to_donate;
226 const HloInputOutputAliasConfig& config = module.input_output_alias_config();
227 TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
228 [&](const ShapeIndex& output_index,
229 const HloInputOutputAliasConfig::Alias& alias) {
230 if (tuple_inputs) {
231 if (alias.parameter_number != 0) {
232 return InvalidArgument(
233 "Unexpected parameter number %d in alias config with tupled "
234 "inputs",
235 alias.parameter_number);
236 }
237 const ShapeIndex& index = alias.parameter_index;
238 if (!index.empty()) {
239 int this_parameter = index.data()[0];
240 if (this_parameter >= number_of_parameters) {
241 return InvalidArgument(
242 "Unexpected parameter index %s in alias config with tupled "
243 "inputs and %d parameters",
244 index.ToString(), number_of_parameters);
245 }
246 parameters_to_donate.insert(this_parameter);
247 }
248 } else {
249 int this_parameter = alias.parameter_number;
250 if (this_parameter >= number_of_parameters) {
251 return InvalidArgument(
252 "Unexpected parameter number %d in alias config without tupled "
253 "inputs and %d parameters",
254 this_parameter, number_of_parameters);
255 }
256 parameters_to_donate.insert(this_parameter);
257 }
258 return Status::OK();
259 }));
260 return parameters_to_donate;
261 }
262
263 } // namespace xla
264