Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/cann_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ flagcxResult_t cannAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(aclrtStream));
(*newStream)->base = (aclrtStream)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/cuda_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ flagcxResult_t cudaAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t));
(*newStream)->base = (cudaStream_t)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/ducuda_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ flagcxResult_t ducudaAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t));
(*newStream)->base = (cudaStream_t)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/hip_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ flagcxResult_t hipAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(hipStream_t));
(*newStream)->base = (hipStream_t)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/ixcuda_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ flagcxResult_t ixcudaAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t));
(*newStream)->base = (cudaStream_t)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/kunlunxin_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ flagcxResult_t kunlunAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t));
(*newStream)->base = (cudaStream_t)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/maca_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ flagcxResult_t macaAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(mcStream_t));
(*newStream)->base = (mcStream_t)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/mlu_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ flagcxResult_t mluAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(cnrtQueue_t));
(*newStream)->base = (cnrtQueue_t)oldStream;
return flagcxSuccess;
}

Expand Down
2 changes: 1 addition & 1 deletion flagcx/adaptor/device/musa_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ flagcxResult_t musaAdaptorStreamCopy(flagcxStream_t *newStream,
void *oldStream) {
(*newStream) = NULL;
flagcxCalloc(newStream, 1);
memcpy((void *)*newStream, oldStream, sizeof(musaStream_t));
(*newStream)->base = (musaStream_t)oldStream;
return flagcxSuccess;
}

Expand Down
30 changes: 19 additions & 11 deletions plugin/interservice/flagcx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,18 @@
flagcxMemcpyType_t = ctypes.c_int
flagcxMemType_t = ctypes.c_int
flagcxEventType_t = ctypes.c_int
flagcxIpcMemHandle_t = ctypes.c_void_p

flagcxHandlerGroup_t = ctypes.c_void_p
flagcxComm_t = ctypes.c_void_p
flagcxEvent_t = ctypes.c_void_p
cudaStream_t = ctypes.c_void_p
flagcxStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p


class flagcxStream(ctypes.Structure):
_fields_ = [("base", cudaStream_t)]
flagcxStream_t = ctypes.POINTER(flagcxStream)


class flagcxUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 256)]
flagcxUniqueId_t = ctypes.POINTER(flagcxUniqueId)


DEVICE_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t)
DEVICE_MEMCPY_FUNCTYPE = ctypes.CFUNCTYPE(
flagcxResult_t, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t,
Expand All @@ -57,6 +51,7 @@ class flagcxUniqueId(ctypes.Structure):
GET_DEVICE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int))
GET_DEVICE_COUNT_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int))
GET_VENDOR_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_char_p)
HOST_GET_DEVICE_POINTER_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p)

STREAM_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxStream_t))
STREAM_DESTROY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t)
Expand All @@ -71,6 +66,13 @@ class flagcxUniqueId(ctypes.Structure):
EVENT_RECORD_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t, flagcxStream_t)
EVENT_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t)
EVENT_QUERY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t)

IPC_MEM_HANDLE_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxIpcMemHandle_t), ctypes.POINTER(ctypes.c_size_t))
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IPC_MEM_HANDLE_CREATE_FUNCTYPE signature appears inconsistent with the C header. According to flagcx.h line 158-159, ipcMemHandleCreate takes flagcxIpcMemHandle_t *handle (pointer to pointer) and size_t *size, but flagcxIpcMemHandle_t is already defined as ctypes.c_void_p (line 22), which is a single pointer. The first parameter should be ctypes.POINTER(ctypes.c_void_p) to match the double-pointer indirection.

Copilot uses AI. Check for mistakes.
IPC_MEM_HANDLE_GET_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.c_void_p)
IPC_MEM_HANDLE_OPEN_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.POINTER(ctypes.c_void_p))
Comment on lines +71 to +72
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter types for IPC_MEM_HANDLE_GET_FUNCTYPE and IPC_MEM_HANDLE_OPEN_FUNCTYPE are swapped. According to flagcx.h lines 160-162, ipcMemHandleGet takes (flagcxIpcMemHandle_t handle, void *devPtr) while ipcMemHandleOpen takes (flagcxIpcMemHandle_t handle, void **devPtr). The current definitions have the pointer indirection reversed - line 71 should have ctypes.POINTER(ctypes.c_void_p) as the second parameter, and line 72 should have ctypes.c_void_p.

Suggested change
IPC_MEM_HANDLE_GET_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.c_void_p)
IPC_MEM_HANDLE_OPEN_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.POINTER(ctypes.c_void_p))
IPC_MEM_HANDLE_GET_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.POINTER(ctypes.c_void_p))
IPC_MEM_HANDLE_OPEN_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.c_void_p)

Copilot uses AI. Check for mistakes.
IPC_MEM_HANDLE_CLOSE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_void_p)
IPC_MEM_HANDLE_FREE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t)

class flagcxDeviceHandle(ctypes.Structure):
_fields_ = [
# Basic functions
Expand All @@ -83,6 +85,7 @@ class flagcxDeviceHandle(ctypes.Structure):
("getDevice", GET_DEVICE_FUNCTYPE),
("getDeviceCount", GET_DEVICE_COUNT_FUNCTYPE),
("getVendor", GET_VENDOR_FUNCTYPE),
("hostGetDevicePointer", HOST_GET_DEVICE_POINTER_FUNCTYPE),
# Stream functions
("streamCreate", STREAM_CREATE_FUNCTYPE),
("streamDestroy", STREAM_DESTROY_FUNCTYPE),
Expand All @@ -97,6 +100,12 @@ class flagcxDeviceHandle(ctypes.Structure):
("eventRecord", EVENT_RECORD_FUNCTYPE),
("eventSynchronize", EVENT_SYNCHRONIZE_FUNCTYPE),
("eventQuery", EVENT_QUERY_FUNCTYPE),
# IpcMemHandle functions
("ipcMemHandleCreate", IPC_MEM_HANDLE_CREATE_FUNCTYPE),
("ipcMemHandleGet", IPC_MEM_HANDLE_GET_FUNCTYPE),
("ipcMemHandleOpen", IPC_MEM_HANDLE_OPEN_FUNCTYPE),
("ipcMemHandleClose", IPC_MEM_HANDLE_CLOSE_FUNCTYPE),
("ipcMemHandleFree", IPC_MEM_HANDLE_FREE_FUNCTYPE),
]
flagcxDeviceHandle_t = ctypes.POINTER(flagcxDeviceHandle)

Expand Down Expand Up @@ -401,7 +410,7 @@ def adaptor_stream_create(self):

def adaptor_stream_copy(self, old_stream):
new_stream = flagcxStream_t()
self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCopy(ctypes.byref(new_stream), ctypes.byref(cudaStream_t(old_stream.cuda_stream))))
self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCopy(ctypes.byref(new_stream), ctypes.c_void_p(old_stream.cuda_stream)))
return new_stream

def adaptor_stream_free(self, stream):
Expand All @@ -416,6 +425,5 @@ def sync_stream(self, stream):

__all__ = [
"FLAGCXLibrary", "flagcxDataTypeEnum", "flagcxRedOpTypeEnum", "flagcxUniqueId",
"flagcxHandlerGroup_t", "flagcxComm_t", "flagcxStream_t", "flagcxEvent_t", "buffer_type", "cudaStream_t"
"flagcxHandlerGroup_t", "flagcxComm_t", "flagcxStream_t", "flagcxEvent_t", "buffer_type"
]