• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
18 
19 #include "tensorflow/compiler/xla/service/op_expander_pass.h"
20 
21 namespace xla {
22 
23 // This pass rewrites scatter operations into (roughly) while loops of
24 // dynamic-update-slices.
25 //
26 // This pass can be used in two ways:
27 //
28 //   - kEliminateAllScatters: For backends that don't support scatter, this pass
29 //     can convert every scatter into a loop.
30 //
31 //   - kEliminateSimpleScatters: For backends that *do* support scatter, this
32 //     pass can strength-reduce "simple" scatters -- specifically, scatters that
33 //     can be represented without a loop -- to dynamic-update-slices.
34 //
35 // Note that even in kEliminateSimpleScatters mode, this pass may still expand a
36 // scatter into a loop (with a trip-count of 1).  It's up to other
37 // simplification passes to remove the loop.
38 class ScatterExpander : public OpExpanderPass {
39  public:
40   enum Mode {
41     kEliminateAllScatters,
42     kEliminateSimpleScatters,
43   };
44 
ScatterExpander(Mode m)45   explicit ScatterExpander(Mode m) : mode_(m) {}
46 
name()47   absl::string_view name() const override { return "scatter_expander"; }
48 
49  protected:
50   bool InstructionMatchesPattern(HloInstruction* inst) override;
51 
52   StatusOr<HloInstruction*> ExpandInstruction(HloInstruction* scatter) override;
53 
54  private:
55   Mode mode_;
56 };
57 
58 }  // namespace xla
59 
60 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
61