• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
18 
19 #include <memory>
20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/service/shape_inference.h"
24 
25 namespace xla {
26 
27 // Visitor which verifies that the output shape is correctly set. Verifies
28 // against the inferred shape for the instruction.
29 // TODO(b/26024837): Check output shape for all instruction types.
30 class ShapeVerifier : public DfsHloVisitor {
31  public:
ShapeVerifier(bool layout_sensitive,bool allow_mixed_precision)32   ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
33       : layout_sensitive_(layout_sensitive),
34         allow_mixed_precision_(allow_mixed_precision) {}
35 
36   // Verifies that entry computation layout matches parameters and root shape of
37   // the module's entry computation.
38   virtual Status VerifyEntryComputationLayout(const HloModule& module);
39 
40   Status Preprocess(HloInstruction* hlo) override;
41 
42   Status HandleElementwiseUnary(HloInstruction* hlo) override;
43   Status HandleElementwiseBinary(HloInstruction* hlo) override;
44   Status HandleClamp(HloInstruction* clamp) override;
45   Status HandleSelect(HloInstruction* select) override;
46   Status HandleTupleSelect(HloInstruction* tuple_select) override;
47   Status HandleConcatenate(HloInstruction* concatenate) override;
48   Status HandleIota(HloInstruction* iota) override;
49   Status HandleConvert(HloInstruction* convert) override;
50   Status HandleBitcastConvert(HloInstruction* convert) override;
51   Status HandleCopy(HloInstruction* copy) override;
52   Status HandleDot(HloInstruction* dot) override;
53   Status HandleConvolution(HloInstruction* convolution) override;
54   Status HandleFft(HloInstruction* fft) override;
55   Status HandleCholesky(HloInstruction* hlo) override;
56   Status HandleTriangularSolve(HloInstruction* hlo) override;
57   Status HandleAllReduce(HloInstruction* crs) override;
58   Status HandleAllToAll(HloInstruction* hlo) override;
59   Status HandleCollectivePermute(HloInstruction* hlo) override;
60   Status HandleReplicaId(HloInstruction* hlo) override;
61   Status HandleReducePrecision(HloInstruction* reduce_precision) override;
62   Status HandleInfeed(HloInstruction*) override;
63   Status HandleOutfeed(HloInstruction*) override;
64   Status HandleRng(HloInstruction*) override;
65   Status HandleReverse(HloInstruction* reverse) override;
66   Status HandleSort(HloInstruction* sort) override;
67   Status HandleConstant(HloInstruction* constant) override;
68   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
69   Status HandleReduce(HloInstruction* reduce) override;
70   Status HandleBitcast(HloInstruction* bitcast) override;
71   Status HandleBroadcast(HloInstruction* broadcast) override;
72   Status HandleReshape(HloInstruction* reshape) override;
73   Status HandleTranspose(HloInstruction* transpose) override;
74   Status HandleParameter(HloInstruction*) override;
75   Status HandleFusion(HloInstruction*) override;
76   Status HandleCall(HloInstruction* call) override;
77   Status HandleCustomCall(HloInstruction*) override;
78   Status HandleSlice(HloInstruction* slice) override;
79   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
80   Status HandleDynamicUpdateSlice(
81       HloInstruction* dynamic_update_slice) override;
82   Status HandleTuple(HloInstruction* tuple) override;
83   Status HandleMap(HloInstruction* map) override;
84   Status HandleReduceWindow(HloInstruction* reduce_window) override;
85   Status HandleSelectAndScatter(HloInstruction* instruction) override;
86   Status HandleWhile(HloInstruction* xla_while) override;
87   Status HandleConditional(HloInstruction* conditional) override;
88   Status HandlePad(HloInstruction* pad) override;
89   Status HandleSend(HloInstruction* send) override;
90   Status HandleSendDone(HloInstruction* send_done) override;
91   Status HandleRecv(HloInstruction* recv) override;
92   Status HandleRecvDone(HloInstruction* recv_done) override;
93   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
94   Status HandleBatchNormInference(
95       HloInstruction* batch_norm_inference) override;
96   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
97   Status HandleGather(HloInstruction* gather) override;
98   Status HandleScatter(HloInstruction* scatter) override;
99   Status HandleAfterAll(HloInstruction* token) override;
100   Status HandleGetDimensionSize(HloInstruction* get_size) override;
101   Status HandleAddDependency(HloInstruction* add_dependency) override;
102 
FinishVisit(HloInstruction *)103   Status FinishVisit(HloInstruction*) override { return Status::OK(); }
104 
105  protected:
106   // Check the instruction's shape against the shape given by ShapeInference
107   // and return an appropriate error if there is a mismatch.
108   Status CheckShape(const HloInstruction* instruction,
109                     const Shape& inferred_shape);
110 
111   // Overload which takes a StatusOr to reduce boilerplate in the caller.
112   Status CheckShape(const HloInstruction* instruction,
113                     const StatusOr<Shape>& inferred_shape_status);
114 
115   // Check a unary (binary, etc) instruction's shape against the inferred shape.
116   Status CheckUnaryShape(const HloInstruction* instruction);
117   Status CheckBinaryShape(const HloInstruction* instruction);
118   Status CheckTernaryShape(const HloInstruction* instruction);
119   Status CheckVariadicShape(const HloInstruction* instruction);
120 
121  private:
122   // Helpers that switch on layout_sensitive_.
ShapesSame(const Shape & a,const Shape & b)123   bool ShapesSame(const Shape& a, const Shape& b) {
124     return layout_sensitive_ ? ShapeUtil::Equal(a, b)
125                              : ShapeUtil::Compatible(a, b);
126   }
ShapesSameIgnoringFpPrecision(const Shape & a,const Shape & b)127   bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) {
128     return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b)
129                              : ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
130   }
StringifyShape(const Shape & s)131   string StringifyShape(const Shape& s) {
132     return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
133                              : ShapeUtil::HumanString(s);
134   }
135 
136   // Helpers that switch on allow_mixed_precision_.
SameElementType(const Shape & a,const Shape & b)137   bool SameElementType(const Shape& a, const Shape& b) {
138     return allow_mixed_precision_
139                ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b)
140                : ShapeUtil::SameElementType(a, b);
141   }
142 
143   // Checks that the given operand of the given instruction is of type TOKEN.
144   Status CheckIsTokenOperand(const HloInstruction* instruction,
145                              int64 operand_no);
146 
147   // Checks that the shape of the given operand of the given instruction matches
148   // the given parameter of the given computation.
149   Status CheckOperandAndParameter(const HloInstruction* instruction,
150                                   int64 operand_number,
151                                   const HloComputation* computation,
152                                   int64 parameter_number);
153 
154   // Returns true if the shapes of the two operands have the same element type,
155   // and the result shape either has the same element type as the operand shapes
156   // or mixed precision is allowed and the result shape and the operand shapes
157   // have floating point element types.
158   bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
159                                  const Shape& result_shape);
160 
161   // If the verifier is layout-sensitive, shapes must be equal to what's
162   // expected.  Otherwise, the shapes must simply be compatible.
163   bool layout_sensitive_;
164 
165   // Whether the inputs and output of an instruction can contain both F32s and
166   // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
167   // this flag.
168   bool allow_mixed_precision_;
169 };
170 
171 // An interface used to encapsulate target-specific verification quirks.
172 class TargetVerifierMetadata {
173  public:
TargetVerifierMetadata(std::function<int64 (const Shape &)> shape_size_function)174   TargetVerifierMetadata(std::function<int64(const Shape&)> shape_size_function)
175       : shape_size_function_(shape_size_function) {}
176 
177   // Returns a target-specific shape size.
ShapeSize(const Shape & shape)178   int64 ShapeSize(const Shape& shape) const {
179     return shape_size_function_(shape);
180   }
181 
182   virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
183 
TargetVerifierMetadata()184   TargetVerifierMetadata() {}
~TargetVerifierMetadata()185   virtual ~TargetVerifierMetadata() {}
186 
187   TargetVerifierMetadata(const TargetVerifierMetadata&) = delete;
188   TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete;
189 
190  private:
191   // Returns a target-specific shape size.
192   std::function<int64(const Shape&)> shape_size_function_;
193 };
194 
195 // The default implementation of TargetVerifierMetadata, used unless the target
196 // needs to override it.
197 class DefaultVerifierMetadata : public TargetVerifierMetadata {
198  public:
DefaultVerifierMetadata(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)199   DefaultVerifierMetadata(
200       bool layout_sensitive, bool allow_mixed_precision,
201       std::function<int64(const Shape&)> shape_size_function)
202       : TargetVerifierMetadata(shape_size_function),
203         layout_sensitive_(layout_sensitive),
204         allow_mixed_precision_(allow_mixed_precision) {}
205 
206   // Creates a ShapeVerifier that checks that shapes match inferred
207   // expectations. This creates a new verifier every time because ShapeVerifier,
208   // being a DfsHloVisitor, is stateful. We want a clean object for each run of
209   // the verifier.
GetVerifier()210   std::unique_ptr<ShapeVerifier> GetVerifier() const override {
211     return absl::make_unique<ShapeVerifier>(layout_sensitive_,
212                                             allow_mixed_precision_);
213   }
214 
215  private:
216   bool layout_sensitive_;
217   bool allow_mixed_precision_;
218 };
219 
220 // HLO pass that verifies invariants of HLO instructions for each computation in
221 // the module.
222 class HloVerifier : public HloModulePass {
223  public:
224   explicit HloVerifier(
225       bool layout_sensitive, bool allow_mixed_precision,
226       std::function<bool(const HloInstruction*)>
227           instruction_can_change_layout_func = {},
228       std::function<int64(const Shape&)> shape_size_func =
229           [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); })
target_metadata_(absl::make_unique<DefaultVerifierMetadata> (layout_sensitive,allow_mixed_precision,shape_size_func))230       : target_metadata_(absl::make_unique<DefaultVerifierMetadata>(
231             layout_sensitive, allow_mixed_precision, shape_size_func)),
232         instruction_can_change_layout_func_(
233             std::move(instruction_can_change_layout_func)) {
234     CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive);
235   }
236 
237   // Uses custom target metadata
HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)238   explicit HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)
239       : target_metadata_(std::move(target_metadata)) {}
240 
241   ~HloVerifier() override = default;
name()242   absl::string_view name() const override { return "verifier"; }
243 
244   // Never returns true; no instructions are ever modified by this pass.
245   StatusOr<bool> Run(HloModule* module) override;
246 
247  private:
248   std::unique_ptr<TargetVerifierMetadata> target_metadata_;
249 
250   // Determines whether an instruction can change layouts.
251   std::function<bool(const HloInstruction*)>
252       instruction_can_change_layout_func_;
253 };
254 
255 }  // namespace xla
256 
257 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
258