1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ 17 18 #include <map> 19 #include <vector> 20 21 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 22 23 namespace tensorflow { 24 25 // Eager op rewrites should inherit from this class and 26 // implement the Run method. 27 class EagerOpRewrite { 28 public: EagerOpRewrite(string name,string file,string line)29 EagerOpRewrite(string name, string file, string line) { 30 debug_info_.name = name; 31 debug_info_.file = file; 32 debug_info_.line = line; 33 } 34 ~EagerOpRewrite()35 virtual ~EagerOpRewrite() {} 36 37 // To be implemented by an Eager op rewrite pass. 38 virtual Status Run(EagerOperation* orig_op, 39 std::unique_ptr<tensorflow::EagerOperation>* out_op) = 0; 40 41 // Holds information about the rewrite registration. 42 struct DebugInfo { 43 string name, file, line; 44 }; 45 46 // Returns information about the registered Eager op rewrite. GetDebugInfo()47 DebugInfo GetDebugInfo() const { return debug_info_; } 48 49 private: 50 DebugInfo debug_info_; 51 }; 52 53 class EagerOpRewriteRegistry { 54 public: 55 // Phases at which the Eager op rewrite pass should run. 56 // For now we only added PRE_EXECUTION. Expand as needed. 57 enum Phase { 58 PRE_EXECUTION = 0, // right before executing an eager op 59 POST_PLACEMENT = 1 // after device placement 60 }; 61 62 // Add a rewrite pass to the registry. 63 // Only one rewrite pass is allowed per phase. 64 void Register(Phase phase, std::unique_ptr<EagerOpRewrite> pass); 65 66 // Run the rewrite pass registered for a given phase. 67 Status RunRewrite(Phase phase, EagerOperation* orig_op, 68 std::unique_ptr<tensorflow::EagerOperation>* out_op); 69 70 // Returns the global registry of rewrite passes. 71 static EagerOpRewriteRegistry* Global(); 72 73 private: 74 static constexpr int32 kNumPhases = 2; 75 // Holds all the registered Eager op rewrites. 76 std::array<std::unique_ptr<EagerOpRewrite>, kNumPhases> rewrites_; 77 }; 78 79 namespace eager_rewrite_registration { 80 81 // This class is used to register a new Eager Op rewrite. 82 class EagerRewriteRegistration { 83 public: EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase,std::unique_ptr<EagerOpRewrite> pass)84 EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase, 85 std::unique_ptr<EagerOpRewrite> pass) { 86 EagerOpRewriteRegistry::Global()->Register(phase, std::move(pass)); 87 } 88 }; 89 90 } // namespace eager_rewrite_registration 91 92 #define REGISTER_REWRITE(phase, rewrite) \ 93 REGISTER_REWRITE_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, phase, rewrite) 94 95 #define REGISTER_REWRITE_UNIQ_HELPER(ctr, file, line, phase, rewrite) \ 96 REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite) 97 98 #define REGISTER_REWRITE_UNIQ(ctr, file, line, phase, rewrite) \ 99 static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration \ 100 register_rewrite_##ctr(phase, \ 101 ::std::unique_ptr<::tensorflow::EagerOpRewrite>( \ 102 new rewrite(#rewrite, file, #line))) 103 104 } // namespace tensorflow 105 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ 106