• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/utils/utils.h"
17 
18 #include <string>
19 
20 #include "absl/strings/match.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/ir/dialect.h"
29 #include "tensorflow/core/ir/interfaces.h"
30 #include "tensorflow/core/ir/tf_op_wrapper.h"
31 #include "tensorflow/core/ir/utility.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33 
34 namespace mlir {
35 namespace tfg {
36 namespace util {
37 
OpHasDevice(Operation * op,const char * device_name)38 bool OpHasDevice(Operation *op, const char *device_name) {
39   std::string task, device;
40   return tensorflow::DeviceNameUtils::SplitDeviceName(TFOp(op).device().data(),
41                                                       &task, &device) &&
42          absl::StartsWith(device, device_name);
43 }
44 
EraseRegularNodeAttributes(NamedAttrList & attr_list)45 void EraseRegularNodeAttributes(NamedAttrList &attr_list) {
46   NamedAttrList new_attr_list;
47   for (NamedAttribute attr : attr_list) {
48     if (attr.getName().strref().startswith("_")) new_attr_list.append(attr);
49   }
50   attr_list = new_attr_list;
51 }
52 
ForwardNonIntrinsicAttributes(Operation * src,Operation * dst)53 void ForwardNonIntrinsicAttributes(Operation *src, Operation *dst) {
54   NamedAttrList dst_attrs(dst->getAttrDictionary());
55   DenseSet<StringAttr> name_set;
56 
57   // Forward all non-intrinsic attributes. If the source op is unregistered,
58   // forward all attributes.
59   if (Optional<RegisteredOperationName> src_name = src->getRegisteredInfo()) {
60     ArrayRef<StringAttr> src_attr_names = src_name->getAttributeNames();
61     name_set.insert(src_attr_names.begin(), src_attr_names.end());
62   }
63   for (const NamedAttribute &attr : src->getAttrs()) {
64     if (!name_set.contains(attr.getName())) dst_attrs.append(attr);
65   }
66 
67   dst->setAttrs(dst_attrs.getDictionary(dst->getContext()));
68 }
69 
UpdateIfPresent(Region & region,function_ref<RegionAttr (RegionAttr)> copy_update)70 static void UpdateIfPresent(Region &region,
71                             function_ref<RegionAttr(RegionAttr)> copy_update) {
72   unsigned index = region.getRegionNumber();
73   auto iface = cast<PreservedAttributesInterface>(region.getParentOp());
74   if (auto attrs = iface.getPreservedAttrs(index))
75     iface.setPreservedAttrs(index, copy_update(attrs));
76 }
77 
UpdateArgAttrsIfPresent(Region & region,function_ref<void (SmallVectorImpl<Attribute> &)> update)78 static void UpdateArgAttrsIfPresent(
79     Region &region, function_ref<void(SmallVectorImpl<Attribute> &)> update) {
80   UpdateIfPresent(region, [&](RegionAttr attrs) {
81     SmallVector<Attribute> args = llvm::to_vector(attrs.getArgAttrs());
82     update(args);
83     return RegionAttr::get(attrs.getAttrs(),
84                            ArrayAttr::get(attrs.getContext(), args),
85                            attrs.getResAttrs());
86   });
87 }
88 
UpdateResultAttrsIfPresent(Region & region,function_ref<void (SmallVectorImpl<Attribute> &)> update)89 static void UpdateResultAttrsIfPresent(
90     Region &region, function_ref<void(SmallVectorImpl<Attribute> &)> update) {
91   UpdateIfPresent(region, [&](RegionAttr attrs) {
92     SmallVector<Attribute> results = llvm::to_vector(attrs.getResAttrs());
93     update(results);
94     return RegionAttr::get(attrs.getAttrs(), attrs.getArgAttrs(),
95                            ArrayAttr::get(attrs.getContext(), results));
96   });
97 }
98 
LoopRegionAddArgument(Region & region,Type type)99 LoopRegionArgumentUpdate LoopRegionAddArgument(Region &region, Type type) {
100   // Add the arguments.
101   BlockArgument data = region.insertArgument(
102       GetLoopRegionDataArgs(region).size(), type, region.getLoc());
103   BlockArgument ctl =
104       region.addArgument(ControlType::get(type.getContext()), region.getLoc());
105 
106   UpdateArgAttrsIfPresent(region, [&](SmallVectorImpl<Attribute> &arg_attrs) {
107     arg_attrs.push_back(DictionaryAttr::get(type.getContext(), {}));
108   });
109 
110   return {data, ctl};
111 }
112 
LoopRegionEraseArgument(Region & region,unsigned index)113 void LoopRegionEraseArgument(Region &region, unsigned index) {
114   Block::BlockArgListType args = GetLoopRegionDataArgs(region);
115   assert(index < args.size());
116 
117   // Erase the arguments.
118   SmallVector<unsigned, 2> indices;
119   indices.push_back(args[index].getArgNumber());
120   indices.push_back(GetLoopRegionControlOf(args[index]).getArgNumber());
121   region.front().eraseArguments(indices);
122 
123   UpdateArgAttrsIfPresent(region, [&](SmallVectorImpl<Attribute> &arg_attrs) {
124     arg_attrs.erase(arg_attrs.begin() + index);
125   });
126 }
127 
LoopRegionResultAdded(Region & region,unsigned num)128 void LoopRegionResultAdded(Region &region, unsigned num) {
129   UpdateResultAttrsIfPresent(
130       region, [&](SmallVectorImpl<Attribute> &res_attrs) {
131         res_attrs.append(num, DictionaryAttr::get(region.getContext(), {}));
132       });
133 }
134 
LoopRegionResultErased(Region & region,unsigned index)135 void LoopRegionResultErased(Region &region, unsigned index) {
136   UpdateResultAttrsIfPresent(region,
137                              [&](SmallVectorImpl<Attribute> &res_attrs) {
138                                res_attrs.erase(res_attrs.begin() + index);
139                              });
140 }
141 
SizedOperandSegmentsEraseOperands(Operation * op,ArrayRef<unsigned> indices)142 void SizedOperandSegmentsEraseOperands(Operation *op,
143                                        ArrayRef<unsigned> indices) {
144   llvm::BitVector erase(op->getNumOperands());
145   for (unsigned index : indices) erase.set(index);
146   SizedOperandSegmentsEraseOperands(op, erase);
147 }
148 
SizedOperandSegmentsEraseOperands(Operation * op,const llvm::BitVector & erase)149 void SizedOperandSegmentsEraseOperands(Operation *op,
150                                        const llvm::BitVector &erase) {
151   // Update the segment sizes if present.
152   Builder b(op->getContext());
153   StringAttr attr_name = b.getStringAttr(
154       OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr());
155   auto segment_sizes = op->getAttrOfType<DenseI32ArrayAttr>(attr_name);
156   if (segment_sizes) {
157     auto values = segment_sizes.asArrayRef();
158     SmallVector<int32_t> new_sizes = llvm::to_vector(values);
159 
160     unsigned base = 0;
161     for (auto it : llvm::zip(values, new_sizes)) {
162       int32_t size = std::get<0>(it);
163       int32_t &new_size = std::get<1>(it);
164       for (int32_t i = 0; i < size; ++i)
165         if (erase.test(base + i)) --new_size;
166       base += size;
167     }
168     assert(llvm::all_of(new_sizes, [](int32_t size) { return size >= 0; }));
169     assert(std::accumulate(new_sizes.begin(), new_sizes.end(), 0) ==
170            op->getNumOperands() - erase.count());
171     segment_sizes = b.getDenseI32ArrayAttr(new_sizes);
172   }
173 
174   op->eraseOperands(erase);
175   if (segment_sizes) op->setAttr(attr_name, segment_sizes);
176 }
177 
178 }  // namespace util
179 }  // namespace tfg
180 }  // namespace mlir
181