1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_META_SPMD_EXPANDER_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_META_SPMD_EXPANDER_H_ 18 19 #include "tensorflow/dtensor/cc/dstatus.h" 20 #include "tensorflow/dtensor/mlir/shape_utils.h" 21 #include "tensorflow/dtensor/mlir/spmd_expander.h" 22 23 namespace tensorflow { 24 namespace dtensor { 25 26 // Pack/Unpack (aka tf.stack/unstack) 27 // For Pack, we verify input tensors have the same layout, and produce a new 28 // tensor of rank N + 1 with an unsharded first dimension. 29 class PackSPMDExpander : public SPMDExpanderBase { 30 private: 31 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 32 33 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 34 mlir::Operation* op, 35 const llvm::DenseMap<int, Layout>& input_layouts) override; 36 37 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 38 mlir::Operation* op, 39 const llvm::DenseMap<int, Layout>& output_layouts) override; 40 }; 41 42 class UnpackSPMDExpander : public SPMDExpanderBase { 43 private: 44 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 45 46 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 47 mlir::Operation* op, 48 const llvm::DenseMap<int, Layout>& input_layouts) override; 49 50 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 51 mlir::Operation* op, 52 const llvm::DenseMap<int, Layout>& output_layouts) override; 53 }; 54 55 class PadSPMDExpander : public SPMDExpanderBase { 56 private: 57 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 58 59 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 60 mlir::Operation* op, 61 const llvm::DenseMap<int, Layout>& input_layouts) override; 62 63 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 64 mlir::Operation* op, 65 const llvm::DenseMap<int, Layout>& output_layouts) override; 66 }; 67 68 class TileSPMDExpander : public SPMDExpanderBase { 69 private: 70 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 71 72 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 73 mlir::Operation* op, 74 const llvm::DenseMap<int, Layout>& input_layouts) override; 75 76 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 77 mlir::Operation* op, 78 const llvm::DenseMap<int, Layout>& output_layouts) override; 79 }; 80 81 // SPMD expansion for reshape. 82 // 83 // If an explicit layout is provided, reshape will adjust the output to 84 // conform to the new layout. N.B. not all possible input/output shapes+layouts 85 // are implemented. 86 // 87 // A fully general reshape involves arbitrary send/recv or collective 88 // permutations, and may be inefficient. 89 // 90 // We provide special cases for a number of common cases. 91 class ReshapeSPMDExpander : public SPMDExpanderBase { 92 private: 93 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 94 95 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 96 mlir::Operation* op, 97 const llvm::DenseMap<int, Layout>& input_layouts) override; 98 99 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 100 mlir::Operation* op, 101 const llvm::DenseMap<int, Layout>& output_layouts) override; 102 }; 103 104 class TransposeSPMDExpander : public SPMDExpanderBase { 105 private: 106 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 107 108 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 109 mlir::Operation* op, 110 const llvm::DenseMap<int, Layout>& input_layouts) override; 111 112 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 113 mlir::Operation* op, 114 const llvm::DenseMap<int, Layout>& output_layouts) override; 115 }; 116 117 class OneHotSPMDExpander : public SPMDExpanderBase { 118 public: 119 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 120 121 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 122 mlir::Operation* op, 123 const llvm::DenseMap<int, Layout>& input_layouts) override; 124 125 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 126 mlir::Operation* op, 127 const llvm::DenseMap<int, Layout>& output_layouts) override; 128 }; 129 130 // SPMD expansion for shape/rank metadata operations. 131 class ShapeSPMDExpander : public SPMDExpanderBase { 132 public: 133 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 134 135 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 136 mlir::Operation* op, 137 const llvm::DenseMap<int, Layout>& input_layouts) override; 138 139 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 140 mlir::Operation* op, 141 const llvm::DenseMap<int, Layout>& output_layouts) override; 142 }; 143 144 } // namespace dtensor 145 } // namespace tensorflow 146 147 #endif // TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_META_SPMD_EXPANDER_H_ 148