1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/utils/utils.h"
17
18 #include <string>
19
20 #include "absl/strings/match.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/OpDefinition.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/OperationSupport.h" // from @llvm-project
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/ir/dialect.h"
29 #include "tensorflow/core/ir/interfaces.h"
30 #include "tensorflow/core/ir/tf_op_wrapper.h"
31 #include "tensorflow/core/ir/utility.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33
34 namespace mlir {
35 namespace tfg {
36 namespace util {
37
OpHasDevice(Operation * op,const char * device_name)38 bool OpHasDevice(Operation *op, const char *device_name) {
39 std::string task, device;
40 return tensorflow::DeviceNameUtils::SplitDeviceName(TFOp(op).device().data(),
41 &task, &device) &&
42 absl::StartsWith(device, device_name);
43 }
44
EraseRegularNodeAttributes(NamedAttrList & attr_list)45 void EraseRegularNodeAttributes(NamedAttrList &attr_list) {
46 NamedAttrList new_attr_list;
47 for (NamedAttribute attr : attr_list) {
48 if (attr.getName().strref().startswith("_")) new_attr_list.append(attr);
49 }
50 attr_list = new_attr_list;
51 }
52
ForwardNonIntrinsicAttributes(Operation * src,Operation * dst)53 void ForwardNonIntrinsicAttributes(Operation *src, Operation *dst) {
54 NamedAttrList dst_attrs(dst->getAttrDictionary());
55 DenseSet<StringAttr> name_set;
56
57 // Forward all non-intrinsic attributes. If the source op is unregistered,
58 // forward all attributes.
59 if (Optional<RegisteredOperationName> src_name = src->getRegisteredInfo()) {
60 ArrayRef<StringAttr> src_attr_names = src_name->getAttributeNames();
61 name_set.insert(src_attr_names.begin(), src_attr_names.end());
62 }
63 for (const NamedAttribute &attr : src->getAttrs()) {
64 if (!name_set.contains(attr.getName())) dst_attrs.append(attr);
65 }
66
67 dst->setAttrs(dst_attrs.getDictionary(dst->getContext()));
68 }
69
UpdateIfPresent(Region & region,function_ref<RegionAttr (RegionAttr)> copy_update)70 static void UpdateIfPresent(Region ®ion,
71 function_ref<RegionAttr(RegionAttr)> copy_update) {
72 unsigned index = region.getRegionNumber();
73 auto iface = cast<PreservedAttributesInterface>(region.getParentOp());
74 if (auto attrs = iface.getPreservedAttrs(index))
75 iface.setPreservedAttrs(index, copy_update(attrs));
76 }
77
UpdateArgAttrsIfPresent(Region & region,function_ref<void (SmallVectorImpl<Attribute> &)> update)78 static void UpdateArgAttrsIfPresent(
79 Region ®ion, function_ref<void(SmallVectorImpl<Attribute> &)> update) {
80 UpdateIfPresent(region, [&](RegionAttr attrs) {
81 SmallVector<Attribute> args = llvm::to_vector(attrs.getArgAttrs());
82 update(args);
83 return RegionAttr::get(attrs.getAttrs(),
84 ArrayAttr::get(attrs.getContext(), args),
85 attrs.getResAttrs());
86 });
87 }
88
UpdateResultAttrsIfPresent(Region & region,function_ref<void (SmallVectorImpl<Attribute> &)> update)89 static void UpdateResultAttrsIfPresent(
90 Region ®ion, function_ref<void(SmallVectorImpl<Attribute> &)> update) {
91 UpdateIfPresent(region, [&](RegionAttr attrs) {
92 SmallVector<Attribute> results = llvm::to_vector(attrs.getResAttrs());
93 update(results);
94 return RegionAttr::get(attrs.getAttrs(), attrs.getArgAttrs(),
95 ArrayAttr::get(attrs.getContext(), results));
96 });
97 }
98
LoopRegionAddArgument(Region & region,Type type)99 LoopRegionArgumentUpdate LoopRegionAddArgument(Region ®ion, Type type) {
100 // Add the arguments.
101 BlockArgument data = region.insertArgument(
102 GetLoopRegionDataArgs(region).size(), type, region.getLoc());
103 BlockArgument ctl =
104 region.addArgument(ControlType::get(type.getContext()), region.getLoc());
105
106 UpdateArgAttrsIfPresent(region, [&](SmallVectorImpl<Attribute> &arg_attrs) {
107 arg_attrs.push_back(DictionaryAttr::get(type.getContext(), {}));
108 });
109
110 return {data, ctl};
111 }
112
LoopRegionEraseArgument(Region & region,unsigned index)113 void LoopRegionEraseArgument(Region ®ion, unsigned index) {
114 Block::BlockArgListType args = GetLoopRegionDataArgs(region);
115 assert(index < args.size());
116
117 // Erase the arguments.
118 SmallVector<unsigned, 2> indices;
119 indices.push_back(args[index].getArgNumber());
120 indices.push_back(GetLoopRegionControlOf(args[index]).getArgNumber());
121 region.front().eraseArguments(indices);
122
123 UpdateArgAttrsIfPresent(region, [&](SmallVectorImpl<Attribute> &arg_attrs) {
124 arg_attrs.erase(arg_attrs.begin() + index);
125 });
126 }
127
LoopRegionResultAdded(Region & region,unsigned num)128 void LoopRegionResultAdded(Region ®ion, unsigned num) {
129 UpdateResultAttrsIfPresent(
130 region, [&](SmallVectorImpl<Attribute> &res_attrs) {
131 res_attrs.append(num, DictionaryAttr::get(region.getContext(), {}));
132 });
133 }
134
LoopRegionResultErased(Region & region,unsigned index)135 void LoopRegionResultErased(Region ®ion, unsigned index) {
136 UpdateResultAttrsIfPresent(region,
137 [&](SmallVectorImpl<Attribute> &res_attrs) {
138 res_attrs.erase(res_attrs.begin() + index);
139 });
140 }
141
SizedOperandSegmentsEraseOperands(Operation * op,ArrayRef<unsigned> indices)142 void SizedOperandSegmentsEraseOperands(Operation *op,
143 ArrayRef<unsigned> indices) {
144 llvm::BitVector erase(op->getNumOperands());
145 for (unsigned index : indices) erase.set(index);
146 SizedOperandSegmentsEraseOperands(op, erase);
147 }
148
SizedOperandSegmentsEraseOperands(Operation * op,const llvm::BitVector & erase)149 void SizedOperandSegmentsEraseOperands(Operation *op,
150 const llvm::BitVector &erase) {
151 // Update the segment sizes if present.
152 Builder b(op->getContext());
153 StringAttr attr_name = b.getStringAttr(
154 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr());
155 auto segment_sizes = op->getAttrOfType<DenseI32ArrayAttr>(attr_name);
156 if (segment_sizes) {
157 auto values = segment_sizes.asArrayRef();
158 SmallVector<int32_t> new_sizes = llvm::to_vector(values);
159
160 unsigned base = 0;
161 for (auto it : llvm::zip(values, new_sizes)) {
162 int32_t size = std::get<0>(it);
163 int32_t &new_size = std::get<1>(it);
164 for (int32_t i = 0; i < size; ++i)
165 if (erase.test(base + i)) --new_size;
166 base += size;
167 }
168 assert(llvm::all_of(new_sizes, [](int32_t size) { return size >= 0; }));
169 assert(std::accumulate(new_sizes.begin(), new_sizes.end(), 0) ==
170 op->getNumOperands() - erase.count());
171 segment_sizes = b.getDenseI32ArrayAttr(new_sizes);
172 }
173
174 op->eraseOperands(erase);
175 if (segment_sizes) op->setAttr(attr_name, segment_sizes);
176 }
177
178 } // namespace util
179 } // namespace tfg
180 } // namespace mlir
181