1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 18 19 #include <memory> 20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 21 22 #include "absl/memory/memory.h" 23 #include "tensorflow/compiler/xla/service/shape_inference.h" 24 25 namespace xla { 26 27 // Visitor which verifies that the output shape is correctly set. Verifies 28 // against the inferred shape for the instruction. 29 // TODO(b/26024837): Check output shape for all instruction types. 30 class ShapeVerifier : public DfsHloVisitor { 31 public: ShapeVerifier(bool layout_sensitive,bool allow_mixed_precision)32 ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) 33 : layout_sensitive_(layout_sensitive), 34 allow_mixed_precision_(allow_mixed_precision) {} 35 36 // Verifies that entry computation layout matches parameters and root shape of 37 // the module's entry computation. 38 virtual Status VerifyEntryComputationLayout(const HloModule& module); 39 40 Status Preprocess(HloInstruction* hlo) override; 41 42 Status HandleElementwiseUnary(HloInstruction* hlo) override; 43 Status HandleElementwiseBinary(HloInstruction* hlo) override; 44 Status HandleClamp(HloInstruction* clamp) override; 45 Status HandleSelect(HloInstruction* select) override; 46 Status HandleTupleSelect(HloInstruction* tuple_select) override; 47 Status HandleConcatenate(HloInstruction* concatenate) override; 48 Status HandleIota(HloInstruction* iota) override; 49 Status HandleConvert(HloInstruction* convert) override; 50 Status HandleBitcastConvert(HloInstruction* convert) override; 51 Status HandleCopy(HloInstruction* copy) override; 52 Status HandleDot(HloInstruction* dot) override; 53 Status HandleConvolution(HloInstruction* convolution) override; 54 Status HandleFft(HloInstruction* fft) override; 55 Status HandleCholesky(HloInstruction* hlo) override; 56 Status HandleTriangularSolve(HloInstruction* hlo) override; 57 Status HandleAllReduce(HloInstruction* crs) override; 58 Status HandleAllToAll(HloInstruction* hlo) override; 59 Status HandleCollectivePermute(HloInstruction* hlo) override; 60 Status HandleReplicaId(HloInstruction* hlo) override; 61 Status HandleReducePrecision(HloInstruction* reduce_precision) override; 62 Status HandleInfeed(HloInstruction*) override; 63 Status HandleOutfeed(HloInstruction*) override; 64 Status HandleRng(HloInstruction*) override; 65 Status HandleReverse(HloInstruction* reverse) override; 66 Status HandleSort(HloInstruction* sort) override; 67 Status HandleConstant(HloInstruction* constant) override; 68 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 69 Status HandleReduce(HloInstruction* reduce) override; 70 Status HandleBitcast(HloInstruction* bitcast) override; 71 Status HandleBroadcast(HloInstruction* broadcast) override; 72 Status HandleReshape(HloInstruction* reshape) override; 73 Status HandleTranspose(HloInstruction* transpose) override; 74 Status HandleParameter(HloInstruction*) override; 75 Status HandleFusion(HloInstruction*) override; 76 Status HandleCall(HloInstruction* call) override; 77 Status HandleCustomCall(HloInstruction*) override; 78 Status HandleSlice(HloInstruction* slice) override; 79 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 80 Status HandleDynamicUpdateSlice( 81 HloInstruction* dynamic_update_slice) override; 82 Status HandleTuple(HloInstruction* tuple) override; 83 Status HandleMap(HloInstruction* map) override; 84 Status HandleReduceWindow(HloInstruction* reduce_window) override; 85 Status HandleSelectAndScatter(HloInstruction* instruction) override; 86 Status HandleWhile(HloInstruction* xla_while) override; 87 Status HandleConditional(HloInstruction* conditional) override; 88 Status HandlePad(HloInstruction* pad) override; 89 Status HandleSend(HloInstruction* send) override; 90 Status HandleSendDone(HloInstruction* send_done) override; 91 Status HandleRecv(HloInstruction* recv) override; 92 Status HandleRecvDone(HloInstruction* recv_done) override; 93 Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; 94 Status HandleBatchNormInference( 95 HloInstruction* batch_norm_inference) override; 96 Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; 97 Status HandleGather(HloInstruction* gather) override; 98 Status HandleScatter(HloInstruction* scatter) override; 99 Status HandleAfterAll(HloInstruction* token) override; 100 Status HandleGetDimensionSize(HloInstruction* get_size) override; 101 Status HandleAddDependency(HloInstruction* add_dependency) override; 102 FinishVisit(HloInstruction *)103 Status FinishVisit(HloInstruction*) override { return Status::OK(); } 104 105 protected: 106 // Check the instruction's shape against the shape given by ShapeInference 107 // and return an appropriate error if there is a mismatch. 108 Status CheckShape(const HloInstruction* instruction, 109 const Shape& inferred_shape); 110 111 // Overload which takes a StatusOr to reduce boilerplate in the caller. 112 Status CheckShape(const HloInstruction* instruction, 113 const StatusOr<Shape>& inferred_shape_status); 114 115 // Check a unary (binary, etc) instruction's shape against the inferred shape. 116 Status CheckUnaryShape(const HloInstruction* instruction); 117 Status CheckBinaryShape(const HloInstruction* instruction); 118 Status CheckTernaryShape(const HloInstruction* instruction); 119 Status CheckVariadicShape(const HloInstruction* instruction); 120 121 private: 122 // Helpers that switch on layout_sensitive_. ShapesSame(const Shape & a,const Shape & b)123 bool ShapesSame(const Shape& a, const Shape& b) { 124 return layout_sensitive_ ? ShapeUtil::Equal(a, b) 125 : ShapeUtil::Compatible(a, b); 126 } ShapesSameIgnoringFpPrecision(const Shape & a,const Shape & b)127 bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { 128 return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) 129 : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); 130 } StringifyShape(const Shape & s)131 string StringifyShape(const Shape& s) { 132 return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) 133 : ShapeUtil::HumanString(s); 134 } 135 136 // Helpers that switch on allow_mixed_precision_. SameElementType(const Shape & a,const Shape & b)137 bool SameElementType(const Shape& a, const Shape& b) { 138 return allow_mixed_precision_ 139 ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b) 140 : ShapeUtil::SameElementType(a, b); 141 } 142 143 // Checks that the given operand of the given instruction is of type TOKEN. 144 Status CheckIsTokenOperand(const HloInstruction* instruction, 145 int64 operand_no); 146 147 // Checks that the shape of the given operand of the given instruction matches 148 // the given parameter of the given computation. 149 Status CheckOperandAndParameter(const HloInstruction* instruction, 150 int64 operand_number, 151 const HloComputation* computation, 152 int64 parameter_number); 153 154 // Returns true if the shapes of the two operands have the same element type, 155 // and the result shape either has the same element type as the operand shapes 156 // or mixed precision is allowed and the result shape and the operand shapes 157 // have floating point element types. 158 bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, 159 const Shape& result_shape); 160 161 // If the verifier is layout-sensitive, shapes must be equal to what's 162 // expected. Otherwise, the shapes must simply be compatible. 163 bool layout_sensitive_; 164 165 // Whether the inputs and output of an instruction can contain both F32s and 166 // BF16s. Tuples that include both F32s and BF16s are allowed regardless of 167 // this flag. 168 bool allow_mixed_precision_; 169 }; 170 171 // An interface used to encapsulate target-specific verification quirks. 172 class TargetVerifierMetadata { 173 public: TargetVerifierMetadata(std::function<int64 (const Shape &)> shape_size_function)174 TargetVerifierMetadata(std::function<int64(const Shape&)> shape_size_function) 175 : shape_size_function_(shape_size_function) {} 176 177 // Returns a target-specific shape size. ShapeSize(const Shape & shape)178 int64 ShapeSize(const Shape& shape) const { 179 return shape_size_function_(shape); 180 } 181 182 virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0; 183 TargetVerifierMetadata()184 TargetVerifierMetadata() {} ~TargetVerifierMetadata()185 virtual ~TargetVerifierMetadata() {} 186 187 TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; 188 TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; 189 190 private: 191 // Returns a target-specific shape size. 192 std::function<int64(const Shape&)> shape_size_function_; 193 }; 194 195 // The default implementation of TargetVerifierMetadata, used unless the target 196 // needs to override it. 197 class DefaultVerifierMetadata : public TargetVerifierMetadata { 198 public: DefaultVerifierMetadata(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)199 DefaultVerifierMetadata( 200 bool layout_sensitive, bool allow_mixed_precision, 201 std::function<int64(const Shape&)> shape_size_function) 202 : TargetVerifierMetadata(shape_size_function), 203 layout_sensitive_(layout_sensitive), 204 allow_mixed_precision_(allow_mixed_precision) {} 205 206 // Creates a ShapeVerifier that checks that shapes match inferred 207 // expectations. This creates a new verifier every time because ShapeVerifier, 208 // being a DfsHloVisitor, is stateful. We want a clean object for each run of 209 // the verifier. GetVerifier()210 std::unique_ptr<ShapeVerifier> GetVerifier() const override { 211 return absl::make_unique<ShapeVerifier>(layout_sensitive_, 212 allow_mixed_precision_); 213 } 214 215 private: 216 bool layout_sensitive_; 217 bool allow_mixed_precision_; 218 }; 219 220 // HLO pass that verifies invariants of HLO instructions for each computation in 221 // the module. 222 class HloVerifier : public HloModulePass { 223 public: 224 explicit HloVerifier( 225 bool layout_sensitive, bool allow_mixed_precision, 226 std::function<bool(const HloInstruction*)> 227 instruction_can_change_layout_func = {}, 228 std::function<int64(const Shape&)> shape_size_func = 229 [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) target_metadata_(absl::make_unique<DefaultVerifierMetadata> (layout_sensitive,allow_mixed_precision,shape_size_func))230 : target_metadata_(absl::make_unique<DefaultVerifierMetadata>( 231 layout_sensitive, allow_mixed_precision, shape_size_func)), 232 instruction_can_change_layout_func_( 233 std::move(instruction_can_change_layout_func)) { 234 CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); 235 } 236 237 // Uses custom target metadata HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)238 explicit HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata) 239 : target_metadata_(std::move(target_metadata)) {} 240 241 ~HloVerifier() override = default; name()242 absl::string_view name() const override { return "verifier"; } 243 244 // Never returns true; no instructions are ever modified by this pass. 245 StatusOr<bool> Run(HloModule* module) override; 246 247 private: 248 std::unique_ptr<TargetVerifierMetadata> target_metadata_; 249 250 // Determines whether an instruction can change layouts. 251 std::function<bool(const HloInstruction*)> 252 instruction_can_change_layout_func_; 253 }; 254 255 } // namespace xla 256 257 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ 258