-
Notifications
You must be signed in to change notification settings - Fork 74
Refactor communicator types to use C++ inheritance #983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
a1794ee
760f84b
b054b68
23d9e8b
ada5a72
a3c883a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -120,22 +120,17 @@ extern bool data_progress_auto; | |
/* Size of system memory pages */ | ||
extern size_t system_page_size; | ||
|
||
class nccl_net_ofi_listen_comm_t; | ||
class nccl_net_ofi_xfer_comm_t; | ||
|
||
class nccl_net_ofi_device_t; | ||
class nccl_net_ofi_domain_t; | ||
class nccl_net_ofi_ep_t; | ||
class nccl_net_ofi_plugin_t; | ||
|
||
struct nccl_net_ofi_req; | ||
struct nccl_net_ofi_comm; | ||
struct nccl_net_ofi_listen_comm; | ||
struct nccl_net_ofi_send_comm; | ||
struct nccl_net_ofi_recv_comm; | ||
|
||
typedef struct nccl_net_ofi_req nccl_net_ofi_req_t; | ||
typedef struct nccl_net_ofi_comm nccl_net_ofi_comm_t; | ||
typedef struct nccl_net_ofi_listen_comm nccl_net_ofi_listen_comm_t; | ||
typedef struct nccl_net_ofi_send_comm nccl_net_ofi_send_comm_t; | ||
typedef struct nccl_net_ofi_recv_comm nccl_net_ofi_recv_comm_t; | ||
|
||
/** | ||
* Request - handle for an outstanding non-blocking communication | ||
|
@@ -220,7 +215,7 @@ typedef enum nccl_ofi_comm_stage { | |
} nccl_ofi_comm_stage_t; | ||
|
||
typedef struct save_comm_state { | ||
nccl_net_ofi_comm_t *comm; | ||
nccl_net_ofi_xfer_comm_t *comm; | ||
nccl_ofi_comm_stage_t stage; | ||
} save_comm_state_t; | ||
|
||
|
@@ -476,6 +471,30 @@ class nccl_net_ofi_domain_t { | |
*/ | ||
virtual nccl_net_ofi_ep_t *create_endpoint() = 0; | ||
|
||
/** | ||
* @brief Register memory region (both Host and CUDA) | ||
* | ||
* @param ckey | ||
* MR cache key reference | ||
* @param type | ||
* Type of MR | ||
|
||
* @return Memory handle | ||
* @return 0 on success | ||
* non-zero on error | ||
*/ | ||
virtual int regMr(nccl_net_ofi_xfer_comm_t *comm, nccl_ofi_mr_ckey_ref ckey, int type, | ||
void **mhandle) = 0; | ||
|
||
/** | ||
* @brief Deregister memory region (both Host and CUDA) | ||
* | ||
* @return Memory handle | ||
* @return 0 on success | ||
* non-zero on error | ||
*/ | ||
virtual int deregMr(nccl_net_ofi_xfer_comm_t *comm, nccl_net_ofi_mr_handle_t *mhandle) = 0; | ||
|
||
/** | ||
* @brief Returns the base domain's device back-pointer. | ||
*/ | ||
|
@@ -670,7 +689,7 @@ class nccl_net_ofi_ep_t { | |
* The callee must allocate memory for send_comm. | ||
*/ | ||
virtual int connect(nccl_net_ofi_conn_handle_t *handle, | ||
nccl_net_ofi_send_comm_t **send_comm, | ||
nccl_net_ofi_xfer_comm_t **send_comm, | ||
int trafficClass) = 0; | ||
|
||
/** | ||
|
@@ -689,6 +708,15 @@ class nccl_net_ofi_ep_t { | |
*/ | ||
virtual int release_ep(bool skip_lock, bool force_cleanup); | ||
|
||
/** | ||
* @brief Get base domain | ||
*/ | ||
inline nccl_net_ofi_domain_t *get_base_domain() | ||
{ | ||
assert(domain != nullptr); | ||
return domain; | ||
} | ||
|
||
/** | ||
* @brief Increments the base endpoint reference count. | ||
*/ | ||
|
@@ -753,96 +781,100 @@ enum nccl_net_ofi_comm_type_t { | |
}; | ||
|
||
/** | ||
* Communicator - base class for communicator structures | ||
* | ||
* This is the base class for the listen, send, and recv | ||
* communicators. It should not be directly extended by transports, | ||
* but instead underlying transports should extend the listen, send, | ||
* and recv communicators. | ||
* Listen Communicator - Communicator for a listen/accept pairing | ||
*/ | ||
struct nccl_net_ofi_comm { | ||
class nccl_net_ofi_listen_comm_t { | ||
public: | ||
virtual ~nccl_net_ofi_listen_comm_t() = default; | ||
|
||
virtual int accept(nccl_net_ofi_xfer_comm_t **recv_comm) = 0; | ||
|
||
virtual int close() = 0; | ||
|
||
/** | ||
* @brief Get base domain from endpoint | ||
*/ | ||
inline nccl_net_ofi_domain_t *get_base_domain() | ||
{ | ||
assert(ep != nullptr); | ||
return ep->get_base_domain(); | ||
} | ||
|
||
enum nccl_net_ofi_comm_type_t type; | ||
nccl_net_ofi_ep_t *ep; | ||
int dev_id; | ||
}; | ||
|
||
/** | ||
* Listen Communicator - Communicator for a listen/accept pairing | ||
* Base transfer communicator class (derived by specialized recv and send comms) | ||
*/ | ||
struct nccl_net_ofi_listen_comm { | ||
nccl_net_ofi_comm_t base; | ||
|
||
int (*accept)(nccl_net_ofi_listen_comm_t *listen_comm, | ||
nccl_net_ofi_recv_comm_t **recv_comm); | ||
int (*close)(nccl_net_ofi_listen_comm_t *listen_comm); | ||
}; | ||
class nccl_net_ofi_xfer_comm_t { | ||
public: | ||
virtual ~nccl_net_ofi_xfer_comm_t() = default; | ||
|
||
struct nccl_net_ofi_send_comm { | ||
nccl_net_ofi_comm_t base; | ||
// TODO: Potentially store this here: int trafficClass; | ||
virtual int close() = 0; | ||
|
||
/* | ||
* @brief Register memory region on send communicator (both Host and CUDA) | ||
* | ||
* @return Memory handle for data send operations | ||
* @return 0 on success | ||
* non-zero on error | ||
*/ | ||
int (*regMr)(nccl_net_ofi_send_comm_t *send_comm, nccl_ofi_mr_ckey_ref ckey, int type, | ||
void **mhandle); | ||
|
||
/* | ||
* @brief Deregister memory region on send communicator (both Host and CUDA) | ||
* | ||
* @return Memory handle for data send operations | ||
* @return 0 on success | ||
* non-zero on error | ||
/** | ||
* Data transfer API functions need to be implemented by the derived recv | ||
* and send communicator classes. Calling send/write/write_inline from | ||
* the derived recv comm class or recv/flush/read from the derived send | ||
* comm class will fail as an unsupported operation | ||
*/ | ||
int (*deregMr)(nccl_net_ofi_send_comm_t *send_comm, nccl_net_ofi_mr_handle_t *mhandle); | ||
|
||
int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag, | ||
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **req); | ||
virtual int send(void *data, size_t size, int tag_arg, | ||
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **req) | ||
{ | ||
return log_unsupported_operation("comm", "isend", -EINVAL); | ||
} | ||
|
||
int (*close)(nccl_net_ofi_send_comm_t *send_comm); | ||
virtual int write(void *src, size_t size, void *src_mhandle, | ||
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **req) | ||
{ | ||
return log_unsupported_operation("comm", "iwrite", -EINVAL); | ||
} | ||
|
||
int (*write)(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t size, void* src_mhandle, | ||
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **req); | ||
int (*write_inline)(nccl_net_ofi_send_comm_t *, void* src, size_t size, | ||
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **request); | ||
}; | ||
virtual int write_inline(void *src, size_t size, | ||
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **request) | ||
{ | ||
return log_unsupported_operation("comm", "iwrite_inline", -EINVAL); | ||
} | ||
|
||
struct nccl_net_ofi_recv_comm { | ||
nccl_net_ofi_comm_t base; | ||
virtual int recv(int n, void **data, size_t *sizes, int *tags, | ||
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req) | ||
{ | ||
return log_unsupported_operation("comm", "irecv", -EINVAL); | ||
} | ||
|
||
/* | ||
* @brief Register memory region on recv communicator (both Host and CUDA) | ||
* | ||
* @return Memory handle for data recv operations | ||
* @return 0 on success | ||
* non-zero on error | ||
*/ | ||
int (*regMr)(nccl_net_ofi_recv_comm_t *recv_comm, nccl_ofi_mr_ckey_ref ckey, int type, | ||
void **mhandle); | ||
virtual int flush(int n, void **data, int *sizes, | ||
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req) | ||
{ | ||
return log_unsupported_operation("comm", "iflush", -EINVAL); | ||
} | ||
virtual int read(void *dest, size_t size, void *dest_mhandle, | ||
uint64_t src, uint64_t mr_key, nccl_net_ofi_req_t **req) | ||
{ | ||
return log_unsupported_operation("comm", "iread", -EINVAL); | ||
} | ||
|
||
/* | ||
* @brief Deregister memory region on recv communicator (both Host and CUDA) | ||
* | ||
* @return Memory handle for data recv operations | ||
* @return 0 on success | ||
* non-zero on error | ||
/** | ||
* @brief Get base domain from endpoint | ||
*/ | ||
int (*deregMr)(nccl_net_ofi_recv_comm_t *recv_comm, nccl_net_ofi_mr_handle_t *mhandle); | ||
|
||
int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, size_t *sizes, int *tags, | ||
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req); | ||
|
||
int (*flush)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes, | ||
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req); | ||
inline nccl_net_ofi_domain_t *get_base_domain() | ||
{ | ||
assert(ep != nullptr); | ||
return ep->get_base_domain(); | ||
} | ||
|
||
int (*close)(nccl_net_ofi_recv_comm_t *recv_comm); | ||
enum nccl_net_ofi_comm_type_t type; | ||
nccl_net_ofi_ep_t *ep; | ||
int dev_id; | ||
Comment on lines
+868
to
+870
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, is there a reasoning these should remain public? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should be made protected, I intended to wait on enforcing the access control until the follow-up comm refactor PR where all stand-alone functions associated with the comm classes are made members and can easily access these member variables without going through getters. |
||
|
||
int (*read)(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size, void* dest_mhandle, | ||
uint64_t src, uint64_t mr_key, nccl_net_ofi_req_t **req); | ||
protected: | ||
inline int log_unsupported_operation(const char *source, const char *api_name, int err_code) | ||
{ | ||
NCCL_OFI_WARN("%s does not support %s API function, RC: %d", source, api_name, err_code); | ||
return err_code; | ||
} | ||
}; | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this is an abstract class this should be a pure virtual destructor.