• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace xla {
33 
34 // DynamicDimensionInference analyzes each HLO instruction in a graph and
35 // inferences which dimensions are dynamic and which scalar instructions
36 // represent the runtime real size of those dynamic dimensions.
37 class DynamicDimensionInference {
38  public:
39   static StatusOr<DynamicDimensionInference> Run(HloModule* module);
40 
41   string ToString() const;
42 
43   // If the dimension `dim` of instruction `inst` at `index` has a dynamic size,
44   // returns a scalar HloInstruction that represents the runtime size of that
45   // dimension. Otherwise returns nullptr.
46   HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
47                                  int64 dim) const;
48 
49   friend class DynamicDimensionInferenceVisitor;
50 
51  private:
52   explicit DynamicDimensionInference(HloModule* module);
53 
54   // DynamicDimension is used as a key in the dynamic key-value mapping. It
55   // unambiguously represents a dynamic dimension of a instruction at a given
56   // index.
57   struct DynamicDimension {
58     // HloInstruction that holds the dimension.
59     HloInstruction* inst;
60     // Subshape of the instruction that holds the dimension.
61     ShapeIndex index;
62     // The dimension number of the dynamic dimension at given index of a given
63     // instruction.
64     int64 dim;
65 
66     // Artifacts needed to make this struct able to be used as a `key` in absl
67     // maps. "friend" keywords are added so these functions can be found through
68     // ADL.
69     template <typename H>
AbslHashValueDynamicDimension70     friend H AbslHashValue(H h, const DynamicDimension& m) {
71       return H::combine(std::move(h), m.inst, m.index, m.dim);
72     }
73 
74     friend bool operator==(const DynamicDimension& lhs,
75                            const DynamicDimension& rhs) {
76       return lhs.inst == rhs.inst && lhs.index == rhs.index &&
77              lhs.dim == rhs.dim;
78     }
79   };
80 
81   // Update the dynamic mapping so that we know dimension `dim` of instruction
82   // `inst` at `index` has a dynamic size, and its runtime size is represented
83   // by a scalar instruction `size`.
SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim,HloInstruction * size)84   void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
85                       HloInstruction* size) {
86     dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size);
87     auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
88     iter.first->second.emplace(DynamicDimension{inst, index, dim});
89   }
90 
91   // Copies the internal mapping from instruction `from` to instruction `to`.
92   // This is useful when an instruction is replaced by the other during the
93   // inferencing process.
94   void CopyMapping(HloInstruction* from, HloInstruction* to);
95 
96   // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in
97   // module_.
98   Status AnalyzeDynamicDimensions();
99 
100   // HloModule being analyzed.
101   HloModule* module_;
102 
103   // dynamic_mapping_ holds the result of the analysis. It maps a dynamic
104   // dimension to a scalar HloInstruction that represents the real dynamic size
105   // of the dynamic dimension.
106   using DynamicMapping = absl::flat_hash_map<DynamicDimension, HloInstruction*>;
107   DynamicMapping dynamic_mapping_;
108 
109   // A convenient mapping from an hlo to the set of dynamic dimensions that it
110   // holds.
111   using PerHloDynamicDimensions =
112       absl::flat_hash_map<HloInstruction*,
113                           absl::flat_hash_set<DynamicDimension>>;
114   PerHloDynamicDimensions per_hlo_dynamic_dimensions_;
115 };
116 
117 }  // namespace xla
118 
119 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_
120