1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 18 19 #include <memory> 20 21 #include "absl/memory/memory.h" 22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 23 #include "tensorflow/compiler/xla/service/shape_inference.h" 24 25 namespace xla { 26 27 // Visitor which verifies that the output shape is correctly set. Verifies 28 // against the inferred shape for the instruction. 29 // TODO(b/26024837): Check output shape for all instruction types. 30 class ShapeVerifier : public DfsHloVisitor { 31 public: ShapeVerifier(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)32 ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision, 33 std::function<int64(const Shape&)> shape_size_function) 34 : layout_sensitive_(layout_sensitive), 35 allow_mixed_precision_(allow_mixed_precision), 36 shape_size_function_(shape_size_function) {} 37 38 // Verifies that entry computation layout matches parameters and root shape of 39 // the module's entry computation. 40 virtual Status VerifyEntryComputationLayout(const HloModule& module); 41 42 Status Preprocess(HloInstruction* hlo) override; 43 44 Status HandleElementwiseUnary(HloInstruction* hlo) override; 45 Status HandleElementwiseBinary(HloInstruction* hlo) override; 46 Status HandleClamp(HloInstruction* clamp) override; 47 Status HandleSelect(HloInstruction* select) override; 48 Status HandleTupleSelect(HloInstruction* tuple_select) override; 49 Status HandleConcatenate(HloInstruction* concatenate) override; 50 Status HandleIota(HloInstruction* iota) override; 51 Status HandleConvert(HloInstruction* convert) override; 52 Status HandleBitcastConvert(HloInstruction* convert) override; 53 Status HandleCopy(HloInstruction* copy) override; 54 Status HandleDot(HloInstruction* dot) override; 55 Status HandleConvolution(HloInstruction* convolution) override; 56 Status HandleFft(HloInstruction* fft) override; 57 Status HandleCholesky(HloInstruction* hlo) override; 58 Status HandleTriangularSolve(HloInstruction* hlo) override; 59 Status HandleAllReduce(HloInstruction* crs) override; 60 Status HandleAllToAll(HloInstruction* hlo) override; 61 Status HandleCollectivePermute(HloInstruction* hlo) override; 62 Status HandlePartitionId(HloInstruction* hlo) override; 63 Status HandleReplicaId(HloInstruction* hlo) override; 64 Status HandleReducePrecision(HloInstruction* reduce_precision) override; 65 Status HandleInfeed(HloInstruction*) override; 66 Status HandleOutfeed(HloInstruction*) override; 67 Status HandleRng(HloInstruction*) override; 68 Status HandleRngGetAndUpdateState(HloInstruction*) override; 69 Status HandleReverse(HloInstruction* reverse) override; 70 Status HandleSort(HloInstruction* sort) override; 71 Status HandleConstant(HloInstruction* constant) override; 72 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 73 Status HandleReduce(HloInstruction* reduce) override; 74 Status HandleBitcast(HloInstruction* bitcast) override; 75 Status HandleBroadcast(HloInstruction* broadcast) override; 76 Status HandleReshape(HloInstruction* reshape) override; 77 Status HandleTranspose(HloInstruction* transpose) override; 78 Status HandleParameter(HloInstruction*) override; 79 Status HandleFusion(HloInstruction*) override; 80 Status HandleCall(HloInstruction* call) override; 81 Status HandleCustomCall(HloInstruction*) override; 82 Status HandleSlice(HloInstruction* slice) override; 83 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 84 Status HandleDynamicUpdateSlice( 85 HloInstruction* dynamic_update_slice) override; 86 Status HandleTuple(HloInstruction* tuple) override; 87 Status HandleMap(HloInstruction* map) override; 88 Status HandleReduceWindow(HloInstruction* reduce_window) override; 89 Status HandleSelectAndScatter(HloInstruction* instruction) override; 90 Status HandleWhile(HloInstruction* xla_while) override; 91 Status HandleConditional(HloInstruction* conditional) override; 92 Status HandlePad(HloInstruction* pad) override; 93 Status HandleCopyStart(HloInstruction* copy_start) override; 94 Status HandleCopyDone(HloInstruction* copy_done) override; 95 Status HandleSend(HloInstruction* send) override; 96 Status HandleSendDone(HloInstruction* send_done) override; 97 Status HandleRecv(HloInstruction* recv) override; 98 Status HandleRecvDone(HloInstruction* recv_done) override; 99 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; 100 Status HandleBatchNormInference( 101 HloInstruction* batch_norm_inference) override; 102 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; 103 Status HandleGather(HloInstruction* gather) override; 104 Status HandleScatter(HloInstruction* scatter) override; 105 Status HandleAfterAll(HloInstruction* token) override; 106 Status HandleGetDimensionSize(HloInstruction* get_size) override; 107 Status HandleSetDimensionSize(HloInstruction* set_size) override; 108 Status HandleAddDependency(HloInstruction* add_dependency) override; 109 FinishVisit(HloInstruction *)110 Status FinishVisit(HloInstruction*) override { return Status::OK(); } 111 112 protected: 113 // Check the instruction's shape against the shape given by ShapeInference 114 // and return an appropriate error if there is a mismatch. 115 Status CheckShape(const HloInstruction* instruction, 116 const Shape& inferred_shape, 117 bool only_compare_minor_to_major_in_layout = false); 118 119 // Overload which takes a StatusOr to reduce boilerplate in the caller. 120 Status CheckShape(const HloInstruction* instruction, 121 const StatusOr<Shape>& inferred_shape_status); 122 123 // Check a unary (binary, etc) instruction's shape against the inferred shape. 124 Status CheckUnaryShape(const HloInstruction* instruction); 125 Status CheckBinaryShape(const HloInstruction* instruction); 126 Status CheckTernaryShape(const HloInstruction* instruction); 127 Status CheckVariadicShape(const HloInstruction* instruction); 128 129 private: 130 // Helpers that switch on layout_sensitive_. 131 bool ShapesSame(const Shape& a, const Shape& b, 132 bool minor_to_major_only = false, 133 bool ignore_memory_space = false) { 134 if (!layout_sensitive_) { 135 return ShapeUtil::Compatible(a, b); 136 } 137 Shape::Equal equal; 138 if (ignore_memory_space) { 139 equal.IgnoreMemorySpaceInLayout(); 140 } 141 if (minor_to_major_only) { 142 equal.MinorToMajorOnlyInLayout(); 143 } 144 return equal(a, b); 145 } 146 147 bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b, 148 bool minor_to_major_only = false) { 149 if (!layout_sensitive_) { 150 return ShapeUtil::CompatibleIgnoringFpPrecision(a, b); 151 } 152 Shape::Equal equal; 153 if (minor_to_major_only) { 154 equal.MinorToMajorOnlyInLayout(); 155 } 156 equal.IgnoreFpPrecision(); 157 return equal(a, b); 158 } 159 StringifyShape(const Shape & s)160 string StringifyShape(const Shape& s) { 161 return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) 162 : ShapeUtil::HumanString(s); 163 } 164 165 // Helpers that switch on allow_mixed_precision_. SameElementType(const Shape & a,const Shape & b)166 bool SameElementType(const Shape& a, const Shape& b) { 167 return allow_mixed_precision_ 168 ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b) 169 : ShapeUtil::SameElementType(a, b); 170 } 171 172 // Checks that the given operand of the given instruction is of type TOKEN. 173 Status CheckIsTokenOperand(const HloInstruction* instruction, 174 int64 operand_no); 175 176 // Checks that the shape of the given operand of the given instruction matches 177 // the given parameter of the given computation. 178 Status CheckOperandAndParameter(const HloInstruction* instruction, 179 int64 operand_number, 180 const HloComputation* computation, 181 int64 parameter_number); 182 183 // Returns true if the shapes of the two operands have the same element type, 184 // and the result shape either has the same element type as the operand shapes 185 // or mixed precision is allowed and the result shape and the operand shapes 186 // have floating point element types. 187 bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, 188 const Shape& result_shape); 189 190 // If the verifier is layout-sensitive, shapes must be equal to what's 191 // expected. Otherwise, the shapes must simply be compatible. 192 bool layout_sensitive_; 193 194 // Whether the inputs and output of an instruction can contain both F32s and 195 // BF16s. Tuples that include both F32s and BF16s are allowed regardless of 196 // this flag. 197 bool allow_mixed_precision_; 198 199 // Returns a target-specific shape size. 200 std::function<int64(const Shape&)> shape_size_function_; 201 }; 202 203 // An interface used to encapsulate target-specific verification quirks. 204 class TargetVerifierMetadata { 205 public: TargetVerifierMetadata(std::function<int64 (const Shape &)> shape_size_function)206 TargetVerifierMetadata(std::function<int64(const Shape&)> shape_size_function) 207 : shape_size_function_(shape_size_function) {} 208 209 // Returns a target-specific shape size. ShapeSize(const Shape & shape)210 int64 ShapeSize(const Shape& shape) const { 211 return shape_size_function_(shape); 212 } 213 214 virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0; 215 TargetVerifierMetadata()216 TargetVerifierMetadata() {} ~TargetVerifierMetadata()217 virtual ~TargetVerifierMetadata() {} 218 219 TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; 220 TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; 221 222 protected: 223 // Returns a target-specific shape size. 224 std::function<int64(const Shape&)> shape_size_function_; 225 }; 226 227 // The default implementation of TargetVerifierMetadata, used unless the target 228 // needs to override it. 229 class DefaultVerifierMetadata : public TargetVerifierMetadata { 230 public: DefaultVerifierMetadata(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)231 DefaultVerifierMetadata( 232 bool layout_sensitive, bool allow_mixed_precision, 233 std::function<int64(const Shape&)> shape_size_function) 234 : TargetVerifierMetadata(shape_size_function), 235 layout_sensitive_(layout_sensitive), 236 allow_mixed_precision_(allow_mixed_precision) {} 237 238 // Creates a ShapeVerifier that checks that shapes match inferred 239 // expectations. This creates a new verifier every time because ShapeVerifier, 240 // being a DfsHloVisitor, is stateful. We want a clean object for each run of 241 // the verifier. GetVerifier()242 std::unique_ptr<ShapeVerifier> GetVerifier() const override { 243 return absl::make_unique<ShapeVerifier>( 244 layout_sensitive_, allow_mixed_precision_, shape_size_function_); 245 } 246 247 private: 248 bool layout_sensitive_; 249 bool allow_mixed_precision_; 250 }; 251 252 // HLO pass that verifies invariants of HLO instructions for each computation in 253 // the module. 254 class HloVerifier : public HloModulePass { 255 public: 256 explicit HloVerifier( 257 bool layout_sensitive, bool allow_mixed_precision, 258 std::function<bool(const HloInstruction*)> 259 instruction_can_change_layout_func = {}, 260 std::function<int64(const Shape&)> shape_size_func = 261 [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) target_metadata_(absl::make_unique<DefaultVerifierMetadata> (layout_sensitive,allow_mixed_precision,shape_size_func))262 : target_metadata_(absl::make_unique<DefaultVerifierMetadata>( 263 layout_sensitive, allow_mixed_precision, shape_size_func)), 264 instruction_can_change_layout_func_( 265 std::move(instruction_can_change_layout_func)) { 266 CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); 267 } 268 269 // Uses custom target metadata HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)270 explicit HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata) 271 : target_metadata_(std::move(target_metadata)) {} 272 273 ~HloVerifier() override = default; name()274 absl::string_view name() const override { return "verifier"; } 275 276 // Never returns true; no instructions are ever modified by this pass. 277 StatusOr<bool> Run(HloModule* module) override; 278 279 private: 280 std::unique_ptr<TargetVerifierMetadata> target_metadata_; 281 282 // Determines whether an instruction can change layouts. 283 std::function<bool(const HloInstruction*)> 284 instruction_can_change_layout_func_; 285 }; 286 287 } // namespace xla 288 289 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 290