• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_META_SPMD_EXPANDER_H_
17 #define TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_META_SPMD_EXPANDER_H_
18 
19 #include "tensorflow/dtensor/cc/dstatus.h"
20 #include "tensorflow/dtensor/mlir/shape_utils.h"
21 #include "tensorflow/dtensor/mlir/spmd_expander.h"
22 
23 namespace tensorflow {
24 namespace dtensor {
25 
26 // Pack/Unpack (aka tf.stack/unstack)
27 // For Pack, we verify input tensors have the same layout, and produce a new
28 // tensor of rank N + 1 with an unsharded first dimension.
29 class PackSPMDExpander : public SPMDExpanderBase {
30  private:
31   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
32 
33   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
34       mlir::Operation* op,
35       const llvm::DenseMap<int, Layout>& input_layouts) override;
36 
37   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
38       mlir::Operation* op,
39       const llvm::DenseMap<int, Layout>& output_layouts) override;
40 };
41 
42 class UnpackSPMDExpander : public SPMDExpanderBase {
43  private:
44   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
45 
46   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
47       mlir::Operation* op,
48       const llvm::DenseMap<int, Layout>& input_layouts) override;
49 
50   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
51       mlir::Operation* op,
52       const llvm::DenseMap<int, Layout>& output_layouts) override;
53 };
54 
55 class PadSPMDExpander : public SPMDExpanderBase {
56  private:
57   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
58 
59   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
60       mlir::Operation* op,
61       const llvm::DenseMap<int, Layout>& input_layouts) override;
62 
63   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
64       mlir::Operation* op,
65       const llvm::DenseMap<int, Layout>& output_layouts) override;
66 };
67 
68 class TileSPMDExpander : public SPMDExpanderBase {
69  private:
70   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
71 
72   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
73       mlir::Operation* op,
74       const llvm::DenseMap<int, Layout>& input_layouts) override;
75 
76   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
77       mlir::Operation* op,
78       const llvm::DenseMap<int, Layout>& output_layouts) override;
79 };
80 
81 // SPMD expansion for reshape.
82 //
83 // If an explicit layout is provided, reshape will adjust the output to
84 // conform to the new layout. N.B. not all possible input/output shapes+layouts
85 // are implemented.
86 //
87 // A fully general reshape involves arbitrary send/recv or collective
88 // permutations, and may be inefficient.
89 //
90 // We provide special cases for a number of common cases.
91 class ReshapeSPMDExpander : public SPMDExpanderBase {
92  private:
93   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
94 
95   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
96       mlir::Operation* op,
97       const llvm::DenseMap<int, Layout>& input_layouts) override;
98 
99   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
100       mlir::Operation* op,
101       const llvm::DenseMap<int, Layout>& output_layouts) override;
102 };
103 
104 class TransposeSPMDExpander : public SPMDExpanderBase {
105  private:
106   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
107 
108   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
109       mlir::Operation* op,
110       const llvm::DenseMap<int, Layout>& input_layouts) override;
111 
112   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
113       mlir::Operation* op,
114       const llvm::DenseMap<int, Layout>& output_layouts) override;
115 };
116 
117 class OneHotSPMDExpander : public SPMDExpanderBase {
118  public:
119   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
120 
121   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
122       mlir::Operation* op,
123       const llvm::DenseMap<int, Layout>& input_layouts) override;
124 
125   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
126       mlir::Operation* op,
127       const llvm::DenseMap<int, Layout>& output_layouts) override;
128 };
129 
130 // SPMD expansion for shape/rank metadata operations.
131 class ShapeSPMDExpander : public SPMDExpanderBase {
132  public:
133   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
134 
135   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
136       mlir::Operation* op,
137       const llvm::DenseMap<int, Layout>& input_layouts) override;
138 
139   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
140       mlir::Operation* op,
141       const llvm::DenseMap<int, Layout>& output_layouts) override;
142 };
143 
144 }  // namespace dtensor
145 }  // namespace tensorflow
146 
147 #endif  // TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_META_SPMD_EXPANDER_H_
148