1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
17
18 #include <cstdint>
19
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/MLIRContext.h" // from @llvm-project
23 #include "mlir/IR/OpImplementation.h" // from @llvm-project
24 #include "mlir/Support/LogicalResult.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27
28 //===----------------------------------------------------------------------===//
29 // TableGen'd op method definitions
30 //===----------------------------------------------------------------------===//
31
32 namespace mlir {
33 namespace TF {
34 namespace {
35
GetRankedTensorType(mlir::Value val)36 RankedTensorType GetRankedTensorType(mlir::Value val) {
37 mlir::Type type = val.getType();
38 if (auto type_with_subtype =
39 mlir::getElementTypeOrSelf(val)
40 .dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>()) {
41 if (type_with_subtype.GetSubtypes().size() == 1) {
42 type = type_with_subtype.GetSubtypes().front();
43 }
44 }
45 return type.dyn_cast_or_null<RankedTensorType>();
46 }
47 } // namespace
48
verify()49 mlir::LogicalResult DTensorLayout::verify() {
50 DTensorLayout op = *this;
51 const auto& layout = op.layout();
52 if (layout.IsEmpty()) return mlir::success();
53
54 auto input_value = op.input();
55
56 RankedTensorType type = GetRankedTensorType(input_value);
57
58 if (!type) return mlir::success();
59
60 const auto& num_shards = layout.num_shards();
61 if (num_shards.size() != type.getRank()) {
62 return op.emitOpError(llvm::formatv(
63 "requires matching rank for layout and input, but got {0} as suggested "
64 "rank from layout but {1} from shape.",
65 num_shards.size(), type.getRank()));
66 }
67
68 for (const auto& dim_and_index :
69 llvm::enumerate(llvm::zip(type.getShape(), num_shards))) {
70 const int dimension_index = dim_and_index.index();
71 const auto& dim_and_shards = dim_and_index.value();
72 const int dim = std::get<0>(dim_and_shards);
73 const int num_shard_for_dim = std::get<1>(dim_and_shards);
74 if (dim <= 0) continue;
75
76 if (dim % num_shard_for_dim != 0)
77 return op.emitOpError(llvm::formatv(
78 "requires dimension {0} to be divisible by sharding "
79 "specified in DTensorLayout, but got dimension size={1} is not "
80 "divisible by number of shards in layout for this dimension={2}.",
81 dimension_index, dim, num_shard_for_dim));
82 }
83
84 return mlir::success();
85 }
86
verify()87 mlir::LogicalResult DTensorAllGatherOp::verify() {
88 DTensorAllGatherOp op = *this;
89 const tensorflow::dtensor::Layout input_layout = op.input_layout();
90 const tensorflow::dtensor::Layout output_layout = op.output_layout();
91
92 if (input_layout.rank() != output_layout.rank())
93 return op.emitOpError()
94 << "received input and output layouts of unequal ranks "
95 << input_layout.rank() << " and " << output_layout.rank();
96
97 for (int32_t i = 0; i < input_layout.rank(); ++i) {
98 if (input_layout.sharding_spec(i) != output_layout.sharding_spec(i) &&
99 tensorflow::dtensor::Layout::IsShardedDimension(
100 output_layout.sharding_spec(i))) {
101 return op.emitOpError()
102 << "dimension " << i << " of output layout has sharding spec "
103 << output_layout.sharding_spec(i)
104 << " which is more sharded then the input layout spec "
105 << input_layout.sharding_spec(i);
106 }
107 }
108
109 RankedTensorType input_type =
110 op.input().getType().dyn_cast<RankedTensorType>();
111 if (!input_type) return mlir::success();
112
113 if (input_type.getRank() != input_layout.rank())
114 return op.emitOpError()
115 << "input layout rank " << input_layout.rank()
116 << " is not equal to input rank " << input_type.getRank();
117
118 RankedTensorType output_type =
119 op.output().getType().dyn_cast<RankedTensorType>();
120 if (!output_type) return mlir::success();
121
122 if (output_type.getRank() != output_layout.rank())
123 return op.emitOpError()
124 << "output layout rank " << output_layout.rank()
125 << " is not equal to output rank " << output_type.getRank();
126
127 std::vector<int64_t> computed_output_shape =
128 output_layout.LocalShapeFromGlobalShape(
129 input_layout.GlobalShapeFromLocalShape(input_type.getShape()));
130
131 for (int32_t i = 0; i < computed_output_shape.size(); ++i) {
132 if (computed_output_shape[i] != output_type.getShape()[i]) {
133 return op.emitOpError()
134 << "computed output shape " << computed_output_shape[i]
135 << " at dimension " << i << " is not equal to actual output shape "
136 << output_type.getShape()[i];
137 }
138 }
139
140 return mlir::success();
141 }
142
verify()143 mlir::LogicalResult DTensorAllScatterOp::verify() {
144 DTensorAllScatterOp op = *this;
145 const tensorflow::dtensor::Layout input_layout = op.input_layout();
146 const tensorflow::dtensor::Layout output_layout = op.output_layout();
147
148 if (input_layout.rank() != output_layout.rank())
149 return op.emitOpError()
150 << "received input and output layouts of unequal ranks "
151 << input_layout.rank() << " and " << output_layout.rank();
152
153 for (int32_t i = 0; i < input_layout.rank(); ++i) {
154 if (input_layout.sharding_spec(i) != output_layout.sharding_spec(i) &&
155 tensorflow::dtensor::Layout::IsShardedDimension(
156 input_layout.sharding_spec(i))) {
157 return op.emitOpError()
158 << "dimension " << i << " of input layout has sharding spec "
159 << input_layout.sharding_spec(i)
160 << " which is more sharded then the output layout spec "
161 << output_layout.sharding_spec(i);
162 }
163 }
164
165 RankedTensorType input_type =
166 op.input().getType().dyn_cast<RankedTensorType>();
167 if (!input_type) return mlir::success();
168
169 if (input_type.getRank() != input_layout.rank())
170 return op.emitOpError()
171 << "input layout rank " << input_layout.rank()
172 << " is not equal to input rank " << input_type.getRank();
173
174 RankedTensorType output_type =
175 op.output().getType().dyn_cast<RankedTensorType>();
176 if (!output_type) return mlir::success();
177
178 if (output_type.getRank() != output_layout.rank())
179 return op.emitOpError()
180 << "output layout rank " << output_layout.rank()
181 << " is not equal to output rank " << output_type.getRank();
182
183 std::vector<int64_t> computed_output_shape =
184 output_layout.LocalShapeFromGlobalShape(
185 input_layout.GlobalShapeFromLocalShape(input_type.getShape()));
186
187 for (int32_t i = 0; i < computed_output_shape.size(); ++i) {
188 if (computed_output_shape[i] != output_type.getShape()[i]) {
189 return op.emitOpError()
190 << "computed output shape " << computed_output_shape[i]
191 << " at dimension " << i << " is not equal to actual output shape "
192 << output_type.getShape()[i];
193 }
194 }
195
196 return mlir::success();
197 }
198
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)199 LogicalResult DTensorLayout::inferReturnTypes(
200 MLIRContext* context, Optional<Location> location, ValueRange operands,
201 DictionaryAttr attributes, RegionRange regions,
202 SmallVectorImpl<Type>& inferredReturnTypes) {
203 assert(operands.size() == 1);
204 inferredReturnTypes.assign({operands[0].getType()});
205 return success();
206 }
207
DTensorOpAdderHook(TensorFlowDialect & dialect)208 void DTensorOpAdderHook(TensorFlowDialect& dialect) {
209 dialect.addOperations<
210 #define GET_OP_LIST
211 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.cc.inc"
212 >();
213 }
214
RegisterOnce()215 int RegisterOnce() {
216 TF_DIALECT_REGISTER_ADDITIONAL_OPERATIONS(DTensorOpAdderHook)
217 return 0;
218 }
219
RegisterDTensorTFOps()220 int RegisterDTensorTFOps() {
221 static int r = RegisterOnce();
222 return r;
223 }
224
225 } // namespace TF
226 } // namespace mlir
227
228 //===----------------------------------------------------------------------===//
229 // TableGen'd op method definitions
230 //===----------------------------------------------------------------------===//
231
232 #define GET_OP_CLASSES
233 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.cc.inc"
234