1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/utils.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/client/executable_build_options.h"
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29
30 namespace {
GetShardedShape(const Shape & shape,const OpSharding & sharding)31 StatusOr<Shape> GetShardedShape(const Shape& shape,
32 const OpSharding& sharding) {
33 if (sharding.type() == OpSharding::TUPLE) {
34 if (!shape.IsTuple()) {
35 return InvalidArgument(
36 "Got tuple OpSharding (%s) for non-tuple shape (%s)",
37 sharding.DebugString(), shape.ToString());
38 }
39 if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
40 return InvalidArgument(
41 "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
42 " (OpSharding: %s, shape: %s)",
43 sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
44 sharding.DebugString(), shape.ToString());
45 }
46 std::vector<Shape> sharded_subshapes;
47 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
48 TF_ASSIGN_OR_RETURN(
49 Shape sharded_subshape,
50 GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
51 sharded_subshapes.emplace_back(std::move(sharded_subshape));
52 }
53 return ShapeUtil::MakeTupleShape(sharded_subshapes);
54 }
55 TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
56 HloSharding::FromProto(sharding));
57 return hlo_sharding.TileShape(shape);
58 }
59
GetShardedShape(const HloInstructionProto & instr)60 StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
61 const Shape unsharded_shape(instr.shape());
62 Shape sharded_shape;
63 if (instr.has_sharding()) {
64 TF_ASSIGN_OR_RETURN(sharded_shape,
65 GetShardedShape(unsharded_shape, instr.sharding()));
66 } else {
67 sharded_shape = unsharded_shape;
68 }
69 LayoutUtil::ClearLayout(&sharded_shape);
70 return sharded_shape;
71 }
72
73 // Returns sharded (argument shapes, result shape) without layouts.
GetShardedProgramShapes(const XlaComputation & computation,const ProgramShape & program_shape)74 StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
75 const XlaComputation& computation, const ProgramShape& program_shape) {
76 std::vector<Shape> arg_shapes;
77 arg_shapes.resize(program_shape.parameters_size());
78 Shape result_shape;
79 for (const HloComputationProto& comp : computation.proto().computations()) {
80 if (comp.id() != computation.proto().entry_computation_id()) {
81 continue;
82 }
83 for (const HloInstructionProto& instr : comp.instructions()) {
84 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
85 if (instr.parameter_number() >= program_shape.parameters_size()) {
86 return InvalidArgument(
87 "Got invalid parameter number %d, expected %d parameters",
88 instr.parameter_number(), program_shape.parameters_size());
89 }
90 TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
91 GetShardedShape(instr));
92 }
93 if (instr.id() == comp.root_id()) {
94 if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
95 return InvalidArgument("Found multiple root instructions");
96 }
97 TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
98 }
99 }
100 }
101 for (int i = 0; i < arg_shapes.size(); ++i) {
102 if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
103 return InvalidArgument("Couldn't find parameter %d", i);
104 }
105 }
106 if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
107 return InvalidArgument("Couldn't find root instruction");
108 }
109 return std::make_pair(arg_shapes, result_shape);
110 }
111 } // namespace
112
ParseDeviceAssignmentCompileOptions(bool compile_portable_executable,ExecutableBuildOptions * build_options,std::function<StatusOr<DeviceAssignment> (int,int)> GetDefaultDeviceAssignmentFunction,int * num_replicas,int * num_partitions,std::shared_ptr<DeviceAssignment> * device_assignment)113 Status ParseDeviceAssignmentCompileOptions(
114 bool compile_portable_executable, ExecutableBuildOptions* build_options,
115 std::function<StatusOr<DeviceAssignment>(int, int)>
116 GetDefaultDeviceAssignmentFunction,
117 int* num_replicas, int* num_partitions,
118 std::shared_ptr<DeviceAssignment>* device_assignment) {
119 if (compile_portable_executable) {
120 if (build_options->has_device_assignment()) {
121 return InvalidArgument(
122 "CompileOptions requests portable executable but "
123 "ExecutableBuildOptions includes a device assignment");
124 }
125 *num_replicas = 1;
126 *num_partitions = 1;
127 } else {
128 if (!build_options->has_device_assignment()) {
129 VLOG(2) << "Compile using default device_assignment.";
130 TF_ASSIGN_OR_RETURN(
131 DeviceAssignment device_assignment,
132 GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
133 build_options->num_partitions()));
134 build_options->set_device_assignment(device_assignment);
135 }
136 VLOG(2) << "Compile device_assignment:\n"
137 << build_options->device_assignment().ToString();
138 *num_replicas = build_options->device_assignment().replica_count();
139 *num_partitions = build_options->device_assignment().computation_count();
140 *device_assignment =
141 std::make_shared<DeviceAssignment>(build_options->device_assignment());
142 }
143 return Status::OK();
144 }
145
DetermineArgumentLayoutsFromCompileOptions(const XlaComputation & computation,std::function<StatusOr<Shape> (Shape)> choose_compact_layout_for_shape_function,absl::optional<std::vector<Shape>> & argument_layouts,ExecutableBuildOptions * build_options,std::vector<const Shape * > * argument_layout_pointers)146 Status DetermineArgumentLayoutsFromCompileOptions(
147 const XlaComputation& computation,
148 std::function<StatusOr<Shape>(Shape)>
149 choose_compact_layout_for_shape_function,
150 absl::optional<std::vector<Shape>>& argument_layouts,
151 ExecutableBuildOptions* build_options,
152 std::vector<const Shape*>* argument_layout_pointers) {
153 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
154 computation.GetProgramShape());
155 if (!argument_layouts) {
156 argument_layouts.emplace(program_shape.parameters());
157 for (Shape& shape : *argument_layouts) {
158 LayoutUtil::ClearLayout(&shape);
159 }
160 } else if (argument_layouts->size() != program_shape.parameters_size()) {
161 return InvalidArgument(
162 "CompileOptions specify %d argument layouts, but computation has %d "
163 "arguments",
164 argument_layouts->size(), program_shape.parameters_size());
165 }
166 argument_layout_pointers->reserve(argument_layouts->size());
167
168 // Assign a default layout based on `sharded_shape` to any array subshapes in
169 // `dst_shape` that are missing layouts.
170 auto assign_layouts = [&choose_compact_layout_for_shape_function](
171 const Shape& sharded_shape, Shape* dst_shape) {
172 return ShapeUtil::ForEachMutableSubshapeWithStatus(
173 dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
174 if (subshape->IsArray() && !subshape->has_layout()) {
175 CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
176 const Shape& sharded_subshape =
177 ShapeUtil::GetSubshape(sharded_shape, idx);
178 LayoutUtil::SetToDefaultLayout(subshape);
179 TF_ASSIGN_OR_RETURN(
180 Shape layout,
181 choose_compact_layout_for_shape_function(sharded_subshape));
182 *subshape->mutable_layout() = layout.layout();
183 }
184 return Status::OK();
185 });
186 };
187 TF_ASSIGN_OR_RETURN(auto sharded_shapes,
188 GetShardedProgramShapes(computation, program_shape));
189
190 CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
191 for (int i = 0; i < argument_layouts->size(); ++i) {
192 Shape* layout = &(*argument_layouts)[i];
193 argument_layout_pointers->push_back(layout);
194 TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
195 }
196
197 Shape result_layout;
198 if (build_options->result_layout()) {
199 result_layout = *build_options->result_layout();
200 } else {
201 result_layout = program_shape.result();
202 LayoutUtil::ClearLayout(&result_layout);
203 }
204 TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
205 build_options->set_result_layout(result_layout);
206 return Status::OK();
207 }
208
ComputeParametersThatMustBeDonated(const HloModule & module,bool tuple_inputs)209 StatusOr<std::vector<int>> ComputeParametersThatMustBeDonated(
210 const HloModule& module, bool tuple_inputs) {
211 HloComputation* computation = module.entry_computation();
212 int number_of_parameters = [&]() -> int {
213 if (tuple_inputs) {
214 CHECK_EQ(computation->num_parameters(), 1);
215 const Shape& input_tuple_shape =
216 computation->parameter_instruction(0)->shape();
217 CHECK(input_tuple_shape.IsTuple());
218 return input_tuple_shape.tuple_shapes_size();
219 } else {
220 return computation->num_parameters();
221 }
222 }();
223 // If any buffer in a parameter is aliased we will donate the entire input
224 // parameter.
225 std::vector<int> parameters_to_donate;
226 parameters_to_donate.reserve(computation->num_parameters());
227 const HloInputOutputAliasConfig& config = module.input_output_alias_config();
228 TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
229 [&](const ShapeIndex& output_index,
230 const HloInputOutputAliasConfig::Alias& alias) {
231 if (tuple_inputs) {
232 if (alias.parameter_number != 0) {
233 return InvalidArgument(
234 "Unexpected parameter number %d in alias config with tupled "
235 "inputs",
236 alias.parameter_number);
237 }
238 const ShapeIndex& index = alias.parameter_index;
239 if (!index.empty()) {
240 int this_parameter = index.data()[0];
241 if (this_parameter >= number_of_parameters) {
242 return InvalidArgument(
243 "Unexpected parameter index %s in alias config with tupled "
244 "inputs and %d parameters",
245 index.ToString(), number_of_parameters);
246 }
247 parameters_to_donate.push_back(this_parameter);
248 }
249 } else {
250 int this_parameter = alias.parameter_number;
251 if (this_parameter >= number_of_parameters) {
252 return InvalidArgument(
253 "Unexpected parameter number %d in alias config without tupled "
254 "inputs and %d parameters",
255 this_parameter, number_of_parameters);
256 }
257 parameters_to_donate.push_back(this_parameter);
258 }
259 return Status::OK();
260 }));
261 absl::c_sort(parameters_to_donate);
262 return parameters_to_donate;
263 }
264
DefaultThreadPoolSize()265 int DefaultThreadPoolSize() {
266 // Google's CI system exposes an environment variable NPROC that describes
267 // a CPU reservation for tests.
268 // TODO(phawkins): expose a better thought-out set of knobs to control
269 // parallelism.
270 const char* nproc_str = std::getenv("NPROC");
271 int nproc = 0;
272 if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
273 return std::max(0, nproc);
274 }
275 return tensorflow::port::MaxParallelism();
276 }
277
HasMajorToMinorLayout(PrimitiveType type,absl::Span<int64_t const> dims,absl::Span<int64_t const> byte_strides)278 bool HasMajorToMinorLayout(PrimitiveType type, absl::Span<int64_t const> dims,
279 absl::Span<int64_t const> byte_strides) {
280 CHECK_EQ(dims.size(), byte_strides.size());
281 // If the array is size 0, the strides are irrelevant.
282 if (absl::c_find(dims, 0) != dims.end()) {
283 return true;
284 }
285 int64_t stride = primitive_util::ByteWidth(type);
286 for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) {
287 // If a dimension is of size 1, its stride is irrelevant.
288 if (dims[i] != 1) {
289 if (byte_strides[i] != stride) {
290 return false;
291 }
292 stride *= dims[i];
293 }
294 }
295 return true;
296 }
297
298 } // namespace xla
299