1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "llvm/ADT/None.h" 17 #include "llvm/Support/raw_ostream.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project 19 #include "mlir/IR/Attributes.h" // from @llvm-project 20 #include "mlir/IR/Builders.h" // from @llvm-project 21 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 22 #include "mlir/IR/Operation.h" // from @llvm-project 23 #include "mlir/IR/PatternMatch.h" // from @llvm-project 24 #include "mlir/Pass/Pass.h" // from @llvm-project 25 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 26 #include "tensorflow/compiler/mlir/lite/transforms/passes.h" 27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" 29 30 namespace mlir { 31 namespace TFL { 32 namespace { 33 // Pass which removes any unused bounded function arguments which maps to 34 // variables, also removes the GlobalTensor which is the variable. 35 class RemoveArgsAndGlobalTensors 36 : public PassWrapper<RemoveArgsAndGlobalTensors, OperationPass<ModuleOp>> { 37 public: 38 RemoveArgsAndGlobalTensors() = default; RemoveArgsAndGlobalTensors(const RemoveArgsAndGlobalTensors &)39 RemoveArgsAndGlobalTensors(const RemoveArgsAndGlobalTensors&) {} 40 runOnOperation()41 void runOnOperation() override { 42 auto module = getOperation(); 43 SymbolTable symbol_table(module); 44 45 // Remove unused arguments in the functions which are bounded input 46 // for a global tensor. Also, removes the now unused global tensors. 47 std::set<mlir::tf_saved_model::GlobalTensorOp> global_tensors_to_remove; 48 for (auto func : module.getOps<FuncOp>()) { 49 llvm::SmallVector<unsigned int> index_to_remove; 50 for (int i = 0; i < func.getNumArguments(); ++i) { 51 if (auto sym = func.template getArgAttrOfType<FlatSymbolRefAttr>( 52 i, "tf_saved_model.bound_input")) { 53 auto global_tensor = 54 symbol_table.lookup<tf_saved_model::GlobalTensorOp>( 55 sym.getValue()); 56 if (global_tensor && func.getArgument(i).getUsers().empty()) { 57 index_to_remove.push_back(i); 58 global_tensors_to_remove.insert(global_tensor); 59 } 60 } 61 } 62 func.eraseArguments(index_to_remove); 63 } 64 for (auto global_tensor : global_tensors_to_remove) { 65 global_tensor->erase(); 66 } 67 } 68 }; 69 70 } // namespace 71 CreateRemoveArgsAndGlobalTensors()72std::unique_ptr<OperationPass<ModuleOp>> CreateRemoveArgsAndGlobalTensors() { 73 return std::make_unique<RemoveArgsAndGlobalTensors>(); 74 } 75 76 static PassRegistration<RemoveArgsAndGlobalTensors> pass( 77 "tfl-remove-unused-function-args", 78 "Removes unused bounded input arguments to function which are unused and " 79 "maps to GlobalTensor."); 80 81 } // namespace TFL 82 } // namespace mlir 83