1 /* Copyright 2019 The TensorFlow Authors. Al Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_ 17 18 #include <string> 19 20 #include "tensorflow/core/lib/gtl/flatset.h" 21 #include "tensorflow/core/platform/types.h" 22 23 namespace tensorflow { 24 25 // TensorFlow runtime (both eager and graph) will aim to colocate ops with 26 // their resource inputs so that the ops can access the resource state. In some 27 // cases, such as tf.data ops, this is not desirable as the ops themselves might 28 // not have a kernel registered for the device on which the resource is placed 29 // and instead use a mechanism, such as a multi-device function, to access the 30 // resource state. 31 // 32 // This registry can be used to register and list ops that should be exempt from 33 // the input colocation described above. 34 // 35 // Example usage: 36 // REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset"); 37 class InputColocationExemptionRegistry { 38 public: 39 // Returns a pointer to a global InputColocationExemptionRegistry object. 40 static InputColocationExemptionRegistry* Global(); 41 42 // Returns the set of ops exempt from the input colocation constraints. Get()43 const gtl::FlatSet<string>& Get() { return ops_; } 44 45 // Registers an op to be excluded from the input colocation constraints. 46 void Register(const string& op); 47 48 private: 49 gtl::FlatSet<string> ops_; 50 }; 51 52 namespace input_colocation_exemption_registration { 53 54 class InputColocationExemptionRegistration { 55 public: InputColocationExemptionRegistration(const string & op)56 explicit InputColocationExemptionRegistration(const string& op) { 57 InputColocationExemptionRegistry::Global()->Register(op); 58 } 59 }; 60 61 } // namespace input_colocation_exemption_registration 62 63 #define REGISTER_INPUT_COLOCATION_EXEMPTION(op) \ 64 REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(__COUNTER__, op) 65 66 #define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(ctr, op) \ 67 REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) 68 69 #define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) \ 70 static input_colocation_exemption_registration:: \ 71 InputColocationExemptionRegistration \ 72 input_colocation_exemption_registration_fn_##ctr(op) 73 74 } // namespace tensorflow 75 76 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_ 77