Skip to content

Commit 4decd69

Browse files
committed
ofi: manage ofi resource lifecycle with smart ptrs
To ensure that Libfabric resources managed by transport datatypes are automatically cleaned up during stack unwinding, this wraps Libfabric resources in unique_ptr with custom destructors that call "fi_close". Adds and refactors existing "ofiutils" helper functions to create and initialize the unique_ptr Libfabric resources. Also adds a ofi_result struct as a custom return type to return both a Libfabric resource unique_ptr and a return code in a single object. I only use the unique_ptr wrapper for classes that own the associated Libfabric resource and the connection management endpoint class' reference to a domain. I continue to use raw pointers for other classes that just use the resource owned by someone else. To support the unique_ptr requirement to not be copyable, this updates the RDMA transport protocol rail datatypes that own Libfabric resources to delete the copy constructor and copy assignment operator. Signed-off-by: Aviv Benchorin <[email protected]>
1 parent 3ac9392 commit 4decd69

11 files changed

+619
-397
lines changed

include/Makefile.am

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ noinst_HEADERS = \
3838
nccl_ofi_system.h \
3939
nccl_ofi_topo.h \
4040
nccl_ofi_tracepoint.h \
41+
ofi/resource_wrapper.h \
4142
platform-aws.h \
4243
internal/tuner/nccl_defaults.h \
4344
stats/histogram.h \

include/cm/nccl_ofi_cm_resources.h

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "cm/nccl_ofi_cm_reqs.h"
1616

1717
#include "nccl_ofi_freelist.h"
18+
#include "ofi/resource_wrapper.h"
1819

1920
namespace nccl_ofi_cm {
2021

@@ -28,8 +29,20 @@ class endpoint
2829
{
2930
public:
3031
/* Memory registration handle */
31-
struct mr_handle_t {
32-
fid_mr *mr;
32+
class mr_handle_t {
33+
public:
34+
/* Default constructor */
35+
mr_handle_t() = default;
36+
37+
/* Move constructor and assignment */
38+
mr_handle_t(mr_handle_t&&) = default;
39+
mr_handle_t& operator=(mr_handle_t&&) = default;
40+
41+
/* Delete copy operations since smart pointers are non-copyable */
42+
mr_handle_t(const mr_handle_t&) = delete;
43+
mr_handle_t& operator=(const mr_handle_t&) = delete;
44+
45+
ofi_mr_ptr mr;
3346
uint64_t mr_key;
3447
endpoint &ep;
3548
};
@@ -41,6 +54,14 @@ class endpoint
4154
* OFI domain against which to construct this ep
4255
*/
4356
endpoint(nccl_net_ofi_domain_t &domain);
57+
58+
/* Move constructor and assignment */
59+
endpoint(endpoint&&) = default;
60+
endpoint& operator=(endpoint&&) = default;
61+
62+
/* Delete copy operations since smart pointers are non-copyable */
63+
endpoint(const endpoint&) = delete;
64+
endpoint& operator=(const endpoint&) = delete;
4465

4566
/**
4667
* Destructor. Closes OFI endpoint if not already closed, as well as
@@ -64,21 +85,21 @@ class endpoint
6485
*
6586
* @param req: used for the context of the operation
6687
*/
67-
int send(nccl_ofi_cm_conn_msg &conn_msg, size_t size, mr_handle_t mr_handle,
88+
int send(nccl_ofi_cm_conn_msg &conn_msg, size_t size, mr_handle_t &mr_handle,
6889
fi_addr_t dest_addr, nccl_ofi_cm_req &req);
6990

7091
/**
7192
* Post a recv to the endpoint, with given parameters
7293
*
7394
* @param req: used for the context of the operation
7495
*/
75-
int recv(nccl_ofi_cm_conn_msg &conn_msg, size_t size, mr_handle_t mr_handle,
96+
int recv(nccl_ofi_cm_conn_msg &conn_msg, size_t size, mr_handle_t &mr_handle,
7697
nccl_ofi_cm_req &req);
7798

7899
/**
79100
* Close associated ofi_ep, while leaving other resources open
80101
*/
81-
int close_ofi_ep();
102+
void close_ofi_ep();
82103

83104
/* Menory registration/deregistration. Note: these functions are static
84105
to be usable with the freelist interface */
@@ -87,12 +108,12 @@ class endpoint
87108
static int dereg_mr(void *handle_ptr);
88109
private:
89110
/* Input to CM */
90-
fid_domain *ofi_domain;
111+
ofi_domain_ptr &ofi_domain;
91112
nccl_ofi_idpool_t &mr_key_pool;
92113

93114
/* Created by CM */
94-
fid_ep *ofi_ep;
95-
fid_av *av;
115+
ofi_av_ptr av;
116+
ofi_ep_ptr ofi_ep;
96117
};
97118

98119
/**

include/nccl_ofi.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "nccl_ofi_topo.h"
2020
#include "nccl_ofi_idpool.h"
2121
#include "nccl_ofi_mr.h"
22+
#include "ofi/resource_wrapper.h"
2223

2324
/*
2425
* NCCL_NET_HANDLE_MAXSIZE is a limited resource (and defined in NCCL).
@@ -148,7 +149,8 @@ struct nccl_net_ofi_req {
148149
int (*test)(nccl_net_ofi_req_t *req, int *done, int *size);
149150
};
150151

151-
struct nccl_net_ofi_mr_handle_t {
152+
class nccl_net_ofi_mr_handle_t {
153+
public:
152154
/**
153155
* @brief Default constructor
154156
*/
@@ -458,15 +460,15 @@ class nccl_net_ofi_domain_t {
458460
* depending on the transport; in that case, this will be the domain object
459461
* associated with the "leader NIC".
460462
*/
461-
virtual struct fid_domain *get_ofi_domain_for_cm() = 0;
463+
virtual ofi_domain_ptr &get_ofi_domain_for_cm() = 0;
462464

463465
/**
464466
* Retrieve an fid_cq object associated with this domain to be used for
465467
* connection management. There may be more than one fid_cq per domain, depending
466468
* on the transport; in that case, this will be the cq object associated with the
467469
* "leader NIC".
468470
*/
469-
virtual struct fid_cq *get_ofi_cq_for_cm() = 0;
471+
virtual ofi_cq_ptr &get_ofi_cq_for_cm() = 0;
470472

471473
/* Create a new endpoint
472474
*

include/nccl_ofi_ofiutils.h

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <rdma/fabric.h>
99

1010
#include "nccl_ofi_param.h"
11+
#include "ofi/resource_wrapper.h"
1112

1213
/*
1314
* Memeory util functions to ensure that the compiler does not optimize
@@ -28,22 +29,68 @@ int nccl_ofi_ofiutils_get_providers(const char *prov_include,
2829
unsigned int *num_prov_infos);
2930

3031

31-
/*
32-
* @brief Allocates and initialises libfabric endpoint and AV.
32+
/**
33+
* @brief Release libfabric endpoint and address vector
34+
*/
35+
void nccl_ofi_ofiutils_ep_release(ofi_ep_ptr& ep, ofi_av_ptr& av, int dev_id);
36+
37+
/**
38+
* @brief Create and initialize libfabric fabric
3339
*
34-
* @param cq: Completion queue to which the new endpoint will be bound
35-
* @return Endpoint ep
36-
* @return Address vector av
40+
* @param info: Fabric info for fabric creation
41+
* @return Result containing error code and fabric pointer
3742
*/
38-
int nccl_ofi_ofiutils_init_connection(struct fi_info *info, struct fid_domain *domain,
39-
struct fid_ep **ep, struct fid_av **av,
40-
struct fid_cq *cq);
43+
ofi_fabric_result nccl_ofi_ofiutils_fabric_create(struct fi_info *info);
4144

42-
/*
43-
* @brief Release libfabric endpoint and address vector
45+
/**
46+
* @brief Create and initialize libfabric domain
47+
*
48+
* @param fabric: Fabric handle
49+
* @param info: Fabric info for domain creation
50+
* @return Result containing error code and domain pointer
51+
*/
52+
ofi_domain_result nccl_ofi_ofiutils_domain_create(ofi_fabric_ptr& fabric, struct fi_info *info);
53+
54+
/**
55+
* @brief Create and initialize libfabric endpoint
56+
*
57+
* @param info: Fabric info for endpoint creation
58+
* @param domain: Fabric domain
59+
* @param av: Address vector to which the new endpoint will be bound
60+
* @param cq: Completion queue to which the new endpoint will be bound
61+
* @return Result containing error code and endpoint pointer
62+
*/
63+
ofi_ep_result nccl_ofi_ofiutils_ep_create(struct fi_info *info, ofi_domain_ptr &domain,
64+
ofi_av_ptr &av, ofi_cq_ptr &cq);
65+
66+
/**
67+
* @brief Create and initialize libfabric address vector
68+
*
69+
* @param domain: Domain handle
70+
* @return Result containing error code and address vector pointer
71+
*/
72+
ofi_av_result nccl_ofi_ofiutils_av_create(ofi_domain_ptr &domain);
73+
74+
/**
75+
* @brief Create and initialize libfabric completion queue
76+
*
77+
* @param domain: Domain handle
78+
* @param cq_attr: CQ attributes
79+
* @return Result containing error code and completion queue pointer
80+
*/
81+
ofi_cq_result nccl_ofi_ofiutils_cq_create(ofi_domain_ptr &domain, struct fi_cq_attr *cq_attr);
82+
83+
/**
84+
* @brief Register memory region with libfabric using fi_mr_regattr
85+
*
86+
* @param domain: Domain handle
87+
* @param mr_attr: Memory region attributes structure
88+
* @param flags: Registration flags
89+
* @return Result containing error code and memory region pointer
4490
*/
45-
void nccl_ofi_ofiutils_ep_release(struct fid_ep *ep, struct fid_av *av,
46-
int dev_id);
91+
ofi_mr_result nccl_ofi_ofiutils_mr_regattr(ofi_domain_ptr &domain,
92+
struct fi_mr_attr *mr_attr,
93+
uint64_t flags);
4794

4895
/*
4996
* @brief Free libfabric NIC info list.

0 commit comments

Comments
 (0)