• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTHON_REF_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTHON_REF_MANAGER_H_
18 
19 #include <deque>
20 
21 #include "absl/base/thread_annotations.h"
22 #include "absl/container/inlined_vector.h"
23 #include "absl/synchronization/mutex.h"
24 #include "absl/types/span.h"
25 #include "pybind11/pybind11.h"
26 
27 namespace xla {
28 
29 // Class that manages destruction of Python objects.
30 //
31 // We must not destroy Python objects without holding the GIL. However, we
32 // frequently want to hold references to Python objects for the duration of
33 // an asynchronous transfer on a Stream, and release our reference when the
34 // transfer completes.
35 //
36 // This class holds references to Python objects outside a GIL scope, that can
37 // be collected later when the GIL is held by calling CollectGarbage().
38 class PythonRefManager {
39  public:
40   PythonRefManager() = default;
41 
42   // Holds references to a set of pybind11::objects, adding the references to
43   // the PythonRefManager on destruction.
44   class ManagedPyObjects {
45    public:
46     ManagedPyObjects() = default;
47     ManagedPyObjects(PythonRefManager* manager,
48                      absl::Span<pybind11::object> objects);
49 
50     ~ManagedPyObjects();
51 
52     ManagedPyObjects(const ManagedPyObjects& other) = delete;
53     ManagedPyObjects(ManagedPyObjects&& other) = default;
54     ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete;
55     ManagedPyObjects& operator=(ManagedPyObjects&& other) = default;
56 
57    private:
58     PythonRefManager* manager_ = nullptr;
59     absl::InlinedVector<pybind11::object, 1> objects_;
60   };
61 
62   // Creates a managed std::shared_ptr to an object. When the shared_ptr is
63   // destroyed, the reference to 'object' will be added to python_garbage_,
64   // and collected next time CollectGarbage() is called.
65   std::shared_ptr<ManagedPyObjects> ManageReference(pybind11::object object);
66   std::shared_ptr<ManagedPyObjects> ManageReferences(
67       absl::Span<pybind11::object> objects);
68 
69   // Adds garbage objects to the manager.
70   void AddGarbage(absl::Span<pybind11::object> garbage);
71   void AddGarbage(absl::Span<std::pair<PyCodeObject*, int> const> garbage);
72 
73   // Releases the contents of python_garbage_. Requires that the GIL is held.
74   // The client calls this method during API entry points where the GIL is held
75   // to free any garbage that has accumulated.
76   void CollectGarbage();
77 
78  private:
79   absl::Mutex mu_;
80   std::deque<pybind11::object> python_garbage_ ABSL_GUARDED_BY(mu_);
81 };
82 
83 // A global PythonRefManager. Unless `CollectGarbage()` is called before
84 // shutdown, this container will hold on to Python objects and thus cause a
85 // leak. This behavior is similar to `tensorflow::ClearDecRefCache()`.
86 PythonRefManager* GlobalPyRefManager();
87 
88 }  // namespace xla
89 
90 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PYTHON_REF_MANAGER_H_
91