1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTHON_REF_MANAGER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTHON_REF_MANAGER_H_ 18 19 #include <deque> 20 21 #include "absl/base/thread_annotations.h" 22 #include "absl/container/inlined_vector.h" 23 #include "absl/synchronization/mutex.h" 24 #include "absl/types/span.h" 25 #include "pybind11/pybind11.h" 26 27 namespace xla { 28 29 // Class that manages destruction of Python objects. 30 // 31 // We must not destroy Python objects without holding the GIL. However, we 32 // frequently want to hold references to Python objects for the duration of 33 // an asynchronous transfer on a Stream, and release our reference when the 34 // transfer completes. 35 // 36 // This class holds references to Python objects outside a GIL scope, that can 37 // be collected later when the GIL is held by calling CollectGarbage(). 38 class PythonRefManager { 39 public: 40 PythonRefManager() = default; 41 42 // Holds references to a set of pybind11::objects, adding the references to 43 // the PythonRefManager on destruction. 44 class ManagedPyObjects { 45 public: 46 ManagedPyObjects() = default; 47 ManagedPyObjects(PythonRefManager* manager, 48 absl::Span<pybind11::object> objects); 49 50 ~ManagedPyObjects(); 51 52 ManagedPyObjects(const ManagedPyObjects& other) = delete; 53 ManagedPyObjects(ManagedPyObjects&& other) = default; 54 ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; 55 ManagedPyObjects& operator=(ManagedPyObjects&& other) = default; 56 57 private: 58 PythonRefManager* manager_ = nullptr; 59 absl::InlinedVector<pybind11::object, 1> objects_; 60 }; 61 62 // Creates a managed std::shared_ptr to an object. When the shared_ptr is 63 // destroyed, the reference to 'object' will be added to python_garbage_, 64 // and collected next time CollectGarbage() is called. 65 std::shared_ptr<ManagedPyObjects> ManageReference(pybind11::object object); 66 std::shared_ptr<ManagedPyObjects> ManageReferences( 67 absl::Span<pybind11::object> objects); 68 69 // Adds garbage objects to the manager. 70 void AddGarbage(absl::Span<pybind11::object> garbage); 71 void AddGarbage(absl::Span<std::pair<PyCodeObject*, int> const> garbage); 72 73 // Releases the contents of python_garbage_. Requires that the GIL is held. 74 // The client calls this method during API entry points where the GIL is held 75 // to free any garbage that has accumulated. 76 void CollectGarbage(); 77 78 private: 79 absl::Mutex mu_; 80 std::deque<pybind11::object> python_garbage_ ABSL_GUARDED_BY(mu_); 81 }; 82 83 // A global PythonRefManager. Unless `CollectGarbage()` is called before 84 // shutdown, this container will hold on to Python objects and thus cause a 85 // leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. 86 PythonRefManager* GlobalPyRefManager(); 87 88 } // namespace xla 89 90 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PYTHON_REF_MANAGER_H_ 91