1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Location.h" // from @llvm-project
26 #include "mlir/IR/MLIRContext.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/OperationSupport.h" // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
30 #include "mlir/IR/Types.h" // from @llvm-project
31 #include "mlir/IR/UseDefLists.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
43 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
44 #include "tensorflow/core/util/device_name_utils.h"
45
46 namespace mlir {
47 namespace TFTPU {
48
49 namespace {
50
51 constexpr char kDeviceAttr[] = "device";
52 constexpr char kDeviceCPU[] = "CPU";
53 constexpr char kFuncDeviceAttr[] = "tf.device";
54
55 // A pass that allows TPU input layout to be determined after JIT compilation.
56 // This is done by adding run-time ops that interpret compilation result and
57 // copy the input to device with that layout.
58 //
59 // Example: original program:
60 //
61 // %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
62 // %compile:2 = "tf._TPUCompileMlir"(...)
63 // %execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
64 //
65 // Without this pass, later TF graph partitioning passes will insert send/recv
66 // between %input and %execute and data will be copied to device in a fixed
67 // layout. With this pass, the program will be transformed into:
68 //
69 // %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
70 // %compile:2 = "tf._TPUCompileMlir"(...)
71 // %get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
72 // %copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
73 // {device = "/TPU:0"}
74 // %execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
75 // {device = "/TPU:0"}
76 //
77 // This way, %compile will determine the layout, which will be respected by
78 // %copy_to_device. There will not be send/recv ops added by later passes,
79 // because tf.TPUCopyWithLayout accepts a host input and produces a device
80 // output.
81 struct TPUDynamicLayoutPass
82 : public TF::PerFunctionAggregateAnalysisConsumerPass<
83 TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
84 void runOnFunction(
85 FuncOp func,
86 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
87
getArgumentmlir::TFTPU::__anoncde9fad20111::TPUDynamicLayoutPass88 StringRef getArgument() const final { return "tf-tpu-dynamic-layout-pass"; }
89
getDescriptionmlir::TFTPU::__anoncde9fad20111::TPUDynamicLayoutPass90 StringRef getDescription() const final {
91 return "Adds ops that allow TPU program inputs to have layouts determined "
92 "at JIT compile time.";
93 }
94 };
95
96 // Checks if the input producer op is supported in this transform. Right now, we
97 // only check if it is a tf.IteratorGetNext where resource input is coming from
98 // a VarHandle on CPU or a function argument assigned to CPU.
IsSupportedInputOp(Operation * op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)99 bool IsSupportedInputOp(
100 Operation* op,
101 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
102 TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
103 if (!iterator_op) return false;
104
105 Value resource_iterator = iterator_op.iterator();
106
107 if (resource_alias_analysis.IsUnknownResource(resource_iterator))
108 return false;
109 llvm::SmallSetVector<Value, 8> aliases =
110 resource_alias_analysis.GetResourceAliases(resource_iterator);
111
112 auto is_generator = [](Value val) {
113 if (val.isa<BlockArgument>()) return true;
114 Operation* definition = val.getDefiningOp();
115 return definition->getNumOperands() == 0 &&
116 definition->getNumResults() == 1;
117 };
118
119 // Check all generator aliases (ops or function argument) are on CPU.
120 FuncOp func = iterator_op->getParentOfType<FuncOp>();
121 return llvm::all_of(aliases, [&](Value alias) {
122 // Ignore non-generator aliases.
123 if (!is_generator(alias)) return true;
124
125 StringAttr device;
126 if (auto arg = alias.dyn_cast<BlockArgument>()) {
127 device = func.getArgAttrOfType<mlir::StringAttr>(arg.getArgNumber(),
128 kFuncDeviceAttr);
129 } else {
130 device = alias.getDefiningOp()->getAttrOfType<StringAttr>(kDeviceAttr);
131 }
132
133 if (!device) return false;
134 tensorflow::DeviceNameUtils::ParsedName parsed_device;
135 if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
136 &parsed_device)) {
137 return false;
138 }
139 return parsed_device.has_type && parsed_device.type == kDeviceCPU;
140 });
141 }
142
CreateBuilderAfterOp(Operation * op)143 OpBuilder CreateBuilderAfterOp(Operation* op) {
144 return OpBuilder(op->getBlock(), ++Block::iterator(op));
145 }
146
147 // Builds a TPUGetLayoutOp with the given compile op and input index.
BuildGetLayout(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp compile_launch,OpBuilder * builder)148 TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index,
149 Value compilation_key,
150 tf_device::LaunchOp compile_launch,
151 OpBuilder* builder) {
152 return builder->create<TF::TPUGetLayoutOp>(
153 compile_launch.getLoc(),
154 llvm::ArrayRef<Type>{RankedTensorType::get({ShapedType::kDynamicSize},
155 builder->getIntegerType(64))},
156 llvm::ArrayRef<Value>{compilation_key},
157 llvm::ArrayRef<NamedAttribute>{
158 builder->getNamedAttr("index",
159 builder->getI64IntegerAttr(execute_arg_index)),
160 builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
161 }
162
163 // Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
BuildCopyWithLayout(tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,TF::TPUGetLayoutOp get_layout,Value input,OpBuilder * builder)164 TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
165 tf_device::LaunchOp compile_launch,
166 TF::TPUGetLayoutOp get_layout,
167 Value input, OpBuilder* builder) {
168 return builder->create<TF::TPUCopyWithLayoutOp>(
169 execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
170 llvm::ArrayRef<Value>{input, get_layout.layout()});
171 }
172
173 // Performs transformation for a non-replicated input.
HandleInput(Value input,const int64_t execute_arg_index,TF::TPUExecuteOp execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch)174 void HandleInput(Value input, const int64_t execute_arg_index,
175 TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch,
176 tf_device::LaunchOp compile_launch) {
177 OpBuilder builder = CreateBuilderAfterOp(compile_launch);
178 auto get_layout = BuildGetLayout(execute_arg_index, execute.key(),
179 compile_launch, &builder);
180 builder.setInsertionPoint(execute_launch);
181 auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
182 get_layout, input, &builder);
183 copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
184 execute.setOperand(execute_arg_index, copy_with_layout);
185 }
186
187 // Performs transformation for replicated inputs. Returns true if this is a
188 // supported case (thus transform happened).
HandleReplicatedInputs(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,mlir::BlockArgument replicate_arg,tf_device::ReplicateOp replicate,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)189 bool HandleReplicatedInputs(
190 const int64_t execute_arg_index, Value compilation_key,
191 tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
192 mlir::BlockArgument replicate_arg, tf_device::ReplicateOp replicate,
193 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
194 // We need to know the devices to copy to.
195 if (!replicate.devices()) return false;
196
197 MutableArrayRef<OpOperand> inputs =
198 replicate.GetOperandsForBlockArgument(replicate_arg);
199 for (auto entry : llvm::enumerate(inputs)) {
200 auto input_op = entry.value().get().getDefiningOp();
201 if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
202 return false;
203 }
204 OpBuilder builder = CreateBuilderAfterOp(compile_launch);
205 auto get_layout = BuildGetLayout(execute_arg_index, compilation_key,
206 compile_launch, &builder);
207 builder.setInsertionPoint(replicate);
208 for (auto entry : llvm::enumerate(inputs)) {
209 auto copy_with_layout =
210 BuildCopyWithLayout(execute_launch, compile_launch, get_layout,
211 entry.value().get(), &builder);
212
213 auto device_list = replicate.devices()
214 .getValue()
215 .get(execute_launch.getDevice())
216 .cast<ArrayAttr>();
217 copy_with_layout->setAttr(kDeviceAttr,
218 device_list.getValue()[entry.index()]);
219
220 entry.value().set(copy_with_layout);
221 }
222 return true;
223 }
224
225 // Performs transformation on a compile and associated execute(s) ops. The
226 // compile should not have other uses.
HandleCompileAndExecutes(tf_device::LaunchOp compile_launch,llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)227 void HandleCompileAndExecutes(
228 tf_device::LaunchOp compile_launch,
229 llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
230 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
231 auto compile =
232 llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
233 tensorflow::tpu::TPUCompileMetadataProto metadata;
234 metadata.ParseFromString(compile.metadata().str());
235 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings =
236 tensorflow::GetMetadataArgumentMapping(metadata);
237
238 bool metadata_updated = false;
239 auto maybe_replicate =
240 execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
241
242 for (auto execute_and_input_mapping :
243 llvm::zip(execute_launches, input_mappings)) {
244 auto& execute_launch = std::get<0>(execute_and_input_mapping);
245 auto execute =
246 llvm::cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
247 const auto& input_mapping = std::get<1>(execute_and_input_mapping);
248
249 for (auto& input_and_idx : llvm::enumerate(execute.args())) {
250 Value input = input_and_idx.value();
251 const int64_t execute_arg_index = input_and_idx.index();
252 if (auto block_arg = input.dyn_cast<BlockArgument>()) {
253 // For a block argument, consider transforms only when it is a
254 // replicated input (defining ops will be outside the replicate node).
255 if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
256 !HandleReplicatedInputs(execute_arg_index, execute.key(),
257 execute_launch, compile_launch, block_arg,
258 maybe_replicate, resource_alias_analysis)) {
259 continue;
260 }
261 } else {
262 // For an op output, consider transforms only when 1) there is no
263 // replication or 2) it is outside the replicate node that encloses the
264 // execute node. (Because if the op is inside replicate, it is probably
265 // not on the host.)
266 auto* input_op = input.getDefiningOp();
267 if (maybe_replicate &&
268 maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
269 continue;
270 }
271 if (!IsSupportedInputOp(input_op, resource_alias_analysis)) continue;
272 HandleInput(input, execute_arg_index, execute, execute_launch,
273 compile_launch);
274 }
275
276 metadata.mutable_args(input_mapping[execute_arg_index])
277 ->set_unrestricted_layout(true);
278 metadata_updated = true;
279 }
280 }
281
282 if (metadata_updated)
283 compile->setAttr("metadata", StringAttr::get(compile.getContext(),
284 metadata.SerializeAsString()));
285 }
286
runOnFunction(FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)287 void TPUDynamicLayoutPass::runOnFunction(
288 FuncOp func,
289 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
290 func.walk([&](TF::_TPUCompileMlirOp compile) {
291 // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
292 auto compile_launch =
293 llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
294 if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
295
296 llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
297 execute_launches.reserve(compile_launch.getNumResults() - 1);
298 for (Value program_result : llvm::drop_begin(compile_launch.results(), 1)) {
299 if (!program_result.hasOneUse()) return;
300 Operation* user = *program_result.user_begin();
301 auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
302 if (!execute) return;
303 auto execute_launch =
304 llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
305 if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
306 execute_launches.push_back(execute_launch);
307 }
308
309 HandleCompileAndExecutes(compile_launch, execute_launches,
310 resource_alias_analysis);
311 });
312 }
313
314 } // namespace
315
CreateTPUDynamicLayoutPass()316 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
317 return std::make_unique<TPUDynamicLayoutPass>();
318 }
319
320 static PassRegistration<TPUDynamicLayoutPass> pass;
321
322 } // namespace TFTPU
323 } // namespace mlir
324