1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17
18 #include <cmath>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/core/lib/core/bits.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32
33 namespace xla {
34
35 constexpr const char HloCostAnalysis::kFlopsKey[];
36 constexpr const char HloCostAnalysis::kTranscendentalsKey[];
37 constexpr const char HloCostAnalysis::kBytesAccessedKey[];
38 constexpr const char HloCostAnalysis::kOptimalSecondsKey[];
39
HloCostAnalysis(const ShapeSizeFunction & shape_size)40 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
41 : HloCostAnalysis(shape_size, {}) {}
42
HloCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)43 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
44 const Properties& per_second_rates)
45 : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
46
Preprocess(const HloInstruction * hlo)47 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
48 // Set current instruction cost values to reasonable default values. Each
49 // handler can overwrite these values. In Postprocess, these values are
50 // accumulated and written to the per-instruction maps.
51 current_properties_.clear();
52 current_should_compute_bottleneck_time_ = true;
53
54 // The default number of bytes accessed for an instruction is the sum of the
55 // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
56 // handle opaque types.
57 float bytes_accessed = GetShapeSize(hlo->shape());
58 SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
59 for (int64 i = 0; i < hlo->operand_count(); ++i) {
60 const HloInstruction* operand = hlo->operand(i);
61 bytes_accessed += GetShapeSize(operand->shape());
62 SetOperandBytesAccessed(i, GetShapeSize(operand->shape()));
63 }
64 current_properties_[kBytesAccessedKey] = bytes_accessed;
65
66 return Status::OK();
67 }
68
Postprocess(const HloInstruction * hlo)69 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
70 if (current_should_compute_bottleneck_time_) {
71 // Compute the time as the time of the bottleneck, i.e. the slowest property
72 // given the per-second rate of each property.
73 float optimal_seconds = 0.0f;
74 for (const auto& property : current_properties_) {
75 if (property.first != kOptimalSecondsKey) {
76 optimal_seconds = std::max(
77 optimal_seconds,
78 property.second /
79 GetProperty(property.first, per_second_rates_, INFINITY));
80 }
81 }
82 current_properties_[kOptimalSecondsKey] = optimal_seconds;
83 }
84
85 TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
86 for (const auto& property : current_properties_) {
87 properties_sum_[property.first] += property.second;
88 }
89
90 return Status::OK();
91 }
92
HandleElementwiseOp(const HloInstruction * hlo_instruction)93 Status HloCostAnalysis::HandleElementwiseOp(
94 const HloInstruction* hlo_instruction) {
95 const auto& shape = hlo_instruction->shape();
96 // For element-wise operations, the number of computations is the same as the
97 // number of elements in the output shape.
98 auto computation_count = ShapeUtil::ElementsIn(shape);
99 auto opcode = hlo_instruction->opcode();
100 // We treat transcendental operations separately since one transcendental
101 // operation can correspond to several floating point ops.
102 // kLogistic is included in "trascendental" as it is implemented using
103 // trascendental ops (tanh or exp).
104 if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
105 opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
106 opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt ||
107 opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
108 opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
109 opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
110 opcode == HloOpcode::kAtan2) {
111 current_properties_[kTranscendentalsKey] = computation_count;
112 } else {
113 // Note: transcendental operations are considered a separate category from
114 // FLOPs.
115 current_properties_[kFlopsKey] = computation_count;
116 }
117 return Status::OK();
118 }
119
GetProperty(const string & key,const Properties & properties,const float default_value)120 /*static*/ float HloCostAnalysis::GetProperty(const string& key,
121 const Properties& properties,
122 const float default_value) {
123 auto key_value = properties.find(key);
124 return key_value == properties.end() ? default_value : key_value->second;
125 }
126
GetPropertyForHlo(const HloInstruction & hlo,const string & key,const HloToProperties & hlo_to_properties)127 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
128 const HloInstruction& hlo, const string& key,
129 const HloToProperties& hlo_to_properties) {
130 auto it = hlo_to_properties.find(&hlo);
131 if (it == hlo_to_properties.end()) {
132 return 0.0f;
133 } else {
134 return GetProperty(key, it->second);
135 }
136 }
137
GetShapeSize(const Shape & shape) const138 int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
139 if (!LayoutUtil::HasLayout(shape)) {
140 return 0;
141 }
142 return shape_size_(shape);
143 }
144
FusionParameterReadBytes(const HloInstruction * hlo) const145 int64 HloCostAnalysis::FusionParameterReadBytes(
146 const HloInstruction* hlo) const {
147 int64 size = 0;
148 bool seen_trivial_user = false;
149 CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
150 hlo->opcode() == HloOpcode::kGetTupleElement));
151 for (const HloInstruction* user : hlo->users()) {
152 switch (user->opcode()) {
153 case HloOpcode::kFusion: {
154 for (int64 idx : user->OperandIndices(hlo)) {
155 size += FusionParameterReadBytes(user->fused_parameter(idx));
156 }
157 break;
158 }
159 case HloOpcode::kSlice:
160 size += GetShapeSize(user->shape());
161 break;
162 case HloOpcode::kDynamicSlice:
163 size += hlo == user->operand(0) ? GetShapeSize(user->shape())
164 : GetShapeSize(hlo->shape());
165 break;
166 case HloOpcode::kDynamicUpdateSlice:
167 // Uses the same shape as 'update' which is operand 1.
168 size += hlo == user->operand(0)
169 ? GetShapeSize(user->operand(1)->shape())
170 : GetShapeSize(hlo->shape());
171 break;
172 case HloOpcode::kBroadcast:
173 case HloOpcode::kReshape:
174 size += GetShapeSize(hlo->shape());
175 break;
176 default:
177 // Other instructions reading this parameter are assumed to be able to
178 // share the read from memory.
179 if (!seen_trivial_user) {
180 seen_trivial_user = true;
181 size += GetShapeSize(hlo->shape());
182 }
183 }
184 }
185 return size;
186 }
187
HandleElementwiseUnary(const HloInstruction * hlo)188 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
189 return HandleElementwiseOp(hlo);
190 }
191
HandleElementwiseBinary(const HloInstruction * hlo)192 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
193 return HandleElementwiseOp(hlo);
194 }
195
HandleCompare(const HloInstruction * compare)196 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
197 return HandleElementwiseOp(compare);
198 }
199
HandleClamp(const HloInstruction * clamp)200 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
201 return HandleElementwiseOp(clamp);
202 }
203
HandleReducePrecision(const HloInstruction * hlo)204 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
205 return HandleElementwiseOp(hlo);
206 }
207
HandleParameter(const HloInstruction *)208 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
209 current_should_compute_bottleneck_time_ = false;
210 current_properties_[kBytesAccessedKey] = 0;
211 SetOutputBytesAccessed(0);
212 current_properties_[kOptimalSecondsKey] = 0;
213 return Status::OK();
214 }
215
HandleConstant(const HloInstruction *)216 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
217 current_should_compute_bottleneck_time_ = false;
218 current_properties_[kBytesAccessedKey] = 0;
219 SetOutputBytesAccessed(0);
220 current_properties_[kOptimalSecondsKey] = 0;
221 return Status::OK();
222 }
223
HandleIota(const HloInstruction *)224 Status HloCostAnalysis::HandleIota(const HloInstruction*) {
225 return Status::OK();
226 }
227
HandleGetTupleElement(const HloInstruction * get_tuple_element)228 Status HloCostAnalysis::HandleGetTupleElement(
229 const HloInstruction* get_tuple_element) {
230 // GetTupleElement forwards a pointer and does not touch each element in the
231 // output.
232 current_should_compute_bottleneck_time_ = false;
233 current_properties_[kBytesAccessedKey] = 0;
234 SetOutputBytesAccessed(0);
235 SetOperandBytesAccessed(0, 0);
236 current_properties_[kOptimalSecondsKey] = 0;
237 return Status::OK();
238 }
239
HandleSelect(const HloInstruction * hlo)240 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
241 return HandleElementwiseOp(hlo);
242 }
243
HandleTupleSelect(const HloInstruction *)244 Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) {
245 return Status::OK();
246 }
247
HandleReverse(const HloInstruction *)248 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
249 return Status::OK();
250 }
251
HandleSlice(const HloInstruction * slice)252 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
253 current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
254 SetOutputBytesAccessed(GetShapeSize(slice->shape()));
255 SetOperandBytesAccessed(0, GetShapeSize(slice->shape()));
256 return Status::OK();
257 }
258
HandleDynamicSlice(const HloInstruction * dynamic_slice)259 Status HloCostAnalysis::HandleDynamicSlice(
260 const HloInstruction* dynamic_slice) {
261 current_properties_[kBytesAccessedKey] =
262 GetShapeSize(dynamic_slice->shape()) * 2 +
263 GetShapeSize(dynamic_slice->operand(1)->shape());
264 SetOutputBytesAccessed(GetShapeSize(dynamic_slice->shape()));
265 SetOperandBytesAccessed(0, GetShapeSize(dynamic_slice->shape()));
266 SetOperandBytesAccessed(1, GetShapeSize(dynamic_slice->operand(1)->shape()));
267 return Status::OK();
268 }
269
HandleDynamicUpdateSlice(const HloInstruction * dynamic_update_slice)270 Status HloCostAnalysis::HandleDynamicUpdateSlice(
271 const HloInstruction* dynamic_update_slice) {
272 current_properties_[kBytesAccessedKey] =
273 GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2 +
274 GetShapeSize(dynamic_update_slice->operand(2)->shape());
275 // Operand 0 aliases with the output.
276 SetOutputBytesAccessed(
277 GetShapeSize(dynamic_update_slice->operand(1)->shape()));
278 SetOperandBytesAccessed(0, 0);
279 SetOperandBytesAccessed(
280 1, GetShapeSize(dynamic_update_slice->operand(1)->shape()));
281 SetOperandBytesAccessed(
282 2, GetShapeSize(dynamic_update_slice->operand(2)->shape()));
283 return Status::OK();
284 }
285
HandleTuple(const HloInstruction * tuple)286 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
287 // The tuple instruction only gathers pointers from inputs (it doesn't iterate
288 // through them). The memory touched is then only the size of the output
289 // index table of the tuple.
290
291 current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
292 SetOutputBytesAccessed(GetShapeSize(tuple->shape()));
293 for (int i = 0; i < tuple->operand_count(); ++i) {
294 SetOperandBytesAccessed(i, 0);
295 }
296 return Status::OK();
297 }
298
HandleConcatenate(const HloInstruction *)299 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
300 return Status::OK();
301 }
302
HandleConvert(const HloInstruction * convert)303 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
304 return HandleElementwiseOp(convert);
305 }
306
HandleCopy(const HloInstruction *)307 Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
308 return Status::OK();
309 }
310
HandleDomain(const HloInstruction * domain)311 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
312 // Domain does not have any computation or data transfer.
313 current_should_compute_bottleneck_time_ = false;
314 current_properties_[kBytesAccessedKey] = 0;
315 SetOutputBytesAccessed(0);
316 for (int i = 0; i < domain->operand_count(); ++i) {
317 SetOperandBytesAccessed(i, 0);
318 }
319 current_properties_[kOptimalSecondsKey] = 0;
320 return Status::OK();
321 }
322
HandleDot(const HloInstruction * dot)323 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
324 const Shape& lhs_shape = dot->operand(0)->shape();
325 const Shape& dot_shape = dot->shape();
326 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
327 // Count of elements along the reduction dimension (last dimension for the
328 // rhs).
329 int64 reduction_width = 1;
330 for (auto dim : dnums.lhs_contracting_dimensions()) {
331 reduction_width *= lhs_shape.dimensions(dim);
332 }
333 // Each output element requires reduction_width FMA operations.
334 current_properties_[kFlopsKey] =
335 kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width;
336 return Status::OK();
337 }
338
HandleInfeed(const HloInstruction * infeed)339 Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
340 // Count nested infeed output tuples.
341 int64 size = 0;
342 for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
343 size += GetShapeSize(indexed_shape.shape);
344 SetOutputBytesAccessed(indexed_shape.index,
345 GetShapeSize(indexed_shape.shape));
346 }
347 SetOutputBytesAccessed(size);
348 current_properties_[kBytesAccessedKey] = size;
349 return Status::OK();
350 }
351
HandleOutfeed(const HloInstruction * outfeed)352 Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
353 // Count nested outfeed operand tuples.
354 current_properties_[kBytesAccessedKey] = 0;
355 for (int64 i = 0; i < outfeed->operand_count(); ++i) {
356 const HloInstruction* operand = outfeed->operand(i);
357 int64 size = 0;
358 for (const auto& indexed_shape :
359 ShapeUtil::GetLeafShapes(operand->shape())) {
360 size += GetShapeSize(indexed_shape.shape);
361 SetOperandBytesAccessed(i, indexed_shape.index,
362 GetShapeSize(indexed_shape.shape));
363 }
364 SetOperandBytesAccessed(i, size);
365 current_properties_[kBytesAccessedKey] += size;
366 }
367 return Status::OK();
368 }
369
HandleMap(const HloInstruction * map)370 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
371 // Compute properties of the mapped function.
372 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
373 ProcessSubcomputation(map->to_apply()));
374
375 // Compute the cost of all elements for this Map operation.
376 const int64 element_count = ShapeUtil::ElementsIn(map->shape());
377 for (const auto& property : sub_properties) {
378 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
379 current_properties_[property.first] = property.second * element_count;
380 }
381 }
382 return Status::OK();
383 }
384
HandleReduce(const HloInstruction * reduce)385 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
386 HloComputation* function = reduce->to_apply();
387 // Compute the cost of the user function.
388 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
389 ProcessSubcomputation(function));
390
391 // Compute the cost of all elements for this Reduce operation.
392 // This counts the number of times the reduction function is applied, so it
393 // does not need to be multiplied by the number of input tensors - that's
394 // already "priced in" by the sub-computation doing more work.
395 auto arg = reduce->operand(0);
396 auto output_shape = reduce->shape().IsArray()
397 ? reduce->shape()
398 : reduce->shape().tuple_shapes(0);
399 int64 reduction_count =
400 ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
401 for (const auto& property : sub_properties) {
402 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
403 current_properties_[property.first] = property.second * reduction_count;
404 }
405 }
406 return Status::OK();
407 }
408
HandleReduceWindow(const HloInstruction * reduce_window)409 Status HloCostAnalysis::HandleReduceWindow(
410 const HloInstruction* reduce_window) {
411 const Window& window = reduce_window->window();
412 auto function = reduce_window->to_apply();
413 // Compute the properties of the reduction function.
414 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
415 ProcessSubcomputation(function));
416
417 // Compute the cost of all elements for this ReduceWindow operation. For each
418 // output element there are window_size - 1 reductions to perform.
419 int64 window_element_count = 1;
420 for (const auto& dimension : window.dimensions()) {
421 window_element_count *= dimension.size();
422 }
423
424 const int64 output_element_count =
425 ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
426 ? reduce_window->shape()
427 : reduce_window->shape().tuple_shapes(0));
428 const int64 reduction_count =
429 (window_element_count - 1) * output_element_count;
430 for (const auto& property : sub_properties) {
431 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
432 current_properties_[property.first] = property.second * reduction_count;
433 }
434 }
435 return Status::OK();
436 }
437
HandleSelectAndScatter(const HloInstruction * instruction)438 Status HloCostAnalysis::HandleSelectAndScatter(
439 const HloInstruction* instruction) {
440 // Compute the properties of the select and scatter function.
441 // Compute the properties of the reduction function.
442 TF_ASSIGN_OR_RETURN(const Properties select_properties,
443 ProcessSubcomputation(instruction->select()));
444 TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
445 ProcessSubcomputation(instruction->scatter()));
446
447 // Compute the cost of all elements for this operation. For each scatter
448 // source element there are window_size - 1 select computations to perform and
449 // 1 scatter computation to perform.
450 const auto source = instruction->operand(1);
451 const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
452 int64 window_element_count = 1;
453 for (const auto& dimension : instruction->window().dimensions()) {
454 window_element_count *= dimension.size();
455 }
456 const int64 select_count = source_element_count * (window_element_count - 1);
457 for (const auto& property : select_properties) {
458 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
459 current_properties_[property.first] += property.second * select_count;
460 }
461 }
462 for (const auto& property : scatter_properties) {
463 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
464 current_properties_[property.first] +=
465 property.second * source_element_count;
466 }
467 }
468 return Status::OK();
469 }
470
HandleBitcast(const HloInstruction *)471 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
472 // A bitcast does no computation and touches no memory.
473 current_properties_[kBytesAccessedKey] = 0;
474 SetOutputBytesAccessed(0);
475 SetOperandBytesAccessed(0, 0);
476 current_properties_[kOptimalSecondsKey] = 0;
477 return Status::OK();
478 }
479
HandleBroadcast(const HloInstruction *)480 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
481 return Status::OK();
482 }
483
HandlePad(const HloInstruction *)484 Status HloCostAnalysis::HandlePad(const HloInstruction*) {
485 return Status::OK();
486 }
487
HandleCopyStart(const HloInstruction *)488 Status HloCostAnalysis::HandleCopyStart(const HloInstruction*) {
489 return Status::OK();
490 }
491
HandleCopyDone(const HloInstruction *)492 Status HloCostAnalysis::HandleCopyDone(const HloInstruction*) {
493 return Status::OK();
494 }
495
HandleSend(const HloInstruction *)496 Status HloCostAnalysis::HandleSend(const HloInstruction*) {
497 return Status::OK();
498 }
499
HandleSendDone(const HloInstruction *)500 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
501 return Status::OK();
502 }
503
HandleRecv(const HloInstruction *)504 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
505 return Status::OK();
506 }
507
HandleRecvDone(const HloInstruction *)508 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
509 return Status::OK();
510 }
511
HandleReshape(const HloInstruction *)512 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
513 return Status::OK();
514 }
515
HandleDynamicReshape(const HloInstruction *)516 Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
517 return Status::OK();
518 }
519
HandleBatchNormTraining(const HloInstruction *)520 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
521 // TODO(b/62294698): Implement cost analysis for batch-norm-training.
522 return Status::OK();
523 }
524
HandleBatchNormInference(const HloInstruction *)525 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
526 // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
527 return Status::OK();
528 }
529
HandleBatchNormGrad(const HloInstruction *)530 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
531 // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
532 return Status::OK();
533 }
534
HandleTranspose(const HloInstruction * transpose)535 Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) {
536 if (transpose->IsEffectiveBitcast()) {
537 return HandleBitcast(transpose);
538 }
539 return Status::OK();
540 }
541
HandleAfterAll(const HloInstruction * token)542 Status HloCostAnalysis::HandleAfterAll(const HloInstruction* token) {
543 // This instruction is used to enforce ordering at compile time. No code is
544 // emitted.
545 current_should_compute_bottleneck_time_ = false;
546 current_properties_[kBytesAccessedKey] = 0;
547 SetOutputBytesAccessed(0);
548 for (int i = 0; i < token->operand_count(); ++i) {
549 SetOperandBytesAccessed(i, 0);
550 }
551 current_properties_[kOptimalSecondsKey] = 0;
552 return Status::OK();
553 }
554
HandleAddDependency(const HloInstruction * add_dependency)555 Status HloCostAnalysis::HandleAddDependency(
556 const HloInstruction* add_dependency) {
557 // This instruction is used to enforce ordering at compile time. No code is
558 // emitted.
559 current_should_compute_bottleneck_time_ = false;
560 current_properties_[kBytesAccessedKey] = 0;
561 SetOutputBytesAccessed(0);
562 for (int i = 0; i < add_dependency->operand_count(); ++i) {
563 SetOperandBytesAccessed(i, 0);
564 }
565 current_properties_[kOptimalSecondsKey] = 0;
566 return Status::OK();
567 }
568
HandleConvolution(const HloInstruction * convolution)569 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
570 auto lhs = convolution->operand(0);
571 auto rhs = convolution->operand(1);
572 Window window = convolution->window();
573 const auto& result_shape = convolution->shape();
574 const Shape& lhs_shape = lhs->shape();
575 const Shape& rhs_shape = rhs->shape();
576
577 const auto& dnums = convolution->convolution_dimension_numbers();
578
579 const int64 input_batch_dim = dnums.input_batch_dimension();
580 const int64 input_feature_dim = dnums.input_feature_dimension();
581 const int64 output_feature_dim = dnums.output_feature_dimension();
582 const int64 input_feature =
583 ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
584 const int64 output_feature =
585 ShapeUtil::GetDimension(result_shape, output_feature_dim);
586 const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
587
588 DimensionVector kernel_limits;
589 DimensionVector output_limits;
590 DimensionVector input_limits;
591 if (window.dimensions().empty()) {
592 window = window_util::MakeWindow({1});
593 kernel_limits.push_back(1);
594 output_limits.push_back(1);
595 input_limits.push_back(1);
596 } else {
597 for (int64 spatial_dimension = 0;
598 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
599 // Spatial dimension number for kernel (rhs).
600 const int64 kernel_spatial_dim =
601 dnums.kernel_spatial_dimensions(spatial_dimension);
602 const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
603 kernel_limits.push_back(kernel_limit);
604
605 // Spatial dimension number for output.
606 const int64 output_spatial_dim =
607 dnums.output_spatial_dimensions(spatial_dimension);
608 const int64 output_limit = result_shape.dimensions(output_spatial_dim);
609 output_limits.push_back(output_limit);
610
611 // Spatial dimension number for input (lhs).
612 const int64 input_spatial_dim =
613 dnums.input_spatial_dimensions(spatial_dimension);
614 const int64 input_limit = lhs_shape.dimensions(input_spatial_dim);
615 input_limits.push_back(input_limit);
616 }
617 }
618
619 DimensionVector valid_position_counts;
620
621 // Loop over each spatial dimension.
622 for (int64 spatial_dimension = 0;
623 spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
624 const auto& window_dim = window.dimensions(spatial_dimension);
625 // These two conditions will create an N^2 iteration pattern with only N
626 // valid elements. This is a performance optimization and produces the same
627 // result as the whole loop.
628 if (input_limits[spatial_dimension] == output_limits[spatial_dimension] &&
629 kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
630 input_limits[spatial_dimension] == window_dim.base_dilation() &&
631 window_dim.window_dilation() == 1 &&
632 std::max<int64>(1, input_limits[spatial_dimension] - 1) ==
633 window_dim.stride() &&
634 window_dim.padding_low() == 0 && window_dim.padding_high() == 0) {
635 valid_position_counts.push_back(input_limits[spatial_dimension]);
636 continue;
637 }
638
639 if (input_limits[spatial_dimension] == 1 &&
640 kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
641 window_dim.window_dilation() == 1 && window_dim.base_dilation() == 1 &&
642 window_dim.stride() == 1 &&
643 window_dim.padding_high() == output_limits[spatial_dimension] - 1 &&
644 window_dim.padding_low() == output_limits[spatial_dimension] - 1) {
645 valid_position_counts.push_back(output_limits[spatial_dimension]);
646 continue;
647 }
648
649 int64 valid_position_count = 0;
650 // Loop over each point in the kernel.
651 for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
652 ++kernel_idx) {
653 // Loop over each point in the output.
654 for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension];
655 ++output_idx) {
656 // Calculate lhs (input) index without taking base dilation into
657 // account.
658 const int64 undilated_index = output_idx * window_dim.stride() -
659 window_dim.padding_low() +
660 kernel_idx * window_dim.window_dilation();
661
662 // Calculate the actual lhs (input) index after dilation. Avoid the
663 // division as an optimization.
664 const int64 lhs_spatial_index =
665 window_dim.base_dilation() > 1
666 ? undilated_index / window_dim.base_dilation()
667 : undilated_index;
668
669 // Skip if the lhs (input) index is to be dilated.
670 if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
671 continue;
672 }
673
674 // Skip if input index is not in bound.
675 if (lhs_spatial_index < 0 ||
676 lhs_spatial_index >= input_limits[spatial_dimension]) {
677 continue;
678 }
679
680 valid_position_count += 1;
681 }
682 }
683 valid_position_counts.push_back(valid_position_count);
684 }
685
686 const int64 fma_count = (input_feature / convolution->feature_group_count()) *
687 output_feature *
688 (batch / convolution->batch_group_count()) *
689 Product(valid_position_counts);
690 current_properties_[kFlopsKey] = fma_count * kFmaFlops;
691 return Status::OK();
692 }
693
HandleFft(const HloInstruction * fft)694 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
695 auto real_shape =
696 fft->operand(0)->shape().IsTuple()
697 ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
698 : fft->operand(0)->shape();
699 constexpr int kFmaPerComplexMul = 4;
700 int64 log_factors = 1;
701 for (int64 dim : fft->fft_length()) {
702 log_factors *= tensorflow::Log2Floor(dim);
703 }
704 current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
705 ShapeUtil::ElementsIn(real_shape);
706 return Status::OK();
707 }
708
HandleTriangularSolve(const HloInstruction * hlo)709 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
710 // Half of operand 0 is read.
711 float bytes_accessed = GetShapeSize(hlo->shape());
712 SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
713 bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
714 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
715 bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
716 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(1)->shape()));
717 current_properties_[kBytesAccessedKey] = bytes_accessed;
718
719 const Shape& a_shape = hlo->operand(0)->shape();
720 const Shape& b_shape = hlo->operand(1)->shape();
721 // Estimate as batch * mn^2 / 2 flops.
722 int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
723 elems *= ShapeUtil::ElementsIn(b_shape);
724 current_properties_[kFlopsKey] = kFmaFlops * elems;
725 return Status::OK();
726 }
727
HandleCholesky(const HloInstruction * hlo)728 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
729 // Half of operand 0 is read and half of the output will be written.
730 float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
731 SetOutputBytesAccessed(GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
732 bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
733 SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
734 current_properties_[kBytesAccessedKey] = bytes_accessed;
735
736 const Shape& a_shape = hlo->operand(0)->shape();
737 // Estimate as batch * n^3 / 3 flops.
738 int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
739 elems *= ShapeUtil::ElementsIn(a_shape);
740 current_properties_[kFlopsKey] = elems / 3;
741 return Status::OK();
742 }
743
HandleAllGather(const HloInstruction * hlo)744 Status HloCostAnalysis::HandleAllGather(const HloInstruction* hlo) {
745 return Status::OK();
746 }
747
HandleAllReduce(const HloInstruction * crs)748 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
749 // We assume 2 replicas, so that each output element is the sum of two input
750 // elements.
751 //
752 // TODO(b/33004697): Compute correct cost here, taking the actual number of
753 // replicas into account.
754 double flops = 0.0;
755 ShapeUtil::ForEachSubshape(crs->shape(),
756 [&](const Shape& subshape, const ShapeIndex&) {
757 if (subshape.IsArray()) {
758 flops += ShapeUtil::ElementsIn(subshape);
759 }
760 });
761 current_properties_[kFlopsKey] = flops;
762 return Status::OK();
763 }
764
HandleAllToAll(const HloInstruction * hlo)765 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
766 return Status::OK();
767 }
768
HandleCollectivePermute(const HloInstruction *)769 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
770 return Status::OK();
771 }
772
HandleCollectivePermuteStart(const HloInstruction *)773 Status HloCostAnalysis::HandleCollectivePermuteStart(
774 const HloInstruction* /*hlo*/) {
775 return Status::OK();
776 }
777
HandleCollectivePermuteDone(const HloInstruction *)778 Status HloCostAnalysis::HandleCollectivePermuteDone(
779 const HloInstruction* /*hlo*/) {
780 return Status::OK();
781 }
782
HandlePartitionId(const HloInstruction *)783 Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
784 return Status::OK();
785 }
786
HandleReplicaId(const HloInstruction *)787 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) {
788 return Status::OK();
789 }
790
HandleRng(const HloInstruction * random)791 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
792 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
793 // cost changes with the implementation and the distribution. For now, assume
794 // the cost of each RNG is same as a transcendental operation.
795 current_properties_[kTranscendentalsKey] =
796 ShapeUtil::ElementsIn(random->shape());
797 return Status::OK();
798 }
799
HandleRngBitGenerator(const HloInstruction * random)800 Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
801 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
802 // cost changes with the implementation and the distribution. For now, assume
803 // the cost of each RNG is same as a transcendental operation.
804 current_properties_[kTranscendentalsKey] =
805 ShapeUtil::ElementsInRecursive(random->shape());
806 return Status::OK();
807 }
808
HandleRngGetAndUpdateState(const HloInstruction * random)809 Status HloCostAnalysis::HandleRngGetAndUpdateState(
810 const HloInstruction* random) {
811 return Status::OK();
812 }
813
HandleFusion(const HloInstruction * fusion)814 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
815 if (fusion->IsCustomFusion()) {
816 for (const HloInstruction* hlo :
817 fusion->fused_instructions_computation()->instructions()) {
818 if (hlo->opcode() == HloOpcode::kGather) {
819 return HandleGather(hlo);
820 }
821 if (hlo->opcode() == HloOpcode::kScatter) {
822 return HandleScatter(hlo);
823 }
824 }
825 }
826 TF_ASSIGN_OR_RETURN(
827 current_properties_,
828 ProcessSubcomputation(fusion->fused_instructions_computation()));
829
830 // Fusion nodes that produce a tuple also produce the entries in the tuple.
831 // Ignore the memory accessed inside fused ops, since fusion is supposed to
832 // prevent intermediate data from touching slow memory.
833 current_properties_[kBytesAccessedKey] = 0;
834 ShapeUtil::ForEachSubshape(
835 fusion->shape(),
836 [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
837 if (!subshape.IsArray()) {
838 return;
839 }
840 if (shape_index.empty()) {
841 if (fusion->fused_expression_root()->opcode() ==
842 HloOpcode::kDynamicUpdateSlice) {
843 int64 size = GetShapeSize(
844 fusion->fused_expression_root()->operand(1)->shape());
845 current_properties_[kBytesAccessedKey] += size;
846 SetOutputBytesAccessed(shape_index, size);
847 return;
848 }
849 } else if (shape_index.size() == 1) {
850 if (fusion->fused_expression_root()->opcode() == HloOpcode::kTuple &&
851 fusion->fused_expression_root()
852 ->operand(shape_index[0])
853 ->opcode() == HloOpcode::kDynamicUpdateSlice) {
854 int64 size = GetShapeSize(fusion->fused_expression_root()
855 ->operand(shape_index[0])
856 ->operand(1)
857 ->shape());
858 current_properties_[kBytesAccessedKey] += size;
859 SetOutputBytesAccessed(shape_index, size);
860 return;
861 }
862 }
863 current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
864 SetOutputBytesAccessed(shape_index, GetShapeSize(subshape));
865 });
866
867 if (fusion->shape().IsTuple()) {
868 // Propagate and accumulate the output tuple bytes from the tuple subshapes.
869 // This ensures we have the correct output bytes accessed for the shape
870 // index
871 // {}.
872 std::function<float(const Shape&, const ShapeIndex&)>
873 propagate_output_size_to_parent;
874 propagate_output_size_to_parent = [&](const Shape& shape,
875 const ShapeIndex& shape_index) {
876 auto output_bytes_it =
877 current_properties_.find(GetOutputBytesAccessedKey(shape_index));
878 if (output_bytes_it != current_properties_.end()) {
879 return output_bytes_it->second;
880 }
881 float bytes_accessed = 0;
882 for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
883 const Shape& subshape = shape.tuple_shapes(i);
884 ShapeIndex subshape_index(shape_index);
885 subshape_index.push_back(i);
886 bytes_accessed +=
887 propagate_output_size_to_parent(subshape, subshape_index);
888 }
889 SetOutputBytesAccessed(shape_index, bytes_accessed);
890 return bytes_accessed;
891 };
892 current_properties_.erase(
893 current_properties_.find(GetOutputBytesAccessedKey()));
894 propagate_output_size_to_parent(fusion->shape(), {});
895 }
896
897 for (int64 i = 0; i < fusion->fused_parameters().size(); ++i) {
898 const HloInstruction* operand = fusion->fused_parameter(i);
899 int64 operand_size = 0;
900 if (!fusion->shape().IsTuple()) {
901 operand_size = FusionParameterReadBytes(operand);
902 } else {
903 // If the fusion parameter is a tuple type, find the gte for the leaf
904 // shape and calculate the bytes accessed for those array types.
905 for (const auto& indexed_shape :
906 ShapeUtil::GetLeafShapes(operand->shape())) {
907 const HloInstruction* gte = operand;
908 for (int64 index : indexed_shape.index) {
909 for (const HloInstruction* user : gte->users()) {
910 if (user->opcode() == HloOpcode::kGetTupleElement &&
911 user->tuple_index() == index) {
912 gte = user;
913 break;
914 }
915 }
916 }
917 int64 size = FusionParameterReadBytes(gte);
918 operand_size += size;
919 SetOperandBytesAccessed(i, indexed_shape.index, size);
920 }
921 }
922 current_properties_[kBytesAccessedKey] += operand_size;
923 SetOperandBytesAccessed(i, operand_size);
924 }
925
926 return Status::OK();
927 }
928
HandleCall(const HloInstruction * call)929 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
930 TF_ASSIGN_OR_RETURN(current_properties_,
931 ProcessSubcomputation(call->to_apply()));
932 current_should_compute_bottleneck_time_ = false;
933 return Status::OK();
934 }
935
HandleCustomCall(const HloInstruction * custom_call)936 Status HloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
937 // Mark applicable fields as "unknown", since we don't know what CustomCall
938 // does. This is better than returning an error, which would stop iteration,
939 // and therefore would prevent us from getting *any* stats for a computation
940 // which contains a CustomCall.
941 current_properties_[kOptimalSecondsKey] = -1;
942 current_properties_[kBytesAccessedKey] = -1;
943 SetOutputBytesAccessed(-1);
944 for (int i = 0; i < custom_call->operand_count(); ++i) {
945 SetOperandBytesAccessed(i, -1);
946 }
947 current_properties_[kFlopsKey] = -1;
948 current_should_compute_bottleneck_time_ = false;
949 return Status::OK();
950 }
951
HandleSort(const HloInstruction * sort)952 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
953 // This assumes a comparison based N*log(N) algorithm. As for all ops, the
954 // actual properties of the op depend on the backend implementation.
955 int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
956 current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
957 return Status::OK();
958 }
959
HandleWhile(const HloInstruction * xla_while)960 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
961 // Since the number of iterations of the while node will not always be
962 // something that we can statically analyze, we cannot precisely compute the
963 // cost of a while node. For now compute the cost of a single iteration.
964 TF_ASSIGN_OR_RETURN(const Properties body_properties,
965 ProcessSubcomputation(xla_while->while_body()));
966
967 TF_ASSIGN_OR_RETURN(const Properties condition_properties,
968 ProcessSubcomputation(xla_while->while_condition()));
969
970 current_properties_.clear();
971 for (const auto& property : body_properties) {
972 current_properties_[property.first] += property.second;
973 }
974 for (const auto& property : condition_properties) {
975 current_properties_[property.first] += property.second;
976 }
977 current_should_compute_bottleneck_time_ = false;
978
979 return Status::OK();
980 }
981
HandleConditional(const HloInstruction * conditional)982 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
983 // Compute the cost of the branch computations and take the maximum from those
984 // for each property.
985 TF_ASSIGN_OR_RETURN(
986 const Properties branch0_computation_properties,
987 ProcessSubcomputation(conditional->branch_computation(0)));
988 current_properties_ = branch0_computation_properties;
989 for (int j = 1; j < conditional->branch_count(); ++j) {
990 TF_ASSIGN_OR_RETURN(
991 const Properties branch_computation_properties,
992 ProcessSubcomputation(conditional->branch_computation(j)));
993 for (const auto& property : branch_computation_properties) {
994 if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_,
995 property)) {
996 auto& current_property = current_properties_[property.first];
997 current_property = std::max(current_property, property.second);
998 }
999 }
1000 }
1001 current_should_compute_bottleneck_time_ = false;
1002
1003 return Status::OK();
1004 }
1005
HandleGather(const HloInstruction * gather)1006 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
1007 // Gather doesn't read the whole input buffer, it's equivalent to a copy the
1008 // size of the output shape and a read of the gather indices.
1009 int64 output_size = GetShapeSize(gather->shape());
1010 current_properties_[kBytesAccessedKey] =
1011 output_size * 2 + GetShapeSize(gather->operand(1)->shape());
1012 SetOperandBytesAccessed(0, output_size);
1013 SetOperandBytesAccessed(1, GetShapeSize(gather->operand(1)->shape()));
1014 SetOutputBytesAccessed(output_size);
1015 // Gather does not issue any flops.
1016 return Status::OK();
1017 }
1018
HandleScatter(const HloInstruction * scatter)1019 Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
1020 // Scatter accesses the equivalent of 3 update shapes (input, output, and
1021 // updates), and the scatter indices.
1022 int64 update_size = GetShapeSize(scatter->operand(2)->shape());
1023 current_properties_[kBytesAccessedKey] =
1024 update_size * 3 + GetShapeSize(scatter->operand(1)->shape());
1025 SetOperandBytesAccessed(0, update_size);
1026 SetOperandBytesAccessed(1, GetShapeSize(scatter->operand(1)->shape()));
1027 SetOperandBytesAccessed(2, update_size);
1028 SetOutputBytesAccessed(update_size);
1029 const int64 element_count =
1030 ShapeUtil::ElementsIn(scatter->operand(2)->shape());
1031 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
1032 ProcessSubcomputation(scatter->to_apply()));
1033 for (const auto& property : sub_properties) {
1034 if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
1035 current_properties_[property.first] = property.second * element_count;
1036 }
1037 }
1038 return Status::OK();
1039 }
1040
HandleGetDimensionSize(const HloInstruction *)1041 Status HloCostAnalysis::HandleGetDimensionSize(
1042 const HloInstruction* /*get_size*/) {
1043 return Status::OK();
1044 }
1045
HandleSetDimensionSize(const HloInstruction *)1046 Status HloCostAnalysis::HandleSetDimensionSize(
1047 const HloInstruction* /*set_size*/) {
1048 return Status::OK();
1049 }
1050
FinishVisit(const HloInstruction *)1051 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
1052 return Status::OK();
1053 }
1054
flop_count() const1055 float HloCostAnalysis::flop_count() const {
1056 return GetProperty(kFlopsKey, properties_sum_);
1057 }
1058
transcendental_count() const1059 float HloCostAnalysis::transcendental_count() const {
1060 return GetProperty(kTranscendentalsKey, properties_sum_);
1061 }
1062
bytes_accessed() const1063 float HloCostAnalysis::bytes_accessed() const {
1064 return GetProperty(kBytesAccessedKey, properties_sum_);
1065 }
1066
optimal_seconds() const1067 float HloCostAnalysis::optimal_seconds() const {
1068 return GetProperty(kOptimalSecondsKey, properties_sum_);
1069 }
1070
flop_count(const HloInstruction & hlo) const1071 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
1072 return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
1073 }
1074
transcendental_count(const HloInstruction & hlo) const1075 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
1076 return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
1077 }
1078
bytes_accessed(const HloInstruction & hlo) const1079 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
1080 return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
1081 }
1082
operand_bytes_accessed(const HloInstruction & hlo,int64 operand_num,ShapeIndex index) const1083 int64 HloCostAnalysis::operand_bytes_accessed(const HloInstruction& hlo,
1084 int64 operand_num,
1085 ShapeIndex index) const {
1086 return GetPropertyForHlo(hlo, GetOperandBytesAccessedKey(operand_num, index),
1087 hlo_properties_);
1088 }
1089
output_bytes_accessed(const HloInstruction & hlo,ShapeIndex index) const1090 int64 HloCostAnalysis::output_bytes_accessed(const HloInstruction& hlo,
1091 ShapeIndex index) const {
1092 return GetPropertyForHlo(hlo, GetOutputBytesAccessedKey(index),
1093 hlo_properties_);
1094 }
1095
optimal_seconds(const HloInstruction & hlo) const1096 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
1097 return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
1098 }
1099
GetBytesRead(const HloInstruction & hlo,absl::optional<int64> memory_space) const1100 int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo,
1101 absl::optional<int64> memory_space) const {
1102 int64 bytes_read = 0;
1103 for (int operand_number = 0; operand_number < hlo.operand_count();
1104 ++operand_number) {
1105 for (const ShapeUtil::IndexedShape& indexed_shape :
1106 ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) {
1107 absl::optional<int64> index_memory_space;
1108 if (indexed_shape.shape.has_layout()) {
1109 index_memory_space = indexed_shape.shape.layout().memory_space();
1110 }
1111 if (!memory_space || memory_space == index_memory_space) {
1112 bytes_read +=
1113 operand_bytes_accessed(hlo, operand_number, indexed_shape.index);
1114 }
1115 }
1116 }
1117 return bytes_read;
1118 }
1119
GetBytesWritten(const HloInstruction & hlo,absl::optional<int64> memory_space) const1120 int64 HloCostAnalysis::GetBytesWritten(
1121 const HloInstruction& hlo, absl::optional<int64> memory_space) const {
1122 int64 bytes_written = 0;
1123 for (const ShapeUtil::IndexedShape& indexed_shape :
1124 ShapeUtil::GetLeafShapes(hlo.shape())) {
1125 absl::optional<int64> index_memory_space;
1126 if (indexed_shape.shape.has_layout()) {
1127 index_memory_space = indexed_shape.shape.layout().memory_space();
1128 }
1129 if (!memory_space || memory_space == index_memory_space) {
1130 bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
1131 }
1132 }
1133 return bytes_written;
1134 }
1135
ProcessSubcomputation(HloComputation * computation)1136 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
1137 HloComputation* computation) {
1138 auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);
1139 visitor->ReserveVisitStates(computation->instruction_count());
1140 TF_RETURN_IF_ERROR(computation->Accept(visitor.get()));
1141 hlo_properties_.insert(visitor->hlo_properties_.begin(),
1142 visitor->hlo_properties_.end());
1143 return visitor->properties();
1144 }
1145
CreateNestedCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)1146 std::unique_ptr<HloCostAnalysis> HloCostAnalysis::CreateNestedCostAnalysis(
1147 const ShapeSizeFunction& shape_size, const Properties& per_second_rates) {
1148 return absl::WrapUnique(new HloCostAnalysis(shape_size, per_second_rates));
1149 }
1150
SetOperandBytesAccessed(int64 operand_num,float value)1151 void HloCostAnalysis::SetOperandBytesAccessed(int64 operand_num, float value) {
1152 current_properties_[GetOperandBytesAccessedKey(operand_num).c_str()] = value;
1153 }
1154
SetOperandBytesAccessed(int64 operand_num,ShapeIndex index,float value)1155 void HloCostAnalysis::SetOperandBytesAccessed(int64 operand_num,
1156 ShapeIndex index, float value) {
1157 current_properties_[GetOperandBytesAccessedKey(operand_num, index).c_str()] =
1158 value;
1159 }
1160
SetOutputBytesAccessed(float value)1161 void HloCostAnalysis::SetOutputBytesAccessed(float value) {
1162 current_properties_[GetOutputBytesAccessedKey()] = value;
1163 }
1164
SetOutputBytesAccessed(ShapeIndex index,float value)1165 void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
1166 current_properties_[GetOutputBytesAccessedKey(index)] = value;
1167 }
1168
GetOperandBytesAccessedKey(int64 operand_num,ShapeIndex index)1169 /*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey(
1170 int64 operand_num, ShapeIndex index) {
1171 return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ",
1172 index.ToString());
1173 }
1174
GetOutputBytesAccessedKey(ShapeIndex index)1175 /*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey(
1176 ShapeIndex index) {
1177 return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
1178 }
1179
1180 } // namespace xla
1181