• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
17 
18 #include <cstdint>
19 
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
24 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 
28 //===----------------------------------------------------------------------===//
29 // TableGen'd op method definitions
30 //===----------------------------------------------------------------------===//
31 
32 namespace mlir {
33 namespace TF {
34 namespace {
35 
GetRankedTensorType(mlir::Value val)36 RankedTensorType GetRankedTensorType(mlir::Value val) {
37   mlir::Type type = val.getType();
38   if (auto type_with_subtype =
39           mlir::getElementTypeOrSelf(val)
40               .dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>()) {
41     if (type_with_subtype.GetSubtypes().size() == 1) {
42       type = type_with_subtype.GetSubtypes().front();
43     }
44   }
45   return type.dyn_cast_or_null<RankedTensorType>();
46 }
47 }  // namespace
48 
verify()49 mlir::LogicalResult DTensorLayout::verify() {
50   DTensorLayout op = *this;
51   const auto& layout = op.layout();
52   if (layout.IsEmpty()) return mlir::success();
53 
54   auto input_value = op.input();
55 
56   RankedTensorType type = GetRankedTensorType(input_value);
57 
58   if (!type) return mlir::success();
59 
60   const auto& num_shards = layout.num_shards();
61   if (num_shards.size() != type.getRank()) {
62     return op.emitOpError(llvm::formatv(
63         "requires matching rank for layout and input, but got {0} as suggested "
64         "rank from layout but {1} from shape.",
65         num_shards.size(), type.getRank()));
66   }
67 
68   for (const auto& dim_and_index :
69        llvm::enumerate(llvm::zip(type.getShape(), num_shards))) {
70     const int dimension_index = dim_and_index.index();
71     const auto& dim_and_shards = dim_and_index.value();
72     const int dim = std::get<0>(dim_and_shards);
73     const int num_shard_for_dim = std::get<1>(dim_and_shards);
74     if (dim <= 0) continue;
75 
76     if (dim % num_shard_for_dim != 0)
77       return op.emitOpError(llvm::formatv(
78           "requires dimension {0} to be divisible by sharding "
79           "specified in DTensorLayout, but got dimension size={1} is not "
80           "divisible by number of shards in layout for this dimension={2}.",
81           dimension_index, dim, num_shard_for_dim));
82   }
83 
84   return mlir::success();
85 }
86 
verify()87 mlir::LogicalResult DTensorAllGatherOp::verify() {
88   DTensorAllGatherOp op = *this;
89   const tensorflow::dtensor::Layout input_layout = op.input_layout();
90   const tensorflow::dtensor::Layout output_layout = op.output_layout();
91 
92   if (input_layout.rank() != output_layout.rank())
93     return op.emitOpError()
94            << "received input and output layouts of unequal ranks "
95            << input_layout.rank() << " and " << output_layout.rank();
96 
97   for (int32_t i = 0; i < input_layout.rank(); ++i) {
98     if (input_layout.sharding_spec(i) != output_layout.sharding_spec(i) &&
99         tensorflow::dtensor::Layout::IsShardedDimension(
100             output_layout.sharding_spec(i))) {
101       return op.emitOpError()
102              << "dimension " << i << " of output layout has sharding spec "
103              << output_layout.sharding_spec(i)
104              << " which is more sharded then the input layout spec "
105              << input_layout.sharding_spec(i);
106     }
107   }
108 
109   RankedTensorType input_type =
110       op.input().getType().dyn_cast<RankedTensorType>();
111   if (!input_type) return mlir::success();
112 
113   if (input_type.getRank() != input_layout.rank())
114     return op.emitOpError()
115            << "input layout rank " << input_layout.rank()
116            << " is not equal to input rank " << input_type.getRank();
117 
118   RankedTensorType output_type =
119       op.output().getType().dyn_cast<RankedTensorType>();
120   if (!output_type) return mlir::success();
121 
122   if (output_type.getRank() != output_layout.rank())
123     return op.emitOpError()
124            << "output layout rank " << output_layout.rank()
125            << " is not equal to output rank " << output_type.getRank();
126 
127   std::vector<int64_t> computed_output_shape =
128       output_layout.LocalShapeFromGlobalShape(
129           input_layout.GlobalShapeFromLocalShape(input_type.getShape()));
130 
131   for (int32_t i = 0; i < computed_output_shape.size(); ++i) {
132     if (computed_output_shape[i] != output_type.getShape()[i]) {
133       return op.emitOpError()
134              << "computed output shape " << computed_output_shape[i]
135              << " at dimension " << i << " is not equal to actual output shape "
136              << output_type.getShape()[i];
137     }
138   }
139 
140   return mlir::success();
141 }
142 
verify()143 mlir::LogicalResult DTensorAllScatterOp::verify() {
144   DTensorAllScatterOp op = *this;
145   const tensorflow::dtensor::Layout input_layout = op.input_layout();
146   const tensorflow::dtensor::Layout output_layout = op.output_layout();
147 
148   if (input_layout.rank() != output_layout.rank())
149     return op.emitOpError()
150            << "received input and output layouts of unequal ranks "
151            << input_layout.rank() << " and " << output_layout.rank();
152 
153   for (int32_t i = 0; i < input_layout.rank(); ++i) {
154     if (input_layout.sharding_spec(i) != output_layout.sharding_spec(i) &&
155         tensorflow::dtensor::Layout::IsShardedDimension(
156             input_layout.sharding_spec(i))) {
157       return op.emitOpError()
158              << "dimension " << i << " of input layout has sharding spec "
159              << input_layout.sharding_spec(i)
160              << " which is more sharded then the output layout spec "
161              << output_layout.sharding_spec(i);
162     }
163   }
164 
165   RankedTensorType input_type =
166       op.input().getType().dyn_cast<RankedTensorType>();
167   if (!input_type) return mlir::success();
168 
169   if (input_type.getRank() != input_layout.rank())
170     return op.emitOpError()
171            << "input layout rank " << input_layout.rank()
172            << " is not equal to input rank " << input_type.getRank();
173 
174   RankedTensorType output_type =
175       op.output().getType().dyn_cast<RankedTensorType>();
176   if (!output_type) return mlir::success();
177 
178   if (output_type.getRank() != output_layout.rank())
179     return op.emitOpError()
180            << "output layout rank " << output_layout.rank()
181            << " is not equal to output rank " << output_type.getRank();
182 
183   std::vector<int64_t> computed_output_shape =
184       output_layout.LocalShapeFromGlobalShape(
185           input_layout.GlobalShapeFromLocalShape(input_type.getShape()));
186 
187   for (int32_t i = 0; i < computed_output_shape.size(); ++i) {
188     if (computed_output_shape[i] != output_type.getShape()[i]) {
189       return op.emitOpError()
190              << "computed output shape " << computed_output_shape[i]
191              << " at dimension " << i << " is not equal to actual output shape "
192              << output_type.getShape()[i];
193     }
194   }
195 
196   return mlir::success();
197 }
198 
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)199 LogicalResult DTensorLayout::inferReturnTypes(
200     MLIRContext* context, Optional<Location> location, ValueRange operands,
201     DictionaryAttr attributes, RegionRange regions,
202     SmallVectorImpl<Type>& inferredReturnTypes) {
203   assert(operands.size() == 1);
204   inferredReturnTypes.assign({operands[0].getType()});
205   return success();
206 }
207 
DTensorOpAdderHook(TensorFlowDialect & dialect)208 void DTensorOpAdderHook(TensorFlowDialect& dialect) {
209   dialect.addOperations<
210 #define GET_OP_LIST
211 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.cc.inc"
212       >();
213 }
214 
RegisterOnce()215 int RegisterOnce() {
216   TF_DIALECT_REGISTER_ADDITIONAL_OPERATIONS(DTensorOpAdderHook)
217   return 0;
218 }
219 
RegisterDTensorTFOps()220 int RegisterDTensorTFOps() {
221   static int r = RegisterOnce();
222   return r;
223 }
224 
225 }  // namespace TF
226 }  // namespace mlir
227 
228 //===----------------------------------------------------------------------===//
229 // TableGen'd op method definitions
230 //===----------------------------------------------------------------------===//
231 
232 #define GET_OP_CLASSES
233 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.cc.inc"
234