Skip to content

Commit 1c5c19d

Browse files
authored
weakreaf alternative (#55)
1 parent 21d26cc commit 1c5c19d

File tree

2 files changed

+267
-4
lines changed

2 files changed

+267
-4
lines changed

deepview_profile/tracking/memory/weights.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import torch
2-
import weakref
2+
import inspect
33

44
from deepview_profile.tracking.base import TrackerBase
55
from deepview_profile.tracking.call_stack import CallStack
66
from deepview_profile.tracking.hook_manager import HookManager
77
from deepview_profile.tracking.utils import tensor_size_bytes
8-
8+
from deepview_profile.util_weak import WeakTensorKeyDictionary
99

1010
class WeightsTracker(TrackerBase):
1111
def __init__(self, project_root):
1212
super().__init__()
1313
self._hook_manager = HookManager()
14-
self._module_parameters = weakref.WeakKeyDictionary()
14+
self._module_parameters = WeakTensorKeyDictionary()
1515
self._project_root = project_root
1616

1717
def start_tracking(self):
@@ -47,7 +47,7 @@ def hook(*args, **kwargs):
4747
name = args[1]
4848
parameter = args[2]
4949
retval = func(*args, **kwargs)
50-
if parameter is not None:
50+
if parameter is not None and parameter not in self._module_parameters:
5151
self._module_parameters[parameter] = (
5252
name,
5353
CallStack.from_here(self._project_root, start_from=2),

deepview_profile/util_weak.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import weakref
2+
from weakref import ref
3+
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
4+
from collections.abc import MutableMapping, Mapping
5+
from typing import Dict
6+
import collections.abc as _collections_abc
7+
8+
9+
__all__ = ['WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
10+
11+
12+
# This file defines a variant of WeakKeyDictionary that overrides the hashing
13+
# behavior of the key to use object identity, rather than the builtin
14+
# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
15+
# __eq__ implementation return a Tensor (elementwise equality), which means
16+
# you can't use them directly with the WeakKeyDictionary in standard library.
17+
#
18+
# Our implementation strategy is to create a wrapper weak key object, which we
19+
# use as a key in a stock Python dictionary. This is similar to how weakref
20+
# implements WeakKeyDictionary, but instead of using weakref.ref as the
21+
# wrapper, we use a custom wrapper that has different __eq__ and __hash__
22+
# behavior. Note that we subsequently store this weak key directly in an
23+
# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
24+
# be a dictionary so it would have no strong references. Ensuring that
25+
# only live WeakIdKeys are in the map is handled by putting finalizers on the
26+
# original key object.
27+
28+
29+
# It is simpler to implement this with composition, but if we want to
30+
# directly reuse the callback mechanism on weakref, we need the weakref
31+
# and the key to be exactly the same object. Reusing the callback mechanism
32+
# minimizes the divergence between our implementation and Lib/weakref.py
33+
#
34+
# NB: Prefer using this when working with weakrefs of Tensors; e.g., do
35+
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
36+
# easy to get wrong cases transparently for you.
37+
class WeakIdRef(weakref.ref):
38+
__slots__ = ['_id']
39+
40+
def __init__(self, key, callback=None):
41+
# Unlike stock weakref, which preserves hash semantics of the
42+
# original object but lazily defers hash calls until the first
43+
# time the user attempts to hash the weakref, we can eagerly
44+
# cache the id of the key as we know this is definitely the hash
45+
# method
46+
self._id = id(key)
47+
super().__init__(key, callback)
48+
49+
def __call__(self):
50+
r = super().__call__()
51+
# Special logic for Tensor PyObject resurrection
52+
if hasattr(r, '_fix_weakref'):
53+
r._fix_weakref() # type: ignore[union-attr]
54+
return r
55+
56+
def __hash__(self):
57+
return self._id
58+
59+
def __eq__(self, other):
60+
# An attractive but wrong alternate implementation is to only test if
61+
# the stored _ids match. This can lead to an ABA problem if you have:
62+
#
63+
# a1 = A()
64+
# w1 = WeakIdRef(a)
65+
# del a1
66+
# a2 = A() # suppose it gets the same ID as a1
67+
# w2 = WeakIdRef(a2)
68+
# print(w1 == w2)
69+
#
70+
# This should be False, as a1 and a2 are unrelated (and a1 is
71+
# dead anyway)
72+
a = self()
73+
b = other()
74+
if a is not None and b is not None:
75+
return a is b
76+
return self is other
77+
78+
# This is directly adapted from cpython/Lib/weakref.py
79+
class WeakIdKeyDictionary(MutableMapping):
80+
data: Dict[WeakIdRef, object]
81+
82+
def __init__(self, dict=None):
83+
self.data = {}
84+
85+
def remove(k, selfref=ref(self)):
86+
self = selfref()
87+
if self is not None:
88+
if self._iterating:
89+
self._pending_removals.append(k)
90+
else:
91+
try:
92+
del self.data[k]
93+
except KeyError:
94+
pass
95+
self._remove = remove
96+
# A list of dead weakrefs (keys to be removed)
97+
self._pending_removals = []
98+
self._iterating = set()
99+
self._dirty_len = False
100+
if dict is not None:
101+
self.update(dict)
102+
103+
def _commit_removals(self):
104+
# NOTE: We don't need to call this method before mutating the dict,
105+
# because a dead weakref never compares equal to a live weakref,
106+
# even if they happened to refer to equal objects.
107+
# However, it means keys may already have been removed.
108+
pop = self._pending_removals.pop
109+
d = self.data
110+
while True:
111+
try:
112+
key = pop()
113+
except IndexError:
114+
return
115+
116+
try:
117+
del d[key]
118+
except KeyError:
119+
pass
120+
121+
def _scrub_removals(self):
122+
d = self.data
123+
self._pending_removals = [k for k in self._pending_removals if k in d]
124+
self._dirty_len = False
125+
126+
def __delitem__(self, key):
127+
self._dirty_len = True
128+
del self.data[WeakIdRef(key)] # CHANGED
129+
130+
def __getitem__(self, key):
131+
return self.data[WeakIdRef(key)] # CHANGED
132+
133+
def __len__(self):
134+
if self._dirty_len and self._pending_removals:
135+
# self._pending_removals may still contain keys which were
136+
# explicitly removed, we have to scrub them (see issue #21173).
137+
self._scrub_removals()
138+
return len(self.data) - len(self._pending_removals)
139+
140+
def __repr__(self):
141+
return "<%s at %#x>" % (self.__class__.__name__, id(self))
142+
143+
def __setitem__(self, key, value):
144+
self.data[WeakIdRef(key, self._remove)] = value # CHANGED
145+
146+
def copy(self):
147+
new = WeakIdKeyDictionary()
148+
with _IterationGuard(self):
149+
for key, value in self.data.items():
150+
o = key()
151+
if o is not None:
152+
new[o] = value
153+
return new
154+
155+
__copy__ = copy
156+
157+
def __deepcopy__(self, memo):
158+
from copy import deepcopy
159+
new = self.__class__()
160+
with _IterationGuard(self):
161+
for key, value in self.data.items():
162+
o = key()
163+
if o is not None:
164+
new[o] = deepcopy(value, memo)
165+
return new
166+
167+
def get(self, key, default=None):
168+
return self.data.get(WeakIdRef(key), default) # CHANGED
169+
170+
def __contains__(self, key):
171+
try:
172+
wr = WeakIdRef(key)
173+
except TypeError:
174+
return False
175+
return wr in self.data
176+
177+
def items(self):
178+
with _IterationGuard(self):
179+
for wr, value in self.data.items():
180+
key = wr()
181+
if key is not None:
182+
yield key, value
183+
184+
def keys(self):
185+
with _IterationGuard(self):
186+
for wr in self.data:
187+
obj = wr()
188+
if obj is not None:
189+
yield obj
190+
191+
__iter__ = keys
192+
193+
def values(self):
194+
with _IterationGuard(self):
195+
for wr, value in self.data.items():
196+
if wr() is not None:
197+
yield value
198+
199+
def keyrefs(self):
200+
"""Return a list of weak references to the keys.
201+
202+
The references are not guaranteed to be 'live' at the time
203+
they are used, so the result of calling the references needs
204+
to be checked before being used. This can be used to avoid
205+
creating references that will cause the garbage collector to
206+
keep the keys around longer than needed.
207+
208+
"""
209+
return list(self.data)
210+
211+
def popitem(self):
212+
self._dirty_len = True
213+
while True:
214+
key, value = self.data.popitem()
215+
o = key()
216+
if o is not None:
217+
return o, value
218+
219+
def pop(self, key, *args):
220+
self._dirty_len = True
221+
return self.data.pop(WeakIdRef(key), *args) # CHANGED
222+
223+
def setdefault(self, key, default=None):
224+
return self.data.setdefault(WeakIdRef(key, self._remove), default) # CHANGED
225+
226+
def update(self, dict=None, **kwargs):
227+
d = self.data
228+
if dict is not None:
229+
if not hasattr(dict, "items"):
230+
dict = type({})(dict)
231+
for key, value in dict.items():
232+
d[WeakIdRef(key, self._remove)] = value # CHANGED
233+
if len(kwargs):
234+
self.update(kwargs)
235+
236+
def __ior__(self, other):
237+
self.update(other)
238+
return self
239+
240+
def __or__(self, other):
241+
if isinstance(other, _collections_abc.Mapping):
242+
c = self.copy()
243+
c.update(other)
244+
return c
245+
return NotImplemented
246+
247+
def __ror__(self, other):
248+
if isinstance(other, _collections_abc.Mapping):
249+
c = self.__class__()
250+
c.update(other)
251+
c.update(self)
252+
return c
253+
return NotImplemented
254+
255+
# Default Mapping equality will tests keys for equality, but
256+
# we want to test ids for equality
257+
def __eq__(self, other):
258+
if not isinstance(other, Mapping):
259+
return NotImplemented
260+
return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
261+
262+
# Convenience alias
263+
WeakTensorKeyDictionary = WeakIdKeyDictionary

0 commit comments

Comments
 (0)