1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ 18 19 #include <utility> 20 21 #include "absl/container/flat_hash_set.h" 22 #include "absl/types/optional.h" 23 #include "tensorflow/compiler/xla/service/hlo.pb.h" 24 #include "tensorflow/compiler/xla/shape_tree.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 27 namespace xla { 28 29 class HloModule; 30 31 // This class specifies the alias map from output index to parameter number and 32 // parameter index in the entry computation. 33 class HloInputOutputAliasConfig { 34 public: 35 // The kind of aliases which can be set. A kMayAlias is one setup at 36 // compilation time by the user, and has to be respected. A kMustAlias one 37 // might be setup by the compiler, if it decides it is convenient to do so. 38 enum AliasKind { 39 kMayAlias, 40 kMustAlias, 41 }; 42 // Defines the alias information for a given output buffer. A given output 43 // buffer shape index can refer only to one parameter+index. 44 struct Alias { 45 Alias(int64 parameter_number, ShapeIndex parameter_index, 46 AliasKind kind = kMayAlias) parameter_numberAlias47 : parameter_number(parameter_number), 48 parameter_index(std::move(parameter_index)), 49 kind(kind) {} 50 51 int64 parameter_number; 52 ShapeIndex parameter_index; 53 AliasKind kind; 54 must_aliasAlias55 bool must_alias() const { return kind == kMustAlias; } 56 ToStringAlias57 std::string ToString() { 58 return absl::StrFormat("(%lld, %s, %s)", parameter_number, 59 parameter_index.ToString(), 60 kind == kMustAlias ? "must-alias" : "may-alias"); 61 } 62 }; 63 64 HloInputOutputAliasConfig() = default; 65 HloInputOutputAliasConfig(Shape output_shape)66 explicit HloInputOutputAliasConfig(Shape output_shape) 67 : alias_(std::move(output_shape)) {} 68 69 virtual ~HloInputOutputAliasConfig() = default; 70 71 // Sets up alias config from `output_index` to `param_index` at 72 // `param_number`. 73 Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, 74 const ShapeIndex& param_index, 75 AliasKind must_alias = kMayAlias); 76 77 // Returns true if the given parameter is aliased with one of the output 78 // buffers. ParameterHasAlias(int64 param_number,const ShapeIndex & param_index)79 bool ParameterHasAlias(int64 param_number, 80 const ShapeIndex& param_index) const { 81 return GetAliasedOutput(param_number, param_index).has_value(); 82 } 83 84 // Checks whether the provided output index has already been aliased. 85 bool OutputHasAlias(const ShapeIndex& output_index) const; 86 87 // (De)Serializes an HloInputOutputAliasConfig to/from an 88 // HloInputOutputAliasProto. 89 HloInputOutputAliasProto ToProto() const; 90 91 static StatusOr<HloInputOutputAliasConfig> CreateFromProto( 92 Shape output_shape, const HloInputOutputAliasProto& proto); 93 94 // Returns the output index that the given parameter and parameter index is 95 // aliased with. A nullopt is returned if there is no output that is aliased 96 // with the parameter number and index. 97 absl::optional<ShapeIndex> GetAliasedOutput( 98 int64 param_number, const ShapeIndex& param_index) const; 99 100 // Returns the number of parameter and index of the parameter buffer that the 101 // given output buffer index is aliased with. A nullopt is returned if there 102 // is no parameter is aliased with the specific output. 103 absl::optional<Alias> GetAliasedParameter( 104 const ShapeIndex& output_index) const; 105 106 // Returns if the parameter at the given parameter number and parameter 107 // index must-alias with an output. 108 bool ParameterMustAlias(int64 param_number, 109 const ShapeIndex& param_index) const; 110 111 using AliasFn = 112 std::function<void(const ShapeIndex& output_index, const Alias&)>; 113 114 // Iterates through each aliased output and input. 115 void ForEachAlias(AliasFn fn) const; 116 117 using AliasFnWithStatus = 118 std::function<Status(const ShapeIndex& output_index, const Alias&)>; 119 120 // Verifies that the given config is valid for the given module. 121 // Specifically, the config's input and output should be in-bound and size of 122 // the aliased buffers should match. 123 Status Verify(const HloModule& module, 124 std::function<int64(const Shape&)> size_func_) const; 125 126 Status ForEachAliasWithStatus(AliasFnWithStatus fn) const; 127 128 // Returns the shape of the output of the alias config. 129 const Shape& shape() const; 130 131 string ToString() const; 132 133 string ToShortString() const; 134 135 private: 136 // A ShapeTree which indicates the list of buffers that's expected to be 137 // aliased. The key on this shape tree represents the output index. The value 138 // is an Alias data structure which defines the input parameter coordinates. 139 // If the value is nullopt, it means there is no parameter aliasing for this 140 // output. 141 ShapeTree<absl::optional<Alias>> alias_; 142 }; 143 144 std::ostream& operator<<(std::ostream& out, 145 const HloInputOutputAliasConfig& config); 146 147 } // namespace xla 148 149 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ 150