@@ -95,6 +95,9 @@ class QueueManager(BaseManager):
95
95
self .finish_request_barrier = [
96
96
threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
97
97
]
98
+ self .worker_process_tp_barrier = [
99
+ threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
100
+ ]
98
101
99
102
# Register shared objects with proxy types
100
103
QueueManager .register (
@@ -161,6 +164,10 @@ class QueueManager(BaseManager):
161
164
"get_finish_request_barrier" ,
162
165
callable = lambda idx : self .finish_request_barrier [idx ],
163
166
)
167
+ QueueManager .register (
168
+ "get_worker_process_tp_barrier" ,
169
+ callable = lambda idx : self .worker_process_tp_barrier [idx ],
170
+ )
164
171
self .manager : BaseManager = QueueManager (address = self .address , authkey = self .authkey )
165
172
self .manager .start ()
166
173
else :
@@ -180,6 +187,7 @@ class QueueManager(BaseManager):
180
187
QueueManager .register ("get_disaggregate_requests" )
181
188
QueueManager .register ("get_available_prefill_instances" )
182
189
QueueManager .register ("get_finish_request_barrier" )
190
+ QueueManager .register ("get_worker_process_tp_barrier" )
183
191
self .manager = QueueManager (address = self .address , authkey = self .authkey )
184
192
self ._connect_with_retry ()
185
193
@@ -199,6 +207,7 @@ class QueueManager(BaseManager):
199
207
self .disaggregate_requests = self .manager .get_disaggregate_requests (self .local_data_parallel_id )
200
208
self .available_prefill_instances = self .manager .get_available_prefill_instances ()
201
209
self .finish_request_barrier = self .manager .get_finish_request_barrier (self .local_data_parallel_id )
210
+ self .worker_process_tp_barrier = self .manager .get_worker_process_tp_barrier (self .local_data_parallel_id )
202
211
self .finished_req_queue = self .manager .get_finish_request_queue (self .local_data_parallel_id )
203
212
assert self .num_client == len (self .client_read_flag )
204
213
0 commit comments