Skip to content

Commit 8e6b0dd

Browse files
committed
[xpu] use cpu barrier
1 parent ed2dcec commit 8e6b0dd

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

fastdeploy/inter_communicator/engine_worker_queue.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class QueueManager(BaseManager):
9595
self.finish_request_barrier = [
9696
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
9797
]
98+
self.worker_process_tp_barrier = [
99+
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
100+
]
98101

99102
# Register shared objects with proxy types
100103
QueueManager.register(
@@ -161,6 +164,10 @@ class QueueManager(BaseManager):
161164
"get_finish_request_barrier",
162165
callable=lambda idx: self.finish_request_barrier[idx],
163166
)
167+
QueueManager.register(
168+
"get_worker_process_tp_barrier",
169+
callable=lambda idx: self.worker_process_tp_barrier[idx],
170+
)
164171
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
165172
self.manager.start()
166173
else:
@@ -180,6 +187,7 @@ class QueueManager(BaseManager):
180187
QueueManager.register("get_disaggregate_requests")
181188
QueueManager.register("get_available_prefill_instances")
182189
QueueManager.register("get_finish_request_barrier")
190+
QueueManager.register("get_worker_process_tp_barrier")
183191
self.manager = QueueManager(address=self.address, authkey=self.authkey)
184192
self._connect_with_retry()
185193

@@ -199,6 +207,7 @@ class QueueManager(BaseManager):
199207
self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id)
200208
self.available_prefill_instances = self.manager.get_available_prefill_instances()
201209
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
210+
self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id)
202211
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
203212
assert self.num_client == len(self.client_read_flag)
204213

fastdeploy/worker/worker_process.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ def _broadcast_model_weights_signal(self, src: int, group) -> int:
253253
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
254254
return model_weights_signal_tensor.item()
255255

256+
def _tp_barrier_wait(self):
257+
if current_platform.is_xpu():
258+
self.task_queue.worker_process_tp_barrier.wait()
259+
else:
260+
paddle.distributed.barrier(self.parallel_config.tp_group)
261+
256262
def event_loop_normal(self) -> None:
257263
"""Main event loop for Paddle Distributed Workers.
258264
TODO(gongshaotian): support remote calling of functions that control worker.
@@ -295,7 +301,7 @@ def event_loop_normal(self) -> None:
295301

296302
if self.parallel_config.tensor_parallel_size > 1:
297303
# Synchronize the signal for other workers
298-
paddle.distributed.barrier(self.parallel_config.tp_group)
304+
self._tp_barrier_wait()
299305

300306
if self.fd_config.load_config.dynamic_load_weight:
301307
if self.parallel_config.enable_expert_parallel:
@@ -346,7 +352,7 @@ def event_loop_normal(self) -> None:
346352

347353
if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
348354
if self.ranks > 1:
349-
paddle.distributed.barrier(self.parallel_config.tp_group)
355+
self._tp_barrier_wait()
350356

351357
time.sleep(0.001)
352358
continue

0 commit comments

Comments
 (0)