• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass hoists a `tf_device.launch` body and assigns a `device` attribute
17 // to each TensorFlow dialect op in the body based on the `device` attribute on
18 // the `tf_device.launch`. If a TensorFlow dialect op already has a device
19 // attribute, that attribute will be overwritten with the `tf_device.launch`
20 // device.
21 //
22 // For example:
23 //   %island:5 = tf_executor.island {
24 //     %a = "tf.opA"() : () -> tensor<i1>
25 //     %launch:2 = "tf_device.launch"() ( {
26 //       %b = "tf.opB"() : () -> tensor<i32>
27 //       %c = "tf.opC"() : () -> tensor<f32>
28 //       tf_device.return %c, %b : tensor<f32>, tensor<i32>
29 //     }) {device = "CPU:0"} : () -> (tensor<f32>, tensor<i32>)
30 //     %d = "tf.opD"() : () -> tensor<i1>
31 //     tf_executor.yield %a, %launch#0, %launch#1, %d :
32 //                       tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
33 //   }
34 //
35 // Will be transformed into:
36 //   %island:5 = tf_executor.island {
37 //     %a = "tf.opA"() : () -> tensor<i1>
38 //     %b = "tf.opB"() {device = "CPU:0"} : () -> tensor<i32>
39 //     %c = "tf.opC"() {device = "CPU:0"} : () -> tensor<f32>
40 //     %d = "tf.opD"() : () -> tensor<i1>
41 //     tf_executor.yield %a, %c, %b, %d :
42 //                       tensor<i1>, tensor<f32>, tensor<i32>, tensor<i1>
43 //   }
44 
45 #include "llvm/Support/Casting.h"
46 #include "mlir/IR/Attributes.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
48 #include "mlir/IR/Dialect.h"  // from @llvm-project
49 #include "mlir/IR/Operation.h"  // from @llvm-project
50 #include "mlir/IR/Visitors.h"  // from @llvm-project
51 #include "mlir/Pass/Pass.h"  // from @llvm-project
52 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
54 
55 namespace mlir {
56 namespace TFDevice {
57 namespace {
58 constexpr char kDeviceAttr[] = "device";
59 
60 struct LaunchToDeviceAttributePass
61     : public PassWrapper<LaunchToDeviceAttributePass, FunctionPass> {
62   void runOnFunction() override;
63 };
64 
65 // Assign all ops in region with specified device from launch.
AssignDevicesInRegion(const Dialect * tf_dialect,tf_device::LaunchOp launch,Region & region)66 LogicalResult AssignDevicesInRegion(const Dialect* tf_dialect,
67                                     tf_device::LaunchOp launch,
68                                     Region& region) {
69   auto result = region.walk([&](Operation* op) -> WalkResult {
70     if (op->getDialect() != tf_dialect) return WalkResult::advance();
71 
72     auto device_attr = op->getAttr(kDeviceAttr);
73     if (!device_attr) {
74       op->setAttr(kDeviceAttr, launch.deviceAttr());
75       return WalkResult::advance();
76     }
77 
78     if (auto device_str_attr = device_attr.dyn_cast<StringAttr>()) {
79       if (device_str_attr.getValue().empty()) {
80         op->setAttr(kDeviceAttr, launch.deviceAttr());
81         return WalkResult::advance();
82       } else if (device_str_attr.getValue() != launch.device()) {
83         return launch.emitOpError()
84                << "inner op has conflicting 'device' attribute, "
85                   "got '"
86                << device_str_attr.getValue() << "' but expected '"
87                << launch.device() << "'";
88       }
89     } else {
90       return launch.emitOpError()
91              << "inner op has bad 'device' attribute, got " << device_attr;
92     }
93 
94     return WalkResult::advance();
95   });
96 
97   return failure(result.wasInterrupted());
98 }
99 
HoistOpsAndAnnotateWithDevice(const Dialect * tf_dialect,tf_device::LaunchOp launch)100 LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect,
101                                             tf_device::LaunchOp launch) {
102   // Forward launch inner op results to launch op results.
103   launch.replaceAllUsesWith(launch.GetBody().getTerminator()->getOperands());
104 
105   // For all inner ops, assign the launch device as a `device` attribute.
106   if (failed(AssignDevicesInRegion(tf_dialect, launch, launch.body())))
107     return failure();
108 
109   // Move all inner ops of the launch to the block containing the launch.
110   auto body = launch.GetBody().without_terminator();
111   Operation* launch_op = launch.getOperation();
112   launch_op->getBlock()->getOperations().splice(
113       launch_op->getIterator(), launch.GetBody().getOperations(), body.begin(),
114       body.end());
115 
116   launch.erase();
117 
118   return success();
119 }
120 
runOnFunction()121 void LaunchToDeviceAttributePass::runOnFunction() {
122   const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
123   if (!tf_dialect) {
124     getOperation().emitError() << "'tf' dialect is not registered";
125     return signalPassFailure();
126   }
127 
128   auto result = getOperation().walk([&tf_dialect](tf_device::LaunchOp launch) {
129     if (failed(HoistOpsAndAnnotateWithDevice(tf_dialect, launch)))
130       return WalkResult::interrupt();
131 
132     return WalkResult::advance();
133   });
134 
135   if (result.wasInterrupted()) return signalPassFailure();
136 }
137 
138 }  // anonymous namespace
139 
CreateLaunchToDeviceAttributePass()140 std::unique_ptr<OperationPass<FuncOp>> CreateLaunchToDeviceAttributePass() {
141   return std::make_unique<LaunchToDeviceAttributePass>();
142 }
143 
144 static PassRegistration<LaunchToDeviceAttributePass> pass(
145     "tf-launch-to-device-attribute",
146     "Hoists and annotates device launch inner ops with associated device "
147     "attribute");
148 
149 }  // namespace TFDevice
150 }  // namespace mlir
151