• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_module.h"
19 
20 namespace xla {
21 
OutputHasAlias(const ShapeIndex & output_index) const22 bool HloInputOutputAliasConfig::OutputHasAlias(
23     const ShapeIndex& output_index) const {
24   return alias_.element(output_index).has_value();
25 }
26 
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,AliasKind kind)27 Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
28                                              int64 param_number,
29                                              const ShapeIndex& param_index,
30                                              AliasKind kind) {
31   TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias)
32       << kind;
33   TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
34       << "Trying to set up alias at " << output_index.ToString()
35       << " which is an invalid index for shape "
36       << ShapeUtil::HumanString(alias_.shape());
37   TF_RET_CHECK(param_number >= 0) << param_number;
38   TF_RET_CHECK(!OutputHasAlias(output_index))
39       << "Output index " << output_index << " already has an alias setup";
40   // Output can't be aliased with multiple parameters.
41   TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
42       "Trying to set up output alias for param %lld at %s but failed: output "
43       "index %s is already aliased with param %lld at %s",
44       param_number, param_index.ToString(), output_index.ToString(),
45       alias_.element(output_index)->parameter_number,
46       alias_.element(output_index)->parameter_index.ToString());
47   (*alias_.mutable_element(output_index)) =
48       Alias(kind, param_number, param_index);
49   VLOG(4) << "Set up alias between output index " << output_index.ToString()
50           << " and parameter " << param_index << " at index "
51           << param_index.ToString();
52   return Status::OK();
53 }
54 
ToProto() const55 HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
56   HloInputOutputAliasProto result;
57   alias_.ForEachElement(
58       [&](const ShapeIndex& index, const absl::optional<Alias>& data) {
59         if (data) {
60           HloInputOutputAliasProto::AliasEntryProto entry;
61           switch (data->kind) {
62             case AliasKind::kUserAlias:
63               entry.set_kind(HloInputOutputAliasProto::USER_ALIAS);
64               break;
65             case AliasKind::kSystemAlias:
66               entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS);
67               break;
68             default:
69               LOG(FATAL) << "Unknown alias kind " << data->kind;
70           }
71           for (int64 i : index) {
72             entry.add_output_shape_index(i);
73           }
74           entry.set_parameter_number(data->parameter_number);
75           for (int64 i : data->parameter_index) {
76             entry.add_parameter_shape_index(i);
77           }
78           result.add_entries()->Swap(&entry);
79         }
80       });
81   return result;
82 }
83 
CreateFromProto(Shape output_shape,const HloInputOutputAliasProto & proto)84 StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
85     Shape output_shape, const HloInputOutputAliasProto& proto) {
86   HloInputOutputAliasConfig result(std::move(output_shape));
87   for (const HloInputOutputAliasProto::AliasEntryProto& entry :
88        proto.entries()) {
89     ShapeIndex output_index(entry.output_shape_index().begin(),
90                             entry.output_shape_index().end());
91     int64 param_number = entry.parameter_number();
92     ShapeIndex param_index(entry.parameter_shape_index().begin(),
93                            entry.parameter_shape_index().end());
94     // Handle backward compatibility with existing protos, which only knew of
95     // system aliases.
96     AliasKind kind = AliasKind::kSystemAlias;
97     if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) {
98       kind = AliasKind::kUserAlias;
99     }
100     TF_RETURN_IF_ERROR(
101         result.SetUpAlias(output_index, param_number, param_index, kind));
102   }
103   return result;
104 }
105 
shape() const106 const Shape& HloInputOutputAliasConfig::shape() const { return alias_.shape(); }
107 
ToString() const108 string HloInputOutputAliasConfig::ToString() const {
109   std::vector<string> pieces;
110   pieces.push_back("HloInputOutputAliasConfig");
111   pieces.push_back(
112       absl::StrFormat("  Output shape: %s", alias_.shape().ToString()));
113 
114   ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
115     const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM";
116     pieces.push_back(absl::StrFormat(
117         "  OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:",
118         output_index.ToString(), kind, alias.parameter_number,
119         alias.parameter_index.ToString()));
120   });
121   return absl::StrJoin(pieces, "\n");
122 }
123 
124 HloInputOutputAliasConfig::AliasKind
ParameterAliasKind(int64 param_number,const ShapeIndex & param_index) const125 HloInputOutputAliasConfig::ParameterAliasKind(
126     int64 param_number, const ShapeIndex& param_index) const {
127   AliasKind kind = AliasKind::kNoAlias;
128   alias_.ForEachElement(
129       [&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
130         if (alias && alias->parameter_number == param_number &&
131             alias->parameter_index == param_index) {
132           kind = alias->kind;
133         }
134       });
135   return kind;
136 }
137 
GetAliasedOutput(int64 param_number,const ShapeIndex & param_index) const138 absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
139     int64 param_number, const ShapeIndex& param_index) const {
140   absl::optional<ShapeIndex> output;
141   alias_.ForEachElement(
142       [&](const xla::ShapeIndex& output_index, absl::optional<Alias> alias) {
143         if (alias && alias->parameter_number == param_number &&
144             alias->parameter_index == param_index) {
145           output = output_index;
146         }
147       });
148   return output;
149 }
150 
151 absl::optional<HloInputOutputAliasConfig::Alias>
GetAliasedParameter(const ShapeIndex & output_index) const152 HloInputOutputAliasConfig::GetAliasedParameter(
153     const ShapeIndex& output_index) const {
154   CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
155       << ToString() << " " << alias_.shape().ToString() << " " << output_index;
156   return alias_.element(output_index);
157 }
158 
ForEachAlias(AliasFn fn) const159 void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
160   alias_.ForEachElement(
161       [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
162         if (aliased) {
163           fn(output_index, *aliased);
164         }
165       });
166 }
167 
ForEachAliasWithStatus(AliasFnWithStatus fn) const168 Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
169     AliasFnWithStatus fn) const {
170   return alias_.ForEachElementWithStatus(
171       [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
172         if (aliased) {
173           TF_RETURN_IF_ERROR(fn(output_index, *aliased));
174         }
175         return Status::OK();
176       });
177 }
178 
Verify(const HloModule & module,std::function<int64 (const Shape &)> size_func) const179 Status HloInputOutputAliasConfig::Verify(
180     const HloModule& module,
181     std::function<int64(const Shape&)> size_func) const {
182   std::vector<ShapeTree<bool>> param_has_seen;
183   const HloComputation* entry = module.entry_computation();
184   for (int64 i = 0; i < entry->num_parameters(); ++i) {
185     HloInstruction* param = entry->parameter_instruction(i);
186     param_has_seen.emplace_back(param->shape());
187   }
188   return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
189                                     const Alias& alias) -> Status {
190     const HloInstruction* root = entry->root_instruction();
191 
192     TF_RET_CHECK(0 <= alias.parameter_number);
193     TF_RET_CHECK(entry->num_parameters() > alias.parameter_number);
194     const Shape& param_shape =
195         entry->parameter_instruction(alias.parameter_number)->shape();
196     const Shape& output_shape = root->shape();
197     TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index));
198     TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
199 
200     const Shape& param_subshape =
201         ShapeUtil::GetSubshape(param_shape, alias.parameter_index);
202     const Shape& output_subshape =
203         ShapeUtil::GetSubshape(output_shape, output_index);
204     TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape));
205     TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape));
206 
207     if (size_func(param_subshape) != size_func(output_subshape)) {
208       return InternalError(
209           "Expected aliased input %lld at index %s and output at index %s to "
210           "have the same size. Input sub-shape is %s with size %lld, output "
211           "sub-shape is %s with size %lld",
212           alias.parameter_number, alias.parameter_index.ToString(),
213           output_index.ToString(),
214           ShapeUtil::HumanStringWithLayout(param_subshape),
215           size_func(param_subshape),
216           ShapeUtil::HumanStringWithLayout(output_subshape),
217           size_func(output_subshape));
218     }
219 
220     // Check each alias.parameter_number and alias.parameter_index pair only
221     // show up once. No input can be aliased with output buffers.
222     TF_RET_CHECK(param_has_seen[alias.parameter_number].element(
223                      alias.parameter_index) == false);
224     *(param_has_seen[alias.parameter_number].mutable_element(
225         alias.parameter_index)) = true;
226     return Status::OK();
227   });
228 }
229 
operator <<(std::ostream & out,const HloInputOutputAliasConfig & config)230 std::ostream& operator<<(std::ostream& out,
231                          const HloInputOutputAliasConfig& config) {
232   out << config.ToString();
233   return out;
234 }
235 }  // namespace xla
236