Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 112 additions & 80 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,17 @@ extern bool data_progress_auto;
/* Size of system memory pages */
extern size_t system_page_size;

class nccl_net_ofi_listen_comm_t;
class nccl_net_ofi_xfer_comm_t;

class nccl_net_ofi_device_t;
class nccl_net_ofi_domain_t;
class nccl_net_ofi_ep_t;
class nccl_net_ofi_plugin_t;

struct nccl_net_ofi_req;
struct nccl_net_ofi_comm;
struct nccl_net_ofi_listen_comm;
struct nccl_net_ofi_send_comm;
struct nccl_net_ofi_recv_comm;

typedef struct nccl_net_ofi_req nccl_net_ofi_req_t;
typedef struct nccl_net_ofi_comm nccl_net_ofi_comm_t;
typedef struct nccl_net_ofi_listen_comm nccl_net_ofi_listen_comm_t;
typedef struct nccl_net_ofi_send_comm nccl_net_ofi_send_comm_t;
typedef struct nccl_net_ofi_recv_comm nccl_net_ofi_recv_comm_t;

/**
* Request - handle for an outstanding non-blocking communication
Expand Down Expand Up @@ -220,7 +215,7 @@ typedef enum nccl_ofi_comm_stage {
} nccl_ofi_comm_stage_t;

typedef struct save_comm_state {
nccl_net_ofi_comm_t *comm;
nccl_net_ofi_xfer_comm_t *comm;
nccl_ofi_comm_stage_t stage;
} save_comm_state_t;

Expand Down Expand Up @@ -476,6 +471,30 @@ class nccl_net_ofi_domain_t {
*/
virtual nccl_net_ofi_ep_t *create_endpoint() = 0;

/**
* @brief Register memory region (both Host and CUDA)
*
* @param ckey
* MR cache key reference
* @param type
* Type of MR

* @return Memory handle
* @return 0 on success
* non-zero on error
*/
virtual int regMr(nccl_net_ofi_xfer_comm_t *comm, nccl_ofi_mr_ckey_ref ckey, int type,
void **mhandle) = 0;

/**
* @brief Deregister memory region (both Host and CUDA)
*
* @return Memory handle
* @return 0 on success
* non-zero on error
*/
virtual int deregMr(nccl_net_ofi_xfer_comm_t *comm, nccl_net_ofi_mr_handle_t *mhandle) = 0;

/**
* @brief Returns the base domain's device back-pointer.
*/
Expand Down Expand Up @@ -670,7 +689,7 @@ class nccl_net_ofi_ep_t {
* The callee must allocate memory for send_comm.
*/
virtual int connect(nccl_net_ofi_conn_handle_t *handle,
nccl_net_ofi_send_comm_t **send_comm,
nccl_net_ofi_xfer_comm_t **send_comm,
int trafficClass) = 0;

/**
Expand All @@ -689,6 +708,15 @@ class nccl_net_ofi_ep_t {
*/
virtual int release_ep(bool skip_lock, bool force_cleanup);

/**
* @brief Get base domain
*/
inline nccl_net_ofi_domain_t *get_base_domain()
{
assert(domain != nullptr);
return domain;
}

/**
* @brief Increments the base endpoint reference count.
*/
Expand Down Expand Up @@ -753,96 +781,100 @@ enum nccl_net_ofi_comm_type_t {
};

/**
* Communicator - base class for communicator structures
*
* This is the base class for the listen, send, and recv
* communicators. It should not be directly extended by transports,
* but instead underlying transports should extend the listen, send,
* and recv communicators.
* Listen Communicator - Communicator for a listen/accept pairing
*/
struct nccl_net_ofi_comm {
class nccl_net_ofi_listen_comm_t {
public:
virtual ~nccl_net_ofi_listen_comm_t() = default;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is an abstract class this should be a pure virtual destructor.


virtual int accept(nccl_net_ofi_xfer_comm_t **recv_comm) = 0;

virtual int close() = 0;

/**
* @brief Get base domain from endpoint
*/
inline nccl_net_ofi_domain_t *get_base_domain()
{
assert(ep != nullptr);
return ep->get_base_domain();
}

enum nccl_net_ofi_comm_type_t type;
nccl_net_ofi_ep_t *ep;
int dev_id;
};

/**
* Listen Communicator - Communicator for a listen/accept pairing
* Base transfer communicator class (derived by specialized recv and send comms)
*/
struct nccl_net_ofi_listen_comm {
nccl_net_ofi_comm_t base;

int (*accept)(nccl_net_ofi_listen_comm_t *listen_comm,
nccl_net_ofi_recv_comm_t **recv_comm);
int (*close)(nccl_net_ofi_listen_comm_t *listen_comm);
};
class nccl_net_ofi_xfer_comm_t {
public:
virtual ~nccl_net_ofi_xfer_comm_t() = default;

struct nccl_net_ofi_send_comm {
nccl_net_ofi_comm_t base;
// TODO: Potentially store this here: int trafficClass;
virtual int close() = 0;

/*
* @brief Register memory region on send communicator (both Host and CUDA)
*
* @return Memory handle for data send operations
* @return 0 on success
* non-zero on error
*/
int (*regMr)(nccl_net_ofi_send_comm_t *send_comm, nccl_ofi_mr_ckey_ref ckey, int type,
void **mhandle);

/*
* @brief Deregister memory region on send communicator (both Host and CUDA)
*
* @return Memory handle for data send operations
* @return 0 on success
* non-zero on error
/**
* Data transfer API functions need to be implemented by the derived recv
* and send communicator classes. Calling send/write/write_inline from
* the derived recv comm class or recv/flush/read from the derived send
* comm class will fail as an unsupported operation
*/
int (*deregMr)(nccl_net_ofi_send_comm_t *send_comm, nccl_net_ofi_mr_handle_t *mhandle);

int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **req);
virtual int send(void *data, size_t size, int tag_arg,
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **req)
{
return log_unsupported_operation("comm", "isend", -EINVAL);
}

int (*close)(nccl_net_ofi_send_comm_t *send_comm);
virtual int write(void *src, size_t size, void *src_mhandle,
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **req)
{
return log_unsupported_operation("comm", "iwrite", -EINVAL);
}

int (*write)(nccl_net_ofi_send_comm_t *send_comm, void* src, size_t size, void* src_mhandle,
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **req);
int (*write_inline)(nccl_net_ofi_send_comm_t *, void* src, size_t size,
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **request);
};
virtual int write_inline(void *src, size_t size,
uint64_t dest, uint64_t mr_key, nccl_net_ofi_req_t **request)
{
return log_unsupported_operation("comm", "iwrite_inline", -EINVAL);
}

struct nccl_net_ofi_recv_comm {
nccl_net_ofi_comm_t base;
virtual int recv(int n, void **data, size_t *sizes, int *tags,
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req)
{
return log_unsupported_operation("comm", "irecv", -EINVAL);
}

/*
* @brief Register memory region on recv communicator (both Host and CUDA)
*
* @return Memory handle for data recv operations
* @return 0 on success
* non-zero on error
*/
int (*regMr)(nccl_net_ofi_recv_comm_t *recv_comm, nccl_ofi_mr_ckey_ref ckey, int type,
void **mhandle);
virtual int flush(int n, void **data, int *sizes,
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req)
{
return log_unsupported_operation("comm", "iflush", -EINVAL);
}
virtual int read(void *dest, size_t size, void *dest_mhandle,
uint64_t src, uint64_t mr_key, nccl_net_ofi_req_t **req)
{
return log_unsupported_operation("comm", "iread", -EINVAL);
}

/*
* @brief Deregister memory region on recv communicator (both Host and CUDA)
*
* @return Memory handle for data recv operations
* @return 0 on success
* non-zero on error
/**
* @brief Get base domain from endpoint
*/
int (*deregMr)(nccl_net_ofi_recv_comm_t *recv_comm, nccl_net_ofi_mr_handle_t *mhandle);

int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, size_t *sizes, int *tags,
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req);

int (*flush)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes,
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req);
inline nccl_net_ofi_domain_t *get_base_domain()
{
assert(ep != nullptr);
return ep->get_base_domain();
}

int (*close)(nccl_net_ofi_recv_comm_t *recv_comm);
enum nccl_net_ofi_comm_type_t type;
nccl_net_ofi_ep_t *ep;
int dev_id;
Comment on lines +868 to +870
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is there a reasoning these should remain public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be made protected, I intended to wait on enforcing the access control until the follow-up comm refactor PR where all stand-alone functions associated with the comm classes are made members and can easily access these member variables without going through getters.


int (*read)(nccl_net_ofi_recv_comm_t *recv_comm, void* dest, size_t size, void* dest_mhandle,
uint64_t src, uint64_t mr_key, nccl_net_ofi_req_t **req);
protected:
inline int log_unsupported_operation(const char *source, const char *api_name, int err_code)
{
NCCL_OFI_WARN("%s does not support %s API function, RC: %d", source, api_name, err_code);
return err_code;
}
};

/**
Expand Down
Loading