• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/platform/macros.h"
32 
33 namespace xla {
34 
35 // DynamicDimensionInference analyzes each HLO instruction in a graph and
36 // inferences which dimensions are dynamic and which scalar instructions
37 // represent the runtime real size of those dynamic dimensions.
38 class DynamicDimensionInference {
39  public:
40   using CustomCallInferenceHandler =
41       std::function<Status(HloInstruction*, DynamicDimensionInference*)>;
42 
43   static StatusOr<DynamicDimensionInference> Run(
44       HloModule* module,
45       CustomCallInferenceHandler custom_call_handler = nullptr);
46 
47   string ToString() const;
48 
49   // If the dimension `dim` of instruction `inst` at `index` has a dynamic size,
50   // returns a scalar HloInstruction that represents the runtime size of that
51   // dimension. Otherwise returns nullptr.
52   HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
53                                  int64_t dim) const;
54 
55   // Returns dynamic sizes of all dimensions of `inst`'s leaf node at `index`.
56   // Static sizes are represented by nullptr.
57   std::vector<HloInstruction*> GetDynamicSizes(HloInstruction* inst,
58                                                const ShapeIndex& index) const;
59 
60   // Returns if `index` at `inst` contains any dynamic dimension.
61   // Recursively go into tuples.
62   bool HasDynamicDimension(HloInstruction* inst,
63                            ShapeIndexView index = {}) const;
64 
65   // Forward dynamic dimension size at `dim` from `inst` to `new_inst`.
66   Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst,
67                             const ShapeIndex& index);
68 
69   // Update the dynamic mapping so that we know dimension `dim` of instruction
70   // `inst` at `index` has a dynamic size, and its runtime size is represented
71   // by a scalar instruction `size`.
72   void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
73                       int64_t dim, HloInstruction* size);
74 
75   // For all tensors whose dynamic dimension is `replace`, replace them with
76   // `with`.
77   void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
78                                           HloInstruction* with);
79 
80   // Update dynamic dimension inference to analyze `inst`. Useful to
81   // incrementally track new instructions added after initial run.
82   Status Update(HloInstruction* inst);
83 
84   friend class DynamicDimensionInferenceVisitor;
85 
86  private:
87   explicit DynamicDimensionInference(
88       HloModule* module, CustomCallInferenceHandler custom_call_handler);
89 
90   // DynamicDimension is used as a key in the dynamic key-value mapping. It
91   // unambiguously represents a dynamic dimension of a instruction at a given
92   // index.
93   struct DynamicDimension {
94     // HloInstruction that holds the dimension.
95     HloInstruction* inst;
96     // Subshape of the instruction that holds the dimension.
97     ShapeIndex index;
98     // The dimension number of the dynamic dimension at given index of a given
99     // instruction.
100     int64 dim;
101 
102     // Artifacts needed to make this struct able to be used as a `key` in absl
103     // maps. "friend" keywords are added so these functions can be found through
104     // ADL.
105     template <typename H>
AbslHashValueDynamicDimension106     friend H AbslHashValue(H h, const DynamicDimension& m) {
107       return H::combine(std::move(h), m.inst, m.index, m.dim);
108     }
109 
110     friend bool operator==(const DynamicDimension& lhs,
111                            const DynamicDimension& rhs) {
112       return lhs.inst == rhs.inst && lhs.index == rhs.index &&
113              lhs.dim == rhs.dim;
114     }
115   };
116 
117   // Copies the internal mapping from instruction `from` to instruction `to`.
118   // This is useful when an instruction is replaced by the other during the
119   // inferencing process.
120   void CopyMapping(HloInstruction* from, HloInstruction* to);
121 
122   // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in
123   // module_.
124   Status AnalyzeDynamicDimensions();
125 
126   // HloModule being analyzed.
127   HloModule* module_;
128 
129   // dynamic_mapping_ holds the result of the analysis. It maps a dynamic
130   // dimension to a scalar HloInstruction that represents the real dynamic size
131   // of the dynamic dimension.
132   using DynamicMapping = absl::flat_hash_map<DynamicDimension, HloInstruction*>;
133   DynamicMapping dynamic_mapping_;
134 
135   // A convenient mapping from an hlo to the set of dynamic dimensions that it
136   // holds.
137   using PerHloDynamicDimensions =
138       absl::flat_hash_map<HloInstruction*,
139                           absl::flat_hash_set<DynamicDimension>>;
140   PerHloDynamicDimensions per_hlo_dynamic_dimensions_;
141 
142   // A handler for custom calls.
143   CustomCallInferenceHandler custom_call_handler_;
144 };
145 
146 }  // namespace xla
147 
148 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
149