• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
18 
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/service/hlo.pb.h"
24 #include "tensorflow/compiler/xla/shape_tree.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 
27 namespace xla {
28 
29 class HloModule;
30 
31 // This class specifies the alias map from output index to parameter number and
32 // parameter index in the entry computation.
33 class HloInputOutputAliasConfig {
34  public:
35   // The kind of aliases which can be set. A kMayAlias is one setup at
36   // compilation time by the user, and has to be respected. A kMustAlias one
37   // might be setup by the compiler, if it decides it is convenient to do so.
38   enum AliasKind {
39     kMayAlias,
40     kMustAlias,
41   };
42   // Defines the alias information for a given output buffer. A given output
43   // buffer shape index can refer only to one parameter+index.
44   struct Alias {
45     Alias(int64 parameter_number, ShapeIndex parameter_index,
46           AliasKind kind = kMayAlias)
parameter_numberAlias47         : parameter_number(parameter_number),
48           parameter_index(std::move(parameter_index)),
49           kind(kind) {}
50 
51     int64 parameter_number;
52     ShapeIndex parameter_index;
53     AliasKind kind;
54 
must_aliasAlias55     bool must_alias() const { return kind == kMustAlias; }
56 
ToStringAlias57     std::string ToString() {
58       return absl::StrFormat("(%lld, %s, %s)", parameter_number,
59                              parameter_index.ToString(),
60                              kind == kMustAlias ? "must-alias" : "may-alias");
61     }
62   };
63 
64   HloInputOutputAliasConfig() = default;
65 
HloInputOutputAliasConfig(Shape output_shape)66   explicit HloInputOutputAliasConfig(Shape output_shape)
67       : alias_(std::move(output_shape)) {}
68 
69   virtual ~HloInputOutputAliasConfig() = default;
70 
71   // Sets up alias config from `output_index` to `param_index` at
72   // `param_number`.
73   Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
74                     const ShapeIndex& param_index,
75                     AliasKind must_alias = kMayAlias);
76 
77   // Returns true if the given parameter is aliased with one of the output
78   // buffers.
ParameterHasAlias(int64 param_number,const ShapeIndex & param_index)79   bool ParameterHasAlias(int64 param_number,
80                          const ShapeIndex& param_index) const {
81     return GetAliasedOutput(param_number, param_index).has_value();
82   }
83 
84   // Checks whether the provided output index has already been aliased.
85   bool OutputHasAlias(const ShapeIndex& output_index) const;
86 
87   // (De)Serializes an HloInputOutputAliasConfig to/from an
88   // HloInputOutputAliasProto.
89   HloInputOutputAliasProto ToProto() const;
90 
91   static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
92       Shape output_shape, const HloInputOutputAliasProto& proto);
93 
94   // Returns the output index that the given parameter and parameter index is
95   // aliased with. A nullopt is returned if there is no output that is aliased
96   // with the parameter number and index.
97   absl::optional<ShapeIndex> GetAliasedOutput(
98       int64 param_number, const ShapeIndex& param_index) const;
99 
100   // Returns the number of parameter and index of the parameter buffer that the
101   // given output buffer index is aliased with. A nullopt is returned if there
102   // is no parameter is aliased with the specific output.
103   absl::optional<Alias> GetAliasedParameter(
104       const ShapeIndex& output_index) const;
105 
106   // Returns if the parameter at the given parameter number and parameter
107   // index must-alias with an output.
108   bool ParameterMustAlias(int64 param_number,
109                           const ShapeIndex& param_index) const;
110 
111   using AliasFn =
112       std::function<void(const ShapeIndex& output_index, const Alias&)>;
113 
114   // Iterates through each aliased output and input.
115   void ForEachAlias(AliasFn fn) const;
116 
117   using AliasFnWithStatus =
118       std::function<Status(const ShapeIndex& output_index, const Alias&)>;
119 
120   // Verifies that the given config is valid for the given module.
121   // Specifically, the config's input and output should be in-bound and size of
122   // the aliased buffers should match.
123   Status Verify(const HloModule& module,
124                 std::function<int64(const Shape&)> size_func_) const;
125 
126   Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
127 
128   // Returns the shape of the output of the alias config.
129   const Shape& shape() const;
130 
131   string ToString() const;
132 
133   string ToShortString() const;
134 
135  private:
136   // A ShapeTree which indicates the list of buffers that's expected to be
137   // aliased. The key on this shape tree represents the output index. The value
138   // is an Alias data structure which defines the input parameter coordinates.
139   // If the value is nullopt, it means there is no parameter aliasing for this
140   // output.
141   ShapeTree<absl::optional<Alias>> alias_;
142 };
143 
144 std::ostream& operator<<(std::ostream& out,
145                          const HloInputOutputAliasConfig& config);
146 
147 }  // namespace xla
148 
149 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
150