• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Error.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/Regex.h"
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/device_set.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 
39 constexpr char kDevicesAttr[] = "tf.devices";
40 
41 namespace {
42 
43 // Parse GPU compute capability from physical device description. If compute
44 // capability is not found in device description, return an empty dictionary
45 // attribute.
ParseGpuDeviceMetadata(const Device & device,mlir::Builder * builder)46 mlir::DictionaryAttr ParseGpuDeviceMetadata(const Device& device,
47                                             mlir::Builder* builder) {
48   // Parse GPU device compute capability from physical device description.
49   static auto* r = new llvm::Regex("compute capability: ([0-9]+)\\.([0-9]+)");
50 
51   llvm::SmallVector<llvm::StringRef, 3> cc;
52   if (r->match(device.attributes().physical_device_desc(), &cc)) {
53     return mlir::TF::GpuDeviceMetadata::get(
54         builder->getI32IntegerAttr(std::stoi(cc[1].str())),
55         builder->getI32IntegerAttr(std::stoi(cc[2].str())),
56         builder->getContext());
57   }
58 
59   return builder->getDictionaryAttr({});
60 }
61 
62 // Get devices from an array of string attributes.
63 // TODO(ezhulenev): Update all tests to use dictionary attribute for
64 // `tf.devices` and remove this function.
GetDevicesFromOp(mlir::Operation * op,mlir::ArrayAttr array_attr,mlir::TF::RuntimeDevices * devices)65 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
66                                      mlir::ArrayAttr array_attr,
67                                      mlir::TF::RuntimeDevices* devices) {
68   DeviceNameUtils::ParsedName device;
69 
70   for (auto& kv : llvm::enumerate(array_attr)) {
71     const int idx = kv.index();
72 
73     auto string_attr = kv.value().dyn_cast<mlir::StringAttr>();
74     if (!string_attr)
75       return op->emitOpError(llvm::formatv(
76           "bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
77 
78     if (DeviceNameUtils::ParseFullName(string_attr.getValue().str(), &device)) {
79       devices->AddDevice(device);
80     } else {
81       return op->emitOpError(
82           llvm::formatv("bad '{0}' attribute, '{1}', not a valid device",
83                         kDevicesAttr, string_attr.getValue()));
84     }
85   }
86 
87   return mlir::success();
88 }
89 
90 // Get devices from a dictionary attribute.
GetDevicesFromOp(mlir::Operation * op,mlir::DictionaryAttr dict_attr,mlir::TF::RuntimeDevices * devices)91 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
92                                      mlir::DictionaryAttr dict_attr,
93                                      mlir::TF::RuntimeDevices* devices) {
94   DeviceNameUtils::ParsedName device;
95 
96   // Parse device names and metadata from dictionary attribute.
97   for (auto& kv : dict_attr) {
98     const mlir::Identifier name = kv.first;
99     const mlir::Attribute attr = kv.second;
100 
101     if (!DeviceNameUtils::ParseFullName(name.str(), &device))
102       return op->emitOpError(
103           llvm::formatv("bad '{0}' attribute, '{1}', not a valid device",
104                         kDevicesAttr, name.strref()));
105 
106     if (auto gpu_metadata = attr.dyn_cast<mlir::TF::GpuDeviceMetadata>()) {
107       devices->AddGpuDevice(device, gpu_metadata);
108     } else {
109       devices->AddDevice(device);
110     }
111   }
112 
113   return mlir::success();
114 }
115 
116 }  // namespace
117 
AddDevicesToOp(mlir::Operation * op,const DeviceSet * device_set)118 void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) {
119   if (!device_set) return;
120 
121   mlir::MLIRContext* ctx = op->getContext();
122   mlir::Builder builder(ctx);
123 
124   // Collect devices with attached metadata.
125   llvm::SmallVector<mlir::NamedAttribute, 8> devices;
126   devices.reserve(device_set->devices().size());
127 
128   // For device that do not have any metadata, or if we failed to parse metadata
129   // from the DeviceSet, we add empty dictionary to the `tf.devices` attribute.
130   for (Device* device : device_set->devices()) {
131     string name = DeviceNameUtils::ParsedNameToString(device->parsed_name());
132 
133     if (device->device_type() == DEVICE_GPU) {
134       auto metadata = ParseGpuDeviceMetadata(*device, &builder);
135       devices.push_back(builder.getNamedAttr(name, metadata));
136     } else {
137       auto metadata = builder.getDictionaryAttr({});
138       devices.push_back(builder.getNamedAttr(name, metadata));
139     }
140   }
141 
142   op->setAttr(kDevicesAttr, builder.getDictionaryAttr(devices));
143 }
144 
GetDevicesFromOp(mlir::Operation * op,mlir::TF::RuntimeDevices * devices)145 mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op,
146                                      mlir::TF::RuntimeDevices* devices) {
147   auto devices_attr = op->getAttr(kDevicesAttr);
148   if (!devices_attr) return mlir::success();
149 
150   if (auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>()) {
151     return GetDevicesFromOp(op, array_attr, devices);
152 
153   } else if (auto dict_attr = devices_attr.dyn_cast<mlir::DictionaryAttr>()) {
154     return GetDevicesFromOp(op, dict_attr, devices);
155   }
156 
157   return op->emitOpError(
158       llvm::formatv("unsupported '{0}' attribute", kDevicesAttr));
159 }
160 
GetDeviceOrdinalFromDeviceString(mlir::Location loc,llvm::StringRef device,int64_t * device_ordinal)161 mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc,
162                                                      llvm::StringRef device,
163                                                      int64_t* device_ordinal) {
164   DeviceNameUtils::ParsedName parsed_name;
165   if (!DeviceNameUtils::ParseFullName(
166           absl::string_view(device.data(), device.size()), &parsed_name))
167     return mlir::emitError(loc) << "invalid device '" << device << "'";
168 
169   if (!parsed_name.has_id)
170     return mlir::emitError(loc) << "device '" << device << "' has no id";
171 
172   *device_ordinal = parsed_name.id;
173   return mlir::success();
174 }
175 
176 }  // namespace tensorflow
177