• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
18 
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 #include "tensorflow/compiler/xla/service/shape_inference.h"
24 
25 namespace xla {
26 
27 // Visitor which verifies that the output shape is correctly set. Verifies
28 // against the inferred shape for the instruction.
29 // TODO(b/26024837): Check output shape for all instruction types.
30 class ShapeVerifier : public DfsHloVisitor {
31  public:
ShapeVerifier(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)32   ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision,
33                 std::function<int64(const Shape&)> shape_size_function)
34       : layout_sensitive_(layout_sensitive),
35         allow_mixed_precision_(allow_mixed_precision),
36         shape_size_function_(shape_size_function) {}
37 
38   // Verifies that entry computation layout matches parameters and root shape of
39   // the module's entry computation.
40   virtual Status VerifyEntryComputationLayout(const HloModule& module);
41 
42   Status Preprocess(HloInstruction* hlo) override;
43 
44   Status HandleElementwiseUnary(HloInstruction* hlo) override;
45   Status HandleElementwiseBinary(HloInstruction* hlo) override;
46   Status HandleClamp(HloInstruction* clamp) override;
47   Status HandleSelect(HloInstruction* select) override;
48   Status HandleTupleSelect(HloInstruction* tuple_select) override;
49   Status HandleConcatenate(HloInstruction* concatenate) override;
50   Status HandleIota(HloInstruction* iota) override;
51   Status HandleConvert(HloInstruction* convert) override;
52   Status HandleBitcastConvert(HloInstruction* convert) override;
53   Status HandleCopy(HloInstruction* copy) override;
54   Status HandleDot(HloInstruction* dot) override;
55   Status HandleConvolution(HloInstruction* convolution) override;
56   Status HandleFft(HloInstruction* fft) override;
57   Status HandleCholesky(HloInstruction* hlo) override;
58   Status HandleTriangularSolve(HloInstruction* hlo) override;
59   Status HandleAllReduce(HloInstruction* crs) override;
60   Status HandleAllToAll(HloInstruction* hlo) override;
61   Status HandleCollectivePermute(HloInstruction* hlo) override;
62   Status HandlePartitionId(HloInstruction* hlo) override;
63   Status HandleReplicaId(HloInstruction* hlo) override;
64   Status HandleReducePrecision(HloInstruction* reduce_precision) override;
65   Status HandleInfeed(HloInstruction*) override;
66   Status HandleOutfeed(HloInstruction*) override;
67   Status HandleRng(HloInstruction*) override;
68   Status HandleRngGetAndUpdateState(HloInstruction*) override;
69   Status HandleReverse(HloInstruction* reverse) override;
70   Status HandleSort(HloInstruction* sort) override;
71   Status HandleConstant(HloInstruction* constant) override;
72   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
73   Status HandleReduce(HloInstruction* reduce) override;
74   Status HandleBitcast(HloInstruction* bitcast) override;
75   Status HandleBroadcast(HloInstruction* broadcast) override;
76   Status HandleReshape(HloInstruction* reshape) override;
77   Status HandleTranspose(HloInstruction* transpose) override;
78   Status HandleParameter(HloInstruction*) override;
79   Status HandleFusion(HloInstruction*) override;
80   Status HandleCall(HloInstruction* call) override;
81   Status HandleCustomCall(HloInstruction*) override;
82   Status HandleSlice(HloInstruction* slice) override;
83   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
84   Status HandleDynamicUpdateSlice(
85       HloInstruction* dynamic_update_slice) override;
86   Status HandleTuple(HloInstruction* tuple) override;
87   Status HandleMap(HloInstruction* map) override;
88   Status HandleReduceWindow(HloInstruction* reduce_window) override;
89   Status HandleSelectAndScatter(HloInstruction* instruction) override;
90   Status HandleWhile(HloInstruction* xla_while) override;
91   Status HandleConditional(HloInstruction* conditional) override;
92   Status HandlePad(HloInstruction* pad) override;
93   Status HandleCopyStart(HloInstruction* copy_start) override;
94   Status HandleCopyDone(HloInstruction* copy_done) override;
95   Status HandleSend(HloInstruction* send) override;
96   Status HandleSendDone(HloInstruction* send_done) override;
97   Status HandleRecv(HloInstruction* recv) override;
98   Status HandleRecvDone(HloInstruction* recv_done) override;
99   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override;
100   Status HandleBatchNormInference(
101       HloInstruction* batch_norm_inference) override;
102   Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
103   Status HandleGather(HloInstruction* gather) override;
104   Status HandleScatter(HloInstruction* scatter) override;
105   Status HandleAfterAll(HloInstruction* token) override;
106   Status HandleGetDimensionSize(HloInstruction* get_size) override;
107   Status HandleSetDimensionSize(HloInstruction* set_size) override;
108   Status HandleAddDependency(HloInstruction* add_dependency) override;
109 
FinishVisit(HloInstruction *)110   Status FinishVisit(HloInstruction*) override { return Status::OK(); }
111 
112  protected:
113   // Check the instruction's shape against the shape given by ShapeInference
114   // and return an appropriate error if there is a mismatch.
115   Status CheckShape(const HloInstruction* instruction,
116                     const Shape& inferred_shape,
117                     bool only_compare_minor_to_major_in_layout = false);
118 
119   // Overload which takes a StatusOr to reduce boilerplate in the caller.
120   Status CheckShape(const HloInstruction* instruction,
121                     const StatusOr<Shape>& inferred_shape_status);
122 
123   // Check a unary (binary, etc) instruction's shape against the inferred shape.
124   Status CheckUnaryShape(const HloInstruction* instruction);
125   Status CheckBinaryShape(const HloInstruction* instruction);
126   Status CheckTernaryShape(const HloInstruction* instruction);
127   Status CheckVariadicShape(const HloInstruction* instruction);
128 
129  private:
130   // Helpers that switch on layout_sensitive_.
131   bool ShapesSame(const Shape& a, const Shape& b,
132                   bool minor_to_major_only = false,
133                   bool ignore_memory_space = false) {
134     if (!layout_sensitive_) {
135       return ShapeUtil::Compatible(a, b);
136     }
137     Shape::Equal equal;
138     if (ignore_memory_space) {
139       equal.IgnoreMemorySpaceInLayout();
140     }
141     if (minor_to_major_only) {
142       equal.MinorToMajorOnlyInLayout();
143     }
144     return equal(a, b);
145   }
146 
147   bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b,
148                                      bool minor_to_major_only = false) {
149     if (!layout_sensitive_) {
150       return ShapeUtil::CompatibleIgnoringFpPrecision(a, b);
151     }
152     Shape::Equal equal;
153     if (minor_to_major_only) {
154       equal.MinorToMajorOnlyInLayout();
155     }
156     equal.IgnoreFpPrecision();
157     return equal(a, b);
158   }
159 
StringifyShape(const Shape & s)160   string StringifyShape(const Shape& s) {
161     return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s)
162                              : ShapeUtil::HumanString(s);
163   }
164 
165   // Helpers that switch on allow_mixed_precision_.
SameElementType(const Shape & a,const Shape & b)166   bool SameElementType(const Shape& a, const Shape& b) {
167     return allow_mixed_precision_
168                ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b)
169                : ShapeUtil::SameElementType(a, b);
170   }
171 
172   // Checks that the given operand of the given instruction is of type TOKEN.
173   Status CheckIsTokenOperand(const HloInstruction* instruction,
174                              int64 operand_no);
175 
176   // Checks that the shape of the given operand of the given instruction matches
177   // the given parameter of the given computation.
178   Status CheckOperandAndParameter(const HloInstruction* instruction,
179                                   int64 operand_number,
180                                   const HloComputation* computation,
181                                   int64 parameter_number);
182 
183   // Returns true if the shapes of the two operands have the same element type,
184   // and the result shape either has the same element type as the operand shapes
185   // or mixed precision is allowed and the result shape and the operand shapes
186   // have floating point element types.
187   bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1,
188                                  const Shape& result_shape);
189 
190   // If the verifier is layout-sensitive, shapes must be equal to what's
191   // expected.  Otherwise, the shapes must simply be compatible.
192   bool layout_sensitive_;
193 
194   // Whether the inputs and output of an instruction can contain both F32s and
195   // BF16s. Tuples that include both F32s and BF16s are allowed regardless of
196   // this flag.
197   bool allow_mixed_precision_;
198 
199   // Returns a target-specific shape size.
200   std::function<int64(const Shape&)> shape_size_function_;
201 };
202 
203 // An interface used to encapsulate target-specific verification quirks.
204 class TargetVerifierMetadata {
205  public:
TargetVerifierMetadata(std::function<int64 (const Shape &)> shape_size_function)206   TargetVerifierMetadata(std::function<int64(const Shape&)> shape_size_function)
207       : shape_size_function_(shape_size_function) {}
208 
209   // Returns a target-specific shape size.
ShapeSize(const Shape & shape)210   int64 ShapeSize(const Shape& shape) const {
211     return shape_size_function_(shape);
212   }
213 
214   virtual std::unique_ptr<ShapeVerifier> GetVerifier() const = 0;
215 
TargetVerifierMetadata()216   TargetVerifierMetadata() {}
~TargetVerifierMetadata()217   virtual ~TargetVerifierMetadata() {}
218 
219   TargetVerifierMetadata(const TargetVerifierMetadata&) = delete;
220   TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete;
221 
222  protected:
223   // Returns a target-specific shape size.
224   std::function<int64(const Shape&)> shape_size_function_;
225 };
226 
227 // The default implementation of TargetVerifierMetadata, used unless the target
228 // needs to override it.
229 class DefaultVerifierMetadata : public TargetVerifierMetadata {
230  public:
DefaultVerifierMetadata(bool layout_sensitive,bool allow_mixed_precision,std::function<int64 (const Shape &)> shape_size_function)231   DefaultVerifierMetadata(
232       bool layout_sensitive, bool allow_mixed_precision,
233       std::function<int64(const Shape&)> shape_size_function)
234       : TargetVerifierMetadata(shape_size_function),
235         layout_sensitive_(layout_sensitive),
236         allow_mixed_precision_(allow_mixed_precision) {}
237 
238   // Creates a ShapeVerifier that checks that shapes match inferred
239   // expectations. This creates a new verifier every time because ShapeVerifier,
240   // being a DfsHloVisitor, is stateful. We want a clean object for each run of
241   // the verifier.
GetVerifier()242   std::unique_ptr<ShapeVerifier> GetVerifier() const override {
243     return absl::make_unique<ShapeVerifier>(
244         layout_sensitive_, allow_mixed_precision_, shape_size_function_);
245   }
246 
247  private:
248   bool layout_sensitive_;
249   bool allow_mixed_precision_;
250 };
251 
252 // HLO pass that verifies invariants of HLO instructions for each computation in
253 // the module.
254 class HloVerifier : public HloModulePass {
255  public:
256   explicit HloVerifier(
257       bool layout_sensitive, bool allow_mixed_precision,
258       std::function<bool(const HloInstruction*)>
259           instruction_can_change_layout_func = {},
260       std::function<int64(const Shape&)> shape_size_func =
261           [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); })
target_metadata_(absl::make_unique<DefaultVerifierMetadata> (layout_sensitive,allow_mixed_precision,shape_size_func))262       : target_metadata_(absl::make_unique<DefaultVerifierMetadata>(
263             layout_sensitive, allow_mixed_precision, shape_size_func)),
264         instruction_can_change_layout_func_(
265             std::move(instruction_can_change_layout_func)) {
266     CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive);
267   }
268 
269   // Uses custom target metadata
HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)270   explicit HloVerifier(std::unique_ptr<TargetVerifierMetadata> target_metadata)
271       : target_metadata_(std::move(target_metadata)) {}
272 
273   ~HloVerifier() override = default;
name()274   absl::string_view name() const override { return "verifier"; }
275 
276   // Never returns true; no instructions are ever modified by this pass.
277   StatusOr<bool> Run(HloModule* module) override;
278 
279  private:
280   std::unique_ptr<TargetVerifierMetadata> target_metadata_;
281 
282   // Determines whether an instruction can change layouts.
283   std::function<bool(const HloInstruction*)>
284       instruction_can_change_layout_func_;
285 };
286 
287 }  // namespace xla
288 
289 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
290