• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
17 
18 #include <cmath>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/core/lib/core/bits.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 
33 namespace xla {
34 
35 constexpr const char HloCostAnalysis::kFlopsKey[];
36 constexpr const char HloCostAnalysis::kTranscendentalsKey[];
37 constexpr const char HloCostAnalysis::kBytesAccessedKey[];
38 constexpr const char HloCostAnalysis::kOptimalSecondsKey[];
39 
HloCostAnalysis(const ShapeSizeFunction & shape_size)40 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size)
41     : HloCostAnalysis(shape_size, {}) {}
42 
HloCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)43 HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size,
44                                  const Properties& per_second_rates)
45     : shape_size_(shape_size), per_second_rates_(per_second_rates) {}
46 
Preprocess(const HloInstruction * hlo)47 Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) {
48   // Set current instruction cost values to reasonable default values. Each
49   // handler can overwrite these values. In Postprocess, these values are
50   // accumulated and written to the per-instruction maps.
51   current_properties_.clear();
52   current_should_compute_bottleneck_time_ = true;
53 
54   // The default number of bytes accessed for an instruction is the sum of the
55   // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not
56   // handle opaque types.
57   float bytes_accessed = GetShapeSize(hlo->shape());
58   SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
59   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
60     const HloInstruction* operand = hlo->operand(i);
61     bytes_accessed += GetShapeSize(operand->shape());
62     SetOperandBytesAccessed(i, GetShapeSize(operand->shape()));
63   }
64   current_properties_[kBytesAccessedKey] = bytes_accessed;
65 
66   return Status::OK();
67 }
68 
Postprocess(const HloInstruction * hlo)69 Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) {
70   if (current_should_compute_bottleneck_time_) {
71     // Compute the time as the time of the bottleneck, i.e. the slowest property
72     // given the per-second rate of each property.
73     float optimal_seconds = 0.0f;
74     for (const auto& property : current_properties_) {
75       if (property.first != kOptimalSecondsKey) {
76         optimal_seconds = std::max(
77             optimal_seconds,
78             property.second /
79                 GetProperty(property.first, per_second_rates_, INFINITY));
80       }
81     }
82     current_properties_[kOptimalSecondsKey] = optimal_seconds;
83   }
84 
85   TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second);
86   for (const auto& property : current_properties_) {
87     properties_sum_[property.first] += property.second;
88   }
89 
90   return Status::OK();
91 }
92 
HandleElementwiseOp(const HloInstruction * hlo_instruction)93 Status HloCostAnalysis::HandleElementwiseOp(
94     const HloInstruction* hlo_instruction) {
95   const auto& shape = hlo_instruction->shape();
96   // For element-wise operations, the number of computations is the same as the
97   // number of elements in the output shape.
98   auto computation_count = ShapeUtil::ElementsIn(shape);
99   auto opcode = hlo_instruction->opcode();
100   // We treat transcendental operations separately since one transcendental
101   // operation can correspond to several floating point ops.
102   // kLogistic is included in "trascendental" as it is implemented using
103   // trascendental ops (tanh or exp).
104   if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
105       opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
106       opcode == HloOpcode::kSqrt || opcode == HloOpcode::kCbrt ||
107       opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
108       opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
109       opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
110       opcode == HloOpcode::kAtan2) {
111     current_properties_[kTranscendentalsKey] = computation_count;
112   } else {
113     // Note: transcendental operations are considered a separate category from
114     // FLOPs.
115     current_properties_[kFlopsKey] = computation_count;
116   }
117   return Status::OK();
118 }
119 
GetProperty(const string & key,const Properties & properties,const float default_value)120 /*static*/ float HloCostAnalysis::GetProperty(const string& key,
121                                               const Properties& properties,
122                                               const float default_value) {
123   auto key_value = properties.find(key);
124   return key_value == properties.end() ? default_value : key_value->second;
125 }
126 
GetPropertyForHlo(const HloInstruction & hlo,const string & key,const HloToProperties & hlo_to_properties)127 /*static*/ float HloCostAnalysis::GetPropertyForHlo(
128     const HloInstruction& hlo, const string& key,
129     const HloToProperties& hlo_to_properties) {
130   auto it = hlo_to_properties.find(&hlo);
131   if (it == hlo_to_properties.end()) {
132     return 0.0f;
133   } else {
134     return GetProperty(key, it->second);
135   }
136 }
137 
GetShapeSize(const Shape & shape) const138 int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const {
139   if (!LayoutUtil::HasLayout(shape)) {
140     return 0;
141   }
142   return shape_size_(shape);
143 }
144 
FusionParameterReadBytes(const HloInstruction * hlo) const145 int64 HloCostAnalysis::FusionParameterReadBytes(
146     const HloInstruction* hlo) const {
147   int64_t size = 0;
148   bool seen_trivial_user = false;
149   CHECK(hlo->IsFused() && (hlo->opcode() == HloOpcode::kParameter ||
150                            hlo->opcode() == HloOpcode::kGetTupleElement));
151   for (const HloInstruction* user : hlo->users()) {
152     switch (user->opcode()) {
153       case HloOpcode::kFusion: {
154         for (int64_t idx : user->OperandIndices(hlo)) {
155           size += FusionParameterReadBytes(user->fused_parameter(idx));
156         }
157         break;
158       }
159       case HloOpcode::kSlice:
160         size += GetShapeSize(user->shape());
161         break;
162       case HloOpcode::kDynamicSlice:
163         size += hlo == user->operand(0) ? GetShapeSize(user->shape())
164                                         : GetShapeSize(hlo->shape());
165         break;
166       case HloOpcode::kDynamicUpdateSlice:
167         // Uses the same shape as 'update' which is operand 1.
168         size += hlo == user->operand(0)
169                     ? GetShapeSize(user->operand(1)->shape())
170                     : GetShapeSize(hlo->shape());
171         break;
172       case HloOpcode::kBroadcast:
173       case HloOpcode::kReshape:
174         size += GetShapeSize(hlo->shape());
175         break;
176       default:
177         // Other instructions reading this parameter are assumed to be able to
178         // share the read from memory.
179         if (!seen_trivial_user) {
180           seen_trivial_user = true;
181           size += GetShapeSize(hlo->shape());
182         }
183     }
184   }
185   return size;
186 }
187 
HandleElementwiseUnary(const HloInstruction * hlo)188 Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) {
189   return HandleElementwiseOp(hlo);
190 }
191 
HandleElementwiseBinary(const HloInstruction * hlo)192 Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) {
193   return HandleElementwiseOp(hlo);
194 }
195 
HandleCompare(const HloInstruction * compare)196 Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) {
197   return HandleElementwiseOp(compare);
198 }
199 
HandleClamp(const HloInstruction * clamp)200 Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) {
201   return HandleElementwiseOp(clamp);
202 }
203 
HandleReducePrecision(const HloInstruction * hlo)204 Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) {
205   return HandleElementwiseOp(hlo);
206 }
207 
HandleParameter(const HloInstruction *)208 Status HloCostAnalysis::HandleParameter(const HloInstruction*) {
209   current_should_compute_bottleneck_time_ = false;
210   current_properties_[kBytesAccessedKey] = 0;
211   SetOutputBytesAccessed(0);
212   current_properties_[kOptimalSecondsKey] = 0;
213   return Status::OK();
214 }
215 
HandleConstant(const HloInstruction *)216 Status HloCostAnalysis::HandleConstant(const HloInstruction*) {
217   current_should_compute_bottleneck_time_ = false;
218   current_properties_[kBytesAccessedKey] = 0;
219   SetOutputBytesAccessed(0);
220   current_properties_[kOptimalSecondsKey] = 0;
221   return Status::OK();
222 }
223 
HandleIota(const HloInstruction *)224 Status HloCostAnalysis::HandleIota(const HloInstruction*) {
225   return Status::OK();
226 }
227 
HandleGetTupleElement(const HloInstruction * get_tuple_element)228 Status HloCostAnalysis::HandleGetTupleElement(
229     const HloInstruction* get_tuple_element) {
230   // GetTupleElement forwards a pointer and does not touch each element in the
231   // output.
232   current_should_compute_bottleneck_time_ = false;
233   current_properties_[kBytesAccessedKey] = 0;
234   SetOutputBytesAccessed(0);
235   SetOperandBytesAccessed(0, 0);
236   current_properties_[kOptimalSecondsKey] = 0;
237   return Status::OK();
238 }
239 
HandleSelect(const HloInstruction * hlo)240 Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
241   return HandleElementwiseOp(hlo);
242 }
243 
HandleTupleSelect(const HloInstruction *)244 Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) {
245   return Status::OK();
246 }
247 
HandleReverse(const HloInstruction *)248 Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
249   return Status::OK();
250 }
251 
HandleSlice(const HloInstruction * slice)252 Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
253   current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2;
254   SetOutputBytesAccessed(GetShapeSize(slice->shape()));
255   SetOperandBytesAccessed(0, GetShapeSize(slice->shape()));
256   return Status::OK();
257 }
258 
HandleDynamicSlice(const HloInstruction * dynamic_slice)259 Status HloCostAnalysis::HandleDynamicSlice(
260     const HloInstruction* dynamic_slice) {
261   current_properties_[kBytesAccessedKey] =
262       GetShapeSize(dynamic_slice->shape()) * 2 +
263       GetShapeSize(dynamic_slice->operand(1)->shape());
264   SetOutputBytesAccessed(GetShapeSize(dynamic_slice->shape()));
265   SetOperandBytesAccessed(0, GetShapeSize(dynamic_slice->shape()));
266   SetOperandBytesAccessed(1, GetShapeSize(dynamic_slice->operand(1)->shape()));
267   return Status::OK();
268 }
269 
HandleDynamicUpdateSlice(const HloInstruction * dynamic_update_slice)270 Status HloCostAnalysis::HandleDynamicUpdateSlice(
271     const HloInstruction* dynamic_update_slice) {
272   current_properties_[kBytesAccessedKey] =
273       GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2 +
274       GetShapeSize(dynamic_update_slice->operand(2)->shape());
275   // Operand 0 aliases with the output.
276   SetOutputBytesAccessed(
277       GetShapeSize(dynamic_update_slice->operand(1)->shape()));
278   SetOperandBytesAccessed(0, 0);
279   SetOperandBytesAccessed(
280       1, GetShapeSize(dynamic_update_slice->operand(1)->shape()));
281   SetOperandBytesAccessed(
282       2, GetShapeSize(dynamic_update_slice->operand(2)->shape()));
283   return Status::OK();
284 }
285 
HandleTuple(const HloInstruction * tuple)286 Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) {
287   // The tuple instruction only gathers pointers from inputs (it doesn't iterate
288   // through them). The memory touched is then only the size of the output
289   // index table of the tuple.
290 
291   current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape());
292   SetOutputBytesAccessed(GetShapeSize(tuple->shape()));
293   for (int i = 0; i < tuple->operand_count(); ++i) {
294     SetOperandBytesAccessed(i, 0);
295   }
296   return Status::OK();
297 }
298 
HandleConcatenate(const HloInstruction *)299 Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) {
300   return Status::OK();
301 }
302 
HandleConvert(const HloInstruction * convert)303 Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) {
304   return HandleElementwiseOp(convert);
305 }
306 
HandleCopy(const HloInstruction *)307 Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
308   return Status::OK();
309 }
310 
HandleDomain(const HloInstruction * domain)311 Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
312   // Domain does not have any computation or data transfer.
313   current_should_compute_bottleneck_time_ = false;
314   current_properties_[kBytesAccessedKey] = 0;
315   SetOutputBytesAccessed(0);
316   for (int i = 0; i < domain->operand_count(); ++i) {
317     SetOperandBytesAccessed(i, 0);
318   }
319   current_properties_[kOptimalSecondsKey] = 0;
320   return Status::OK();
321 }
322 
HandleDot(const HloInstruction * dot)323 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
324   const Shape& lhs_shape = dot->operand(0)->shape();
325   const Shape& dot_shape = dot->shape();
326   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
327   // Count of elements along the reduction dimension (last dimension for the
328   // rhs).
329   int64_t reduction_width = 1;
330   for (auto dim : dnums.lhs_contracting_dimensions()) {
331     reduction_width *= lhs_shape.dimensions(dim);
332   }
333   // Each output element requires reduction_width FMA operations.
334   current_properties_[kFlopsKey] =
335       kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width;
336   return Status::OK();
337 }
338 
HandleInfeed(const HloInstruction * infeed)339 Status HloCostAnalysis::HandleInfeed(const HloInstruction* infeed) {
340   // Count nested infeed output tuples.
341   int64_t size = 0;
342   for (const auto& indexed_shape : ShapeUtil::GetLeafShapes(infeed->shape())) {
343     size += GetShapeSize(indexed_shape.shape);
344     SetOutputBytesAccessed(indexed_shape.index,
345                            GetShapeSize(indexed_shape.shape));
346   }
347   SetOutputBytesAccessed(size);
348   current_properties_[kBytesAccessedKey] = size;
349   return Status::OK();
350 }
351 
HandleOutfeed(const HloInstruction * outfeed)352 Status HloCostAnalysis::HandleOutfeed(const HloInstruction* outfeed) {
353   // Count nested outfeed operand tuples.
354   current_properties_[kBytesAccessedKey] = 0;
355   for (int64_t i = 0; i < outfeed->operand_count(); ++i) {
356     const HloInstruction* operand = outfeed->operand(i);
357     int64_t size = 0;
358     for (const auto& indexed_shape :
359          ShapeUtil::GetLeafShapes(operand->shape())) {
360       size += GetShapeSize(indexed_shape.shape);
361       SetOperandBytesAccessed(i, indexed_shape.index,
362                               GetShapeSize(indexed_shape.shape));
363     }
364     SetOperandBytesAccessed(i, size);
365     current_properties_[kBytesAccessedKey] += size;
366   }
367   return Status::OK();
368 }
369 
HandleMap(const HloInstruction * map)370 Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
371   // Compute properties of the mapped function.
372   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
373                       ProcessSubcomputation(map->to_apply()));
374 
375   // Compute the cost of all elements for this Map operation.
376   const int64_t element_count = ShapeUtil::ElementsIn(map->shape());
377   for (const auto& property : sub_properties) {
378     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
379       current_properties_[property.first] = property.second * element_count;
380     }
381   }
382   return Status::OK();
383 }
384 
HandleReduce(const HloInstruction * reduce)385 Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) {
386   HloComputation* function = reduce->to_apply();
387   // Compute the cost of the user function.
388   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
389                       ProcessSubcomputation(function));
390 
391   // Compute the cost of all elements for this Reduce operation.
392   // This counts the number of times the reduction function is applied, so it
393   // does not need to be multiplied by the number of input tensors - that's
394   // already "priced in" by the sub-computation doing more work.
395   auto arg = reduce->operand(0);
396   auto output_shape = reduce->shape().IsArray()
397                           ? reduce->shape()
398                           : reduce->shape().tuple_shapes(0);
399   int64_t reduction_count =
400       ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape);
401   for (const auto& property : sub_properties) {
402     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
403       current_properties_[property.first] = property.second * reduction_count;
404     }
405   }
406   return Status::OK();
407 }
408 
HandleReduceWindow(const HloInstruction * reduce_window)409 Status HloCostAnalysis::HandleReduceWindow(
410     const HloInstruction* reduce_window) {
411   const Window& window = reduce_window->window();
412   auto function = reduce_window->to_apply();
413   // Compute the properties of the reduction function.
414   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
415                       ProcessSubcomputation(function));
416 
417   // Compute the cost of all elements for this ReduceWindow operation. For each
418   // output element there are window_size - 1 reductions to perform.
419   int64_t window_element_count = 1;
420   for (const auto& dimension : window.dimensions()) {
421     window_element_count *= dimension.size();
422   }
423 
424   const int64_t output_element_count =
425       ShapeUtil::ElementsIn(reduce_window->shape().IsArray()
426                                 ? reduce_window->shape()
427                                 : reduce_window->shape().tuple_shapes(0));
428   const int64_t reduction_count =
429       (window_element_count - 1) * output_element_count;
430   for (const auto& property : sub_properties) {
431     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
432       current_properties_[property.first] = property.second * reduction_count;
433     }
434   }
435   return Status::OK();
436 }
437 
HandleSelectAndScatter(const HloInstruction * instruction)438 Status HloCostAnalysis::HandleSelectAndScatter(
439     const HloInstruction* instruction) {
440   // Compute the properties of the select and scatter function.
441   // Compute the properties of the reduction function.
442   TF_ASSIGN_OR_RETURN(const Properties select_properties,
443                       ProcessSubcomputation(instruction->select()));
444   TF_ASSIGN_OR_RETURN(const Properties scatter_properties,
445                       ProcessSubcomputation(instruction->scatter()));
446 
447   // Compute the cost of all elements for this operation. For each scatter
448   // source element there are window_size - 1 select computations to perform and
449   // 1 scatter computation to perform.
450   const auto source = instruction->operand(1);
451   const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
452   int64_t window_element_count = 1;
453   for (const auto& dimension : instruction->window().dimensions()) {
454     window_element_count *= dimension.size();
455   }
456   const int64_t select_count =
457       source_element_count * (window_element_count - 1);
458   for (const auto& property : select_properties) {
459     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
460       current_properties_[property.first] += property.second * select_count;
461     }
462   }
463   for (const auto& property : scatter_properties) {
464     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
465       current_properties_[property.first] +=
466           property.second * source_element_count;
467     }
468   }
469   return Status::OK();
470 }
471 
HandleBitcast(const HloInstruction *)472 Status HloCostAnalysis::HandleBitcast(const HloInstruction*) {
473   // A bitcast does no computation and touches no memory.
474   current_properties_[kBytesAccessedKey] = 0;
475   SetOutputBytesAccessed(0);
476   SetOperandBytesAccessed(0, 0);
477   current_properties_[kOptimalSecondsKey] = 0;
478   return Status::OK();
479 }
480 
HandleBroadcast(const HloInstruction *)481 Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
482   return Status::OK();
483 }
484 
HandlePad(const HloInstruction *)485 Status HloCostAnalysis::HandlePad(const HloInstruction*) {
486   return Status::OK();
487 }
488 
HandleCopyStart(const HloInstruction *)489 Status HloCostAnalysis::HandleCopyStart(const HloInstruction*) {
490   return Status::OK();
491 }
492 
HandleCopyDone(const HloInstruction *)493 Status HloCostAnalysis::HandleCopyDone(const HloInstruction*) {
494   return Status::OK();
495 }
496 
HandleSend(const HloInstruction *)497 Status HloCostAnalysis::HandleSend(const HloInstruction*) {
498   return Status::OK();
499 }
500 
HandleSendDone(const HloInstruction *)501 Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
502   return Status::OK();
503 }
504 
HandleRecv(const HloInstruction *)505 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
506   return Status::OK();
507 }
508 
HandleRecvDone(const HloInstruction *)509 Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
510   return Status::OK();
511 }
512 
HandleReshape(const HloInstruction *)513 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
514   return Status::OK();
515 }
516 
HandleDynamicReshape(const HloInstruction *)517 Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
518   return Status::OK();
519 }
520 
HandleBatchNormTraining(const HloInstruction *)521 Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
522   // TODO(b/62294698): Implement cost analysis for batch-norm-training.
523   return Status::OK();
524 }
525 
HandleBatchNormInference(const HloInstruction *)526 Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) {
527   // TODO(b/62294698): Implement cost analysis for batch-norm-inference.
528   return Status::OK();
529 }
530 
HandleBatchNormGrad(const HloInstruction *)531 Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) {
532   // TODO(b/62294698): Implement cost analysis for batch-norm-grad.
533   return Status::OK();
534 }
535 
HandleTranspose(const HloInstruction * transpose)536 Status HloCostAnalysis::HandleTranspose(const HloInstruction* transpose) {
537   if (transpose->IsEffectiveBitcast()) {
538     return HandleBitcast(transpose);
539   }
540   return Status::OK();
541 }
542 
HandleAfterAll(const HloInstruction * token)543 Status HloCostAnalysis::HandleAfterAll(const HloInstruction* token) {
544   // This instruction is used to enforce ordering at compile time. No code is
545   // emitted.
546   current_should_compute_bottleneck_time_ = false;
547   current_properties_[kBytesAccessedKey] = 0;
548   SetOutputBytesAccessed(0);
549   for (int i = 0; i < token->operand_count(); ++i) {
550     SetOperandBytesAccessed(i, 0);
551   }
552   current_properties_[kOptimalSecondsKey] = 0;
553   return Status::OK();
554 }
555 
HandleAddDependency(const HloInstruction * add_dependency)556 Status HloCostAnalysis::HandleAddDependency(
557     const HloInstruction* add_dependency) {
558   // This instruction is used to enforce ordering at compile time. No code is
559   // emitted.
560   current_should_compute_bottleneck_time_ = false;
561   current_properties_[kBytesAccessedKey] = 0;
562   SetOutputBytesAccessed(0);
563   for (int i = 0; i < add_dependency->operand_count(); ++i) {
564     SetOperandBytesAccessed(i, 0);
565   }
566   current_properties_[kOptimalSecondsKey] = 0;
567   return Status::OK();
568 }
569 
HandleConvolution(const HloInstruction * convolution)570 Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
571   auto lhs = convolution->operand(0);
572   auto rhs = convolution->operand(1);
573   Window window = convolution->window();
574   const auto& result_shape = convolution->shape();
575   const Shape& lhs_shape = lhs->shape();
576   const Shape& rhs_shape = rhs->shape();
577 
578   const auto& dnums = convolution->convolution_dimension_numbers();
579 
580   const int64_t input_batch_dim = dnums.input_batch_dimension();
581   const int64_t input_feature_dim = dnums.input_feature_dimension();
582   const int64_t output_feature_dim = dnums.output_feature_dimension();
583   const int64_t input_feature =
584       ShapeUtil::GetDimension(lhs_shape, input_feature_dim);
585   const int64_t output_feature =
586       ShapeUtil::GetDimension(result_shape, output_feature_dim);
587   const int64_t batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim);
588 
589   DimensionVector kernel_limits;
590   DimensionVector output_limits;
591   DimensionVector input_limits;
592   if (window.dimensions().empty()) {
593     window = window_util::MakeWindow({1});
594     kernel_limits.push_back(1);
595     output_limits.push_back(1);
596     input_limits.push_back(1);
597   } else {
598     for (int64_t spatial_dimension = 0;
599          spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
600       // Spatial dimension number for kernel (rhs).
601       const int64_t kernel_spatial_dim =
602           dnums.kernel_spatial_dimensions(spatial_dimension);
603       const int64_t kernel_limit = rhs_shape.dimensions(kernel_spatial_dim);
604       kernel_limits.push_back(kernel_limit);
605 
606       // Spatial dimension number for output.
607       const int64_t output_spatial_dim =
608           dnums.output_spatial_dimensions(spatial_dimension);
609       const int64_t output_limit = result_shape.dimensions(output_spatial_dim);
610       output_limits.push_back(output_limit);
611 
612       // Spatial dimension number for input (lhs).
613       const int64_t input_spatial_dim =
614           dnums.input_spatial_dimensions(spatial_dimension);
615       const int64_t input_limit = lhs_shape.dimensions(input_spatial_dim);
616       input_limits.push_back(input_limit);
617     }
618   }
619 
620   DimensionVector valid_position_counts;
621 
622   // Loop over each spatial dimension.
623   for (int64_t spatial_dimension = 0;
624        spatial_dimension < window.dimensions_size(); ++spatial_dimension) {
625     const auto& window_dim = window.dimensions(spatial_dimension);
626     // These two conditions will create an N^2 iteration pattern with only N
627     // valid elements. This is a performance optimization and produces the same
628     // result as the whole loop.
629     if (input_limits[spatial_dimension] == output_limits[spatial_dimension] &&
630         kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
631         input_limits[spatial_dimension] == window_dim.base_dilation() &&
632         window_dim.window_dilation() == 1 &&
633         std::max<int64>(1, input_limits[spatial_dimension] - 1) ==
634             window_dim.stride() &&
635         window_dim.padding_low() == 0 && window_dim.padding_high() == 0) {
636       valid_position_counts.push_back(input_limits[spatial_dimension]);
637       continue;
638     }
639 
640     if (input_limits[spatial_dimension] == 1 &&
641         kernel_limits[spatial_dimension] == output_limits[spatial_dimension] &&
642         window_dim.window_dilation() == 1 && window_dim.base_dilation() == 1 &&
643         window_dim.stride() == 1 &&
644         window_dim.padding_high() == output_limits[spatial_dimension] - 1 &&
645         window_dim.padding_low() == output_limits[spatial_dimension] - 1) {
646       valid_position_counts.push_back(output_limits[spatial_dimension]);
647       continue;
648     }
649 
650     int64_t valid_position_count = 0;
651     // Loop over each point in the kernel.
652     for (int64_t kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension];
653          ++kernel_idx) {
654       // Loop over each point in the output.
655       for (int64_t output_idx = 0;
656            output_idx < output_limits[spatial_dimension]; ++output_idx) {
657         // Calculate lhs (input) index without taking base dilation into
658         // account.
659         const int64_t undilated_index =
660             output_idx * window_dim.stride() - window_dim.padding_low() +
661             kernel_idx * window_dim.window_dilation();
662 
663         // Calculate the actual lhs (input) index after dilation. Avoid the
664         // division as an optimization.
665         const int64_t lhs_spatial_index =
666             window_dim.base_dilation() > 1
667                 ? undilated_index / window_dim.base_dilation()
668                 : undilated_index;
669 
670         // Skip if the lhs (input) index is to be dilated.
671         if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) {
672           continue;
673         }
674 
675         // Skip if input index is not in bound.
676         if (lhs_spatial_index < 0 ||
677             lhs_spatial_index >= input_limits[spatial_dimension]) {
678           continue;
679         }
680 
681         valid_position_count += 1;
682       }
683     }
684     valid_position_counts.push_back(valid_position_count);
685   }
686 
687   const int64_t fma_count =
688       (input_feature / convolution->feature_group_count()) * output_feature *
689       (batch / convolution->batch_group_count()) *
690       Product(valid_position_counts);
691   current_properties_[kFlopsKey] = fma_count * kFmaFlops;
692   return Status::OK();
693 }
694 
HandleFft(const HloInstruction * fft)695 Status HloCostAnalysis::HandleFft(const HloInstruction* fft) {
696   auto real_shape =
697       fft->operand(0)->shape().IsTuple()
698           ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0)
699           : fft->operand(0)->shape();
700   constexpr int kFmaPerComplexMul = 4;
701   int64_t log_factors = 1;
702   for (int64_t dim : fft->fft_length()) {
703     log_factors *= tensorflow::Log2Floor(dim);
704   }
705   current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors *
706                                    ShapeUtil::ElementsIn(real_shape);
707   return Status::OK();
708 }
709 
HandleTriangularSolve(const HloInstruction * hlo)710 Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
711   // Half of operand 0 is read.
712   float bytes_accessed = GetShapeSize(hlo->shape());
713   SetOutputBytesAccessed(GetShapeSize(hlo->shape()));
714   bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
715   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
716   bytes_accessed += GetShapeSize(hlo->operand(1)->shape());
717   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(1)->shape()));
718   current_properties_[kBytesAccessedKey] = bytes_accessed;
719 
720   const Shape& a_shape = hlo->operand(0)->shape();
721   const Shape& b_shape = hlo->operand(1)->shape();
722   // Estimate as batch * mn^2 / 2 flops.
723   int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
724   elems *= ShapeUtil::ElementsIn(b_shape);
725   current_properties_[kFlopsKey] = kFmaFlops * elems;
726   return Status::OK();
727 }
728 
HandleCholesky(const HloInstruction * hlo)729 Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
730   // Half of operand 0 is read and half of the output will be written.
731   float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
732   SetOutputBytesAccessed(GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
733   bytes_accessed += GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
734   SetOperandBytesAccessed(0, GetShapeSize(hlo->operand(0)->shape()) / 2.0f);
735   current_properties_[kBytesAccessedKey] = bytes_accessed;
736 
737   const Shape& a_shape = hlo->operand(0)->shape();
738   // Estimate as batch * n^3 / 3 flops.
739   int64_t elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
740   elems *= ShapeUtil::ElementsIn(a_shape);
741   current_properties_[kFlopsKey] = elems / 3;
742   return Status::OK();
743 }
744 
HandleAllGather(const HloInstruction *)745 Status HloCostAnalysis::HandleAllGather(const HloInstruction* /*hlo*/) {
746   return Status::OK();
747 }
748 
HandleAllGatherStart(const HloInstruction * hlo)749 Status HloCostAnalysis::HandleAllGatherStart(const HloInstruction* hlo) {
750   return HandleAllGather(hlo);
751 }
752 
HandleAllGatherDone(const HloInstruction *)753 Status HloCostAnalysis::HandleAllGatherDone(const HloInstruction* /*hlo*/) {
754   return Status::OK();
755 }
756 
HandleAllReduce(const HloInstruction * crs)757 Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
758   // We assume 2 replicas, so that each output element is the sum of two input
759   // elements.
760   //
761   // TODO(b/33004697): Compute correct cost here, taking the actual number of
762   // replicas into account.
763   double flops = 0.0;
764   int64_t output_bytes_accessed = 0;
765   ShapeUtil::ForEachSubshape(
766       crs->shape(), [&](const Shape& subshape, const ShapeIndex&) {
767         if (subshape.IsArray()) {
768           flops += ShapeUtil::ElementsIn(subshape);
769           output_bytes_accessed += GetShapeSize(subshape);
770         }
771       });
772   int64_t bytes_accessed = output_bytes_accessed;
773   for (const HloInstruction* operand : crs->operands()) {
774     bytes_accessed += GetShapeSize(operand->shape());
775   }
776   current_properties_[kFlopsKey] = flops;
777   SetOutputBytesAccessed(output_bytes_accessed);
778   current_properties_[kBytesAccessedKey] = bytes_accessed;
779   return Status::OK();
780 }
781 
HandleReduceScatter(const HloInstruction * hlo)782 Status HloCostAnalysis::HandleReduceScatter(const HloInstruction* hlo) {
783   return Status::OK();
784 }
785 
HandleAllReduceStart(const HloInstruction * hlo)786 Status HloCostAnalysis::HandleAllReduceStart(const HloInstruction* hlo) {
787   return HandleAllReduce(hlo);
788 }
789 
HandleAllReduceDone(const HloInstruction *)790 Status HloCostAnalysis::HandleAllReduceDone(const HloInstruction* /*hlo*/) {
791   return Status::OK();
792 }
793 
HandleAllToAll(const HloInstruction * hlo)794 Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
795   return Status::OK();
796 }
797 
HandleCollectivePermute(const HloInstruction *)798 Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) {
799   return Status::OK();
800 }
801 
HandleCollectivePermuteStart(const HloInstruction *)802 Status HloCostAnalysis::HandleCollectivePermuteStart(
803     const HloInstruction* /*hlo*/) {
804   return Status::OK();
805 }
806 
HandleCollectivePermuteDone(const HloInstruction *)807 Status HloCostAnalysis::HandleCollectivePermuteDone(
808     const HloInstruction* /*hlo*/) {
809   return Status::OK();
810 }
811 
HandlePartitionId(const HloInstruction *)812 Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) {
813   return Status::OK();
814 }
815 
HandleReplicaId(const HloInstruction *)816 Status HloCostAnalysis::HandleReplicaId(const HloInstruction* /*hlo*/) {
817   return Status::OK();
818 }
819 
HandleRng(const HloInstruction * random)820 Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
821   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
822   // cost changes with the implementation and the distribution. For now, assume
823   // the cost of each RNG is same as a transcendental operation.
824   current_properties_[kTranscendentalsKey] =
825       ShapeUtil::ElementsIn(random->shape());
826   return Status::OK();
827 }
828 
HandleRngBitGenerator(const HloInstruction * random)829 Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
830   // TODO(b/26346211): Implement better estimates for the RNG cost, since the
831   // cost changes with the implementation and the distribution. For now, assume
832   // the cost of each RNG is same as a transcendental operation.
833   current_properties_[kTranscendentalsKey] =
834       ShapeUtil::ElementsInRecursive(random->shape());
835   return Status::OK();
836 }
837 
HandleRngGetAndUpdateState(const HloInstruction * random)838 Status HloCostAnalysis::HandleRngGetAndUpdateState(
839     const HloInstruction* random) {
840   return Status::OK();
841 }
842 
HandleFusion(const HloInstruction * fusion)843 Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) {
844   if (fusion->IsCustomFusion()) {
845     for (const HloInstruction* hlo :
846          fusion->fused_instructions_computation()->instructions()) {
847       if (hlo->opcode() == HloOpcode::kGather) {
848         return HandleGather(hlo);
849       }
850       if (hlo->opcode() == HloOpcode::kScatter) {
851         return HandleScatter(hlo);
852       }
853     }
854   }
855   TF_ASSIGN_OR_RETURN(
856       current_properties_,
857       ProcessSubcomputation(fusion->fused_instructions_computation()));
858 
859   // Fusion nodes that produce a tuple also produce the entries in the tuple.
860   // Ignore the memory accessed inside fused ops, since fusion is supposed to
861   // prevent intermediate data from touching slow memory.
862   current_properties_[kBytesAccessedKey] = 0;
863   ShapeUtil::ForEachSubshape(
864       fusion->shape(),
865       [this, fusion](const Shape& subshape, const ShapeIndex& shape_index) {
866         if (!subshape.IsArray()) {
867           return;
868         }
869         if (shape_index.empty()) {
870           if (fusion->fused_expression_root()->opcode() ==
871               HloOpcode::kDynamicUpdateSlice) {
872             int64_t size = GetShapeSize(
873                 fusion->fused_expression_root()->operand(1)->shape());
874             current_properties_[kBytesAccessedKey] += size;
875             SetOutputBytesAccessed(shape_index, size);
876             return;
877           }
878         } else if (shape_index.size() == 1) {
879           if (fusion->fused_expression_root()->opcode() == HloOpcode::kTuple &&
880               fusion->fused_expression_root()
881                       ->operand(shape_index[0])
882                       ->opcode() == HloOpcode::kDynamicUpdateSlice) {
883             int64_t size = GetShapeSize(fusion->fused_expression_root()
884                                             ->operand(shape_index[0])
885                                             ->operand(1)
886                                             ->shape());
887             current_properties_[kBytesAccessedKey] += size;
888             SetOutputBytesAccessed(shape_index, size);
889             return;
890           }
891         }
892         current_properties_[kBytesAccessedKey] += GetShapeSize(subshape);
893         SetOutputBytesAccessed(shape_index, GetShapeSize(subshape));
894       });
895 
896   if (fusion->shape().IsTuple()) {
897     // Propagate and accumulate the output tuple bytes from the tuple subshapes.
898     // This ensures we have the correct output bytes accessed for the shape
899     // index
900     // {}.
901     std::function<float(const Shape&, const ShapeIndex&)>
902         propagate_output_size_to_parent;
903     propagate_output_size_to_parent = [&](const Shape& shape,
904                                           const ShapeIndex& shape_index) {
905       auto output_bytes_it =
906           current_properties_.find(GetOutputBytesAccessedKey(shape_index));
907       if (output_bytes_it != current_properties_.end()) {
908         return output_bytes_it->second;
909       }
910       float bytes_accessed = 0;
911       for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
912         const Shape& subshape = shape.tuple_shapes(i);
913         ShapeIndex subshape_index(shape_index);
914         subshape_index.push_back(i);
915         bytes_accessed +=
916             propagate_output_size_to_parent(subshape, subshape_index);
917       }
918       SetOutputBytesAccessed(shape_index, bytes_accessed);
919       return bytes_accessed;
920     };
921     current_properties_.erase(
922         current_properties_.find(GetOutputBytesAccessedKey()));
923     propagate_output_size_to_parent(fusion->shape(), {});
924   }
925 
926   for (int64_t i = 0; i < fusion->fused_parameters().size(); ++i) {
927     const HloInstruction* operand = fusion->fused_parameter(i);
928     int64_t operand_size = 0;
929     if (!fusion->shape().IsTuple()) {
930       operand_size = FusionParameterReadBytes(operand);
931     } else {
932       // If the fusion parameter is a tuple type, find the gte for the leaf
933       // shape and calculate the bytes accessed for those array types.
934       for (const auto& indexed_shape :
935            ShapeUtil::GetLeafShapes(operand->shape())) {
936         const HloInstruction* gte = operand;
937         for (int64_t index : indexed_shape.index) {
938           for (const HloInstruction* user : gte->users()) {
939             if (user->opcode() == HloOpcode::kGetTupleElement &&
940                 user->tuple_index() == index) {
941               gte = user;
942               break;
943             }
944           }
945         }
946         int64_t size = FusionParameterReadBytes(gte);
947         operand_size += size;
948         SetOperandBytesAccessed(i, indexed_shape.index, size);
949       }
950     }
951     current_properties_[kBytesAccessedKey] += operand_size;
952     SetOperandBytesAccessed(i, operand_size);
953   }
954 
955   return Status::OK();
956 }
957 
HandleCall(const HloInstruction * call)958 Status HloCostAnalysis::HandleCall(const HloInstruction* call) {
959   TF_ASSIGN_OR_RETURN(current_properties_,
960                       ProcessSubcomputation(call->to_apply()));
961   current_should_compute_bottleneck_time_ = false;
962   return Status::OK();
963 }
964 
HandleCustomCall(const HloInstruction * custom_call)965 Status HloCostAnalysis::HandleCustomCall(const HloInstruction* custom_call) {
966   // Mark applicable fields as "unknown", since we don't know what CustomCall
967   // does.  This is better than returning an error, which would stop iteration,
968   // and therefore would prevent us from getting *any* stats for a computation
969   // which contains a CustomCall.
970   current_properties_[kOptimalSecondsKey] = -1;
971   current_properties_[kBytesAccessedKey] = -1;
972   SetOutputBytesAccessed(-1);
973   for (int i = 0; i < custom_call->operand_count(); ++i) {
974     SetOperandBytesAccessed(i, -1);
975   }
976   current_properties_[kFlopsKey] = -1;
977   current_should_compute_bottleneck_time_ = false;
978   return Status::OK();
979 }
980 
HandleSort(const HloInstruction * sort)981 Status HloCostAnalysis::HandleSort(const HloInstruction* sort) {
982   // This assumes a comparison based N*log(N) algorithm. As for all ops, the
983   // actual properties of the op depend on the backend implementation.
984   int64_t elements = ShapeUtil::ElementsIn(sort->operand(0)->shape());
985   current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements);
986   return Status::OK();
987 }
988 
HandleWhile(const HloInstruction * xla_while)989 Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) {
990   // Since the number of iterations of the while node will not always be
991   // something that we can statically analyze, we cannot precisely compute the
992   // cost of a while node. For now compute the cost of a single iteration.
993   TF_ASSIGN_OR_RETURN(const Properties body_properties,
994                       ProcessSubcomputation(xla_while->while_body()));
995 
996   TF_ASSIGN_OR_RETURN(const Properties condition_properties,
997                       ProcessSubcomputation(xla_while->while_condition()));
998 
999   current_properties_.clear();
1000   for (const auto& property : body_properties) {
1001     current_properties_[property.first] += property.second;
1002   }
1003   for (const auto& property : condition_properties) {
1004     current_properties_[property.first] += property.second;
1005   }
1006   current_should_compute_bottleneck_time_ = false;
1007 
1008   return Status::OK();
1009 }
1010 
HandleConditional(const HloInstruction * conditional)1011 Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
1012   // Compute the cost of the branch computations and take the maximum from those
1013   // for each property.
1014   TF_ASSIGN_OR_RETURN(
1015       const Properties branch0_computation_properties,
1016       ProcessSubcomputation(conditional->branch_computation(0)));
1017   current_properties_ = branch0_computation_properties;
1018   for (int j = 1; j < conditional->branch_count(); ++j) {
1019     TF_ASSIGN_OR_RETURN(
1020         const Properties branch_computation_properties,
1021         ProcessSubcomputation(conditional->branch_computation(j)));
1022     for (const auto& property : branch_computation_properties) {
1023       if (!tensorflow::gtl::InsertIfNotPresent(&current_properties_,
1024                                                property)) {
1025         auto& current_property = current_properties_[property.first];
1026         current_property = std::max(current_property, property.second);
1027       }
1028     }
1029   }
1030   current_should_compute_bottleneck_time_ = false;
1031 
1032   return Status::OK();
1033 }
1034 
HandleGather(const HloInstruction * gather)1035 Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
1036   // Gather doesn't read the whole input buffer, it's equivalent to a copy the
1037   // size of the output shape and a read of the gather indices.
1038   int64_t output_size = GetShapeSize(gather->shape());
1039   current_properties_[kBytesAccessedKey] =
1040       output_size * 2 + GetShapeSize(gather->operand(1)->shape());
1041   SetOperandBytesAccessed(0, output_size);
1042   SetOperandBytesAccessed(1, GetShapeSize(gather->operand(1)->shape()));
1043   SetOutputBytesAccessed(output_size);
1044   // Gather does not issue any flops.
1045   return Status::OK();
1046 }
1047 
HandleScatter(const HloInstruction * scatter)1048 Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
1049   // Scatter accesses the equivalent of 3 update shapes (input, output, and
1050   // updates), and the scatter indices.
1051   int64_t update_size = GetShapeSize(scatter->operand(2)->shape());
1052   current_properties_[kBytesAccessedKey] =
1053       update_size * 3 + GetShapeSize(scatter->operand(1)->shape());
1054   SetOperandBytesAccessed(0, update_size);
1055   SetOperandBytesAccessed(1, GetShapeSize(scatter->operand(1)->shape()));
1056   SetOperandBytesAccessed(2, update_size);
1057   SetOutputBytesAccessed(update_size);
1058   const int64_t element_count =
1059       ShapeUtil::ElementsIn(scatter->operand(2)->shape());
1060   TF_ASSIGN_OR_RETURN(const Properties sub_properties,
1061                       ProcessSubcomputation(scatter->to_apply()));
1062   for (const auto& property : sub_properties) {
1063     if (!absl::StartsWith(property.first, kBytesAccessedKey)) {
1064       current_properties_[property.first] = property.second * element_count;
1065     }
1066   }
1067   return Status::OK();
1068 }
1069 
HandleGetDimensionSize(const HloInstruction *)1070 Status HloCostAnalysis::HandleGetDimensionSize(
1071     const HloInstruction* /*get_size*/) {
1072   return Status::OK();
1073 }
1074 
HandleSetDimensionSize(const HloInstruction *)1075 Status HloCostAnalysis::HandleSetDimensionSize(
1076     const HloInstruction* /*set_size*/) {
1077   return Status::OK();
1078 }
1079 
FinishVisit(const HloInstruction *)1080 Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
1081   return Status::OK();
1082 }
1083 
flop_count() const1084 float HloCostAnalysis::flop_count() const {
1085   return GetProperty(kFlopsKey, properties_sum_);
1086 }
1087 
transcendental_count() const1088 float HloCostAnalysis::transcendental_count() const {
1089   return GetProperty(kTranscendentalsKey, properties_sum_);
1090 }
1091 
bytes_accessed() const1092 float HloCostAnalysis::bytes_accessed() const {
1093   return GetProperty(kBytesAccessedKey, properties_sum_);
1094 }
1095 
optimal_seconds() const1096 float HloCostAnalysis::optimal_seconds() const {
1097   return GetProperty(kOptimalSecondsKey, properties_sum_);
1098 }
1099 
flop_count(const HloInstruction & hlo) const1100 int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const {
1101   return GetPropertyForHlo(hlo, kFlopsKey, hlo_properties_);
1102 }
1103 
transcendental_count(const HloInstruction & hlo) const1104 int64 HloCostAnalysis::transcendental_count(const HloInstruction& hlo) const {
1105   return GetPropertyForHlo(hlo, kTranscendentalsKey, hlo_properties_);
1106 }
1107 
bytes_accessed(const HloInstruction & hlo) const1108 int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const {
1109   return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_);
1110 }
1111 
operand_bytes_accessed(const HloInstruction & hlo,int64_t operand_num,ShapeIndex index) const1112 int64 HloCostAnalysis::operand_bytes_accessed(const HloInstruction& hlo,
1113                                               int64_t operand_num,
1114                                               ShapeIndex index) const {
1115   return GetPropertyForHlo(hlo, GetOperandBytesAccessedKey(operand_num, index),
1116                            hlo_properties_);
1117 }
1118 
output_bytes_accessed(const HloInstruction & hlo,ShapeIndex index) const1119 int64 HloCostAnalysis::output_bytes_accessed(const HloInstruction& hlo,
1120                                              ShapeIndex index) const {
1121   return GetPropertyForHlo(hlo, GetOutputBytesAccessedKey(index),
1122                            hlo_properties_);
1123 }
1124 
optimal_seconds(const HloInstruction & hlo) const1125 float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const {
1126   return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_);
1127 }
1128 
GetBytesRead(const HloInstruction & hlo,absl::optional<int64> memory_space) const1129 int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo,
1130                                     absl::optional<int64> memory_space) const {
1131   int64_t bytes_read = 0;
1132   for (int operand_number = 0; operand_number < hlo.operand_count();
1133        ++operand_number) {
1134     for (const ShapeUtil::IndexedShape& indexed_shape :
1135          ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) {
1136       absl::optional<int64> index_memory_space;
1137       if (indexed_shape.shape.has_layout()) {
1138         index_memory_space = indexed_shape.shape.layout().memory_space();
1139       }
1140       if (!memory_space || memory_space == index_memory_space) {
1141         bytes_read +=
1142             operand_bytes_accessed(hlo, operand_number, indexed_shape.index);
1143       }
1144     }
1145   }
1146   return bytes_read;
1147 }
1148 
GetBytesWritten(const HloInstruction & hlo,absl::optional<int64> memory_space) const1149 int64 HloCostAnalysis::GetBytesWritten(
1150     const HloInstruction& hlo, absl::optional<int64> memory_space) const {
1151   int64_t bytes_written = 0;
1152   for (const ShapeUtil::IndexedShape& indexed_shape :
1153        ShapeUtil::GetLeafShapes(hlo.shape())) {
1154     absl::optional<int64> index_memory_space;
1155     if (indexed_shape.shape.has_layout()) {
1156       index_memory_space = indexed_shape.shape.layout().memory_space();
1157     }
1158     if (!memory_space || memory_space == index_memory_space) {
1159       bytes_written += output_bytes_accessed(hlo, indexed_shape.index);
1160     }
1161   }
1162   return bytes_written;
1163 }
1164 
ProcessSubcomputation(HloComputation * computation)1165 StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation(
1166     HloComputation* computation) {
1167   auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_);
1168   visitor->ReserveVisitStates(computation->instruction_count());
1169   TF_RETURN_IF_ERROR(computation->Accept(visitor.get()));
1170   hlo_properties_.insert(visitor->hlo_properties_.begin(),
1171                          visitor->hlo_properties_.end());
1172   return visitor->properties();
1173 }
1174 
CreateNestedCostAnalysis(const ShapeSizeFunction & shape_size,const Properties & per_second_rates)1175 std::unique_ptr<HloCostAnalysis> HloCostAnalysis::CreateNestedCostAnalysis(
1176     const ShapeSizeFunction& shape_size, const Properties& per_second_rates) {
1177   return absl::WrapUnique(new HloCostAnalysis(shape_size, per_second_rates));
1178 }
1179 
SetOperandBytesAccessed(int64_t operand_num,float value)1180 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1181                                               float value) {
1182   current_properties_[GetOperandBytesAccessedKey(operand_num).c_str()] = value;
1183 }
1184 
SetOperandBytesAccessed(int64_t operand_num,ShapeIndex index,float value)1185 void HloCostAnalysis::SetOperandBytesAccessed(int64_t operand_num,
1186                                               ShapeIndex index, float value) {
1187   current_properties_[GetOperandBytesAccessedKey(operand_num, index).c_str()] =
1188       value;
1189 }
1190 
SetOutputBytesAccessed(float value)1191 void HloCostAnalysis::SetOutputBytesAccessed(float value) {
1192   current_properties_[GetOutputBytesAccessedKey()] = value;
1193 }
1194 
SetOutputBytesAccessed(ShapeIndex index,float value)1195 void HloCostAnalysis::SetOutputBytesAccessed(ShapeIndex index, float value) {
1196   current_properties_[GetOutputBytesAccessedKey(index)] = value;
1197 }
1198 
GetOperandBytesAccessedKey(int64_t operand_num,ShapeIndex index)1199 /*static*/ std::string HloCostAnalysis::GetOperandBytesAccessedKey(
1200     int64_t operand_num, ShapeIndex index) {
1201   return absl::StrCat(kBytesAccessedKey, " operand ", operand_num, " ",
1202                       index.ToString());
1203 }
1204 
GetOutputBytesAccessedKey(ShapeIndex index)1205 /*static*/ std::string HloCostAnalysis::GetOutputBytesAccessedKey(
1206     ShapeIndex index) {
1207   return absl::StrCat(kBytesAccessedKey, " output ", index.ToString());
1208 }
1209 
1210 }  // namespace xla
1211