1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17
18 #include <cmath>
19
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/lib/core/bits.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/map_util.h"
26
27 namespace xla {
28
29 constexpr char HloCostAnalysis::kFlopsKey[];
30 constexpr char HloCostAnalysis::kTranscendentalsKey[];
31 constexpr char HloCostAnalysis::kBytesAccessedKey[];
32 constexpr char HloCostAnalysis::kOptimalSecondsKey[];
33
HloCostAnalysis(const ShapeSizeFunction & shape_size)34 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
35 : HloCostAnalysis(shape_size, {}) {}
36
HloCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)37 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
38 const Properties& per_second_rates)
39 : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
40
Preprocess(const HloInstruction * hlo)41 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
42 // Set current instruction cost values to reasonable default values. Each
43 // handler can overwrite these values. In Postprocess, these values are
44 // accumulated and written to the per-instruction maps.
45 current_properties_.clear();
46 current_should_compute_bottleneck_time_ = true;
47
48 // The default number of bytes accessed for an instruction is the sum of the
49 // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
50 // handle opaque types.
51 float bytes_accessed = shape_size_(hlo->shape());
52 for (const HloInstruction* operand : hlo->operands()) {
53 bytes_accessed += shape_size_(operand->shape());
54 }
55 current_properties_[kBytesAccessedKey] = bytes_accessed;
56
57 return Status::OK();
58 }
59
Postprocess(const HloInstruction * hlo)60 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
61 if (current_should_compute_bottleneck_time_) {
62 // Compute the time as the time of the bottleneck, i.e. the slowest property
63 // given the per-second rate of each property.
64 float optimal_seconds = 0.0f;
65 for (const auto& property : current_properties_) {
66 if (property.first != kOptimalSecondsKey) {
67 optimal_seconds = std::max(
68 optimal_seconds,
69 property.second /
70 GetProperty(property.first, per_second_rates_, INFINITY));
71 }
72 }
73 current_properties_[kOptimalSecondsKey] = optimal_seconds;
74 }
75
76 TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
77 for (const auto& property : current_properties_) {
78 properties_sum_[property.first] += property.second;
79 }
80
81 return Status::OK();
82 }
83
HandleElementwiseOp(const HloInstruction * hlo_instruction)84 Status HloCostAnalysis::HandleElementwiseOp(
85 const HloInstruction* hlo_instruction) {
86 const auto& shape = hlo_instruction->shape();
87 // For element-wise operations, the number of computations is the same as the
88 // number of elements in the output shape.
89 auto computation_count = ShapeUtil::ElementsIn(shape);
90 auto opcode = hlo_instruction->opcode();
91 // We treat transcendental operations separately since one transcendental
92 // operation can correspond to several floating point ops.
93 if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower ||
94 opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin ||
95 opcode == HloOpcode::kCos) {
96 current_properties_[kTranscendentalsKey] = computation_count;
97 } else {
98 // Note: transcendental operations are considered a separate category from
99 // FLOPs.
100 current_properties_[kFlopsKey] = computation_count;
101 }
102 return Status::OK();
103 }
104
GetProperty(const string & key,const Properties & properties,const float default_value)105 /*static*/ float HloCostAnalysis::GetProperty(const string& key,
106 const Properties& properties,
107 const float default_value) {
108 auto key_value = properties.find(key);
109 return key_value == properties.end() ? default_value : key_value->second;
110 }
111
GetPropertyForHlo(const HloInstruction & hlo,const string & key,const HloToProperties & hlo_to_properties)112 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
113 const HloInstruction& hlo, const string& key,
114 const HloToProperties& hlo_to_properties) {
115 auto it = hlo_to_properties.find(&hlo);
116 if (it == hlo_to_properties.end()) {
117 return 0.0f;
118 } else {
119 return GetProperty(key, it->second);
120 }
121 }
122
HandleElementwiseUnary(const HloInstruction * hlo)123 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
124 return HandleElementwiseOp(hlo);
125 }
126
HandleElementwiseBinary(const HloInstruction * hlo)127 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
128 return HandleElementwiseOp(hlo);
129 }
130
HandleCompare(const HloInstruction * compare)131 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
132 return HandleElementwiseOp(compare);
133 }
134
HandleClamp(const HloInstruction * clamp)135 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
136 return HandleElementwiseOp(clamp);
137 }
138
HandleReducePrecision(const HloInstruction * hlo)139 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
140 return HandleElementwiseOp(hlo);
141 }
142
HandleParameter(const HloInstruction *)143 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
144 current_properties_[kBytesAccessedKey] = 0;
145 return Status::OK();
146 }
147
HandleConstant(const HloInstruction *)148 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
149 current_properties_[kBytesAccessedKey] = 0;
150 return Status::OK();
151 }
152
HandleGetTupleElement(const HloInstruction *)153 Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) {
154 // GetTupleElement forwards a pointer and does not touch each element in the
155 // output.
156 current_properties_[kBytesAccessedKey] = 0;
157 return Status::OK();
158 }
159
HandleSelect(const HloInstruction *)160 Status HloCostAnalysis::HandleSelect(const HloInstruction*) {
161 return Status::OK();
162 }
163
HandleReverse(const HloInstruction *)164 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
165 return Status::OK();
166 }
167
HandleSlice(const HloInstruction *)168 Status HloCostAnalysis::HandleSlice(const HloInstruction*) {
169 return Status::OK();
170 }
171
HandleDynamicSlice(const HloInstruction *)172 Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) {
173 return Status::OK();
174 }
175
HandleDynamicUpdateSlice(const HloInstruction *)176 Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) {
177 return Status::OK();
178 }
179
HandleTuple(const HloInstruction * tuple)180 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
181 // The tuple instruction only gathers pointers from inputs (it doesn't iterate
182 // through them). The memory touched is then only the size of the output
183 // index table of the tuple.
184
185 current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape());
186 return Status::OK();
187 }
188
HandleConcatenate(const HloInstruction *)189 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
190 return Status::OK();
191 }
192
HandleConvert(const HloInstruction * convert)193 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
194 return HandleElementwiseOp(convert);
195 }
196
HandleCopy(const HloInstruction *)197 Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
198 return Status::OK();
199 }
200
HandleDot(const HloInstruction * dot)201 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
202 const Shape& lhs_shape = dot->operand(0)->shape();
203 const Shape& rhs_shape = dot->operand(1)->shape();
204 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
205 // Count of elements along the reduction dimension (last dimension for the
206 // rhs).
207 int64 reduction_width =
208 lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0));
209 // First divide by reduction width before multiplying by rhs elements to avoid
210 // overflow.
211 int64 fma_count;
212 if (reduction_width == 0) {
213 fma_count = 0;
214 } else {
215 fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) *
216 ShapeUtil::ElementsIn(rhs_shape);
217 }
218
219 // We count an FMA operation as 2 floating point operations.
220 current_properties_[kFlopsKey] = kFmaFlops * fma_count;
221 return Status::OK();
222 }
223
HandleInfeed(const HloInstruction *)224 Status HloCostAnalysis::HandleInfeed(const HloInstruction*) {
225 return Status::OK();
226 }
227
HandleOutfeed(const HloInstruction *)228 Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
229 return Status::OK();
230 }
231
HandleHostCompute(const HloInstruction *)232 Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) {
233 return Status::OK();
234 }
235
HandleMap(const HloInstruction * map)236 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
237 // Compute properties of the mapped function.
238 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
239 ProcessSubcomputation(map->to_apply()));
240
241 // Compute the cost of all elements for this Map operation.
242 const int64 element_count = ShapeUtil::ElementsIn(map->shape());
243 for (const auto& property : sub_properties) {
244 if (property.first != kBytesAccessedKey) {
245 current_properties_[property.first] = property.second * element_count;
246 }
247 }
248 return Status::OK();
249 }
250
HandleReduce(const HloInstruction * reduce)251 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
252 auto arg = reduce->operand(0);
253 HloComputation* function = reduce->to_apply();
254 // Compute the cost of the user function.
255 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
256 ProcessSubcomputation(function));
257
258 // Compute the cost of all elements for this Reduce operation.
259 int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
260 ShapeUtil::ElementsIn(reduce->shape());
261 for (const auto& property : sub_properties) {
262 if (property.first != kBytesAccessedKey) {
263 current_properties_[property.first] = property.second * reduction_count;
264 }
265 }
266 return Status::OK();
267 }
268
HandleReduceWindow(const HloInstruction * reduce_window)269 Status HloCostAnalysis::HandleReduceWindow(
270 const HloInstruction* reduce_window) {
271 const Window& window = reduce_window->window();
272 auto function = reduce_window->to_apply();
273 // Compute the properties of the reduction function.
274 TF_ASSIGN_OR_RETURN(const Properties sub_properties,
275 ProcessSubcomputation(function));
276
277 // Compute the cost of all elements for this ReduceWindow operation. For each
278 // output element there are window_size - 1 reductions to perform.
279 int64 window_element_count = 1;
280 for (const auto& dimension : window.dimensions()) {
281 window_element_count *= dimension.size();
282 }
283 const int64 output_element_count =
284 ShapeUtil::ElementsIn(reduce_window->shape());
285 const int64 reduction_count =
286 (window_element_count - 1) * output_element_count;
287 for (const auto& property : sub_properties) {
288 if (property.first != kBytesAccessedKey) {
289 current_properties_[property.first] = property.second * reduction_count;
290 }
291 }
292 return Status::OK();
293 }
294
HandleSelectAndScatter(const HloInstruction * instruction)295 Status HloCostAnalysis::HandleSelectAndScatter(
296 const HloInstruction* instruction) {
297 // Compute the properties of the select and scatter function.
298 // Compute the properties of the reduction function.
299 TF_ASSIGN_OR_RETURN(const Properties select_properties,
300 ProcessSubcomputation(instruction->select()));
301 TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
302 ProcessSubcomputation(instruction->scatter()));
303
304 // Compute the cost of all elements for this operation. For each scatter
305 // source element there are window_size - 1 select computations to perform and
306 // 1 scatter computation to perform.
307 const auto source = instruction->operand(1);
308 const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
309 int64 window_element_count = 1;
310 for (const auto& dimension : instruction->window().dimensions()) {
311 window_element_count *= dimension.size();
312 }
313 const int64 select_count = source_element_count * (window_element_count - 1);
314 for (const auto& property : select_properties) {
315 if (property.first != kBytesAccessedKey) {
316 current_properties_[property.first] += property.second * select_count;
317 }
318 }
319 for (const auto& property : scatter_properties) {
320 if (property.first != kBytesAccessedKey) {
321 current_properties_[property.first] +=
322 property.second * source_element_count;
323 }
324 }
325 return Status::OK();
326 }
327
HandleBitcast(const HloInstruction *)328 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
329 // A bitcast does no computation and touches no memory.
330 current_properties_[kBytesAccessedKey] = 0;
331 return Status::OK();
332 }
333
HandleBroadcast(const HloInstruction *)334 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
335 return Status::OK();
336 }
337
HandlePad(const HloInstruction *)338 Status HloCostAnalysis::HandlePad(const HloInstruction*) {
339 return Status::OK();
340 }
341
HandleSend(const HloInstruction *)342 Status HloCostAnalysis::HandleSend(const HloInstruction*) {
343 return Status::OK();
344 }
345
HandleSendDone(const HloInstruction *)346 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
347 return Status::OK();
348 }
349
HandleRecv(const HloInstruction *)350 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
351 return Status::OK();
352 }
353
HandleRecvDone(const HloInstruction *)354 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
355 return Status::OK();
356 }
357
HandleReshape(const HloInstruction *)358 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
359 return Status::OK();
360 }
361
HandleBatchNormTraining(const HloInstruction *)362 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
363 // TODO(b/62294698): Implement cost analysis for batch-norm-training.
364 return Status::OK();
365 }
366
HandleBatchNormInference(const HloInstruction *)367 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
368 // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
369 return Status::OK();
370 }
371
HandleBatchNormGrad(const HloInstruction *)372 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
373 // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
374 return Status::OK();
375 }
376
HandleTranspose(const HloInstruction *)377 Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
378 return Status::OK();
379 }
380
HandleConvolution(const HloInstruction * convolution)381 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
382 auto rhs_instruction = convolution->operand(1);
383 const auto& dnums = convolution->convolution_dimension_numbers();
384 const int64 output_features =
385 convolution->shape().dimensions(dnums.output_feature_dimension());
386
387 // For each output element, we do one fma per element in the kernel at some
388 // given output feature index.
389 const int64 fmas_per_output_element =
390 output_features > 0
391 ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features
392 : 0;
393 const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
394 current_properties_[kFlopsKey] =
395 output_elements * fmas_per_output_element * kFmaFlops;
396 return Status::OK();
397 }
398
HandleFft(const HloInstruction * fft)399 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
400 auto real_shape =
401 ShapeUtil::IsTuple(fft->operand(0)->shape())
402 ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
403 : fft->operand(0)->shape();
404 constexpr int kFmaPerComplexMul = 4;
405 int64 log_factors = 1;
406 for (int64 dim : fft->fft_length()) {
407 log_factors *= tensorflow::Log2Floor(dim);
408 }
409 current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
410 ShapeUtil::ElementsIn(real_shape);
411 return Status::OK();
412 }
413
HandleCrossReplicaSum(const HloInstruction * crs)414 Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
415 // We assume 2 replicas, so that each output element is the sum of two input
416 // elements.
417 //
418 // TODO(b/33004697): Compute correct cost here, taking the actual number of
419 // replicas into account.
420 double flops = 0.0;
421 ShapeUtil::ForEachSubshape(
422 crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) {
423 if (ShapeUtil::IsArray(subshape)) {
424 flops += ShapeUtil::ElementsIn(subshape);
425 }
426 });
427 current_properties_[kFlopsKey] = flops;
428 return Status::OK();
429 }
430
HandleRng(const HloInstruction * random)431 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
432 // TODO(b/26346211): Implement better estimates for the RNG cost, since the
433 // cost changes with the implementation and the distribution. For now, assume
434 // the cost of each RNG is same as a transcendental operation.
435 current_properties_[kTranscendentalsKey] =
436 ShapeUtil::ElementsIn(random->shape());
437 return Status::OK();
438 }
439
HandleFusion(const HloInstruction * fusion)440 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
441 // Compute the properties of the fused expression and attribute them to the
442 // fusion node. Use a dummy shape_size to avoid any errors from trying to
443 // calculate the size of a shape that does not have a layout, since nodes
444 // inside fusion nodes do not necessarily have a layout assigned.
445 ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; };
446 TF_ASSIGN_OR_RETURN(
447 current_properties_,
448 ProcessSubcomputation(fusion->fused_instructions_computation(),
449 &shape_size));
450
451 // Fusion nodes that produce a tuple also produce the entries in the tuple.
452 // Ignore the memory accessed inside fused ops, since fusion is supposed to
453 // prevent intermediate data from touching slow memory.
454 current_properties_[kBytesAccessedKey] = 0;
455 ShapeUtil::ForEachSubshape(
456 fusion->shape(),
457 [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
458 current_properties_[kBytesAccessedKey] += shape_size_(subshape);
459 });
460
461 for (const HloInstruction* operand : fusion->operands()) {
462 current_properties_[kBytesAccessedKey] += shape_size_(operand->shape());
463 }
464
465 return Status::OK();
466 }
467
HandleCall(const HloInstruction * call)468 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
469 TF_ASSIGN_OR_RETURN(current_properties_,
470 ProcessSubcomputation(call->to_apply()));
471 current_should_compute_bottleneck_time_ = false;
472 return Status::OK();
473 }
474
HandleCustomCall(const HloInstruction *)475 Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) {
476 // We can't do anything sane with CustomCalls, since we don't know what they
477 // do, and returning an error status will stop iteration over this
478 // computation, which is probably also not what we want. So just punt and
479 // return OK. This will cause all of the properties to be reported as 0,
480 // which is fine.
481 current_should_compute_bottleneck_time_ = false;
482 return Status::OK();
483 }
484
HandleSort(const HloInstruction * sort)485 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
486 // This assumes a comparison based N*log(N) algorithm. As for all ops, the
487 // actual properties of the op depend on the backend implementation.
488 int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
489 current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
490 return Status::OK();
491 }
492
HandleWhile(const HloInstruction * xla_while)493 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
494 // Since the number of iterations of the while node will not always be
495 // something that we can statically analyze, we cannot precisely compute the
496 // cost of a while node. For now compute the cost of a single iteration.
497 //
498 // TODO(b/26346211): Improve the cost analysis for while nodes.
499 TF_ASSIGN_OR_RETURN(const Properties body_properties,
500 ProcessSubcomputation(xla_while->while_body()));
501
502 TF_ASSIGN_OR_RETURN(const Properties condition_properties,
503 ProcessSubcomputation(xla_while->while_condition()));
504
505 current_properties_.clear();
506 for (const auto& property : body_properties) {
507 current_properties_[property.first] += property.second;
508 }
509 for (const auto& property : condition_properties) {
510 current_properties_[property.first] += property.second;
511 }
512 current_should_compute_bottleneck_time_ = false;
513
514 return Status::OK();
515 }
516
HandleConditional(const HloInstruction * conditional)517 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
518 // Compute the cost of the true and false computations and take the maximum
519 // from those for each property.
520 TF_ASSIGN_OR_RETURN(const Properties true_computation_properties,
521 ProcessSubcomputation(conditional->true_computation()));
522 TF_ASSIGN_OR_RETURN(const Properties false_computation_properties,
523 ProcessSubcomputation(conditional->false_computation()));
524 current_properties_ = true_computation_properties;
525 for (const auto& property : false_computation_properties) {
526 if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) {
527 current_properties_[property.first] =
528 std::max(current_properties_[property.first], property.second);
529 }
530 }
531 current_should_compute_bottleneck_time_ = false;
532
533 return Status::OK();
534 }
535
HandleGather(const HloInstruction * gather)536 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
537 // Gather does not issue any flops.
538 return Status::OK();
539 }
540
FinishVisit(const HloInstruction *)541 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
542 return Status::OK();
543 }
544
flop_count() const545 float HloCostAnalysis::flop_count() const {
546 return GetProperty(kFlopsKey, properties_sum_);
547 }
548
transcendental_count() const549 float HloCostAnalysis::transcendental_count() const {
550 return GetProperty(kTranscendentalsKey, properties_sum_);
551 }
552
bytes_accessed() const553 float HloCostAnalysis::bytes_accessed() const {
554 return GetProperty(kBytesAccessedKey, properties_sum_);
555 }
556
optimal_seconds() const557 float HloCostAnalysis::optimal_seconds() const {
558 return GetProperty(kOptimalSecondsKey, properties_sum_);
559 }
560
flop_count(const HloInstruction & hlo) const561 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
562 return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
563 }
564
transcendental_count(const HloInstruction & hlo) const565 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
566 return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
567 }
568
bytes_accessed(const HloInstruction & hlo) const569 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
570 return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
571 }
572
optimal_seconds(const HloInstruction & hlo) const573 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
574 return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
575 }
576
ProcessSubcomputation(HloComputation * computation,const ShapeSizeFunction * shape_size)577 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
578 HloComputation* computation, const ShapeSizeFunction* shape_size) {
579 if (shape_size == nullptr) {
580 shape_size = &shape_size_;
581 }
582 HloCostAnalysis visitor(*shape_size, per_second_rates_);
583 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
584 return visitor.properties();
585 }
586
587 } // namespace xla
588