1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17
18 #include <cmath>
19
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/compiler/xla/window_util.h"
24 #include "tensorflow/core/lib/core/bits.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27
28 namespace xla {
29
30 constexpr char HloCostAnalysis::kFlopsKey[];
31 constexpr char HloCostAnalysis::kTranscendentalsKey[];
32 constexpr char HloCostAnalysis::kBytesAccessedKey[];
33 constexpr char HloCostAnalysis::kOptimalSecondsKey[];
34
HloCostAnalysis(const ShapeSizeFunction & shape_size)35 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
36 : HloCostAnalysis(shape_size, {}) {}
37
HloCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)38 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
39 const Properties& per_second_rates)
40 : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
41
Preprocess(const HloInstruction * hlo)42 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
43 // Set current instruction cost values to reasonable default values. Each
44 // handler can overwrite these values. In Postprocess, these values are
45 // accumulated and written to the per-instruction maps.
46 current_properties_.clear();
47 current_should_compute_bottleneck_time_ = true;
48
49 // The default number of bytes accessed for an instruction is the sum of the
50 // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
51 // handle opaque types.
52 float bytes_accessed = GetShapeSize(hlo->shape());
53 for (const HloInstruction* operand : hlo->operands()) {
54 bytes_accessed += GetShapeSize(operand->shape());
55 }
56 current_properties_[kBytesAccessedKey] = bytes_accessed;
57
58 return Status::OK();
59 }
60
Postprocess(const HloInstruction * hlo)61 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
62 if (current_should_compute_bottleneck_time_) {
63 // Compute the time as the time of the bottleneck, i.e. the slowest property
64 // given the per-second rate of each property.
65 float optimal_seconds = 0.0f;
66 for (const auto& property : current_properties_) {
67 if (property.first != kOptimalSecondsKey) {
68 optimal_seconds = std::max(
69 optimal_seconds,
70 property.second /
71 GetProperty(property.first, per_second_rates_, INFINITY));
72 }
73 }
74 current_properties_[kOptimalSecondsKey] = optimal_seconds;
75 }
76
77 TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
78 for (const auto& property : current_properties_) {
79 properties_sum_[property.first] += property.second;
80 }
81
82 return Status::OK();
83 }
84
HandleElementwiseOp(const HloInstruction * hlo_instruction)85 Status HloCostAnalysis::HandleElementwiseOp(
86 const HloInstruction* hlo_instruction) {
87 const auto& shape = hlo_instruction->shape();
88 // For element-wise operations, the number of computations is the same as the
89 // number of elements in the output shape.
90 auto computation_count = ShapeUtil::ElementsIn(shape);
91 auto opcode = hlo_instruction->opcode();
92 // We treat transcendental operations separately since one transcendental
93 // operation can correspond to several floating point ops.
94 if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
95 opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt ||
96 opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
97 opcode == HloOpcode::kSin || opcode == HloOpcode::kCos) {
98 current_properties_[kTranscendentalsKey] = computation_count;
99 } else {
100 // Note: transcendental operations are considered a separate category from
101 // FLOPs.
102 current_properties_[kFlopsKey] = computation_count;
103 }
104 return Status::OK();
105 }
106
GetProperty(const string & key,const Properties & properties,const float default_value)107 /*static*/ float HloCostAnalysis::GetProperty(const string& key,
108 const Properties& properties,
109 const float default_value) {
110 auto key_value = properties.find(key);
111 return key_value == properties.end() ? default_value : key_value->second;
112 }
113
GetPropertyForHlo(const HloInstruction & hlo,const string & key,const HloToProperties & hlo_to_properties)114 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
115 const HloInstruction& hlo, const string& key,
116 const HloToProperties& hlo_to_properties) {
117 auto it = hlo_to_properties.find(&hlo);
118 if (it == hlo_to_properties.end()) {
119 return 0.0f;
120 } else {
121 return GetProperty(key, it->second);
122 }
123 }
124
GetShapeSize(const Shape & shape) const125 int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
126 if (!LayoutUtil::HasLayout(shape)) {
127 return 0;
128 }
129 return shape_size_(shape);
130 }
131
HandleElementwiseUnary(const HloInstruction * hlo)132 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
133 return HandleElementwiseOp(hlo);
134 }
135
HandleElementwiseBinary(const HloInstruction * hlo)136 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
137 return HandleElementwiseOp(hlo);
138 }
139
HandleCompare(const HloInstruction * compare)140 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
141 return HandleElementwiseOp(compare);
142 }
143
HandleClamp(const HloInstruction * clamp)144 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
145 return HandleElementwiseOp(clamp);
146 }
147
HandleReducePrecision(const HloInstruction * hlo)148 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
149 return HandleElementwiseOp(hlo);
150 }
151
HandleParameter(const HloInstruction *)152 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
153 current_should_compute_bottleneck_time_ = false;
154 current_properties_[kBytesAccessedKey] = 0;
155 current_properties_[kOptimalSecondsKey] = 0;
156 return Status::OK();
157 }
158
HandleConstant(const HloInstruction *)159 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
160 current_should_compute_bottleneck_time_ = false;
161 current_properties_[kBytesAccessedKey] = 0;
162 current_properties_[kOptimalSecondsKey] = 0;
163 return Status::OK();
164 }
165
HandleIota(const HloInstruction *)166 Status HloCostAnalysis::HandleIota(const HloInstruction*) {
167 return Status::OK();
168 }
169
HandleGetTupleElement(const HloInstruction *)170 Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) {
171 // GetTupleElement forwards a pointer and does not touch each element in the
172 // output.
173 current_should_compute_bottleneck_time_ = false;
174 current_properties_[kBytesAccessedKey] = 0;
175 current_properties_[kOptimalSecondsKey] = 0;
176 return Status::OK();
177 }
178
HandleSelect(const HloInstruction * hlo)179 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
180 return HandleElementwiseOp(hlo);
181 }
182
HandleTupleSelect(const HloInstruction *)183 Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) {
184 return Status::OK();
185 }
186
HandleReverse(const HloInstruction *)187 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
188 return Status::OK();
189 }
190
HandleSlice(const HloInstruction * slice)191 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
192 current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
193 return Status::OK();
194 }
195
HandleDynamicSlice(const HloInstruction * dynamic_slice)196 Status HloCostAnalysis::HandleDynamicSlice(
197 const HloInstruction* dynamic_slice) {
198 current_properties_[kBytesAccessedKey] =
199 GetShapeSize(dynamic_slice->shape()) * 2;
200 return Status::OK();
201 }
202
HandleDynamicUpdateSlice(const HloInstruction * dynamic_update_slice)203 Status HloCostAnalysis::HandleDynamicUpdateSlice(
204 const HloInstruction* dynamic_update_slice) {
205 current_properties_[kBytesAccessedKey] =
206 GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2;
207 return Status::OK();
208 }
209
HandleTuple(const HloInstruction * tuple)210 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
211 // The tuple instruction only gathers pointers from inputs (it doesn't iterate
212 // through them). The memory touched is then only the size of the output
213 // index table of the tuple.
214
215 current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
216 return Status::OK();
217 }
218
HandleConcatenate(const HloInstruction *)219 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
220 return Status::OK();
221 }
222
HandleConvert(const HloInstruction * convert)223 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
224 return HandleElementwiseOp(convert);
225 }
226
HandleCopy(const HloInstruction *)227 Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
228 return Status::OK();
229 }
230
HandleDomain(const HloInstruction * domain)231 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
232 // Domain does not have any computation or data transfer.
233 current_should_compute_bottleneck_time_ = false;
234 current_properties_[kBytesAccessedKey] = 0;
235 current_properties_[kOptimalSecondsKey] = 0;
236 return Status::OK();
237 }
238
HandleDot(const HloInstruction * dot)239 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
240 const Shape& lhs_shape = dot->operand(0)->shape();
241 const Shape& dot_shape = dot->shape();
242 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
243 // Count of elements along the reduction dimension (last dimension for the
244 // rhs).
245 int64 reduction_width = 1;
246 for (auto dim : dnums.lhs_contracting_dimensions()) {
247 reduction_width *= lhs_shape.dimensions(dim);
248 }
249 // Each output elment requires reduction_width FMA operations.
250 current_properties_[kFlopsKey] =
251 kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width;
252 return Status::OK();
253 }
254
HandleInfeed(const HloInstruction *)255 Status HloCostAnalysis::HandleInfeed(const HloInstruction*) {
256 return Status::OK();
257 }
258
HandleOutfeed(const HloInstruction *)259 Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
260 return Status::OK();
261 }
262
HandleMap(const HloInstruction * map)263 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
264 // Compute properties of the mapped function.
265 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
266 ProcessNestedSubcomputation(map->to_apply()));
267
268 // Compute the cost of all elements for this Map operation.
269 const int64 element_count = ShapeUtil::ElementsIn(map->shape());
270 for (const auto& property : sub_properties) {
271 if (property.first != kBytesAccessedKey) {
272 current_properties_[property.first] = property.second * element_count;
273 }
274 }
275 return Status::OK();
276 }
277
HandleReduce(const HloInstruction * reduce)278 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
279 HloComputation* function = reduce->to_apply();
280 // Compute the cost of the user function.
281 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
282 ProcessNestedSubcomputation(function));
283
284 // Compute the cost of all elements for this Reduce operation.
285 // This counts the number of times the reduction function is applied, so it
286 // does not need to be multiplied by the number of input tensors - that's
287 // already "priced in" by the sub-computation doing more work.
288 auto arg = reduce->operand(0);
289 auto output_shape = reduce->shape().IsArray()
290 ? reduce->shape()
291 : reduce->shape().tuple_shapes(0);
292 int64 reduction_count =
293 ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
294 for (const auto& property : sub_properties) {
295 if (property.first != kBytesAccessedKey) {
296 current_properties_[property.first] = property.second * reduction_count;
297 }
298 }
299 return Status::OK();
300 }
301
HandleReduceWindow(const HloInstruction * reduce_window)302 Status HloCostAnalysis::HandleReduceWindow(
303 const HloInstruction* reduce_window) {
304 const Window& window = reduce_window->window();
305 auto function = reduce_window->to_apply();
306 // Compute the properties of the reduction function.
307 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
308 ProcessNestedSubcomputation(function));
309
310 // Compute the cost of all elements for this ReduceWindow operation. For each
311 // output element there are window_size - 1 reductions to perform.
312 int64 window_element_count = 1;
313 for (const auto& dimension : window.dimensions()) {
314 window_element_count *= dimension.size();
315 }
316 const int64 output_element_count =
317 ShapeUtil::ElementsIn(reduce_window->shape());
318 const int64 reduction_count =
319 (window_element_count - 1) * output_element_count;
320 for (const auto& property : sub_properties) {
321 if (property.first != kBytesAccessedKey) {
322 current_properties_[property.first] = property.second * reduction_count;
323 }
324 }
325 return Status::OK();
326 }
327
HandleSelectAndScatter(const HloInstruction * instruction)328 Status HloCostAnalysis::HandleSelectAndScatter(
329 const HloInstruction* instruction) {
330 // Compute the properties of the select and scatter function.
331 // Compute the properties of the reduction function.
332 TF_ASSIGN_OR_RETURN(const Properties select_properties,
333 ProcessNestedSubcomputation(instruction->select()));
334 TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
335 ProcessNestedSubcomputation(instruction->scatter()));
336
337 // Compute the cost of all elements for this operation. For each scatter
338 // source element there are window_size - 1 select computations to perform and
339 // 1 scatter computation to perform.
340 const auto source = instruction->operand(1);
341 const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
342 int64 window_element_count = 1;
343 for (const auto& dimension : instruction->window().dimensions()) {
344 window_element_count *= dimension.size();
345 }
346 const int64 select_count = source_element_count * (window_element_count - 1);
347 for (const auto& property : select_properties) {
348 if (property.first != kBytesAccessedKey) {
349 current_properties_[property.first] += property.second * select_count;
350 }
351 }
352 for (const auto& property : scatter_properties) {
353 if (property.first != kBytesAccessedKey) {
354 current_properties_[property.first] +=
355 property.second * source_element_count;
356 }
357 }
358 return Status::OK();
359 }
360
HandleBitcast(const HloInstruction *)361 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
362 // A bitcast does no computation and touches no memory.
363 current_properties_[kBytesAccessedKey] = 0;
364 current_properties_[kOptimalSecondsKey] = 0;
365 return Status::OK();
366 }
367
HandleBroadcast(const HloInstruction *)368 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
369 return Status::OK();
370 }
371
HandlePad(const HloInstruction *)372 Status HloCostAnalysis::HandlePad(const HloInstruction*) {
373 return Status::OK();
374 }
375
HandleSend(const HloInstruction *)376 Status HloCostAnalysis::HandleSend(const HloInstruction*) {
377 return Status::OK();
378 }
379
HandleSendDone(const HloInstruction *)380 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
381 return Status::OK();
382 }
383
HandleRecv(const HloInstruction *)384 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
385 return Status::OK();
386 }
387
HandleRecvDone(const HloInstruction *)388 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
389 return Status::OK();
390 }
391
HandleReshape(const HloInstruction *)392 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
393 return Status::OK();
394 }
395
HandleBatchNormTraining(const HloInstruction *)396 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
397 // TODO(b/62294698): Implement cost analysis for batch-norm-training.
398 return Status::OK();
399 }
400
HandleBatchNormInference(const HloInstruction *)401 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
402 // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
403 return Status::OK();
404 }
405
HandleBatchNormGrad(const HloInstruction *)406 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
407 // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
408 return Status::OK();
409 }
410
HandleTranspose(const HloInstruction *)411 Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
412 return Status::OK();
413 }
414
HandleAfterAll(const HloInstruction *)415 Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) {
416 // This instruction is used to enforce ordering at compile time. No code is
417 // emitted.
418 current_should_compute_bottleneck_time_ = false;
419 current_properties_[kBytesAccessedKey] = 0;
420 current_properties_[kOptimalSecondsKey] = 0;
421 return Status::OK();
422 }
423
HandleAddDependency(const HloInstruction * add_dependency)424 Status HloCostAnalysis::HandleAddDependency(
425 const HloInstruction* add_dependency) {
426 // This instruction is used to enforce ordering at compile time. No code is
427 // emitted.
428 current_should_compute_bottleneck_time_ = false;
429 current_properties_[kBytesAccessedKey] = 0;
430 current_properties_[kOptimalSecondsKey] = 0;
431 return Status::OK();
432 }
433
HandleConvolution(const HloInstruction * convolution)434 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
435 auto lhs = convolution->operand(0);
436 auto rhs = convolution->operand(1);
437 Window window = convolution->window();
438 const auto& result_shape = convolution->shape();
439 const Shape& lhs_shape = lhs->shape();
440 const Shape& rhs_shape = rhs->shape();
441
442 const auto& dnums = convolution->convolution_dimension_numbers();
443
444 const int64 input_batch_dim = dnums.input_batch_dimension();
445 const int64 input_feature_dim = dnums.input_feature_dimension();
446 const int64 output_feature_dim = dnums.output_feature_dimension();
447 const int64 input_feature =
448 ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
449 const int64 output_feature =
450 ShapeUtil::GetDimension(result_shape, output_feature_dim);
451 const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
452
453 DimensionVector kernel_limits;
454 DimensionVector output_limits;
455 DimensionVector input_limits;
456 if (window.dimensions().empty()) {
457 window = window_util::MakeWindow({1});
458 kernel_limits.push_back(1);
459 output_limits.push_back(1);
460 input_limits.push_back(1);
461 } else {
462 for (int64 spatial_dimension = 0;
463 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
464 // Spatial dimension number for kernel (rhs).
465 const int64 kernel_spatial_dim =
466 dnums.kernel_spatial_dimensions(spatial_dimension);
467 const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
468 kernel_limits.push_back(kernel_limit);
469
470 // Spatial dimension number for output.
471 const int64 output_spatial_dim =
472 dnums.output_spatial_dimensions(spatial_dimension);
473 const int64 output_limit = result_shape.dimensions(output_spatial_dim);
474 output_limits.push_back(output_limit);
475
476 // Spatial dimension number for input (lhs).
477 const int64 input_spatial_dim =
478 dnums.input_spatial_dimensions(spatial_dimension);
479 const int64 input_limit = lhs_shape.dimensions(input_spatial_dim);
480 input_limits.push_back(input_limit);
481 }
482 }
483
484 DimensionVector valid_position_counts;
485
486 // Loop over each spatial dimension.
487 for (int64 spatial_dimension = 0;
488 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
489 int64 valid_position_count = 0;
490 // Loop over each point in the kernel.
491 for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
492 ++kernel_idx) {
493 // Loop over each point in the output.
494 for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension];
495 ++output_idx) {
496 // Calculate lhs (input) index without taking base dilation into
497 // account.
498 const auto& window_dim = window.dimensions(spatial_dimension);
499 const int64 undilated_index = output_idx * window_dim.stride() -
500 window_dim.padding_low() +
501 kernel_idx * window_dim.window_dilation();
502
503 // Calculate the actual lhs (input) index after dilation. Avoid the
504 // division as an optimization.
505 const int64 lhs_spatial_index =
506 window_dim.base_dilation() > 1
507 ? undilated_index / window_dim.base_dilation()
508 : undilated_index;
509
510 // Skip if the lhs (input) index is to be dilated.
511 if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
512 continue;
513 }
514
515 // Skip if input index is not in bound.
516 if (lhs_spatial_index < 0 ||
517 lhs_spatial_index >= input_limits[spatial_dimension]) {
518 continue;
519 }
520
521 valid_position_count += 1;
522 }
523 }
524 valid_position_counts.push_back(valid_position_count);
525 }
526
527 const int64 fma_count = (input_feature / convolution->feature_group_count()) *
528 output_feature *
529 (batch / convolution->batch_group_count()) *
530 Product(valid_position_counts);
531 current_properties_[kFlopsKey] = fma_count * kFmaFlops;
532 return Status::OK();
533 }
534
HandleFft(const HloInstruction * fft)535 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
536 auto real_shape =
537 fft->operand(0)->shape().IsTuple()
538 ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
539 : fft->operand(0)->shape();
540 constexpr int kFmaPerComplexMul = 4;
541 int64 log_factors = 1;
542 for (int64 dim : fft->fft_length()) {
543 log_factors *= tensorflow::Log2Floor(dim);
544 }
545 current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
546 ShapeUtil::ElementsIn(real_shape);
547 return Status::OK();
548 }
549
HandleTriangularSolve(const HloInstruction * hlo)550 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
551 float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
552 bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
553 current_properties_[kBytesAccessedKey] = bytes_accessed;
554
555 const Shape& a_shape = hlo->operand(0)->shape();
556 const Shape& b_shape = hlo->operand(1)->shape();
557 // Estimate as batch * mn^2 / 2 flops.
558 int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
559 elems *= ShapeUtil::ElementsIn(b_shape);
560 current_properties_[kFlopsKey] = kFmaFlops * elems;
561 return Status::OK();
562 }
563
HandleCholesky(const HloInstruction * hlo)564 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
565 float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
566 current_properties_[kBytesAccessedKey] = bytes_accessed;
567
568 const Shape& a_shape = hlo->operand(0)->shape();
569 // Estimate as batch * n^3 / 3 flops.
570 int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
571 elems *= ShapeUtil::ElementsIn(a_shape);
572 current_properties_[kFlopsKey] = elems / 3;
573 return Status::OK();
574 }
575
HandleAllReduce(const HloInstruction * crs)576 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
577 // We assume 2 replicas, so that each output element is the sum of two input
578 // elements.
579 //
580 // TODO(b/33004697): Compute correct cost here, taking the actual number of
581 // replicas into account.
582 double flops = 0.0;
583 ShapeUtil::ForEachSubshape(crs->shape(),
584 [&](const Shape& subshape, const ShapeIndex&) {
585 if (subshape.IsArray()) {
586 flops += ShapeUtil::ElementsIn(subshape);
587 }
588 });
589 current_properties_[kFlopsKey] = flops;
590 return Status::OK();
591 }
592
HandleAllToAll(const HloInstruction * hlo)593 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
594 return Status::OK();
595 }
596
HandleCollectivePermute(const HloInstruction *)597 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
598 return Status::OK();
599 }
600
HandleReplicaId(const HloInstruction *)601 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) {
602 return Status::OK();
603 }
604
HandleRng(const HloInstruction * random)605 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
606 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
607 // cost changes with the implementation and the distribution. For now, assume
608 // the cost of each RNG is same as a transcendental operation.
609 current_properties_[kTranscendentalsKey] =
610 ShapeUtil::ElementsIn(random->shape());
611 return Status::OK();
612 }
613
HandleFusion(const HloInstruction * fusion)614 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
615 TF_ASSIGN_OR_RETURN(
616 current_properties_,
617 ProcessNestedSubcomputation(fusion->fused_instructions_computation()));
618
619 // Fusion nodes that produce a tuple also produce the entries in the tuple.
620 // Ignore the memory accessed inside fused ops, since fusion is supposed to
621 // prevent intermediate data from touching slow memory.
622 current_properties_[kBytesAccessedKey] = 0;
623 ShapeUtil::ForEachSubshape(
624 fusion->shape(),
625 [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
626 current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
627 });
628
629 for (const HloInstruction* operand : fusion->operands()) {
630 current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape());
631 }
632
633 return Status::OK();
634 }
635
HandleCall(const HloInstruction * call)636 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
637 TF_ASSIGN_OR_RETURN(current_properties_,
638 ProcessUnnestedSubcomputation(call->to_apply()));
639 current_should_compute_bottleneck_time_ = false;
640 return Status::OK();
641 }
642
HandleCustomCall(const HloInstruction *)643 Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) {
644 // Mark applicable fields as "unknown", since we don't know what CustomCall
645 // does. This is better than returning an error, which would stop iteration,
646 // and therefore would prevent us from getting *any* stats for a computation
647 // which contains a CustomCall.
648 current_properties_[kOptimalSecondsKey] = -1;
649 current_properties_[kBytesAccessedKey] = -1;
650 current_properties_[kFlopsKey] = -1;
651 current_should_compute_bottleneck_time_ = false;
652 return Status::OK();
653 }
654
HandleSort(const HloInstruction * sort)655 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
656 // This assumes a comparison based N*log(N) algorithm. As for all ops, the
657 // actual properties of the op depend on the backend implementation.
658 int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
659 current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
660 return Status::OK();
661 }
662
HandleWhile(const HloInstruction * xla_while)663 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
664 // Since the number of iterations of the while node will not always be
665 // something that we can statically analyze, we cannot precisely compute the
666 // cost of a while node. For now compute the cost of a single iteration.
667 TF_ASSIGN_OR_RETURN(const Properties body_properties,
668 ProcessUnnestedSubcomputation(xla_while->while_body()));
669
670 TF_ASSIGN_OR_RETURN(
671 const Properties condition_properties,
672 ProcessUnnestedSubcomputation(xla_while->while_condition()));
673
674 current_properties_.clear();
675 for (const auto& property : body_properties) {
676 current_properties_[property.first] += property.second;
677 }
678 for (const auto& property : condition_properties) {
679 current_properties_[property.first] += property.second;
680 }
681 current_should_compute_bottleneck_time_ = false;
682
683 return Status::OK();
684 }
685
HandleConditional(const HloInstruction * conditional)686 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
687 // Compute the cost of the branch computations and take the maximum from those
688 // for each property.
689 TF_ASSIGN_OR_RETURN(
690 const Properties branch0_computation_properties,
691 ProcessUnnestedSubcomputation(conditional->branch_computation(0)));
692 current_properties_ = branch0_computation_properties;
693 for (int j = 1; j < conditional->branch_count(); ++j) {
694 TF_ASSIGN_OR_RETURN(
695 const Properties branch_computation_properties,
696 ProcessUnnestedSubcomputation(conditional->branch_computation(j)));
697 for (const auto& property : branch_computation_properties) {
698 if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_,
699 property)) {
700 auto& current_property = current_properties_[property.first];
701 current_property = std::max(current_property, property.second);
702 }
703 }
704 }
705 current_should_compute_bottleneck_time_ = false;
706
707 return Status::OK();
708 }
709
HandleGather(const HloInstruction * gather)710 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
711 // Gather doesn't read the whole input buffer, it's equivalent to a copy the
712 // size of the output shape and a read of the gather indices.
713 current_properties_[kBytesAccessedKey] =
714 GetShapeSize(gather->shape()) * 2 +
715 GetShapeSize(gather->operand(1)->shape());
716 // Gather does not issue any flops.
717 return Status::OK();
718 }
719
HandleScatter(const HloInstruction * scatter)720 Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
721 current_properties_[kBytesAccessedKey] =
722 GetShapeSize(scatter->operand(2)->shape()) * 2 +
723 GetShapeSize(scatter->operand(1)->shape());
724 const int64 element_count =
725 ShapeUtil::ElementsIn(scatter->operand(2)->shape());
726 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
727 ProcessNestedSubcomputation(scatter->to_apply()));
728 for (const auto& property : sub_properties) {
729 if (property.first != kBytesAccessedKey) {
730 current_properties_[property.first] = property.second * element_count;
731 }
732 }
733 return Status::OK();
734 }
735
HandleGetDimensionSize(const HloInstruction *)736 Status HloCostAnalysis::HandleGetDimensionSize(
737 const HloInstruction* /*get_size*/) {
738 return Status::OK();
739 }
740
FinishVisit(const HloInstruction *)741 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
742 return Status::OK();
743 }
744
flop_count() const745 float HloCostAnalysis::flop_count() const {
746 return GetProperty(kFlopsKey, properties_sum_);
747 }
748
transcendental_count() const749 float HloCostAnalysis::transcendental_count() const {
750 return GetProperty(kTranscendentalsKey, properties_sum_);
751 }
752
bytes_accessed() const753 float HloCostAnalysis::bytes_accessed() const {
754 return GetProperty(kBytesAccessedKey, properties_sum_);
755 }
756
optimal_seconds() const757 float HloCostAnalysis::optimal_seconds() const {
758 return GetProperty(kOptimalSecondsKey, properties_sum_);
759 }
760
flop_count(const HloInstruction & hlo) const761 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
762 return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
763 }
764
transcendental_count(const HloInstruction & hlo) const765 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
766 return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
767 }
768
bytes_accessed(const HloInstruction & hlo) const769 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
770 return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
771 }
772
optimal_seconds(const HloInstruction & hlo) const773 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
774 return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
775 }
776
777 StatusOr<HloCostAnalysis::Properties>
ProcessNestedSubcomputation(HloComputation * computation)778 HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) {
779 HloCostAnalysis visitor(shape_size_, per_second_rates_);
780 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
781 return visitor.properties();
782 }
783
784 StatusOr<HloCostAnalysis::Properties>
ProcessUnnestedSubcomputation(HloComputation * computation)785 HloCostAnalysis::ProcessUnnestedSubcomputation(HloComputation* computation) {
786 HloCostAnalysis visitor(shape_size_, per_second_rates_);
787 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
788 hlo_properties_.insert(visitor.hlo_properties_.begin(),
789 visitor.hlo_properties_.end());
790 return visitor.properties();
791 }
792
793 } // namespace xla
794