1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_module.h" 27 #include "tensorflow/compiler/xla/status.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace xla { 33 34 // DynamicDimensionInference analyzes each HLO instruction in a graph and 35 // inferences which dimensions are dynamic and which scalar instructions 36 // represent the runtime real size of those dynamic dimensions. 37 class DynamicDimensionInference { 38 public: 39 using CustomCallInferenceHandler = 40 std::function<Status(HloInstruction*, DynamicDimensionInference*)>; 41 42 static StatusOr<DynamicDimensionInference> Run( 43 HloModule* module, 44 CustomCallInferenceHandler custom_call_handler = nullptr); 45 46 string ToString() const; 47 48 // If the dimension `dim` of instruction `inst` at `index` has a dynamic size, 49 // returns a scalar HloInstruction that represents the runtime size of that 50 // dimension. Otherwise returns nullptr. 51 HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, 52 int64 dim) const; 53 54 // Returns dynamic sizes of all dimensions of `inst`'s leaf node at `index`. 55 // Static sizes are represented by nullptr. 56 std::vector<HloInstruction*> GetDynamicSizes(HloInstruction* inst, 57 const ShapeIndex& index) const; 58 59 // Returns if current instruction contains any dynamic dimension. 60 // Recursively go into tuples. 61 bool HasDynamicDimension(HloInstruction* inst) const; 62 63 // Forward dynamic dimension size at `dim` from `inst` to `new_inst`. 64 Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst, 65 const ShapeIndex& index); 66 67 // Update the dynamic mapping so that we know dimension `dim` of instruction 68 // `inst` at `index` has a dynamic size, and its runtime size is represented 69 // by a scalar instruction `size`. 70 void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, 71 HloInstruction* size); 72 73 // For all tensors whose dynamic dimension is `replace`, replace them with 74 // `with`. 75 void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace, 76 HloInstruction* with); 77 78 friend class DynamicDimensionInferenceVisitor; 79 80 private: 81 explicit DynamicDimensionInference( 82 HloModule* module, CustomCallInferenceHandler custom_call_handler); 83 84 // DynamicDimension is used as a key in the dynamic key-value mapping. It 85 // unambiguously represents a dynamic dimension of a instruction at a given 86 // index. 87 struct DynamicDimension { 88 // HloInstruction that holds the dimension. 89 HloInstruction* inst; 90 // Subshape of the instruction that holds the dimension. 91 ShapeIndex index; 92 // The dimension number of the dynamic dimension at given index of a given 93 // instruction. 94 int64 dim; 95 96 // Artifacts needed to make this struct able to be used as a `key` in absl 97 // maps. "friend" keywords are added so these functions can be found through 98 // ADL. 99 template <typename H> AbslHashValueDynamicDimension100 friend H AbslHashValue(H h, const DynamicDimension& m) { 101 return H::combine(std::move(h), m.inst, m.index, m.dim); 102 } 103 104 friend bool operator==(const DynamicDimension& lhs, 105 const DynamicDimension& rhs) { 106 return lhs.inst == rhs.inst && lhs.index == rhs.index && 107 lhs.dim == rhs.dim; 108 } 109 }; 110 111 // Copies the internal mapping from instruction `from` to instruction `to`. 112 // This is useful when an instruction is replaced by the other during the 113 // inferencing process. 114 void CopyMapping(HloInstruction* from, HloInstruction* to); 115 116 // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in 117 // module_. 118 Status AnalyzeDynamicDimensions(); 119 120 // HloModule being analyzed. 121 HloModule* module_; 122 123 // dynamic_mapping_ holds the result of the analysis. It maps a dynamic 124 // dimension to a scalar HloInstruction that represents the real dynamic size 125 // of the dynamic dimension. 126 using DynamicMapping = absl::flat_hash_map<DynamicDimension, HloInstruction*>; 127 DynamicMapping dynamic_mapping_; 128 129 // A convenient mapping from an hlo to the set of dynamic dimensions that it 130 // holds. 131 using PerHloDynamicDimensions = 132 absl::flat_hash_map<HloInstruction*, 133 absl::flat_hash_set<DynamicDimension>>; 134 PerHloDynamicDimensions per_hlo_dynamic_dimensions_; 135 136 // A handler for custom calls. 137 CustomCallInferenceHandler custom_call_handler_; 138 }; 139 140 } // namespace xla 141 142 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ 143