• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the LMHLO dialect.
17 
18 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
19 
20 #include <assert.h>
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/APInt.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/FormatVariadic.h"
32 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
33 #include "mlir/Dialect/StandardOps/IR/Ops.h"
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/Builders.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/Dialect.h"
38 #include "mlir/IR/Location.h"
39 #include "mlir/IR/MLIRContext.h"
40 #include "mlir/IR/OpDefinition.h"
41 #include "mlir/IR/OpImplementation.h"
42 #include "mlir/IR/Operation.h"
43 #include "mlir/IR/OperationSupport.h"
44 #include "mlir/IR/PatternMatch.h"
45 #include "mlir/IR/TypeUtilities.h"
46 #include "mlir/IR/Types.h"
47 #include "mlir/IR/Value.h"
48 
49 namespace mlir {
50 namespace lmhlo {
51 
LmhloDialect(MLIRContext * context)52 LmhloDialect::LmhloDialect(MLIRContext* context)
53     : Dialect(getDialectNamespace(), context, TypeID::get<LmhloDialect>()) {
54   addOperations<
55 #define GET_OP_LIST
56 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
57       >();
58 }
59 
60 // Verifies replica groups attached to collective communication operations.
61 // If the attribute is not empty, it must be a rank 2 tensor, and each replica
62 // should appear exactly once. If `is_uniform_sized` is true, then we also check
63 // that each group is of the same size. If the operation has
64 // `use_global_device_id` set, then replica group cannot be empty.
65 template <typename OpT>
VerifyReplicaGroups(OpT op,bool is_uniform_sized)66 LogicalResult VerifyReplicaGroups(OpT op, bool is_uniform_sized) {
67   DenseIntElementsAttr attr = op.replica_groups();
68   auto replica_group_type = attr.getType().dyn_cast<RankedTensorType>();
69   if (!replica_group_type || replica_group_type.getRank() != 2 ||
70       !replica_group_type.getElementType().isInteger(/*width=*/64))
71     return op.emitOpError(
72         "replica groups should be a rank 2 tensor of 64 bit integers");
73 
74   if (replica_group_type.getShape().equals(ArrayRef<int64_t>{0, 0}))
75     return success();
76 
77   int64_t max_replica_id_seen = 0;
78   llvm::SmallSet<int64_t, 8> replica_seen;
79   for (int64_t id : attr.getValues<int64_t>()) {
80     if (is_uniform_sized && id == -1) {
81       return op.emitOpError("Invalid replica id -1");
82     }
83     if (id != -1) {
84       if (!replica_seen.insert(id).second) {
85         return op.emitOpError("replica id #") << id << " seen more than once";
86       }
87       max_replica_id_seen = std::max(max_replica_id_seen, id);
88     }
89   }
90 
91   for (int64_t id = 0; id <= max_replica_id_seen; id++) {
92     if (!replica_seen.contains(id)) {
93       return op.emitOpError("replica id #")
94              << id << " not seen in replica groups";
95     }
96   }
97   return success();
98 }
99 
100 // TODO(jurahul): Add verification for output shape.
Verify(AllGatherOp op)101 static LogicalResult Verify(AllGatherOp op) {
102   return VerifyReplicaGroups(op, /*is_uniform_sized=*/true);
103 }
104 
105 // TODO(jurahul): Add verification for output shape.
Verify(AllToAllOp op)106 static LogicalResult Verify(AllToAllOp op) {
107   return VerifyReplicaGroups(op, /*is_uniform_sized=*/true);
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // AllReduceOp
112 //===----------------------------------------------------------------------===//
113 
Verify(AllReduceOp op)114 static LogicalResult Verify(AllReduceOp op) {
115   if (failed(VerifyReplicaGroups(op, /*is_uniform_sized=*/false)))
116     return failure();
117 
118   // AllReduce had variadic operands and results that have the same size.
119   // Each memeber of the operand should have the same type as the corresponding
120   // member of the result.
121   for (auto it : llvm::enumerate(
122            llvm::zip(op.operands().getTypes(), op.results().getTypes()))) {
123     Type operandType = std::get<0>(it.value());
124     Type resultType = std::get<1>(it.value());
125     if (operandType != resultType)
126       return op.emitOpError("requires operand #")
127              << it.index() << " (type: " << operandType << ") and result #"
128              << it.index() << " (type: " << resultType << ") to have same type";
129   }
130   return success();
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // ConstOp.
135 //===----------------------------------------------------------------------===//
136 
137 /// An lho.constant on an memref that is locally allocated and with no other
138 /// users (other than dealloc's) can be erased.
139 // TODO: This can be generalized to an arbitrary op by making use of memory
140 // effects (write memory effect).
141 struct EraseConstOp : public OpRewritePattern<ConstOp> {
142   using OpRewritePattern<ConstOp>::OpRewritePattern;
143 
matchAndRewritemlir::lmhlo::EraseConstOp144   LogicalResult matchAndRewrite(ConstOp op,
145                                 PatternRewriter& rewriter) const override {
146     Value memref = op.output();
147     if (!memref.getDefiningOp<AllocOp>()) {
148       return failure();
149     }
150 
151     // Check that all uses of the memref are either DeallocOps or this op.
152     for (Operation* user : memref.getUsers())
153       if (user != op && !isa<DeallocOp>(user)) return failure();
154 
155     rewriter.eraseOp(op);
156     return success();
157   }
158 };
159 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)160 void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
161                                           MLIRContext* context) {
162   results.insert<EraseConstOp>(context);
163 }
164 
165 }  // namespace lmhlo
166 }  // namespace mlir
167 
168 #define GET_OP_CLASSES
169 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
170 
171 namespace mlir {
172 namespace lmhlo {
173 
174 // TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
175 
build(OpBuilder & builder,OperationState & result,ArrayRef<NamedAttribute> attributes)176 void FusionOp::build(OpBuilder& builder, OperationState& result,
177                      ArrayRef<NamedAttribute> attributes) {
178   result.addAttributes(attributes);
179   Region* bodyRegion = result.addRegion();
180   FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
181 }
182 
183 }  // namespace lmhlo
184 }  // namespace mlir
185