-
Notifications
You must be signed in to change notification settings - Fork 19.6k
/
Copy pathtracking.py
299 lines (243 loc) · 8.92 KB
/
tracking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
from functools import wraps
from keras.src import tree
from keras.src.backend.common.global_state import get_global_attribute
from keras.src.backend.common.global_state import set_global_attribute
from keras.src.utils import python_utils
class DotNotTrackScope:
def __enter__(self):
self.original_value = is_tracking_enabled()
set_global_attribute("tracking_on", False)
def __exit__(self, *args, **kwargs):
set_global_attribute("tracking_on", self.original_value)
def is_tracking_enabled():
return get_global_attribute("tracking_on", True)
def no_automatic_dependency_tracking(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
with DotNotTrackScope():
return fn(*args, **kwargs)
return wrapper
def safe_register_tree_node_class(cls):
try:
return tree.register_tree_node_class(cls)
except ValueError:
# optree raises a ValueError if the class is already registered.
# Triggered if config.set_backend() is called multiple times.
return cls
class Tracker:
"""Attribute tracker, used for e.g. Variable tracking.
Monitors certain attribute types
and put them in appropriate lists in case of a match.
Also passively tracks certain mutable collections
(dict, list) so that items added to them later
still get tracked. This is done by wrapping these
collections into an equivalent, tracking-aware object.
Example:
```python
def __init__(self):
self.tracker = Tracker(
# Format: `name: (test_fn, store)`
{
"variables":
(lambda x: isinstance(x, Variable), self._variables),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
"layers": (lambda x: isinstance(x, Layer), self._layers),
}
)
def __setattr__(self, name, value):
if hasattr(self, "_tracker"):
value = self._tracker.track(value)
return super().__setattr__(name, value)
```
"""
def __init__(self, config, exclusions=None):
self.config = config
self.stored_ids = {name: set() for name in self.config.keys()}
self.locked = False
self._lock_violation_msg = None
self.exclusions = exclusions or {}
def track(self, attr):
if not is_tracking_enabled():
return attr
for store_name, (is_attr_type, _) in self.config.items():
if is_attr_type(attr):
if store_name in self.exclusions:
for excl in self.exclusions[store_name]:
if self.is_in_store(excl, attr):
return attr
if not self.is_in_store(store_name, attr):
self.add_to_store(store_name, attr)
return attr
if isinstance(attr, tuple) and hasattr(attr, "_fields"):
# Named tuple case.
wrapped_attr = {}
for name, e in attr._asdict().items():
wrapped_attr[name] = self.track(e)
return attr.__class__(**wrapped_attr)
if isinstance(attr, tuple):
wrapped_attr = []
for e in attr:
wrapped_attr.append(self.track(e))
return attr.__class__(wrapped_attr)
elif isinstance(attr, list):
return TrackedList(attr, self)
elif isinstance(attr, dict):
# TODO: OrderedDict?
return TrackedDict(attr, self)
elif isinstance(attr, set):
return TrackedSet(attr, self)
return attr
def untrack(self, value):
for store_name in self.stored_ids.keys():
if id(value) in self.stored_ids[store_name]:
self.stored_ids[store_name].remove(id(value))
python_utils.remove_by_id(self.config[store_name][1], value)
def lock(self, msg=None):
self.locked = True
if msg is not None:
self._lock_violation_msg = msg
def unlock(self):
self.locked = False
def add_to_store(self, store_name, value):
if self.locked:
raise ValueError(self._lock_violation_msg)
self.config[store_name][1].append(value)
self.stored_ids[store_name].add(id(value))
def is_in_store(self, store_name, value):
return id(value) in self.stored_ids[store_name]
def replace_tracked_value(self, store_name, old_value, new_value):
if not self.is_in_store(store_name, old_value):
raise ValueError(f"Unknown value: {old_value}")
store_list = self.config[store_name][1]
index = store_list.index(old_value)
store_list[index] = new_value
self.stored_ids[store_name].remove(id(old_value))
self.stored_ids[store_name].add(id(new_value))
@safe_register_tree_node_class
class TrackedList(list):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = [tracker.track(v) for v in values]
super().__init__(values or [])
def append(self, value):
if self.tracker:
self.tracker.track(value)
super().append(value)
def insert(self, index, value):
if self.tracker:
self.tracker.track(value)
super().insert(index, value)
def extend(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
super().extend(values)
def remove(self, value):
if self.tracker:
self.tracker.untrack(value)
try:
super().remove(value)
except ValueError:
python_utils.remove_by_id(self, value)
def pop(self, index=-1):
if self.tracker:
value = self[index]
self.tracker.untrack(value)
return super().pop(index)
else:
return super().pop(index)
def clear(self):
if self.tracker:
for value in self:
self.tracker.untrack(value)
super().clear()
def __delitem__(self, index):
value = self[index] # Get value before removing
super().__delitem__(index)
if self.tracker:
self.tracker.untrack(value)
def tree_flatten(self):
# For optree / dmtree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
# For optree / dmtree
return cls(children)
@safe_register_tree_node_class
class TrackedDict(dict):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = {k: tracker.track(v) for k, v in values.items()}
super().__init__(values or [])
def __setitem__(self, key, value):
if self.tracker:
self.tracker.track(value)
super().__setitem__(key, value)
def update(self, mapping):
if self.tracker:
mapping = {k: self.tracker.track(v) for k, v in mapping.items()}
super().update(mapping)
def pop(self, key, default=None):
if self.tracker:
value = super().pop(key, default)
if value is not default:
self.tracker.untrack(value)
return value
else:
return super().pop(key, default)
def popitem(self):
key, value = super().popitem()
if self.tracker:
self.tracker.untrack(value)
return key, value
def clear(self):
if self.tracker:
for value in self.values():
self.tracker.untrack(value)
super().clear()
def tree_flatten(self):
# For optree / dmtree
keys = sorted(list(self.keys()))
values = [self[k] for k in keys]
return values, keys, keys
@classmethod
def tree_unflatten(cls, keys, values):
# For optree / dmtree
return cls(zip(keys, values))
@safe_register_tree_node_class
class TrackedSet(set):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
if tracker and values:
values = {tracker.track(v) for v in values}
super().__init__(values or [])
def add(self, value):
if self.tracker:
self.tracker.track(value)
super().add(value)
def update(self, values):
if self.tracker:
values = [self.tracker.track(v) for v in values]
super().update(values)
def remove(self, value):
if self.tracker:
self.tracker.untrack(value)
super().remove(value)
def pop(self):
value = super().pop()
if self.tracker:
self.tracker.untrack(value)
return value
def clear(self):
if self.tracker:
for value in self:
self.tracker.untrack(value)
super().clear()
def tree_flatten(self):
# For optree / dmtree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
# For optree / dmtree
return cls(children)