1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17
18 #include <cmath>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/core/lib/core/bits.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32
33 namespace xla {
34
35 constexpr const char HloCostAnalysis::kFlopsKey[];
36 constexpr const char HloCostAnalysis::kTranscendentalsKey[];
37 constexpr const char HloCostAnalysis::kBytesAccessedKey[];
38 constexpr const char HloCostAnalysis::kOptimalSecondsKey[];
39
HloCostAnalysis(const ShapeSizeFunction & shape_size)40 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
41 : HloCostAnalysis(shape_size, {}) {}
42
HloCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)43 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
44 const Properties& per_second_rates)
45 : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
46
Preprocess(const HloInstruction * hlo)47 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
48 // Set current instruction cost values to reasonable default values. Each
49 // handler can overwrite these values. In Postprocess, these values are
50 // accumulated and written to the per-instruction maps.
51 current_properties_.clear();
52 current_should_compute_bottleneck_time_ = true;
53
54 // The default number of bytes accessed for an instruction is the sum of the
55 // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
56 // handle opaque types.
57 float bytes_accessed = GetShapeSize(hlo->shape());
58 SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
59 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
60 const HloInstruction* operand = hlo->operand(i);
61 bytes_accessed += GetShapeSize(operand->shape());
62 SetOperandBytesAccessed(i, GetShapeSize(operand->shape()));
63 }
64 current_properties_[kBytesAccessedKey] = bytes_accessed;
65
66 return Status::OK();
67 }
68
Postprocess(const HloInstruction * hlo)69 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
70 if (current_should_compute_bottleneck_time_) {
71 // Compute the time as the time of the bottleneck, i.e. the slowest property
72 // given the per-second rate of each property.
73 float optimal_seconds = 0.0f;
74 for (const auto& property : current_properties_) {
75 if (property.first != kOptimalSecondsKey) {
76 optimal_seconds = std::max(
77 optimal_seconds,
78 property.second /
79 GetProperty(property.first, per_second_rates_, INFINITY));
80 }
81 }
82 current_properties_[kOptimalSecondsKey] = optimal_seconds;
83 }
84
85 TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
86 for (const auto& property : current_properties_) {
87 properties_sum_[property.first] += property.second;
88 }
89
90 return Status::OK();
91 }
92
HandleElementwiseOp(const HloInstruction * hlo_instruction)93 Status HloCostAnalysis::HandleElementwiseOp(
94 const HloInstruction* hlo_instruction) {
95 const auto& shape = hlo_instruction->shape();
96 // For element-wise operations, the number of computations is the same as the
97 // number of elements in the output shape.
98 auto computation_count = ShapeUtil::ElementsIn(shape);
99 auto opcode = hlo_instruction->opcode();
100 // We treat transcendental operations separately since one transcendental
101 // operation can correspond to several floating point ops.
102 // kLogistic is included in "trascendental" as it is implemented using
103 // trascendental ops (tanh or exp).
104 if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
105 opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
106 opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt ||
107 opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
108 opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
109 opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
110 opcode == HloOpcode::kAtan2) {
111 current_properties_[kTranscendentalsKey] = computation_count;
112 } else {
113 // Note: transcendental operations are considered a separate category from
114 // FLOPs.
115 current_properties_[kFlopsKey] = computation_count;
116 }
117 return Status::OK();
118 }
119
GetProperty(const string & key,const Properties & properties,const float default_value)120 /*static*/ float HloCostAnalysis::GetProperty(const string& key,
121 const Properties& properties,
122 const float default_value) {
123 auto key_value = properties.find(key);
124 return key_value == properties.end() ? default_value : key_value->second;
125 }
126
GetPropertyForHlo(const HloInstruction & hlo,const string & key,const HloToProperties & hlo_to_properties)127 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
128 const HloInstruction& hlo, const string& key,
129 const HloToProperties& hlo_to_properties) {
130 auto it = hlo_to_properties.find(&hlo);
131 if (it == hlo_to_properties.end()) {
132 return 0.0f;
133 } else {
134 return GetProperty(key, it->second);
135 }
136 }
137
GetShapeSize(const Shape & shape) const138 int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
139 if (!LayoutUtil::HasLayout(shape)) {
140 return 0;
141 }
142 return shape_size_(shape);
143 }
144
FusionParameterReadBytes(const HloInstruction * hlo) const145 int64 HloCostAnalysis::FusionParameterReadBytes(
146 const HloInstruction* hlo) const {
147 int64_t size = 0;
148 bool seen_trivial_user = false;
149 CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
150 hlo->opcode() == HloOpcode::kGetTupleElement));
151 for (const HloInstruction* user : hlo->users()) {
152 switch (user->opcode()) {
153 case HloOpcode::kFusion: {
154 for (int64_t idx : user->OperandIndices(hlo)) {
155 size += FusionParameterReadBytes(user->fused_parameter(idx));
156 }
157 break;
158 }
159 case HloOpcode::kSlice:
160 size += GetShapeSize(user->shape());
161 break;
162 case HloOpcode::kDynamicSlice:
163 size += hlo == user->operand(0) ? GetShapeSize(user->shape())
164 : GetShapeSize(hlo->shape());
165 break;
166 case HloOpcode::kDynamicUpdateSlice:
167 // Uses the same shape as 'update' which is operand 1.
168 size += hlo == user->operand(0)
169 ? GetShapeSize(user->operand(1)->shape())
170 : GetShapeSize(hlo->shape());
171 break;
172 case HloOpcode::kBroadcast:
173 case HloOpcode::kReshape:
174 size += GetShapeSize(hlo->shape());
175 break;
176 default:
177 // Other instructions reading this parameter are assumed to be able to
178 // share the read from memory.
179 if (!seen_trivial_user) {
180 seen_trivial_user = true;
181 size += GetShapeSize(hlo->shape());
182 }
183 }
184 }
185 return size;
186 }
187
HandleElementwiseUnary(const HloInstruction * hlo)188 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
189 return HandleElementwiseOp(hlo);
190 }
191
HandleElementwiseBinary(const HloInstruction * hlo)192 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
193 return HandleElementwiseOp(hlo);
194 }
195
HandleCompare(const HloInstruction * compare)196 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
197 return HandleElementwiseOp(compare);
198 }
199
HandleClamp(const HloInstruction * clamp)200 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
201 return HandleElementwiseOp(clamp);
202 }
203
HandleReducePrecision(const HloInstruction * hlo)204 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
205 return HandleElementwiseOp(hlo);
206 }
207
HandleParameter(const HloInstruction *)208 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
209 current_should_compute_bottleneck_time_ = false;
210 current_properties_[kBytesAccessedKey] = 0;
211 SetOutputBytesAccessed(0);
212 current_properties_[kOptimalSecondsKey] = 0;
213 return Status::OK();
214 }
215
HandleConstant(const HloInstruction *)216 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
217 current_should_compute_bottleneck_time_ = false;
218 current_properties_[kBytesAccessedKey] = 0;
219 SetOutputBytesAccessed(0);
220 current_properties_[kOptimalSecondsKey] = 0;
221 return Status::OK();
222 }
223
HandleIota(const HloInstruction *)224 Status HloCostAnalysis::HandleIota(const HloInstruction*) {
225 return Status::OK();
226 }
227
HandleGetTupleElement(const HloInstruction * get_tuple_element)228 Status HloCostAnalysis::HandleGetTupleElement(
229 const HloInstruction* get_tuple_element) {
230 // GetTupleElement forwards a pointer and does not touch each element in the
231 // output.
232 current_should_compute_bottleneck_time_ = false;
233 current_properties_[kBytesAccessedKey] = 0;
234 SetOutputBytesAccessed(0);
235 SetOperandBytesAccessed(0, 0);
236 current_properties_[kOptimalSecondsKey] = 0;
237 return Status::OK();
238 }
239
HandleSelect(const HloInstruction * hlo)240 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
241 return HandleElementwiseOp(hlo);
242 }
243
HandleTupleSelect(const HloInstruction *)244 Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) {
245 return Status::OK();
246 }
247
HandleReverse(const HloInstruction *)248 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
249 return Status::OK();
250 }
251
HandleSlice(const HloInstruction * slice)252 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
253 current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
254 SetOutputBytesAccessed(GetShapeSize(slice->shape()));
255 SetOperandBytesAccessed(0, GetShapeSize(slice->shape()));
256 return Status::OK();
257 }
258
HandleDynamicSlice(const HloInstruction * dynamic_slice)259 Status HloCostAnalysis::HandleDynamicSlice(
260 const HloInstruction* dynamic_slice) {
261 current_properties_[kBytesAccessedKey] =
262 GetShapeSize(dynamic_slice->shape()) * 2 +
263 GetShapeSize(dynamic_slice->operand(1)->shape());
264 SetOutputBytesAccessed(GetShapeSize(dynamic_slice->shape()));
265 SetOperandBytesAccessed(0, GetShapeSize(dynamic_slice->shape()));
266 SetOperandBytesAccessed(1, GetShapeSize(dynamic_slice->operand(1)->shape()));
267 return Status::OK();
268 }
269
HandleDynamicUpdateSlice(const HloInstruction * dynamic_update_slice)270 Status HloCostAnalysis::HandleDynamicUpdateSlice(
271 const HloInstruction* dynamic_update_slice) {
272 current_properties_[kBytesAccessedKey] =
273 GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2 +
274 GetShapeSize(dynamic_update_slice->operand(2)->shape());
275 // Operand 0 aliases with the output.
276 SetOutputBytesAccessed(
277 GetShapeSize(dynamic_update_slice->operand(1)->shape()));
278 SetOperandBytesAccessed(0, 0);
279 SetOperandBytesAccessed(
280 1, GetShapeSize(dynamic_update_slice->operand(1)->shape()));
281 SetOperandBytesAccessed(
282 2, GetShapeSize(dynamic_update_slice->operand(2)->shape()));
283 return Status::OK();
284 }
285
HandleTuple(const HloInstruction * tuple)286 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
287 // The tuple instruction only gathers pointers from inputs (it doesn't iterate
288 // through them). The memory touched is then only the size of the output
289 // index table of the tuple.
290
291 current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
292 SetOutputBytesAccessed(GetShapeSize(tuple->shape()));
293 for (int i = 0; i < tuple->operand_count(); ++i) {
294 SetOperandBytesAccessed(i, 0);
295 }
296 return Status::OK();
297 }
298
HandleConcatenate(const HloInstruction *)299 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
300 return Status::OK();
301 }
302
HandleConvert(const HloInstruction * convert)303 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
304 return HandleElementwiseOp(convert);
305 }
306
HandleCopy(const HloInstruction *)307 Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
308 return Status::OK();
309 }
310
HandleDomain(const HloInstruction * domain)311 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
312 // Domain does not have any computation or data transfer.
313 current_should_compute_bottleneck_time_ = false;
314 current_properties_[kBytesAccessedKey] = 0;
315 SetOutputBytesAccessed(0);
316 for (int i = 0; i < domain->operand_count(); ++i) {
317 SetOperandBytesAccessed(i, 0);
318 }
319 current_properties_[kOptimalSecondsKey] = 0;
320 return Status::OK();
321 }
322
HandleDot(const HloInstruction * dot)323 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
324 const Shape& lhs_shape = dot->operand(0)->shape();
325 const Shape& dot_shape = dot->shape();
326 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
327 // Count of elements along the reduction dimension (last dimension for the
328 // rhs).
329 int64_t reduction_width = 1;
330 for (auto dim : dnums.lhs_contracting_dimensions()) {
331 reduction_width *= lhs_shape.dimensions(dim);
332 }
333 // Each output element requires reduction_width FMA operations.
334 current_properties_[kFlopsKey] =
335 kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width;
336 return Status::OK();
337 }
338
HandleInfeed(const HloInstruction * infeed)339 Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
340 // Count nested infeed output tuples.
341 int64_t size = 0;
342 for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
343 size += GetShapeSize(indexed_shape.shape);
344 SetOutputBytesAccessed(indexed_shape.index,
345 GetShapeSize(indexed_shape.shape));
346 }
347 SetOutputBytesAccessed(size);
348 current_properties_[kBytesAccessedKey] = size;
349 return Status::OK();
350 }
351
HandleOutfeed(const HloInstruction * outfeed)352 Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
353 // Count nested outfeed operand tuples.
354 current_properties_[kBytesAccessedKey] = 0;
355 for (int64_t i = 0; i < outfeed->operand_count(); ++i) {
356 const HloInstruction* operand = outfeed->operand(i);
357 int64_t size = 0;
358 for (const auto& indexed_shape :
359 ShapeUtil::GetLeafShapes(operand->shape())) {
360 size += GetShapeSize(indexed_shape.shape);
361 SetOperandBytesAccessed(i, indexed_shape.index,
362 GetShapeSize(indexed_shape.shape));
363 }
364 SetOperandBytesAccessed(i, size);
365 current_properties_[kBytesAccessedKey] += size;
366 }
367 return Status::OK();
368 }
369
HandleMap(const HloInstruction * map)370 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
371 // Compute properties of the mapped function.
372 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
373 ProcessSubcomputation(map->to_apply()));
374
375 // Compute the cost of all elements for this Map operation.
376 const int64_t element_count = ShapeUtil::ElementsIn(map->shape());
377 for (const auto& property : sub_properties) {
378 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
379 current_properties_[property.first] = property.second * element_count;
380 }
381 }
382 return Status::OK();
383 }
384
HandleReduce(const HloInstruction * reduce)385 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
386 HloComputation* function = reduce->to_apply();
387 // Compute the cost of the user function.
388 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
389 ProcessSubcomputation(function));
390
391 // Compute the cost of all elements for this Reduce operation.
392 // This counts the number of times the reduction function is applied, so it
393 // does not need to be multiplied by the number of input tensors - that's
394 // already "priced in" by the sub-computation doing more work.
395 auto arg = reduce->operand(0);
396 auto output_shape = reduce->shape().IsArray()
397 ? reduce->shape()
398 : reduce->shape().tuple_shapes(0);
399 int64_t reduction_count =
400 ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
401 for (const auto& property : sub_properties) {
402 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
403 current_properties_[property.first] = property.second * reduction_count;
404 }
405 }
406 return Status::OK();
407 }
408
HandleReduceWindow(const HloInstruction * reduce_window)409 Status HloCostAnalysis::HandleReduceWindow(
410 const HloInstruction* reduce_window) {
411 const Window& window = reduce_window->window();
412 auto function = reduce_window->to_apply();
413 // Compute the properties of the reduction function.
414 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
415 ProcessSubcomputation(function));
416
417 // Compute the cost of all elements for this ReduceWindow operation. For each
418 // output element there are window_size - 1 reductions to perform.
419 int64_t window_element_count = 1;
420 for (const auto& dimension : window.dimensions()) {
421 window_element_count *= dimension.size();
422 }
423
424 const int64_t output_element_count =
425 ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
426 ? reduce_window->shape()
427 : reduce_window->shape().tuple_shapes(0));
428 const int64_t reduction_count =
429 (window_element_count - 1) * output_element_count;
430 for (const auto& property : sub_properties) {
431 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
432 current_properties_[property.first] = property.second * reduction_count;
433 }
434 }
435 return Status::OK();
436 }
437
HandleSelectAndScatter(const HloInstruction * instruction)438 Status HloCostAnalysis::HandleSelectAndScatter(
439 const HloInstruction* instruction) {
440 // Compute the properties of the select and scatter function.
441 // Compute the properties of the reduction function.
442 TF_ASSIGN_OR_RETURN(const Properties select_properties,
443 ProcessSubcomputation(instruction->select()));
444 TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
445 ProcessSubcomputation(instruction->scatter()));
446
447 // Compute the cost of all elements for this operation. For each scatter
448 // source element there are window_size - 1 select computations to perform and
449 // 1 scatter computation to perform.
450 const auto source = instruction->operand(1);
451 const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
452 int64_t window_element_count = 1;
453 for (const auto& dimension : instruction->window().dimensions()) {
454 window_element_count *= dimension.size();
455 }
456 const int64_t select_count =
457 source_element_count * (window_element_count - 1);
458 for (const auto& property : select_properties) {
459 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
460 current_properties_[property.first] += property.second * select_count;
461 }
462 }
463 for (const auto& property : scatter_properties) {
464 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
465 current_properties_[property.first] +=
466 property.second * source_element_count;
467 }
468 }
469 return Status::OK();
470 }
471
HandleBitcast(const HloInstruction *)472 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
473 // A bitcast does no computation and touches no memory.
474 current_properties_[kBytesAccessedKey] = 0;
475 SetOutputBytesAccessed(0);
476 SetOperandBytesAccessed(0, 0);
477 current_properties_[kOptimalSecondsKey] = 0;
478 return Status::OK();
479 }
480
HandleBroadcast(const HloInstruction *)481 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
482 return Status::OK();
483 }
484
HandlePad(const HloInstruction *)485 Status HloCostAnalysis::HandlePad(const HloInstruction*) {
486 return Status::OK();
487 }
488
HandleCopyStart(const HloInstruction *)489 Status HloCostAnalysis::HandleCopyStart(const HloInstruction*) {
490 return Status::OK();
491 }
492
HandleCopyDone(const HloInstruction *)493 Status HloCostAnalysis::HandleCopyDone(const HloInstruction*) {
494 return Status::OK();
495 }
496
HandleSend(const HloInstruction *)497 Status HloCostAnalysis::HandleSend(const HloInstruction*) {
498 return Status::OK();
499 }
500
HandleSendDone(const HloInstruction *)501 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
502 return Status::OK();
503 }
504
HandleRecv(const HloInstruction *)505 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
506 return Status::OK();
507 }
508
HandleRecvDone(const HloInstruction *)509 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
510 return Status::OK();
511 }
512
HandleReshape(const HloInstruction *)513 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
514 return Status::OK();
515 }
516
HandleDynamicReshape(const HloInstruction *)517 Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
518 return Status::OK();
519 }
520
HandleBatchNormTraining(const HloInstruction *)521 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
522 // TODO(b/62294698): Implement cost analysis for batch-norm-training.
523 return Status::OK();
524 }
525
HandleBatchNormInference(const HloInstruction *)526 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
527 // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
528 return Status::OK();
529 }
530
HandleBatchNormGrad(const HloInstruction *)531 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
532 // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
533 return Status::OK();
534 }
535
HandleTranspose(const HloInstruction * transpose)536 Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) {
537 if (transpose->IsEffectiveBitcast()) {
538 return HandleBitcast(transpose);
539 }
540 return Status::OK();
541 }
542
HandleAfterAll(const HloInstruction * token)543 Status HloCostAnalysis::HandleAfterAll(const HloInstruction* token) {
544 // This instruction is used to enforce ordering at compile time. No code is
545 // emitted.
546 current_should_compute_bottleneck_time_ = false;
547 current_properties_[kBytesAccessedKey] = 0;
548 SetOutputBytesAccessed(0);
549 for (int i = 0; i < token->operand_count(); ++i) {
550 SetOperandBytesAccessed(i, 0);
551 }
552 current_properties_[kOptimalSecondsKey] = 0;
553 return Status::OK();
554 }
555
HandleAddDependency(const HloInstruction * add_dependency)556 Status HloCostAnalysis::HandleAddDependency(
557 const HloInstruction* add_dependency) {
558 // This instruction is used to enforce ordering at compile time. No code is
559 // emitted.
560 current_should_compute_bottleneck_time_ = false;
561 current_properties_[kBytesAccessedKey] = 0;
562 SetOutputBytesAccessed(0);
563 for (int i = 0; i < add_dependency->operand_count(); ++i) {
564 SetOperandBytesAccessed(i, 0);
565 }
566 current_properties_[kOptimalSecondsKey] = 0;
567 return Status::OK();
568 }
569
HandleConvolution(const HloInstruction * convolution)570 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
571 auto lhs = convolution->operand(0);
572 auto rhs = convolution->operand(1);
573 Window window = convolution->window();
574 const auto& result_shape = convolution->shape();
575 const Shape& lhs_shape = lhs->shape();
576 const Shape& rhs_shape = rhs->shape();
577
578 const auto& dnums = convolution->convolution_dimension_numbers();
579
580 const int64_t input_batch_dim = dnums.input_batch_dimension();
581 const int64_t input_feature_dim = dnums.input_feature_dimension();
582 const int64_t output_feature_dim = dnums.output_feature_dimension();
583 const int64_t input_feature =
584 ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
585 const int64_t output_feature =
586 ShapeUtil::GetDimension(result_shape, output_feature_dim);
587 const int64_t batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
588
589 DimensionVector kernel_limits;
590 DimensionVector output_limits;
591 DimensionVector input_limits;
592 if (window.dimensions().empty()) {
593 window = window_util::MakeWindow({1});
594 kernel_limits.push_back(1);
595 output_limits.push_back(1);
596 input_limits.push_back(1);
597 } else {
598 for (int64_t spatial_dimension = 0;
599 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
600 // Spatial dimension number for kernel (rhs).
601 const int64_t kernel_spatial_dim =
602 dnums.kernel_spatial_dimensions(spatial_dimension);
603 const int64_t kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
604 kernel_limits.push_back(kernel_limit);
605
606 // Spatial dimension number for output.
607 const int64_t output_spatial_dim =
608 dnums.output_spatial_dimensions(spatial_dimension);
609 const int64_t output_limit = result_shape.dimensions(output_spatial_dim);
610 output_limits.push_back(output_limit);
611
612 // Spatial dimension number for input (lhs).
613 const int64_t input_spatial_dim =
614 dnums.input_spatial_dimensions(spatial_dimension);
615 const int64_t input_limit = lhs_shape.dimensions(input_spatial_dim);
616 input_limits.push_back(input_limit);
617 }
618 }
619
620 DimensionVector valid_position_counts;
621
622 // Loop over each spatial dimension.
623 for (int64_t spatial_dimension = 0;
624 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
625 const auto& window_dim = window.dimensions(spatial_dimension);
626 // These two conditions will create an N^2 iteration pattern with only N
627 // valid elements. This is a performance optimization and produces the same
628 // result as the whole loop.
629 if (input_limits[spatial_dimension] == output_limits[spatial_dimension] &&
630 kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
631 input_limits[spatial_dimension] == window_dim.base_dilation() &&
632 window_dim.window_dilation() == 1 &&
633 std::max<int64>(1, input_limits[spatial_dimension] - 1) ==
634 window_dim.stride() &&
635 window_dim.padding_low() == 0 && window_dim.padding_high() == 0) {
636 valid_position_counts.push_back(input_limits[spatial_dimension]);
637 continue;
638 }
639
640 if (input_limits[spatial_dimension] == 1 &&
641 kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
642 window_dim.window_dilation() == 1 && window_dim.base_dilation() == 1 &&
643 window_dim.stride() == 1 &&
644 window_dim.padding_high() == output_limits[spatial_dimension] - 1 &&
645 window_dim.padding_low() == output_limits[spatial_dimension] - 1) {
646 valid_position_counts.push_back(output_limits[spatial_dimension]);
647 continue;
648 }
649
650 int64_t valid_position_count = 0;
651 // Loop over each point in the kernel.
652 for (int64_t kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
653 ++kernel_idx) {
654 // Loop over each point in the output.
655 for (int64_t output_idx = 0;
656 output_idx < output_limits[spatial_dimension]; ++output_idx) {
657 // Calculate lhs (input) index without taking base dilation into
658 // account.
659 const int64_t undilated_index =
660 output_idx * window_dim.stride() - window_dim.padding_low() +
661 kernel_idx * window_dim.window_dilation();
662
663 // Calculate the actual lhs (input) index after dilation. Avoid the
664 // division as an optimization.
665 const int64_t lhs_spatial_index =
666 window_dim.base_dilation() > 1
667 ? undilated_index / window_dim.base_dilation()
668 : undilated_index;
669
670 // Skip if the lhs (input) index is to be dilated.
671 if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
672 continue;
673 }
674
675 // Skip if input index is not in bound.
676 if (lhs_spatial_index < 0 ||
677 lhs_spatial_index >= input_limits[spatial_dimension]) {
678 continue;
679 }
680
681 valid_position_count += 1;
682 }
683 }
684 valid_position_counts.push_back(valid_position_count);
685 }
686
687 const int64_t fma_count =
688 (input_feature / convolution->feature_group_count()) * output_feature *
689 (batch / convolution->batch_group_count()) *
690 Product(valid_position_counts);
691 current_properties_[kFlopsKey] = fma_count * kFmaFlops;
692 return Status::OK();
693 }
694
HandleFft(const HloInstruction * fft)695 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
696 auto real_shape =
697 fft->operand(0)->shape().IsTuple()
698 ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
699 : fft->operand(0)->shape();
700 constexpr int kFmaPerComplexMul = 4;
701 int64_t log_factors = 1;
702 for (int64_t dim : fft->fft_length()) {
703 log_factors *= tensorflow::Log2Floor(dim);
704 }
705 current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
706 ShapeUtil::ElementsIn(real_shape);
707 return Status::OK();
708 }
709
HandleTriangularSolve(const HloInstruction * hlo)710 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
711 // Half of operand 0 is read.
712 float bytes_accessed = GetShapeSize(hlo->shape());
713 SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
714 bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
715 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
716 bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
717 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(1)->shape()));
718 current_properties_[kBytesAccessedKey] = bytes_accessed;
719
720 const Shape& a_shape = hlo->operand(0)->shape();
721 const Shape& b_shape = hlo->operand(1)->shape();
722 // Estimate as batch * mn^2 / 2 flops.
723 int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
724 elems *= ShapeUtil::ElementsIn(b_shape);
725 current_properties_[kFlopsKey] = kFmaFlops * elems;
726 return Status::OK();
727 }
728
HandleCholesky(const HloInstruction * hlo)729 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
730 // Half of operand 0 is read and half of the output will be written.
731 float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
732 SetOutputBytesAccessed(GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
733 bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
734 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
735 current_properties_[kBytesAccessedKey] = bytes_accessed;
736
737 const Shape& a_shape = hlo->operand(0)->shape();
738 // Estimate as batch * n^3 / 3 flops.
739 int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
740 elems *= ShapeUtil::ElementsIn(a_shape);
741 current_properties_[kFlopsKey] = elems / 3;
742 return Status::OK();
743 }
744
HandleAllGather(const HloInstruction *)745 Status HloCostAnalysis::HandleAllGather(const HloInstruction* /*hlo*/) {
746 return Status::OK();
747 }
748
HandleAllGatherStart(const HloInstruction * hlo)749 Status HloCostAnalysis::HandleAllGatherStart(const HloInstruction* hlo) {
750 return HandleAllGather(hlo);
751 }
752
HandleAllGatherDone(const HloInstruction *)753 Status HloCostAnalysis::HandleAllGatherDone(const HloInstruction* /*hlo*/) {
754 return Status::OK();
755 }
756
HandleAllReduce(const HloInstruction * crs)757 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
758 // We assume 2 replicas, so that each output element is the sum of two input
759 // elements.
760 //
761 // TODO(b/33004697): Compute correct cost here, taking the actual number of
762 // replicas into account.
763 double flops = 0.0;
764 int64_t output_bytes_accessed = 0;
765 ShapeUtil::ForEachSubshape(
766 crs->shape(), [&](const Shape& subshape, const ShapeIndex&) {
767 if (subshape.IsArray()) {
768 flops += ShapeUtil::ElementsIn(subshape);
769 output_bytes_accessed += GetShapeSize(subshape);
770 }
771 });
772 int64_t bytes_accessed = output_bytes_accessed;
773 for (const HloInstruction* operand : crs->operands()) {
774 bytes_accessed += GetShapeSize(operand->shape());
775 }
776 current_properties_[kFlopsKey] = flops;
777 SetOutputBytesAccessed(output_bytes_accessed);
778 current_properties_[kBytesAccessedKey] = bytes_accessed;
779 return Status::OK();
780 }
781
HandleReduceScatter(const HloInstruction * hlo)782 Status HloCostAnalysis::HandleReduceScatter(const HloInstruction* hlo) {
783 return Status::OK();
784 }
785
HandleAllReduceStart(const HloInstruction * hlo)786 Status HloCostAnalysis::HandleAllReduceStart(const HloInstruction* hlo) {
787 return HandleAllReduce(hlo);
788 }
789
HandleAllReduceDone(const HloInstruction *)790 Status HloCostAnalysis::HandleAllReduceDone(const HloInstruction* /*hlo*/) {
791 return Status::OK();
792 }
793
HandleAllToAll(const HloInstruction * hlo)794 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
795 return Status::OK();
796 }
797
HandleCollectivePermute(const HloInstruction *)798 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
799 return Status::OK();
800 }
801
HandleCollectivePermuteStart(const HloInstruction *)802 Status HloCostAnalysis::HandleCollectivePermuteStart(
803 const HloInstruction* /*hlo*/) {
804 return Status::OK();
805 }
806
HandleCollectivePermuteDone(const HloInstruction *)807 Status HloCostAnalysis::HandleCollectivePermuteDone(
808 const HloInstruction* /*hlo*/) {
809 return Status::OK();
810 }
811
HandlePartitionId(const HloInstruction *)812 Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
813 return Status::OK();
814 }
815
HandleReplicaId(const HloInstruction *)816 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) {
817 return Status::OK();
818 }
819
HandleRng(const HloInstruction * random)820 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
821 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
822 // cost changes with the implementation and the distribution. For now, assume
823 // the cost of each RNG is same as a transcendental operation.
824 current_properties_[kTranscendentalsKey] =
825 ShapeUtil::ElementsIn(random->shape());
826 return Status::OK();
827 }
828
HandleRngBitGenerator(const HloInstruction * random)829 Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
830 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
831 // cost changes with the implementation and the distribution. For now, assume
832 // the cost of each RNG is same as a transcendental operation.
833 current_properties_[kTranscendentalsKey] =
834 ShapeUtil::ElementsInRecursive(random->shape());
835 return Status::OK();
836 }
837
HandleRngGetAndUpdateState(const HloInstruction * random)838 Status HloCostAnalysis::HandleRngGetAndUpdateState(
839 const HloInstruction* random) {
840 return Status::OK();
841 }
842
HandleFusion(const HloInstruction * fusion)843 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
844 if (fusion->IsCustomFusion()) {
845 for (const HloInstruction* hlo :
846 fusion->fused_instructions_computation()->instructions()) {
847 if (hlo->opcode() == HloOpcode::kGather) {
848 return HandleGather(hlo);
849 }
850 if (hlo->opcode() == HloOpcode::kScatter) {
851 return HandleScatter(hlo);
852 }
853 }
854 }
855 TF_ASSIGN_OR_RETURN(
856 current_properties_,
857 ProcessSubcomputation(fusion->fused_instructions_computation()));
858
859 // Fusion nodes that produce a tuple also produce the entries in the tuple.
860 // Ignore the memory accessed inside fused ops, since fusion is supposed to
861 // prevent intermediate data from touching slow memory.
862 current_properties_[kBytesAccessedKey] = 0;
863 ShapeUtil::ForEachSubshape(
864 fusion->shape(),
865 [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
866 if (!subshape.IsArray()) {
867 return;
868 }
869 if (shape_index.empty()) {
870 if (fusion->fused_expression_root()->opcode() ==
871 HloOpcode::kDynamicUpdateSlice) {
872 int64_t size = GetShapeSize(
873 fusion->fused_expression_root()->operand(1)->shape());
874 current_properties_[kBytesAccessedKey] += size;
875 SetOutputBytesAccessed(shape_index, size);
876 return;
877 }
878 } else if (shape_index.size() == 1) {
879 if (fusion->fused_expression_root()->opcode() == HloOpcode::kTuple &&
880 fusion->fused_expression_root()
881 ->operand(shape_index[0])
882 ->opcode() == HloOpcode::kDynamicUpdateSlice) {
883 int64_t size = GetShapeSize(fusion->fused_expression_root()
884 ->operand(shape_index[0])
885 ->operand(1)
886 ->shape());
887 current_properties_[kBytesAccessedKey] += size;
888 SetOutputBytesAccessed(shape_index, size);
889 return;
890 }
891 }
892 current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
893 SetOutputBytesAccessed(shape_index, GetShapeSize(subshape));
894 });
895
896 if (fusion->shape().IsTuple()) {
897 // Propagate and accumulate the output tuple bytes from the tuple subshapes.
898 // This ensures we have the correct output bytes accessed for the shape
899 // index
900 // {}.
901 std::function<float(const Shape&, const ShapeIndex&)>
902 propagate_output_size_to_parent;
903 propagate_output_size_to_parent = [&](const Shape& shape,
904 const ShapeIndex& shape_index) {
905 auto output_bytes_it =
906 current_properties_.find(GetOutputBytesAccessedKey(shape_index));
907 if (output_bytes_it != current_properties_.end()) {
908 return output_bytes_it->second;
909 }
910 float bytes_accessed = 0;
911 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
912 const Shape& subshape = shape.tuple_shapes(i);
913 ShapeIndex subshape_index(shape_index);
914 subshape_index.push_back(i);
915 bytes_accessed +=
916 propagate_output_size_to_parent(subshape, subshape_index);
917 }
918 SetOutputBytesAccessed(shape_index, bytes_accessed);
919 return bytes_accessed;
920 };
921 current_properties_.erase(
922 current_properties_.find(GetOutputBytesAccessedKey()));
923 propagate_output_size_to_parent(fusion->shape(), {});
924 }
925
926 for (int64_t i = 0; i < fusion->fused_parameters().size(); ++i) {
927 const HloInstruction* operand = fusion->fused_parameter(i);
928 int64_t operand_size = 0;
929 if (!fusion->shape().IsTuple()) {
930 operand_size = FusionParameterReadBytes(operand);
931 } else {
932 // If the fusion parameter is a tuple type, find the gte for the leaf
933 // shape and calculate the bytes accessed for those array types.
934 for (const auto& indexed_shape :
935 ShapeUtil::GetLeafShapes(operand->shape())) {
936 const HloInstruction* gte = operand;
937 for (int64_t index : indexed_shape.index) {
938 for (const HloInstruction* user : gte->users()) {
939 if (user->opcode() == HloOpcode::kGetTupleElement &&
940 user->tuple_index() == index) {
941 gte = user;
942 break;
943 }
944 }
945 }
946 int64_t size = FusionParameterReadBytes(gte);
947 operand_size += size;
948 SetOperandBytesAccessed(i, indexed_shape.index, size);
949 }
950 }
951 current_properties_[kBytesAccessedKey] += operand_size;
952 SetOperandBytesAccessed(i, operand_size);
953 }
954
955 return Status::OK();
956 }
957
HandleCall(const HloInstruction * call)958 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
959 TF_ASSIGN_OR_RETURN(current_properties_,
960 ProcessSubcomputation(call->to_apply()));
961 current_should_compute_bottleneck_time_ = false;
962 return Status::OK();
963 }
964
HandleCustomCall(const HloInstruction * custom_call)965 Status HloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
966 // Mark applicable fields as "unknown", since we don't know what CustomCall
967 // does. This is better than returning an error, which would stop iteration,
968 // and therefore would prevent us from getting *any* stats for a computation
969 // which contains a CustomCall.
970 current_properties_[kOptimalSecondsKey] = -1;
971 current_properties_[kBytesAccessedKey] = -1;
972 SetOutputBytesAccessed(-1);
973 for (int i = 0; i < custom_call->operand_count(); ++i) {
974 SetOperandBytesAccessed(i, -1);
975 }
976 current_properties_[kFlopsKey] = -1;
977 current_should_compute_bottleneck_time_ = false;
978 return Status::OK();
979 }
980
HandleSort(const HloInstruction * sort)981 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
982 // This assumes a comparison based N*log(N) algorithm. As for all ops, the
983 // actual properties of the op depend on the backend implementation.
984 int64_t elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
985 current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
986 return Status::OK();
987 }
988
HandleWhile(const HloInstruction * xla_while)989 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
990 // Since the number of iterations of the while node will not always be
991 // something that we can statically analyze, we cannot precisely compute the
992 // cost of a while node. For now compute the cost of a single iteration.
993 TF_ASSIGN_OR_RETURN(const Properties body_properties,
994 ProcessSubcomputation(xla_while->while_body()));
995
996 TF_ASSIGN_OR_RETURN(const Properties condition_properties,
997 ProcessSubcomputation(xla_while->while_condition()));
998
999 current_properties_.clear();
1000 for (const auto& property : body_properties) {
1001 current_properties_[property.first] += property.second;
1002 }
1003 for (const auto& property : condition_properties) {
1004 current_properties_[property.first] += property.second;
1005 }
1006 current_should_compute_bottleneck_time_ = false;
1007
1008 return Status::OK();
1009 }
1010
HandleConditional(const HloInstruction * conditional)1011 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
1012 // Compute the cost of the branch computations and take the maximum from those
1013 // for each property.
1014 TF_ASSIGN_OR_RETURN(
1015 const Properties branch0_computation_properties,
1016 ProcessSubcomputation(conditional->branch_computation(0)));
1017 current_properties_ = branch0_computation_properties;
1018 for (int j = 1; j < conditional->branch_count(); ++j) {
1019 TF_ASSIGN_OR_RETURN(
1020 const Properties branch_computation_properties,
1021 ProcessSubcomputation(conditional->branch_computation(j)));
1022 for (const auto& property : branch_computation_properties) {
1023 if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_,
1024 property)) {
1025 auto& current_property = current_properties_[property.first];
1026 current_property = std::max(current_property, property.second);
1027 }
1028 }
1029 }
1030 current_should_compute_bottleneck_time_ = false;
1031
1032 return Status::OK();
1033 }
1034
HandleGather(const HloInstruction * gather)1035 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
1036 // Gather doesn't read the whole input buffer, it's equivalent to a copy the
1037 // size of the output shape and a read of the gather indices.
1038 int64_t output_size = GetShapeSize(gather->shape());
1039 current_properties_[kBytesAccessedKey] =
1040 output_size * 2 + GetShapeSize(gather->operand(1)->shape());
1041 SetOperandBytesAccessed(0, output_size);
1042 SetOperandBytesAccessed(1, GetShapeSize(gather->operand(1)->shape()));
1043 SetOutputBytesAccessed(output_size);
1044 // Gather does not issue any flops.
1045 return Status::OK();
1046 }
1047
HandleScatter(const HloInstruction * scatter)1048 Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
1049 // Scatter accesses the equivalent of 3 update shapes (input, output, and
1050 // updates), and the scatter indices.
1051 int64_t update_size = GetShapeSize(scatter->operand(2)->shape());
1052 current_properties_[kBytesAccessedKey] =
1053 update_size * 3 + GetShapeSize(scatter->operand(1)->shape());
1054 SetOperandBytesAccessed(0, update_size);
1055 SetOperandBytesAccessed(1, GetShapeSize(scatter->operand(1)->shape()));
1056 SetOperandBytesAccessed(2, update_size);
1057 SetOutputBytesAccessed(update_size);
1058 const int64_t element_count =
1059 ShapeUtil::ElementsIn(scatter->operand(2)->shape());
1060 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
1061 ProcessSubcomputation(scatter->to_apply()));
1062 for (const auto& property : sub_properties) {
1063 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
1064 current_properties_[property.first] = property.second * element_count;
1065 }
1066 }
1067 return Status::OK();
1068 }
1069
HandleGetDimensionSize(const HloInstruction *)1070 Status HloCostAnalysis::HandleGetDimensionSize(
1071 const HloInstruction* /*get_size*/) {
1072 return Status::OK();
1073 }
1074
HandleSetDimensionSize(const HloInstruction *)1075 Status HloCostAnalysis::HandleSetDimensionSize(
1076 const HloInstruction* /*set_size*/) {
1077 return Status::OK();
1078 }
1079
FinishVisit(const HloInstruction *)1080 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
1081 return Status::OK();
1082 }
1083
flop_count() const1084 float HloCostAnalysis::flop_count() const {
1085 return GetProperty(kFlopsKey, properties_sum_);
1086 }
1087
transcendental_count() const1088 float HloCostAnalysis::transcendental_count() const {
1089 return GetProperty(kTranscendentalsKey, properties_sum_);
1090 }
1091
bytes_accessed() const1092 float HloCostAnalysis::bytes_accessed() const {
1093 return GetProperty(kBytesAccessedKey, properties_sum_);
1094 }
1095
optimal_seconds() const1096 float HloCostAnalysis::optimal_seconds() const {
1097 return GetProperty(kOptimalSecondsKey, properties_sum_);
1098 }
1099
flop_count(const HloInstruction & hlo) const1100 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
1101 return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
1102 }
1103
transcendental_count(const HloInstruction & hlo) const1104 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
1105 return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
1106 }
1107
bytes_accessed(const HloInstruction & hlo) const1108 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
1109 return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
1110 }
1111
operand_bytes_accessed(const HloInstruction & hlo,int64_t operand_num,ShapeIndex index) const1112 int64 HloCostAnalysis::operand_bytes_accessed(const HloInstruction& hlo,
1113 int64_t operand_num,
1114 ShapeIndex index) const {
1115 return GetPropertyForHlo(hlo, GetOperandBytesAccessedKey(operand_num, index),
1116 hlo_properties_);
1117 }
1118
output_bytes_accessed(const HloInstruction & hlo,ShapeIndex index) const1119 int64 HloCostAnalysis::output_bytes_accessed(const HloInstruction& hlo,
1120 ShapeIndex index) const {
1121 return GetPropertyForHlo(hlo, GetOutputBytesAccessedKey(index),
1122 hlo_properties_);
1123 }
1124
optimal_seconds(const HloInstruction & hlo) const1125 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
1126 return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
1127 }
1128
GetBytesRead(const HloInstruction & hlo,absl::optional<int64> memory_space) const1129 int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo,
1130 absl::optional<int64> memory_space) const {
1131 int64_t bytes_read = 0;
1132 for (int operand_number = 0; operand_number < hlo.operand_count();
1133 ++operand_number) {
1134 for (const ShapeUtil::IndexedShape& indexed_shape :
1135 ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) {
1136 absl::optional<int64> index_memory_space;
1137 if (indexed_shape.shape.has_layout()) {
1138 index_memory_space = indexed_shape.shape.layout().memory_space();
1139 }
1140 if (!memory_space || memory_space == index_memory_space) {
1141 bytes_read +=
1142 operand_bytes_accessed(hlo, operand_number, indexed_shape.index);
1143 }
1144 }
1145 }
1146 return bytes_read;
1147 }
1148
GetBytesWritten(const HloInstruction & hlo,absl::optional<int64> memory_space) const1149 int64 HloCostAnalysis::GetBytesWritten(
1150 const HloInstruction& hlo, absl::optional<int64> memory_space) const {
1151 int64_t bytes_written = 0;
1152 for (const ShapeUtil::IndexedShape& indexed_shape :
1153 ShapeUtil::GetLeafShapes(hlo.shape())) {
1154 absl::optional<int64> index_memory_space;
1155 if (indexed_shape.shape.has_layout()) {
1156 index_memory_space = indexed_shape.shape.layout().memory_space();
1157 }
1158 if (!memory_space || memory_space == index_memory_space) {
1159 bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
1160 }
1161 }
1162 return bytes_written;
1163 }
1164
ProcessSubcomputation(HloComputation * computation)1165 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
1166 HloComputation* computation) {
1167 auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);
1168 visitor->ReserveVisitStates(computation->instruction_count());
1169 TF_RETURN_IF_ERROR(computation->Accept(visitor.get()));
1170 hlo_properties_.insert(visitor->hlo_properties_.begin(),
1171 visitor->hlo_properties_.end());
1172 return visitor->properties();
1173 }
1174
CreateNestedCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)1175 std::unique_ptr<HloCostAnalysis> HloCostAnalysis::CreateNestedCostAnalysis(
1176 const ShapeSizeFunction& shape_size, const Properties& per_second_rates) {
1177 return absl::WrapUnique(new HloCostAnalysis(shape_size, per_second_rates));
1178 }
1179
SetOperandBytesAccessed(int64_t operand_num,float value)1180 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1181 float value) {
1182 current_properties_[GetOperandBytesAccessedKey(operand_num).c_str()] = value;
1183 }
1184
SetOperandBytesAccessed(int64_t operand_num,ShapeIndex index,float value)1185 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1186 ShapeIndex index, float value) {
1187 current_properties_[GetOperandBytesAccessedKey(operand_num, index).c_str()] =
1188 value;
1189 }
1190
SetOutputBytesAccessed(float value)1191 void HloCostAnalysis::SetOutputBytesAccessed(float value) {
1192 current_properties_[GetOutputBytesAccessedKey()] = value;
1193 }
1194
SetOutputBytesAccessed(ShapeIndex index,float value)1195 void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
1196 current_properties_[GetOutputBytesAccessedKey(index)] = value;
1197 }
1198
GetOperandBytesAccessedKey(int64_t operand_num,ShapeIndex index)1199 /*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey(
1200 int64_t operand_num, ShapeIndex index) {
1201 return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ",
1202 index.ToString());
1203 }
1204
GetOutputBytesAccessedKey(ShapeIndex index)1205 /*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey(
1206 ShapeIndex index) {
1207 return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
1208 }
1209
1210 } // namespace xla
1211