• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17 
18 #include <cmath>
19 
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/lib/core/bits.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/map_util.h"
26 
27 namespace xla {
28 
29 constexpr char HloCostAnalysis::kFlopsKey[];
30 constexpr char HloCostAnalysis::kTranscendentalsKey[];
31 constexpr char HloCostAnalysis::kBytesAccessedKey[];
32 constexpr char HloCostAnalysis::kOptimalSecondsKey[];
33 
HloCostAnalysis(const ShapeSizeFunction & shape_size)34 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
35     : HloCostAnalysis(shape_size, {}) {}
36 
HloCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)37 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
38                                  const Properties& per_second_rates)
39     : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
40 
Preprocess(const HloInstruction * hlo)41 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
42   // Set current instruction cost values to reasonable default values. Each
43   // handler can overwrite these values. In Postprocess, these values are
44   // accumulated and written to the per-instruction maps.
45   current_properties_.clear();
46   current_should_compute_bottleneck_time_ = true;
47 
48   // The default number of bytes accessed for an instruction is the sum of the
49   // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
50   // handle opaque types.
51   float bytes_accessed = shape_size_(hlo->shape());
52   for (const HloInstruction* operand : hlo->operands()) {
53     bytes_accessed += shape_size_(operand->shape());
54   }
55   current_properties_[kBytesAccessedKey] = bytes_accessed;
56 
57   return Status::OK();
58 }
59 
Postprocess(const HloInstruction * hlo)60 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
61   if (current_should_compute_bottleneck_time_) {
62     // Compute the time as the time of the bottleneck, i.e. the slowest property
63     // given the per-second rate of each property.
64     float optimal_seconds = 0.0f;
65     for (const auto& property : current_properties_) {
66       if (property.first != kOptimalSecondsKey) {
67         optimal_seconds = std::max(
68             optimal_seconds,
69             property.second /
70                 GetProperty(property.first, per_second_rates_, INFINITY));
71       }
72     }
73     current_properties_[kOptimalSecondsKey] = optimal_seconds;
74   }
75 
76   TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
77   for (const auto& property : current_properties_) {
78     properties_sum_[property.first] += property.second;
79   }
80 
81   return Status::OK();
82 }
83 
HandleElementwiseOp(const HloInstruction * hlo_instruction)84 Status HloCostAnalysis::HandleElementwiseOp(
85     const HloInstruction* hlo_instruction) {
86   const auto& shape = hlo_instruction->shape();
87   // For element-wise operations, the number of computations is the same as the
88   // number of elements in the output shape.
89   auto computation_count = ShapeUtil::ElementsIn(shape);
90   auto opcode = hlo_instruction->opcode();
91   // We treat transcendental operations separately since one transcendental
92   // operation can correspond to several floating point ops.
93   if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower ||
94       opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin ||
95       opcode == HloOpcode::kCos) {
96     current_properties_[kTranscendentalsKey] = computation_count;
97   } else {
98     // Note: transcendental operations are considered a separate category from
99     // FLOPs.
100     current_properties_[kFlopsKey] = computation_count;
101   }
102   return Status::OK();
103 }
104 
GetProperty(const string & key,const Properties & properties,const float default_value)105 /*static*/ float HloCostAnalysis::GetProperty(const string& key,
106                                               const Properties& properties,
107                                               const float default_value) {
108   auto key_value = properties.find(key);
109   return key_value == properties.end() ? default_value : key_value->second;
110 }
111 
GetPropertyForHlo(const HloInstruction & hlo,const string & key,const HloToProperties & hlo_to_properties)112 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
113     const HloInstruction& hlo, const string& key,
114     const HloToProperties& hlo_to_properties) {
115   auto it = hlo_to_properties.find(&hlo);
116   if (it == hlo_to_properties.end()) {
117     return 0.0f;
118   } else {
119     return GetProperty(key, it->second);
120   }
121 }
122 
HandleElementwiseUnary(const HloInstruction * hlo)123 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
124   return HandleElementwiseOp(hlo);
125 }
126 
HandleElementwiseBinary(const HloInstruction * hlo)127 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
128   return HandleElementwiseOp(hlo);
129 }
130 
HandleCompare(const HloInstruction * compare)131 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
132   return HandleElementwiseOp(compare);
133 }
134 
HandleClamp(const HloInstruction * clamp)135 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
136   return HandleElementwiseOp(clamp);
137 }
138 
HandleReducePrecision(const HloInstruction * hlo)139 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
140   return HandleElementwiseOp(hlo);
141 }
142 
HandleParameter(const HloInstruction *)143 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
144   current_properties_[kBytesAccessedKey] = 0;
145   return Status::OK();
146 }
147 
HandleConstant(const HloInstruction *)148 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
149   current_properties_[kBytesAccessedKey] = 0;
150   return Status::OK();
151 }
152 
HandleGetTupleElement(const HloInstruction *)153 Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) {
154   // GetTupleElement forwards a pointer and does not touch each element in the
155   // output.
156   current_properties_[kBytesAccessedKey] = 0;
157   return Status::OK();
158 }
159 
HandleSelect(const HloInstruction *)160 Status HloCostAnalysis::HandleSelect(const HloInstruction*) {
161   return Status::OK();
162 }
163 
HandleReverse(const HloInstruction *)164 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
165   return Status::OK();
166 }
167 
HandleSlice(const HloInstruction *)168 Status HloCostAnalysis::HandleSlice(const HloInstruction*) {
169   return Status::OK();
170 }
171 
HandleDynamicSlice(const HloInstruction *)172 Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) {
173   return Status::OK();
174 }
175 
HandleDynamicUpdateSlice(const HloInstruction *)176 Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) {
177   return Status::OK();
178 }
179 
HandleTuple(const HloInstruction * tuple)180 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
181   // The tuple instruction only gathers pointers from inputs (it doesn't iterate
182   // through them). The memory touched is then only the size of the output
183   // index table of the tuple.
184 
185   current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape());
186   return Status::OK();
187 }
188 
HandleConcatenate(const HloInstruction *)189 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
190   return Status::OK();
191 }
192 
HandleConvert(const HloInstruction * convert)193 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
194   return HandleElementwiseOp(convert);
195 }
196 
HandleCopy(const HloInstruction *)197 Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
198   return Status::OK();
199 }
200 
HandleDot(const HloInstruction * dot)201 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
202   const Shape& lhs_shape = dot->operand(0)->shape();
203   const Shape& rhs_shape = dot->operand(1)->shape();
204   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
205   // Count of elements along the reduction dimension (last dimension for the
206   // rhs).
207   int64 reduction_width =
208       lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0));
209   // First divide by reduction width before multiplying by rhs elements to avoid
210   // overflow.
211   int64 fma_count;
212   if (reduction_width == 0) {
213     fma_count = 0;
214   } else {
215     fma_count = (ShapeUtil::ElementsIn(lhs_shape) / reduction_width) *
216                 ShapeUtil::ElementsIn(rhs_shape);
217   }
218 
219   // We count an FMA operation as 2 floating point operations.
220   current_properties_[kFlopsKey] = kFmaFlops * fma_count;
221   return Status::OK();
222 }
223 
HandleInfeed(const HloInstruction *)224 Status HloCostAnalysis::HandleInfeed(const HloInstruction*) {
225   return Status::OK();
226 }
227 
HandleOutfeed(const HloInstruction *)228 Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
229   return Status::OK();
230 }
231 
HandleHostCompute(const HloInstruction *)232 Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) {
233   return Status::OK();
234 }
235 
HandleMap(const HloInstruction * map)236 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
237   // Compute properties of the mapped function.
238   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
239                       ProcessSubcomputation(map->to_apply()));
240 
241   // Compute the cost of all elements for this Map operation.
242   const int64 element_count = ShapeUtil::ElementsIn(map->shape());
243   for (const auto& property : sub_properties) {
244     if (property.first != kBytesAccessedKey) {
245       current_properties_[property.first] = property.second * element_count;
246     }
247   }
248   return Status::OK();
249 }
250 
HandleReduce(const HloInstruction * reduce)251 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
252   auto arg = reduce->operand(0);
253   HloComputation* function = reduce->to_apply();
254   // Compute the cost of the user function.
255   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
256                       ProcessSubcomputation(function));
257 
258   // Compute the cost of all elements for this Reduce operation.
259   int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
260                           ShapeUtil::ElementsIn(reduce->shape());
261   for (const auto& property : sub_properties) {
262     if (property.first != kBytesAccessedKey) {
263       current_properties_[property.first] = property.second * reduction_count;
264     }
265   }
266   return Status::OK();
267 }
268 
HandleReduceWindow(const HloInstruction * reduce_window)269 Status HloCostAnalysis::HandleReduceWindow(
270     const HloInstruction* reduce_window) {
271   const Window& window = reduce_window->window();
272   auto function = reduce_window->to_apply();
273   // Compute the properties of the reduction function.
274   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
275                       ProcessSubcomputation(function));
276 
277   // Compute the cost of all elements for this ReduceWindow operation. For each
278   // output element there are window_size - 1 reductions to perform.
279   int64 window_element_count = 1;
280   for (const auto& dimension : window.dimensions()) {
281     window_element_count *= dimension.size();
282   }
283   const int64 output_element_count =
284       ShapeUtil::ElementsIn(reduce_window->shape());
285   const int64 reduction_count =
286       (window_element_count - 1) * output_element_count;
287   for (const auto& property : sub_properties) {
288     if (property.first != kBytesAccessedKey) {
289       current_properties_[property.first] = property.second * reduction_count;
290     }
291   }
292   return Status::OK();
293 }
294 
HandleSelectAndScatter(const HloInstruction * instruction)295 Status HloCostAnalysis::HandleSelectAndScatter(
296     const HloInstruction* instruction) {
297   // Compute the properties of the select and scatter function.
298   // Compute the properties of the reduction function.
299   TF_ASSIGN_OR_RETURN(const Properties select_properties,
300                       ProcessSubcomputation(instruction->select()));
301   TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
302                       ProcessSubcomputation(instruction->scatter()));
303 
304   // Compute the cost of all elements for this operation. For each scatter
305   // source element there are window_size - 1 select computations to perform and
306   // 1 scatter computation to perform.
307   const auto source = instruction->operand(1);
308   const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
309   int64 window_element_count = 1;
310   for (const auto& dimension : instruction->window().dimensions()) {
311     window_element_count *= dimension.size();
312   }
313   const int64 select_count = source_element_count * (window_element_count - 1);
314   for (const auto& property : select_properties) {
315     if (property.first != kBytesAccessedKey) {
316       current_properties_[property.first] += property.second * select_count;
317     }
318   }
319   for (const auto& property : scatter_properties) {
320     if (property.first != kBytesAccessedKey) {
321       current_properties_[property.first] +=
322           property.second * source_element_count;
323     }
324   }
325   return Status::OK();
326 }
327 
HandleBitcast(const HloInstruction *)328 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
329   // A bitcast does no computation and touches no memory.
330   current_properties_[kBytesAccessedKey] = 0;
331   return Status::OK();
332 }
333 
HandleBroadcast(const HloInstruction *)334 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
335   return Status::OK();
336 }
337 
HandlePad(const HloInstruction *)338 Status HloCostAnalysis::HandlePad(const HloInstruction*) {
339   return Status::OK();
340 }
341 
HandleSend(const HloInstruction *)342 Status HloCostAnalysis::HandleSend(const HloInstruction*) {
343   return Status::OK();
344 }
345 
HandleSendDone(const HloInstruction *)346 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
347   return Status::OK();
348 }
349 
HandleRecv(const HloInstruction *)350 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
351   return Status::OK();
352 }
353 
HandleRecvDone(const HloInstruction *)354 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
355   return Status::OK();
356 }
357 
HandleReshape(const HloInstruction *)358 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
359   return Status::OK();
360 }
361 
HandleBatchNormTraining(const HloInstruction *)362 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
363   // TODO(b/62294698): Implement cost analysis for batch-norm-training.
364   return Status::OK();
365 }
366 
HandleBatchNormInference(const HloInstruction *)367 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
368   // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
369   return Status::OK();
370 }
371 
HandleBatchNormGrad(const HloInstruction *)372 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
373   // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
374   return Status::OK();
375 }
376 
HandleTranspose(const HloInstruction *)377 Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
378   return Status::OK();
379 }
380 
HandleConvolution(const HloInstruction * convolution)381 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
382   auto rhs_instruction = convolution->operand(1);
383   const auto& dnums = convolution->convolution_dimension_numbers();
384   const int64 output_features =
385       convolution->shape().dimensions(dnums.output_feature_dimension());
386 
387   // For each output element, we do one fma per element in the kernel at some
388   // given output feature index.
389   const int64 fmas_per_output_element =
390       output_features > 0
391           ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features
392           : 0;
393   const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
394   current_properties_[kFlopsKey] =
395       output_elements * fmas_per_output_element * kFmaFlops;
396   return Status::OK();
397 }
398 
HandleFft(const HloInstruction * fft)399 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
400   auto real_shape =
401       ShapeUtil::IsTuple(fft->operand(0)->shape())
402           ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
403           : fft->operand(0)->shape();
404   constexpr int kFmaPerComplexMul = 4;
405   int64 log_factors = 1;
406   for (int64 dim : fft->fft_length()) {
407     log_factors *= tensorflow::Log2Floor(dim);
408   }
409   current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
410                                    ShapeUtil::ElementsIn(real_shape);
411   return Status::OK();
412 }
413 
HandleCrossReplicaSum(const HloInstruction * crs)414 Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
415   // We assume 2 replicas, so that each output element is the sum of two input
416   // elements.
417   //
418   // TODO(b/33004697): Compute correct cost here, taking the actual number of
419   // replicas into account.
420   double flops = 0.0;
421   ShapeUtil::ForEachSubshape(
422       crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) {
423         if (ShapeUtil::IsArray(subshape)) {
424           flops += ShapeUtil::ElementsIn(subshape);
425         }
426       });
427   current_properties_[kFlopsKey] = flops;
428   return Status::OK();
429 }
430 
HandleRng(const HloInstruction * random)431 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
432   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
433   // cost changes with the implementation and the distribution. For now, assume
434   // the cost of each RNG is same as a transcendental operation.
435   current_properties_[kTranscendentalsKey] =
436       ShapeUtil::ElementsIn(random->shape());
437   return Status::OK();
438 }
439 
HandleFusion(const HloInstruction * fusion)440 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
441   // Compute the properties of the fused expression and attribute them to the
442   // fusion node. Use a dummy shape_size to avoid any errors from trying to
443   // calculate the size of a shape that does not have a layout, since nodes
444   // inside fusion nodes do not necessarily have a layout assigned.
445   ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; };
446   TF_ASSIGN_OR_RETURN(
447       current_properties_,
448       ProcessSubcomputation(fusion->fused_instructions_computation(),
449                             &shape_size));
450 
451   // Fusion nodes that produce a tuple also produce the entries in the tuple.
452   // Ignore the memory accessed inside fused ops, since fusion is supposed to
453   // prevent intermediate data from touching slow memory.
454   current_properties_[kBytesAccessedKey] = 0;
455   ShapeUtil::ForEachSubshape(
456       fusion->shape(),
457       [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) {
458         current_properties_[kBytesAccessedKey] += shape_size_(subshape);
459       });
460 
461   for (const HloInstruction* operand : fusion->operands()) {
462     current_properties_[kBytesAccessedKey] += shape_size_(operand->shape());
463   }
464 
465   return Status::OK();
466 }
467 
HandleCall(const HloInstruction * call)468 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
469   TF_ASSIGN_OR_RETURN(current_properties_,
470                       ProcessSubcomputation(call->to_apply()));
471   current_should_compute_bottleneck_time_ = false;
472   return Status::OK();
473 }
474 
HandleCustomCall(const HloInstruction *)475 Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) {
476   // We can't do anything sane with CustomCalls, since we don't know what they
477   // do, and returning an error status will stop iteration over this
478   // computation, which is probably also not what we want.  So just punt and
479   // return OK.  This will cause all of the properties to be reported as 0,
480   // which is fine.
481   current_should_compute_bottleneck_time_ = false;
482   return Status::OK();
483 }
484 
HandleSort(const HloInstruction * sort)485 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
486   // This assumes a comparison based N*log(N) algorithm. As for all ops, the
487   // actual properties of the op depend on the backend implementation.
488   int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
489   current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
490   return Status::OK();
491 }
492 
HandleWhile(const HloInstruction * xla_while)493 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
494   // Since the number of iterations of the while node will not always be
495   // something that we can statically analyze, we cannot precisely compute the
496   // cost of a while node. For now compute the cost of a single iteration.
497   //
498   // TODO(b/26346211): Improve the cost analysis for while nodes.
499   TF_ASSIGN_OR_RETURN(const Properties body_properties,
500                       ProcessSubcomputation(xla_while->while_body()));
501 
502   TF_ASSIGN_OR_RETURN(const Properties condition_properties,
503                       ProcessSubcomputation(xla_while->while_condition()));
504 
505   current_properties_.clear();
506   for (const auto& property : body_properties) {
507     current_properties_[property.first] += property.second;
508   }
509   for (const auto& property : condition_properties) {
510     current_properties_[property.first] += property.second;
511   }
512   current_should_compute_bottleneck_time_ = false;
513 
514   return Status::OK();
515 }
516 
HandleConditional(const HloInstruction * conditional)517 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
518   // Compute the cost of the true and false computations and take the maximum
519   // from those for each property.
520   TF_ASSIGN_OR_RETURN(const Properties true_computation_properties,
521                       ProcessSubcomputation(conditional->true_computation()));
522   TF_ASSIGN_OR_RETURN(const Properties false_computation_properties,
523                       ProcessSubcomputation(conditional->false_computation()));
524   current_properties_ = true_computation_properties;
525   for (const auto& property : false_computation_properties) {
526     if (!tensorflow::gtl::InsertIfNotPresent(&current_properties_, property)) {
527       current_properties_[property.first] =
528           std::max(current_properties_[property.first], property.second);
529     }
530   }
531   current_should_compute_bottleneck_time_ = false;
532 
533   return Status::OK();
534 }
535 
HandleGather(const HloInstruction * gather)536 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
537   // Gather does not issue any flops.
538   return Status::OK();
539 }
540 
FinishVisit(const HloInstruction *)541 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
542   return Status::OK();
543 }
544 
flop_count() const545 float HloCostAnalysis::flop_count() const {
546   return GetProperty(kFlopsKey, properties_sum_);
547 }
548 
transcendental_count() const549 float HloCostAnalysis::transcendental_count() const {
550   return GetProperty(kTranscendentalsKey, properties_sum_);
551 }
552 
bytes_accessed() const553 float HloCostAnalysis::bytes_accessed() const {
554   return GetProperty(kBytesAccessedKey, properties_sum_);
555 }
556 
optimal_seconds() const557 float HloCostAnalysis::optimal_seconds() const {
558   return GetProperty(kOptimalSecondsKey, properties_sum_);
559 }
560 
flop_count(const HloInstruction & hlo) const561 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
562   return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
563 }
564 
transcendental_count(const HloInstruction & hlo) const565 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
566   return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
567 }
568 
bytes_accessed(const HloInstruction & hlo) const569 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
570   return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
571 }
572 
optimal_seconds(const HloInstruction & hlo) const573 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
574   return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
575 }
576 
ProcessSubcomputation(HloComputation * computation,const ShapeSizeFunction * shape_size)577 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
578     HloComputation* computation, const ShapeSizeFunction* shape_size) {
579   if (shape_size == nullptr) {
580     shape_size = &shape_size_;
581   }
582   HloCostAnalysis visitor(*shape_size, per_second_rates_);
583   TF_RETURN_IF_ERROR(computation->Accept(&visitor));
584   return visitor.properties();
585 }
586 
587 }  // namespace xla
588