1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_module.h" 27 #include "tensorflow/compiler/xla/status.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace xla { 33 34 // DynamicDimensionInference analyzes each HLO instruction in a graph and 35 // inferences which dimensions are dynamic and which scalar instructions 36 // represent the runtime real size of those dynamic dimensions. 37 class DynamicDimensionInference { 38 public: 39 static StatusOr<DynamicDimensionInference> Run(HloModule* module); 40 41 string ToString() const; 42 43 // If the dimension `dim` of instruction `inst` at `index` has a dynamic size, 44 // returns a scalar HloInstruction that represents the runtime size of that 45 // dimension. Otherwise returns nullptr. 46 HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, 47 int64 dim) const; 48 49 friend class DynamicDimensionInferenceVisitor; 50 51 private: 52 explicit DynamicDimensionInference(HloModule* module); 53 54 // DynamicDimension is used as a key in the dynamic key-value mapping. It 55 // unambiguously represents a dynamic dimension of a instruction at a given 56 // index. 57 struct DynamicDimension { 58 // HloInstruction that holds the dimension. 59 HloInstruction* inst; 60 // Subshape of the instruction that holds the dimension. 61 ShapeIndex index; 62 // The dimension number of the dynamic dimension at given index of a given 63 // instruction. 64 int64 dim; 65 66 // Artifacts needed to make this struct able to be used as a `key` in absl 67 // maps. "friend" keywords are added so these functions can be found through 68 // ADL. 69 template <typename H> AbslHashValueDynamicDimension70 friend H AbslHashValue(H h, const DynamicDimension& m) { 71 return H::combine(std::move(h), m.inst, m.index, m.dim); 72 } 73 74 friend bool operator==(const DynamicDimension& lhs, 75 const DynamicDimension& rhs) { 76 return lhs.inst == rhs.inst && lhs.index == rhs.index && 77 lhs.dim == rhs.dim; 78 } 79 }; 80 81 // Update the dynamic mapping so that we know dimension `dim` of instruction 82 // `inst` at `index` has a dynamic size, and its runtime size is represented 83 // by a scalar instruction `size`. SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim,HloInstruction * size)84 void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, 85 HloInstruction* size) { 86 dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size); 87 auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); 88 iter.first->second.emplace(DynamicDimension{inst, index, dim}); 89 } 90 91 // Copies the internal mapping from instruction `from` to instruction `to`. 92 // This is useful when an instruction is replaced by the other during the 93 // inferencing process. 94 void CopyMapping(HloInstruction* from, HloInstruction* to); 95 96 // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in 97 // module_. 98 Status AnalyzeDynamicDimensions(); 99 100 // HloModule being analyzed. 101 HloModule* module_; 102 103 // dynamic_mapping_ holds the result of the analysis. It maps a dynamic 104 // dimension to a scalar HloInstruction that represents the real dynamic size 105 // of the dynamic dimension. 106 using DynamicMapping = absl::flat_hash_map<DynamicDimension, HloInstruction*>; 107 DynamicMapping dynamic_mapping_; 108 109 // A convenient mapping from an hlo to the set of dynamic dimensions that it 110 // holds. 111 using PerHloDynamicDimensions = 112 absl::flat_hash_map<HloInstruction*, 113 absl::flat_hash_set<DynamicDimension>>; 114 PerHloDynamicDimensions per_hlo_dynamic_dimensions_; 115 }; 116 117 } // namespace xla 118 119 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ 120