• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17 
18 #include <cmath>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/core/lib/core/bits.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 
33 namespace xla {
34 
35 constexpr const char HloCostAnalysis::kFlopsKey[];
36 constexpr const char HloCostAnalysis::kTranscendentalsKey[];
37 constexpr const char HloCostAnalysis::kBytesAccessedKey[];
38 constexpr const char HloCostAnalysis::kOptimalSecondsKey[];
39 
HloCostAnalysis(const ShapeSizeFunction & shape_size)40 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
41     : HloCostAnalysis(shape_size, {}) {}
42 
HloCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)43 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
44                                  const Properties& per_second_rates)
45     : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
46 
Preprocess(const HloInstruction * hlo)47 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
48   // Set current instruction cost values to reasonable default values. Each
49   // handler can overwrite these values. In Postprocess, these values are
50   // accumulated and written to the per-instruction maps.
51   current_properties_.clear();
52   current_should_compute_bottleneck_time_ = true;
53 
54   // The default number of bytes accessed for an instruction is the sum of the
55   // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
56   // handle opaque types.
57   float bytes_accessed = GetShapeSize(hlo->shape());
58   SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
59   for (int64 i = 0; i < hlo->operand_count(); ++i) {
60     const HloInstruction* operand = hlo->operand(i);
61     bytes_accessed += GetShapeSize(operand->shape());
62     SetOperandBytesAccessed(i, GetShapeSize(operand->shape()));
63   }
64   current_properties_[kBytesAccessedKey] = bytes_accessed;
65 
66   return Status::OK();
67 }
68 
Postprocess(const HloInstruction * hlo)69 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
70   if (current_should_compute_bottleneck_time_) {
71     // Compute the time as the time of the bottleneck, i.e. the slowest property
72     // given the per-second rate of each property.
73     float optimal_seconds = 0.0f;
74     for (const auto& property : current_properties_) {
75       if (property.first != kOptimalSecondsKey) {
76         optimal_seconds = std::max(
77             optimal_seconds,
78             property.second /
79                 GetProperty(property.first, per_second_rates_, INFINITY));
80       }
81     }
82     current_properties_[kOptimalSecondsKey] = optimal_seconds;
83   }
84 
85   TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
86   for (const auto& property : current_properties_) {
87     properties_sum_[property.first] += property.second;
88   }
89 
90   return Status::OK();
91 }
92 
HandleElementwiseOp(const HloInstruction * hlo_instruction)93 Status HloCostAnalysis::HandleElementwiseOp(
94     const HloInstruction* hlo_instruction) {
95   const auto& shape = hlo_instruction->shape();
96   // For element-wise operations, the number of computations is the same as the
97   // number of elements in the output shape.
98   auto computation_count = ShapeUtil::ElementsIn(shape);
99   auto opcode = hlo_instruction->opcode();
100   // We treat transcendental operations separately since one transcendental
101   // operation can correspond to several floating point ops.
102   // kLogistic is included in "trascendental" as it is implemented using
103   // trascendental ops (tanh or exp).
104   if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
105       opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
106       opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt ||
107       opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
108       opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
109       opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
110       opcode == HloOpcode::kAtan2) {
111     current_properties_[kTranscendentalsKey] = computation_count;
112   } else {
113     // Note: transcendental operations are considered a separate category from
114     // FLOPs.
115     current_properties_[kFlopsKey] = computation_count;
116   }
117   return Status::OK();
118 }
119 
GetProperty(const string & key,const Properties & properties,const float default_value)120 /*static*/ float HloCostAnalysis::GetProperty(const string& key,
121                                               const Properties& properties,
122                                               const float default_value) {
123   auto key_value = properties.find(key);
124   return key_value == properties.end() ? default_value : key_value->second;
125 }
126 
GetPropertyForHlo(const HloInstruction & hlo,const string & key,const HloToProperties & hlo_to_properties)127 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
128     const HloInstruction& hlo, const string& key,
129     const HloToProperties& hlo_to_properties) {
130   auto it = hlo_to_properties.find(&hlo);
131   if (it == hlo_to_properties.end()) {
132     return 0.0f;
133   } else {
134     return GetProperty(key, it->second);
135   }
136 }
137 
GetShapeSize(const Shape & shape) const138 int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
139   if (!LayoutUtil::HasLayout(shape)) {
140     return 0;
141   }
142   return shape_size_(shape);
143 }
144 
FusionParameterReadBytes(const HloInstruction * hlo) const145 int64 HloCostAnalysis::FusionParameterReadBytes(
146     const HloInstruction* hlo) const {
147   int64 size = 0;
148   bool seen_trivial_user = false;
149   CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
150                            hlo->opcode() == HloOpcode::kGetTupleElement));
151   for (const HloInstruction* user : hlo->users()) {
152     switch (user->opcode()) {
153       case HloOpcode::kFusion: {
154         for (int64 idx : user->OperandIndices(hlo)) {
155           size += FusionParameterReadBytes(user->fused_parameter(idx));
156         }
157         break;
158       }
159       case HloOpcode::kSlice:
160         size += GetShapeSize(user->shape());
161         break;
162       case HloOpcode::kDynamicSlice:
163         size += hlo == user->operand(0) ? GetShapeSize(user->shape())
164                                         : GetShapeSize(hlo->shape());
165         break;
166       case HloOpcode::kDynamicUpdateSlice:
167         // Uses the same shape as 'update' which is operand 1.
168         size += hlo == user->operand(0)
169                     ? GetShapeSize(user->operand(1)->shape())
170                     : GetShapeSize(hlo->shape());
171         break;
172       case HloOpcode::kBroadcast:
173       case HloOpcode::kReshape:
174         size += GetShapeSize(hlo->shape());
175         break;
176       default:
177         // Other instructions reading this parameter are assumed to be able to
178         // share the read from memory.
179         if (!seen_trivial_user) {
180           seen_trivial_user = true;
181           size += GetShapeSize(hlo->shape());
182         }
183     }
184   }
185   return size;
186 }
187 
HandleElementwiseUnary(const HloInstruction * hlo)188 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
189   return HandleElementwiseOp(hlo);
190 }
191 
HandleElementwiseBinary(const HloInstruction * hlo)192 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
193   return HandleElementwiseOp(hlo);
194 }
195 
HandleCompare(const HloInstruction * compare)196 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
197   return HandleElementwiseOp(compare);
198 }
199 
HandleClamp(const HloInstruction * clamp)200 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
201   return HandleElementwiseOp(clamp);
202 }
203 
HandleReducePrecision(const HloInstruction * hlo)204 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
205   return HandleElementwiseOp(hlo);
206 }
207 
HandleParameter(const HloInstruction *)208 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
209   current_should_compute_bottleneck_time_ = false;
210   current_properties_[kBytesAccessedKey] = 0;
211   SetOutputBytesAccessed(0);
212   current_properties_[kOptimalSecondsKey] = 0;
213   return Status::OK();
214 }
215 
HandleConstant(const HloInstruction *)216 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
217   current_should_compute_bottleneck_time_ = false;
218   current_properties_[kBytesAccessedKey] = 0;
219   SetOutputBytesAccessed(0);
220   current_properties_[kOptimalSecondsKey] = 0;
221   return Status::OK();
222 }
223 
HandleIota(const HloInstruction *)224 Status HloCostAnalysis::HandleIota(const HloInstruction*) {
225   return Status::OK();
226 }
227 
HandleGetTupleElement(const HloInstruction * get_tuple_element)228 Status HloCostAnalysis::HandleGetTupleElement(
229     const HloInstruction* get_tuple_element) {
230   // GetTupleElement forwards a pointer and does not touch each element in the
231   // output.
232   current_should_compute_bottleneck_time_ = false;
233   current_properties_[kBytesAccessedKey] = 0;
234   SetOutputBytesAccessed(0);
235   SetOperandBytesAccessed(0, 0);
236   current_properties_[kOptimalSecondsKey] = 0;
237   return Status::OK();
238 }
239 
HandleSelect(const HloInstruction * hlo)240 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
241   return HandleElementwiseOp(hlo);
242 }
243 
HandleTupleSelect(const HloInstruction *)244 Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) {
245   return Status::OK();
246 }
247 
HandleReverse(const HloInstruction *)248 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
249   return Status::OK();
250 }
251 
HandleSlice(const HloInstruction * slice)252 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
253   current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
254   SetOutputBytesAccessed(GetShapeSize(slice->shape()));
255   SetOperandBytesAccessed(0, GetShapeSize(slice->shape()));
256   return Status::OK();
257 }
258 
HandleDynamicSlice(const HloInstruction * dynamic_slice)259 Status HloCostAnalysis::HandleDynamicSlice(
260     const HloInstruction* dynamic_slice) {
261   current_properties_[kBytesAccessedKey] =
262       GetShapeSize(dynamic_slice->shape()) * 2 +
263       GetShapeSize(dynamic_slice->operand(1)->shape());
264   SetOutputBytesAccessed(GetShapeSize(dynamic_slice->shape()));
265   SetOperandBytesAccessed(0, GetShapeSize(dynamic_slice->shape()));
266   SetOperandBytesAccessed(1, GetShapeSize(dynamic_slice->operand(1)->shape()));
267   return Status::OK();
268 }
269 
HandleDynamicUpdateSlice(const HloInstruction * dynamic_update_slice)270 Status HloCostAnalysis::HandleDynamicUpdateSlice(
271     const HloInstruction* dynamic_update_slice) {
272   current_properties_[kBytesAccessedKey] =
273       GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2 +
274       GetShapeSize(dynamic_update_slice->operand(2)->shape());
275   // Operand 0 aliases with the output.
276   SetOutputBytesAccessed(
277       GetShapeSize(dynamic_update_slice->operand(1)->shape()));
278   SetOperandBytesAccessed(0, 0);
279   SetOperandBytesAccessed(
280       1, GetShapeSize(dynamic_update_slice->operand(1)->shape()));
281   SetOperandBytesAccessed(
282       2, GetShapeSize(dynamic_update_slice->operand(2)->shape()));
283   return Status::OK();
284 }
285 
HandleTuple(const HloInstruction * tuple)286 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
287   // The tuple instruction only gathers pointers from inputs (it doesn't iterate
288   // through them). The memory touched is then only the size of the output
289   // index table of the tuple.
290 
291   current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
292   SetOutputBytesAccessed(GetShapeSize(tuple->shape()));
293   for (int i = 0; i < tuple->operand_count(); ++i) {
294     SetOperandBytesAccessed(i, 0);
295   }
296   return Status::OK();
297 }
298 
HandleConcatenate(const HloInstruction *)299 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
300   return Status::OK();
301 }
302 
HandleConvert(const HloInstruction * convert)303 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
304   return HandleElementwiseOp(convert);
305 }
306 
HandleCopy(const HloInstruction *)307 Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
308   return Status::OK();
309 }
310 
HandleDomain(const HloInstruction * domain)311 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
312   // Domain does not have any computation or data transfer.
313   current_should_compute_bottleneck_time_ = false;
314   current_properties_[kBytesAccessedKey] = 0;
315   SetOutputBytesAccessed(0);
316   for (int i = 0; i < domain->operand_count(); ++i) {
317     SetOperandBytesAccessed(i, 0);
318   }
319   current_properties_[kOptimalSecondsKey] = 0;
320   return Status::OK();
321 }
322 
HandleDot(const HloInstruction * dot)323 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
324   const Shape& lhs_shape = dot->operand(0)->shape();
325   const Shape& dot_shape = dot->shape();
326   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
327   // Count of elements along the reduction dimension (last dimension for the
328   // rhs).
329   int64 reduction_width = 1;
330   for (auto dim : dnums.lhs_contracting_dimensions()) {
331     reduction_width *= lhs_shape.dimensions(dim);
332   }
333   // Each output element requires reduction_width FMA operations.
334   current_properties_[kFlopsKey] =
335       kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width;
336   return Status::OK();
337 }
338 
HandleInfeed(const HloInstruction * infeed)339 Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
340   // Count nested infeed output tuples.
341   int64 size = 0;
342   for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
343     size += GetShapeSize(indexed_shape.shape);
344     SetOutputBytesAccessed(indexed_shape.index,
345                            GetShapeSize(indexed_shape.shape));
346   }
347   SetOutputBytesAccessed(size);
348   current_properties_[kBytesAccessedKey] = size;
349   return Status::OK();
350 }
351 
HandleOutfeed(const HloInstruction * outfeed)352 Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
353   // Count nested outfeed operand tuples.
354   current_properties_[kBytesAccessedKey] = 0;
355   for (int64 i = 0; i < outfeed->operand_count(); ++i) {
356     const HloInstruction* operand = outfeed->operand(i);
357     int64 size = 0;
358     for (const auto& indexed_shape :
359          ShapeUtil::GetLeafShapes(operand->shape())) {
360       size += GetShapeSize(indexed_shape.shape);
361       SetOperandBytesAccessed(i, indexed_shape.index,
362                               GetShapeSize(indexed_shape.shape));
363     }
364     SetOperandBytesAccessed(i, size);
365     current_properties_[kBytesAccessedKey] += size;
366   }
367   return Status::OK();
368 }
369 
HandleMap(const HloInstruction * map)370 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
371   // Compute properties of the mapped function.
372   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
373                       ProcessSubcomputation(map->to_apply()));
374 
375   // Compute the cost of all elements for this Map operation.
376   const int64 element_count = ShapeUtil::ElementsIn(map->shape());
377   for (const auto& property : sub_properties) {
378     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
379       current_properties_[property.first] = property.second * element_count;
380     }
381   }
382   return Status::OK();
383 }
384 
HandleReduce(const HloInstruction * reduce)385 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
386   HloComputation* function = reduce->to_apply();
387   // Compute the cost of the user function.
388   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
389                       ProcessSubcomputation(function));
390 
391   // Compute the cost of all elements for this Reduce operation.
392   // This counts the number of times the reduction function is applied, so it
393   // does not need to be multiplied by the number of input tensors - that's
394   // already "priced in" by the sub-computation doing more work.
395   auto arg = reduce->operand(0);
396   auto output_shape = reduce->shape().IsArray()
397                           ? reduce->shape()
398                           : reduce->shape().tuple_shapes(0);
399   int64 reduction_count =
400       ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
401   for (const auto& property : sub_properties) {
402     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
403       current_properties_[property.first] = property.second * reduction_count;
404     }
405   }
406   return Status::OK();
407 }
408 
HandleReduceWindow(const HloInstruction * reduce_window)409 Status HloCostAnalysis::HandleReduceWindow(
410     const HloInstruction* reduce_window) {
411   const Window& window = reduce_window->window();
412   auto function = reduce_window->to_apply();
413   // Compute the properties of the reduction function.
414   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
415                       ProcessSubcomputation(function));
416 
417   // Compute the cost of all elements for this ReduceWindow operation. For each
418   // output element there are window_size - 1 reductions to perform.
419   int64 window_element_count = 1;
420   for (const auto& dimension : window.dimensions()) {
421     window_element_count *= dimension.size();
422   }
423 
424   const int64 output_element_count =
425       ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
426                                 ? reduce_window->shape()
427                                 : reduce_window->shape().tuple_shapes(0));
428   const int64 reduction_count =
429       (window_element_count - 1) * output_element_count;
430   for (const auto& property : sub_properties) {
431     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
432       current_properties_[property.first] = property.second * reduction_count;
433     }
434   }
435   return Status::OK();
436 }
437 
HandleSelectAndScatter(const HloInstruction * instruction)438 Status HloCostAnalysis::HandleSelectAndScatter(
439     const HloInstruction* instruction) {
440   // Compute the properties of the select and scatter function.
441   // Compute the properties of the reduction function.
442   TF_ASSIGN_OR_RETURN(const Properties select_properties,
443                       ProcessSubcomputation(instruction->select()));
444   TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
445                       ProcessSubcomputation(instruction->scatter()));
446 
447   // Compute the cost of all elements for this operation. For each scatter
448   // source element there are window_size - 1 select computations to perform and
449   // 1 scatter computation to perform.
450   const auto source = instruction->operand(1);
451   const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
452   int64 window_element_count = 1;
453   for (const auto& dimension : instruction->window().dimensions()) {
454     window_element_count *= dimension.size();
455   }
456   const int64 select_count = source_element_count * (window_element_count - 1);
457   for (const auto& property : select_properties) {
458     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
459       current_properties_[property.first] += property.second * select_count;
460     }
461   }
462   for (const auto& property : scatter_properties) {
463     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
464       current_properties_[property.first] +=
465           property.second * source_element_count;
466     }
467   }
468   return Status::OK();
469 }
470 
HandleBitcast(const HloInstruction *)471 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
472   // A bitcast does no computation and touches no memory.
473   current_properties_[kBytesAccessedKey] = 0;
474   SetOutputBytesAccessed(0);
475   SetOperandBytesAccessed(0, 0);
476   current_properties_[kOptimalSecondsKey] = 0;
477   return Status::OK();
478 }
479 
HandleBroadcast(const HloInstruction *)480 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
481   return Status::OK();
482 }
483 
HandlePad(const HloInstruction *)484 Status HloCostAnalysis::HandlePad(const HloInstruction*) {
485   return Status::OK();
486 }
487 
HandleCopyStart(const HloInstruction *)488 Status HloCostAnalysis::HandleCopyStart(const HloInstruction*) {
489   return Status::OK();
490 }
491 
HandleCopyDone(const HloInstruction *)492 Status HloCostAnalysis::HandleCopyDone(const HloInstruction*) {
493   return Status::OK();
494 }
495 
HandleSend(const HloInstruction *)496 Status HloCostAnalysis::HandleSend(const HloInstruction*) {
497   return Status::OK();
498 }
499 
HandleSendDone(const HloInstruction *)500 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
501   return Status::OK();
502 }
503 
HandleRecv(const HloInstruction *)504 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
505   return Status::OK();
506 }
507 
HandleRecvDone(const HloInstruction *)508 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
509   return Status::OK();
510 }
511 
HandleReshape(const HloInstruction *)512 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
513   return Status::OK();
514 }
515 
HandleDynamicReshape(const HloInstruction *)516 Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
517   return Status::OK();
518 }
519 
HandleBatchNormTraining(const HloInstruction *)520 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
521   // TODO(b/62294698): Implement cost analysis for batch-norm-training.
522   return Status::OK();
523 }
524 
HandleBatchNormInference(const HloInstruction *)525 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
526   // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
527   return Status::OK();
528 }
529 
HandleBatchNormGrad(const HloInstruction *)530 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
531   // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
532   return Status::OK();
533 }
534 
HandleTranspose(const HloInstruction * transpose)535 Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) {
536   if (transpose->IsEffectiveBitcast()) {
537     return HandleBitcast(transpose);
538   }
539   return Status::OK();
540 }
541 
HandleAfterAll(const HloInstruction * token)542 Status HloCostAnalysis::HandleAfterAll(const HloInstruction* token) {
543   // This instruction is used to enforce ordering at compile time. No code is
544   // emitted.
545   current_should_compute_bottleneck_time_ = false;
546   current_properties_[kBytesAccessedKey] = 0;
547   SetOutputBytesAccessed(0);
548   for (int i = 0; i < token->operand_count(); ++i) {
549     SetOperandBytesAccessed(i, 0);
550   }
551   current_properties_[kOptimalSecondsKey] = 0;
552   return Status::OK();
553 }
554 
HandleAddDependency(const HloInstruction * add_dependency)555 Status HloCostAnalysis::HandleAddDependency(
556     const HloInstruction* add_dependency) {
557   // This instruction is used to enforce ordering at compile time. No code is
558   // emitted.
559   current_should_compute_bottleneck_time_ = false;
560   current_properties_[kBytesAccessedKey] = 0;
561   SetOutputBytesAccessed(0);
562   for (int i = 0; i < add_dependency->operand_count(); ++i) {
563     SetOperandBytesAccessed(i, 0);
564   }
565   current_properties_[kOptimalSecondsKey] = 0;
566   return Status::OK();
567 }
568 
HandleConvolution(const HloInstruction * convolution)569 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
570   auto lhs = convolution->operand(0);
571   auto rhs = convolution->operand(1);
572   Window window = convolution->window();
573   const auto& result_shape = convolution->shape();
574   const Shape& lhs_shape = lhs->shape();
575   const Shape& rhs_shape = rhs->shape();
576 
577   const auto& dnums = convolution->convolution_dimension_numbers();
578 
579   const int64 input_batch_dim = dnums.input_batch_dimension();
580   const int64 input_feature_dim = dnums.input_feature_dimension();
581   const int64 output_feature_dim = dnums.output_feature_dimension();
582   const int64 input_feature =
583       ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
584   const int64 output_feature =
585       ShapeUtil::GetDimension(result_shape, output_feature_dim);
586   const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
587 
588   DimensionVector kernel_limits;
589   DimensionVector output_limits;
590   DimensionVector input_limits;
591   if (window.dimensions().empty()) {
592     window = window_util::MakeWindow({1});
593     kernel_limits.push_back(1);
594     output_limits.push_back(1);
595     input_limits.push_back(1);
596   } else {
597     for (int64 spatial_dimension = 0;
598          spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
599       // Spatial dimension number for kernel (rhs).
600       const int64 kernel_spatial_dim =
601           dnums.kernel_spatial_dimensions(spatial_dimension);
602       const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
603       kernel_limits.push_back(kernel_limit);
604 
605       // Spatial dimension number for output.
606       const int64 output_spatial_dim =
607           dnums.output_spatial_dimensions(spatial_dimension);
608       const int64 output_limit = result_shape.dimensions(output_spatial_dim);
609       output_limits.push_back(output_limit);
610 
611       // Spatial dimension number for input (lhs).
612       const int64 input_spatial_dim =
613           dnums.input_spatial_dimensions(spatial_dimension);
614       const int64 input_limit = lhs_shape.dimensions(input_spatial_dim);
615       input_limits.push_back(input_limit);
616     }
617   }
618 
619   DimensionVector valid_position_counts;
620 
621   // Loop over each spatial dimension.
622   for (int64 spatial_dimension = 0;
623        spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
624     const auto& window_dim = window.dimensions(spatial_dimension);
625     // These two conditions will create an N^2 iteration pattern with only N
626     // valid elements. This is a performance optimization and produces the same
627     // result as the whole loop.
628     if (input_limits[spatial_dimension] == output_limits[spatial_dimension] &&
629         kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
630         input_limits[spatial_dimension] == window_dim.base_dilation() &&
631         window_dim.window_dilation() == 1 &&
632         std::max<int64>(1, input_limits[spatial_dimension] - 1) ==
633             window_dim.stride() &&
634         window_dim.padding_low() == 0 && window_dim.padding_high() == 0) {
635       valid_position_counts.push_back(input_limits[spatial_dimension]);
636       continue;
637     }
638 
639     if (input_limits[spatial_dimension] == 1 &&
640         kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
641         window_dim.window_dilation() == 1 && window_dim.base_dilation() == 1 &&
642         window_dim.stride() == 1 &&
643         window_dim.padding_high() == output_limits[spatial_dimension] - 1 &&
644         window_dim.padding_low() == output_limits[spatial_dimension] - 1) {
645       valid_position_counts.push_back(output_limits[spatial_dimension]);
646       continue;
647     }
648 
649     int64 valid_position_count = 0;
650     // Loop over each point in the kernel.
651     for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
652          ++kernel_idx) {
653       // Loop over each point in the output.
654       for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension];
655            ++output_idx) {
656         // Calculate lhs (input) index without taking base dilation into
657         // account.
658         const int64 undilated_index = output_idx * window_dim.stride() -
659                                       window_dim.padding_low() +
660                                       kernel_idx * window_dim.window_dilation();
661 
662         // Calculate the actual lhs (input) index after dilation. Avoid the
663         // division as an optimization.
664         const int64 lhs_spatial_index =
665             window_dim.base_dilation() > 1
666                 ? undilated_index / window_dim.base_dilation()
667                 : undilated_index;
668 
669         // Skip if the lhs (input) index is to be dilated.
670         if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
671           continue;
672         }
673 
674         // Skip if input index is not in bound.
675         if (lhs_spatial_index < 0 ||
676             lhs_spatial_index >= input_limits[spatial_dimension]) {
677           continue;
678         }
679 
680         valid_position_count += 1;
681       }
682     }
683     valid_position_counts.push_back(valid_position_count);
684   }
685 
686   const int64 fma_count = (input_feature / convolution->feature_group_count()) *
687                           output_feature *
688                           (batch / convolution->batch_group_count()) *
689                           Product(valid_position_counts);
690   current_properties_[kFlopsKey] = fma_count * kFmaFlops;
691   return Status::OK();
692 }
693 
HandleFft(const HloInstruction * fft)694 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
695   auto real_shape =
696       fft->operand(0)->shape().IsTuple()
697           ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
698           : fft->operand(0)->shape();
699   constexpr int kFmaPerComplexMul = 4;
700   int64 log_factors = 1;
701   for (int64 dim : fft->fft_length()) {
702     log_factors *= tensorflow::Log2Floor(dim);
703   }
704   current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
705                                    ShapeUtil::ElementsIn(real_shape);
706   return Status::OK();
707 }
708 
HandleTriangularSolve(const HloInstruction * hlo)709 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
710   // Half of operand 0 is read.
711   float bytes_accessed = GetShapeSize(hlo->shape());
712   SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
713   bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
714   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
715   bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
716   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(1)->shape()));
717   current_properties_[kBytesAccessedKey] = bytes_accessed;
718 
719   const Shape& a_shape = hlo->operand(0)->shape();
720   const Shape& b_shape = hlo->operand(1)->shape();
721   // Estimate as batch * mn^2 / 2 flops.
722   int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
723   elems *= ShapeUtil::ElementsIn(b_shape);
724   current_properties_[kFlopsKey] = kFmaFlops * elems;
725   return Status::OK();
726 }
727 
HandleCholesky(const HloInstruction * hlo)728 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
729   // Half of operand 0 is read and half of the output will be written.
730   float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
731   SetOutputBytesAccessed(GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
732   bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
733   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
734   current_properties_[kBytesAccessedKey] = bytes_accessed;
735 
736   const Shape& a_shape = hlo->operand(0)->shape();
737   // Estimate as batch * n^3 / 3 flops.
738   int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
739   elems *= ShapeUtil::ElementsIn(a_shape);
740   current_properties_[kFlopsKey] = elems / 3;
741   return Status::OK();
742 }
743 
HandleAllGather(const HloInstruction * hlo)744 Status HloCostAnalysis::HandleAllGather(const HloInstruction* hlo) {
745   return Status::OK();
746 }
747 
HandleAllReduce(const HloInstruction * crs)748 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
749   // We assume 2 replicas, so that each output element is the sum of two input
750   // elements.
751   //
752   // TODO(b/33004697): Compute correct cost here, taking the actual number of
753   // replicas into account.
754   double flops = 0.0;
755   ShapeUtil::ForEachSubshape(crs->shape(),
756                              [&](const Shape& subshape, const ShapeIndex&) {
757                                if (subshape.IsArray()) {
758                                  flops += ShapeUtil::ElementsIn(subshape);
759                                }
760                              });
761   current_properties_[kFlopsKey] = flops;
762   return Status::OK();
763 }
764 
HandleAllToAll(const HloInstruction * hlo)765 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
766   return Status::OK();
767 }
768 
HandleCollectivePermute(const HloInstruction *)769 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
770   return Status::OK();
771 }
772 
HandleCollectivePermuteStart(const HloInstruction *)773 Status HloCostAnalysis::HandleCollectivePermuteStart(
774     const HloInstruction* /*hlo*/) {
775   return Status::OK();
776 }
777 
HandleCollectivePermuteDone(const HloInstruction *)778 Status HloCostAnalysis::HandleCollectivePermuteDone(
779     const HloInstruction* /*hlo*/) {
780   return Status::OK();
781 }
782 
HandlePartitionId(const HloInstruction *)783 Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
784   return Status::OK();
785 }
786 
HandleReplicaId(const HloInstruction *)787 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) {
788   return Status::OK();
789 }
790 
HandleRng(const HloInstruction * random)791 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
792   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
793   // cost changes with the implementation and the distribution. For now, assume
794   // the cost of each RNG is same as a transcendental operation.
795   current_properties_[kTranscendentalsKey] =
796       ShapeUtil::ElementsIn(random->shape());
797   return Status::OK();
798 }
799 
HandleRngBitGenerator(const HloInstruction * random)800 Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
801   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
802   // cost changes with the implementation and the distribution. For now, assume
803   // the cost of each RNG is same as a transcendental operation.
804   current_properties_[kTranscendentalsKey] =
805       ShapeUtil::ElementsInRecursive(random->shape());
806   return Status::OK();
807 }
808 
HandleRngGetAndUpdateState(const HloInstruction * random)809 Status HloCostAnalysis::HandleRngGetAndUpdateState(
810     const HloInstruction* random) {
811   return Status::OK();
812 }
813 
HandleFusion(const HloInstruction * fusion)814 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
815   if (fusion->IsCustomFusion()) {
816     for (const HloInstruction* hlo :
817          fusion->fused_instructions_computation()->instructions()) {
818       if (hlo->opcode() == HloOpcode::kGather) {
819         return HandleGather(hlo);
820       }
821       if (hlo->opcode() == HloOpcode::kScatter) {
822         return HandleScatter(hlo);
823       }
824     }
825   }
826   TF_ASSIGN_OR_RETURN(
827       current_properties_,
828       ProcessSubcomputation(fusion->fused_instructions_computation()));
829 
830   // Fusion nodes that produce a tuple also produce the entries in the tuple.
831   // Ignore the memory accessed inside fused ops, since fusion is supposed to
832   // prevent intermediate data from touching slow memory.
833   current_properties_[kBytesAccessedKey] = 0;
834   ShapeUtil::ForEachSubshape(
835       fusion->shape(),
836       [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
837         if (!subshape.IsArray()) {
838           return;
839         }
840         if (shape_index.empty()) {
841           if (fusion->fused_expression_root()->opcode() ==
842               HloOpcode::kDynamicUpdateSlice) {
843             int64 size = GetShapeSize(
844                 fusion->fused_expression_root()->operand(1)->shape());
845             current_properties_[kBytesAccessedKey] += size;
846             SetOutputBytesAccessed(shape_index, size);
847             return;
848           }
849         } else if (shape_index.size() == 1) {
850           if (fusion->fused_expression_root()->opcode() == HloOpcode::kTuple &&
851               fusion->fused_expression_root()
852                       ->operand(shape_index[0])
853                       ->opcode() == HloOpcode::kDynamicUpdateSlice) {
854             int64 size = GetShapeSize(fusion->fused_expression_root()
855                                           ->operand(shape_index[0])
856                                           ->operand(1)
857                                           ->shape());
858             current_properties_[kBytesAccessedKey] += size;
859             SetOutputBytesAccessed(shape_index, size);
860             return;
861           }
862         }
863         current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
864         SetOutputBytesAccessed(shape_index, GetShapeSize(subshape));
865       });
866 
867   if (fusion->shape().IsTuple()) {
868     // Propagate and accumulate the output tuple bytes from the tuple subshapes.
869     // This ensures we have the correct output bytes accessed for the shape
870     // index
871     // {}.
872     std::function<float(const Shape&, const ShapeIndex&)>
873         propagate_output_size_to_parent;
874     propagate_output_size_to_parent = [&](const Shape& shape,
875                                           const ShapeIndex& shape_index) {
876       auto output_bytes_it =
877           current_properties_.find(GetOutputBytesAccessedKey(shape_index));
878       if (output_bytes_it != current_properties_.end()) {
879         return output_bytes_it->second;
880       }
881       float bytes_accessed = 0;
882       for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
883         const Shape& subshape = shape.tuple_shapes(i);
884         ShapeIndex subshape_index(shape_index);
885         subshape_index.push_back(i);
886         bytes_accessed +=
887             propagate_output_size_to_parent(subshape, subshape_index);
888       }
889       SetOutputBytesAccessed(shape_index, bytes_accessed);
890       return bytes_accessed;
891     };
892     current_properties_.erase(
893         current_properties_.find(GetOutputBytesAccessedKey()));
894     propagate_output_size_to_parent(fusion->shape(), {});
895   }
896 
897   for (int64 i = 0; i < fusion->fused_parameters().size(); ++i) {
898     const HloInstruction* operand = fusion->fused_parameter(i);
899     int64 operand_size = 0;
900     if (!fusion->shape().IsTuple()) {
901       operand_size = FusionParameterReadBytes(operand);
902     } else {
903       // If the fusion parameter is a tuple type, find the gte for the leaf
904       // shape and calculate the bytes accessed for those array types.
905       for (const auto& indexed_shape :
906            ShapeUtil::GetLeafShapes(operand->shape())) {
907         const HloInstruction* gte = operand;
908         for (int64 index : indexed_shape.index) {
909           for (const HloInstruction* user : gte->users()) {
910             if (user->opcode() == HloOpcode::kGetTupleElement &&
911                 user->tuple_index() == index) {
912               gte = user;
913               break;
914             }
915           }
916         }
917         int64 size = FusionParameterReadBytes(gte);
918         operand_size += size;
919         SetOperandBytesAccessed(i, indexed_shape.index, size);
920       }
921     }
922     current_properties_[kBytesAccessedKey] += operand_size;
923     SetOperandBytesAccessed(i, operand_size);
924   }
925 
926   return Status::OK();
927 }
928 
HandleCall(const HloInstruction * call)929 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
930   TF_ASSIGN_OR_RETURN(current_properties_,
931                       ProcessSubcomputation(call->to_apply()));
932   current_should_compute_bottleneck_time_ = false;
933   return Status::OK();
934 }
935 
HandleCustomCall(const HloInstruction * custom_call)936 Status HloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
937   // Mark applicable fields as "unknown", since we don't know what CustomCall
938   // does.  This is better than returning an error, which would stop iteration,
939   // and therefore would prevent us from getting *any* stats for a computation
940   // which contains a CustomCall.
941   current_properties_[kOptimalSecondsKey] = -1;
942   current_properties_[kBytesAccessedKey] = -1;
943   SetOutputBytesAccessed(-1);
944   for (int i = 0; i < custom_call->operand_count(); ++i) {
945     SetOperandBytesAccessed(i, -1);
946   }
947   current_properties_[kFlopsKey] = -1;
948   current_should_compute_bottleneck_time_ = false;
949   return Status::OK();
950 }
951 
HandleSort(const HloInstruction * sort)952 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
953   // This assumes a comparison based N*log(N) algorithm. As for all ops, the
954   // actual properties of the op depend on the backend implementation.
955   int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
956   current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
957   return Status::OK();
958 }
959 
HandleWhile(const HloInstruction * xla_while)960 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
961   // Since the number of iterations of the while node will not always be
962   // something that we can statically analyze, we cannot precisely compute the
963   // cost of a while node. For now compute the cost of a single iteration.
964   TF_ASSIGN_OR_RETURN(const Properties body_properties,
965                       ProcessSubcomputation(xla_while->while_body()));
966 
967   TF_ASSIGN_OR_RETURN(const Properties condition_properties,
968                       ProcessSubcomputation(xla_while->while_condition()));
969 
970   current_properties_.clear();
971   for (const auto& property : body_properties) {
972     current_properties_[property.first] += property.second;
973   }
974   for (const auto& property : condition_properties) {
975     current_properties_[property.first] += property.second;
976   }
977   current_should_compute_bottleneck_time_ = false;
978 
979   return Status::OK();
980 }
981 
HandleConditional(const HloInstruction * conditional)982 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
983   // Compute the cost of the branch computations and take the maximum from those
984   // for each property.
985   TF_ASSIGN_OR_RETURN(
986       const Properties branch0_computation_properties,
987       ProcessSubcomputation(conditional->branch_computation(0)));
988   current_properties_ = branch0_computation_properties;
989   for (int j = 1; j < conditional->branch_count(); ++j) {
990     TF_ASSIGN_OR_RETURN(
991         const Properties branch_computation_properties,
992         ProcessSubcomputation(conditional->branch_computation(j)));
993     for (const auto& property : branch_computation_properties) {
994       if (!tensorflow::gtl::InsertIfNotPresent(&current_properties_,
995                                                property)) {
996         auto& current_property = current_properties_[property.first];
997         current_property = std::max(current_property, property.second);
998       }
999     }
1000   }
1001   current_should_compute_bottleneck_time_ = false;
1002 
1003   return Status::OK();
1004 }
1005 
HandleGather(const HloInstruction * gather)1006 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
1007   // Gather doesn't read the whole input buffer, it's equivalent to a copy the
1008   // size of the output shape and a read of the gather indices.
1009   int64 output_size = GetShapeSize(gather->shape());
1010   current_properties_[kBytesAccessedKey] =
1011       output_size * 2 + GetShapeSize(gather->operand(1)->shape());
1012   SetOperandBytesAccessed(0, output_size);
1013   SetOperandBytesAccessed(1, GetShapeSize(gather->operand(1)->shape()));
1014   SetOutputBytesAccessed(output_size);
1015   // Gather does not issue any flops.
1016   return Status::OK();
1017 }
1018 
HandleScatter(const HloInstruction * scatter)1019 Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
1020   // Scatter accesses the equivalent of 3 update shapes (input, output, and
1021   // updates), and the scatter indices.
1022   int64 update_size = GetShapeSize(scatter->operand(2)->shape());
1023   current_properties_[kBytesAccessedKey] =
1024       update_size * 3 + GetShapeSize(scatter->operand(1)->shape());
1025   SetOperandBytesAccessed(0, update_size);
1026   SetOperandBytesAccessed(1, GetShapeSize(scatter->operand(1)->shape()));
1027   SetOperandBytesAccessed(2, update_size);
1028   SetOutputBytesAccessed(update_size);
1029   const int64 element_count =
1030       ShapeUtil::ElementsIn(scatter->operand(2)->shape());
1031   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
1032                       ProcessSubcomputation(scatter->to_apply()));
1033   for (const auto& property : sub_properties) {
1034     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
1035       current_properties_[property.first] = property.second * element_count;
1036     }
1037   }
1038   return Status::OK();
1039 }
1040 
HandleGetDimensionSize(const HloInstruction *)1041 Status HloCostAnalysis::HandleGetDimensionSize(
1042     const HloInstruction* /*get_size*/) {
1043   return Status::OK();
1044 }
1045 
HandleSetDimensionSize(const HloInstruction *)1046 Status HloCostAnalysis::HandleSetDimensionSize(
1047     const HloInstruction* /*set_size*/) {
1048   return Status::OK();
1049 }
1050 
FinishVisit(const HloInstruction *)1051 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
1052   return Status::OK();
1053 }
1054 
flop_count() const1055 float HloCostAnalysis::flop_count() const {
1056   return GetProperty(kFlopsKey, properties_sum_);
1057 }
1058 
transcendental_count() const1059 float HloCostAnalysis::transcendental_count() const {
1060   return GetProperty(kTranscendentalsKey, properties_sum_);
1061 }
1062 
bytes_accessed() const1063 float HloCostAnalysis::bytes_accessed() const {
1064   return GetProperty(kBytesAccessedKey, properties_sum_);
1065 }
1066 
optimal_seconds() const1067 float HloCostAnalysis::optimal_seconds() const {
1068   return GetProperty(kOptimalSecondsKey, properties_sum_);
1069 }
1070 
flop_count(const HloInstruction & hlo) const1071 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
1072   return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
1073 }
1074 
transcendental_count(const HloInstruction & hlo) const1075 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
1076   return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
1077 }
1078 
bytes_accessed(const HloInstruction & hlo) const1079 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
1080   return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
1081 }
1082 
operand_bytes_accessed(const HloInstruction & hlo,int64 operand_num,ShapeIndex index) const1083 int64 HloCostAnalysis::operand_bytes_accessed(const HloInstruction& hlo,
1084                                               int64 operand_num,
1085                                               ShapeIndex index) const {
1086   return GetPropertyForHlo(hlo, GetOperandBytesAccessedKey(operand_num, index),
1087                            hlo_properties_);
1088 }
1089 
output_bytes_accessed(const HloInstruction & hlo,ShapeIndex index) const1090 int64 HloCostAnalysis::output_bytes_accessed(const HloInstruction& hlo,
1091                                              ShapeIndex index) const {
1092   return GetPropertyForHlo(hlo, GetOutputBytesAccessedKey(index),
1093                            hlo_properties_);
1094 }
1095 
optimal_seconds(const HloInstruction & hlo) const1096 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
1097   return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
1098 }
1099 
GetBytesRead(const HloInstruction & hlo,absl::optional<int64> memory_space) const1100 int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo,
1101                                     absl::optional<int64> memory_space) const {
1102   int64 bytes_read = 0;
1103   for (int operand_number = 0; operand_number < hlo.operand_count();
1104        ++operand_number) {
1105     for (const ShapeUtil::IndexedShape& indexed_shape :
1106          ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) {
1107       absl::optional<int64> index_memory_space;
1108       if (indexed_shape.shape.has_layout()) {
1109         index_memory_space = indexed_shape.shape.layout().memory_space();
1110       }
1111       if (!memory_space || memory_space == index_memory_space) {
1112         bytes_read +=
1113             operand_bytes_accessed(hlo, operand_number, indexed_shape.index);
1114       }
1115     }
1116   }
1117   return bytes_read;
1118 }
1119 
GetBytesWritten(const HloInstruction & hlo,absl::optional<int64> memory_space) const1120 int64 HloCostAnalysis::GetBytesWritten(
1121     const HloInstruction& hlo, absl::optional<int64> memory_space) const {
1122   int64 bytes_written = 0;
1123   for (const ShapeUtil::IndexedShape& indexed_shape :
1124        ShapeUtil::GetLeafShapes(hlo.shape())) {
1125     absl::optional<int64> index_memory_space;
1126     if (indexed_shape.shape.has_layout()) {
1127       index_memory_space = indexed_shape.shape.layout().memory_space();
1128     }
1129     if (!memory_space || memory_space == index_memory_space) {
1130       bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
1131     }
1132   }
1133   return bytes_written;
1134 }
1135 
ProcessSubcomputation(HloComputation * computation)1136 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
1137     HloComputation* computation) {
1138   auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);
1139   visitor->ReserveVisitStates(computation->instruction_count());
1140   TF_RETURN_IF_ERROR(computation->Accept(visitor.get()));
1141   hlo_properties_.insert(visitor->hlo_properties_.begin(),
1142                          visitor->hlo_properties_.end());
1143   return visitor->properties();
1144 }
1145 
CreateNestedCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)1146 std::unique_ptr<HloCostAnalysis> HloCostAnalysis::CreateNestedCostAnalysis(
1147     const ShapeSizeFunction& shape_size, const Properties& per_second_rates) {
1148   return absl::WrapUnique(new HloCostAnalysis(shape_size, per_second_rates));
1149 }
1150 
SetOperandBytesAccessed(int64 operand_num,float value)1151 void HloCostAnalysis::SetOperandBytesAccessed(int64 operand_num, float value) {
1152   current_properties_[GetOperandBytesAccessedKey(operand_num).c_str()] = value;
1153 }
1154 
SetOperandBytesAccessed(int64 operand_num,ShapeIndex index,float value)1155 void HloCostAnalysis::SetOperandBytesAccessed(int64 operand_num,
1156                                               ShapeIndex index, float value) {
1157   current_properties_[GetOperandBytesAccessedKey(operand_num, index).c_str()] =
1158       value;
1159 }
1160 
SetOutputBytesAccessed(float value)1161 void HloCostAnalysis::SetOutputBytesAccessed(float value) {
1162   current_properties_[GetOutputBytesAccessedKey()] = value;
1163 }
1164 
SetOutputBytesAccessed(ShapeIndex index,float value)1165 void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
1166   current_properties_[GetOutputBytesAccessedKey(index)] = value;
1167 }
1168 
GetOperandBytesAccessedKey(int64 operand_num,ShapeIndex index)1169 /*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey(
1170     int64 operand_num, ShapeIndex index) {
1171   return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ",
1172                       index.ToString());
1173 }
1174 
GetOutputBytesAccessedKey(ShapeIndex index)1175 /*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey(
1176     ShapeIndex index) {
1177   return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
1178 }
1179 
1180 }  // namespace xla
1181