diff --git a/flagcx/adaptor/device/cann_adaptor.cc b/flagcx/adaptor/device/cann_adaptor.cc index 2848d8e3..3ac9f899 100644 --- a/flagcx/adaptor/device/cann_adaptor.cc +++ b/flagcx/adaptor/device/cann_adaptor.cc @@ -120,7 +120,7 @@ flagcxResult_t cannAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(aclrtStream)); + (*newStream)->base = (aclrtStream)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/cuda_adaptor.cc b/flagcx/adaptor/device/cuda_adaptor.cc index 6ea93bee..19dced48 100644 --- a/flagcx/adaptor/device/cuda_adaptor.cc +++ b/flagcx/adaptor/device/cuda_adaptor.cc @@ -191,7 +191,7 @@ flagcxResult_t cudaAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t)); + (*newStream)->base = (cudaStream_t)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/ducuda_adaptor.cc b/flagcx/adaptor/device/ducuda_adaptor.cc index ce10470a..e69b8ab3 100644 --- a/flagcx/adaptor/device/ducuda_adaptor.cc +++ b/flagcx/adaptor/device/ducuda_adaptor.cc @@ -136,7 +136,7 @@ flagcxResult_t ducudaAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t)); + (*newStream)->base = (cudaStream_t)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/hip_adaptor.cc b/flagcx/adaptor/device/hip_adaptor.cc index 1c7f2abb..603caa9d 100644 --- a/flagcx/adaptor/device/hip_adaptor.cc +++ b/flagcx/adaptor/device/hip_adaptor.cc @@ -135,7 +135,7 @@ flagcxResult_t hipAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(hipStream_t)); + (*newStream)->base = (hipStream_t)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/ixcuda_adaptor.cc b/flagcx/adaptor/device/ixcuda_adaptor.cc index 8b999b62..3f9c980f 100644 --- a/flagcx/adaptor/device/ixcuda_adaptor.cc +++ b/flagcx/adaptor/device/ixcuda_adaptor.cc @@ -136,7 +136,7 @@ flagcxResult_t ixcudaAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t)); + (*newStream)->base = (cudaStream_t)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/kunlunxin_adaptor.cc b/flagcx/adaptor/device/kunlunxin_adaptor.cc index 15ecbd27..c441e88e 100644 --- a/flagcx/adaptor/device/kunlunxin_adaptor.cc +++ b/flagcx/adaptor/device/kunlunxin_adaptor.cc @@ -159,7 +159,7 @@ flagcxResult_t kunlunAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(cudaStream_t)); + (*newStream)->base = (cudaStream_t)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/maca_adaptor.cc b/flagcx/adaptor/device/maca_adaptor.cc index 9779eb1a..f8e05f12 100644 --- a/flagcx/adaptor/device/maca_adaptor.cc +++ b/flagcx/adaptor/device/maca_adaptor.cc @@ -141,7 +141,7 @@ flagcxResult_t macaAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(mcStream_t)); + (*newStream)->base = (mcStream_t)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/mlu_adaptor.cc b/flagcx/adaptor/device/mlu_adaptor.cc index 3e16b625..cb241c52 100644 --- a/flagcx/adaptor/device/mlu_adaptor.cc +++ b/flagcx/adaptor/device/mlu_adaptor.cc @@ -119,7 +119,7 @@ flagcxResult_t mluAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(cnrtQueue_t)); + (*newStream)->base = (cnrtQueue_t)oldStream; return flagcxSuccess; } diff --git a/flagcx/adaptor/device/musa_adaptor.cc b/flagcx/adaptor/device/musa_adaptor.cc index 629525ea..41a00f35 100644 --- a/flagcx/adaptor/device/musa_adaptor.cc +++ b/flagcx/adaptor/device/musa_adaptor.cc @@ -140,7 +140,7 @@ flagcxResult_t musaAdaptorStreamCopy(flagcxStream_t *newStream, void *oldStream) { (*newStream) = NULL; flagcxCalloc(newStream, 1); - memcpy((void *)*newStream, oldStream, sizeof(musaStream_t)); + (*newStream)->base = (musaStream_t)oldStream; return flagcxSuccess; } diff --git a/plugin/interservice/flagcx_wrapper.py b/plugin/interservice/flagcx_wrapper.py index f0de12a2..cbc89037 100644 --- a/plugin/interservice/flagcx_wrapper.py +++ b/plugin/interservice/flagcx_wrapper.py @@ -19,24 +19,18 @@ flagcxMemcpyType_t = ctypes.c_int flagcxMemType_t = ctypes.c_int flagcxEventType_t = ctypes.c_int +flagcxIpcMemHandle_t = ctypes.c_void_p flagcxHandlerGroup_t = ctypes.c_void_p flagcxComm_t = ctypes.c_void_p flagcxEvent_t = ctypes.c_void_p -cudaStream_t = ctypes.c_void_p +flagcxStream_t = ctypes.c_void_p buffer_type = ctypes.c_void_p - -class flagcxStream(ctypes.Structure): - _fields_ = [("base", cudaStream_t)] -flagcxStream_t = ctypes.POINTER(flagcxStream) - - class flagcxUniqueId(ctypes.Structure): _fields_ = [("internal", ctypes.c_byte * 256)] flagcxUniqueId_t = ctypes.POINTER(flagcxUniqueId) - DEVICE_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t) DEVICE_MEMCPY_FUNCTYPE = ctypes.CFUNCTYPE( flagcxResult_t, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, @@ -57,6 +51,7 @@ class flagcxUniqueId(ctypes.Structure): GET_DEVICE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int)) GET_DEVICE_COUNT_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_int)) GET_VENDOR_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_char_p) +HOST_GET_DEVICE_POINTER_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p) STREAM_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxStream_t)) STREAM_DESTROY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxStream_t) @@ -71,6 +66,13 @@ class flagcxUniqueId(ctypes.Structure): EVENT_RECORD_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t, flagcxStream_t) EVENT_SYNCHRONIZE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) EVENT_QUERY_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxEvent_t) + +IPC_MEM_HANDLE_CREATE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.POINTER(flagcxIpcMemHandle_t), ctypes.POINTER(ctypes.c_size_t)) +IPC_MEM_HANDLE_GET_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.c_void_p) +IPC_MEM_HANDLE_OPEN_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t, ctypes.POINTER(ctypes.c_void_p)) +IPC_MEM_HANDLE_CLOSE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, ctypes.c_void_p) +IPC_MEM_HANDLE_FREE_FUNCTYPE = ctypes.CFUNCTYPE(flagcxResult_t, flagcxIpcMemHandle_t) + class flagcxDeviceHandle(ctypes.Structure): _fields_ = [ # Basic functions @@ -83,6 +85,7 @@ class flagcxDeviceHandle(ctypes.Structure): ("getDevice", GET_DEVICE_FUNCTYPE), ("getDeviceCount", GET_DEVICE_COUNT_FUNCTYPE), ("getVendor", GET_VENDOR_FUNCTYPE), + ("hostGetDevicePointer", HOST_GET_DEVICE_POINTER_FUNCTYPE), # Stream functions ("streamCreate", STREAM_CREATE_FUNCTYPE), ("streamDestroy", STREAM_DESTROY_FUNCTYPE), @@ -97,6 +100,12 @@ class flagcxDeviceHandle(ctypes.Structure): ("eventRecord", EVENT_RECORD_FUNCTYPE), ("eventSynchronize", EVENT_SYNCHRONIZE_FUNCTYPE), ("eventQuery", EVENT_QUERY_FUNCTYPE), + # IpcMemHandle functions + ("ipcMemHandleCreate", IPC_MEM_HANDLE_CREATE_FUNCTYPE), + ("ipcMemHandleGet", IPC_MEM_HANDLE_GET_FUNCTYPE), + ("ipcMemHandleOpen", IPC_MEM_HANDLE_OPEN_FUNCTYPE), + ("ipcMemHandleClose", IPC_MEM_HANDLE_CLOSE_FUNCTYPE), + ("ipcMemHandleFree", IPC_MEM_HANDLE_FREE_FUNCTYPE), ] flagcxDeviceHandle_t = ctypes.POINTER(flagcxDeviceHandle) @@ -401,7 +410,7 @@ def adaptor_stream_create(self): def adaptor_stream_copy(self, old_stream): new_stream = flagcxStream_t() - self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCopy(ctypes.byref(new_stream), ctypes.byref(cudaStream_t(old_stream.cuda_stream)))) + self.FLAGCX_CHECK(self.handler.contents.devHandle.contents.streamCopy(ctypes.byref(new_stream), ctypes.c_void_p(old_stream.cuda_stream))) return new_stream def adaptor_stream_free(self, stream): @@ -416,6 +425,5 @@ def sync_stream(self, stream): __all__ = [ "FLAGCXLibrary", "flagcxDataTypeEnum", "flagcxRedOpTypeEnum", "flagcxUniqueId", - "flagcxHandlerGroup_t", "flagcxComm_t", "flagcxStream_t", "flagcxEvent_t", "buffer_type", "cudaStream_t" + "flagcxHandlerGroup_t", "flagcxComm_t", "flagcxStream_t", "flagcxEvent_t", "buffer_type" ] -