Skip to content

Commit 6f074cb

Browse files
committed
refactor: Convert platform hooks from weak symbols to C++ polymorphism
Replace the weak symbol-based platform hook system (init, config_endpoint, sort_rails, device_set_guid) with a modern C++ polymorphic design using abstract base classes, virtual methods, and singleton pattern with compile-time platform selection. This architectural change eliminates runtime null pointer checks, provides type safety, and improves code organization while preserving existing functionality. Signed-off-by: Hershel Shah <[email protected]>
1 parent 8f15f98 commit 6f074cb

14 files changed

+500
-223
lines changed

include/nccl_ofi_platform.h

Lines changed: 160 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,58 +5,175 @@
55
#ifndef NCCL_OFI_PLATFORM_H_
66
#define NCCL_OFI_PLATFORM_H_
77

8+
#include <memory>
9+
#include <map>
10+
811
#include <rdma/fabric.h>
912
#include <rdma/fi_endpoint.h>
1013

11-
/* Declare platform-specific hooks that can be provided by platform-specific
12-
* source files (such as the optionally compiled platform_aws.c). The functions
13-
* here are declared as weak symbols so that linkage will not break if no
14-
* platform specific hook is provided; in that case the hook will be NULL at
15-
* runtime.
16-
*/
17-
18-
/* Platform-specific initialization hook.
19-
*/
20-
int platform_init(const char **provider_filter) __attribute__((weak));
21-
22-
/* Platform-specific endpoint configuration hook
23-
*/
24-
int platform_config_endpoint(struct fi_info *info, struct fid_ep *ep) __attribute__((weak));
14+
#include "nccl_ofi_system.h"
2515

26-
/* Platform-specific hook to sort in the multi-rail protocol of the
27-
* plugin
16+
/**
17+
* @brief Abstract base class representing a platform implementation for NCCL OFI plugin
2818
*
29-
* Rail-oriented networks or traffic flows are a common performance
30-
* optimization for ML networks. Generally, Libfabric providers sort
31-
* their provider list by BDFs, which are indicitive of physical
32-
* ordering and good enough. However, on some platforms (especially
33-
* virtualized platforms), this might not actually be sufficient and
34-
* another sorting mechanism may be required to properly group NICs.
19+
* The Platform class provides an interface for platform-specific operations and configurations
20+
* in the NCCL OFI plugin. It defines virtual methods that must be implemented by concrete
21+
* platform implementations to handle platform-specific initialization, endpoint configuration,
22+
* and rail sorting operations.
3523
*
36-
* This interface is called in the topology initialization code to
37-
* order NICs that are behind the same PCIe root complex / switch.
38-
* The info_list will have num_rails providers listed, and will later
39-
* be split into num_groups groups (based on the number of
40-
* accelerators that are also behind the PCIe switch).
24+
* Each platform implementation can specify its priority level for selection, with higher
25+
* priority platforms being preferred over lower priority ones.
4126
*
42-
* Providers of this interface should sort the provided info_list such
43-
* that the Nth provider on this node will be assumed to talk to the
44-
* Nth provider on remote nodes (ie, identify the "rail id" and sort
45-
* by that).
27+
* Future platform are to be implemented by inheriting this class and overriding the
28+
* given functions. Look at PlatformAWS or Default as an example.
4629
*
47-
* @param info_list: pointer to list of `num_rails` info objects
48-
* @param num_rails: number of rails
30+
* @see Default
31+
* @see PlatformAWS
32+
* @see PlatformManager
4933
*/
50-
void platform_sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) __attribute__((weak));
34+
class Platform {
35+
public:
36+
virtual ~Platform() = default;
37+
38+
/**
39+
* @brief Get platform name
40+
*
41+
* @return Platform name string
42+
*/
43+
virtual const char* get_name() const = 0;
44+
45+
/**
46+
* @brief Get platform priority for selection.
47+
*
48+
* @return Priority value (higher values have higher priority)
49+
*/
50+
virtual int get_priority() = 0;
51+
52+
/**
53+
* @brief Platform-specific initialization hook
54+
*
55+
* @param provider_filter Pointer to provider filter string
56+
*
57+
* @return 0 on success, error code on failure
58+
*/
59+
virtual int init(const char **provider_filter) = 0;
60+
61+
/**
62+
* @brief Platform-specific endpoint configuration hook
63+
*
64+
* @param info Fabric info structure
65+
* @param ep Fabric endpoint
66+
*
67+
* @return 0 on success, error code on failure
68+
*/
69+
virtual int config_endpoint(struct fi_info *info, struct fid_ep *ep) = 0;
70+
71+
/**
72+
* @brief Platform-specific hook to sort in the multi-rail protocol of the plugin
73+
*
74+
* Rail-oriented networks or traffic flows are a common performance
75+
* optimization for ML networks. Generally, Libfabric providers sort
76+
* their provider list by BDFs, which are indicitive of physical
77+
* ordering and good enough. However, on some platforms (especially
78+
* virtualized platforms), this might not actually be sufficient and
79+
* another sorting mechanism may be required to properly group NICs.
80+
*
81+
* This interface is called in the topology initialization code to
82+
* order NICs that are behind the same PCIe root complex / switch.
83+
* The info_list will have num_rails providers listed, and will later
84+
* be split into num_groups groups (based on the number of
85+
* accelerators that are also behind the PCIe switch).
86+
*
87+
* Providers of this interface should sort the provided info_list such
88+
* that the Nth provider on this node will be assumed to talk to the
89+
* Nth provider on remote nodes (ie, identify the "rail id" and sort
90+
* by that).
91+
*
92+
* @param info_list Array of fabric info pointers to sort
93+
* @param num_rails Number of rails in the list
94+
* @param num_groups Number of groups to split rails into
95+
*/
96+
virtual void sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) = 0;
97+
98+
/**
99+
* @brief Platform-specific device GUID setter
100+
*
101+
* Sets device GUID to uniquely identify the network device
102+
*
103+
* @param info Fabric info structure
104+
* @param device Network device to set GUID for
105+
*/
106+
virtual void device_set_guid(struct fi_info *info, nccl_net_ofi_device_t *device) = 0;
107+
};
108+
109+
using PlatformPtr = std::unique_ptr<Platform>;
110+
111+
class Default : public Platform {
112+
public:
113+
const char* get_name() const override { return "Default"; }
114+
int get_priority() override { return 0; }
115+
int init(const char **provider_filter) override { return 0; }
116+
int config_endpoint(struct fi_info *info, struct fid_ep *ep) override { return 0; }
117+
void sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) override {}
118+
void device_set_guid(struct fi_info *info, nccl_net_ofi_device_t *device) override {
119+
uint32_t node_id = nccl_ofi_get_unique_node_id();
120+
/*
121+
* Use device_index as lower 8 bits
122+
* Use node_id as next 32 bits (bits 8-39)
123+
* Upper 24 bits remain 0
124+
*/
125+
device->guid = (static_cast<uint64_t>(node_id) << 8) | device->dev_id;
126+
}
127+
};
128+
129+
class PlatformManager {
130+
public:
131+
/**
132+
* @brief Get global instance
133+
*
134+
* @return Reference to global instance
135+
*/
136+
static PlatformManager& get_global();
137+
138+
/**
139+
* @brief Register a platform with the manager
140+
*
141+
* Platforms are automatically sorted by priority in the internal map.
142+
* Higher priority values take precedence and duplicates are dropped.
143+
*
144+
* @param platform Platform instance to register (moved)
145+
*/
146+
void register_platform(PlatformPtr&& platform);
147+
148+
/**
149+
* @brief Get the highest priority platform instance
150+
*
151+
* Returns the platform with the highest priority value.
152+
*
153+
* @return Reference to highest priority platform
154+
*/
155+
inline Platform& get_platform() { return *platforms_.rbegin()->second; }
156+
157+
/**
158+
* @brief Get number of registered platforms (for testing)
159+
*
160+
* @return Number of platforms in the manager
161+
*/
162+
inline size_t get_platform_count() { return platforms_.size(); }
163+
protected:
164+
/**
165+
* @brief Default constructor
166+
* Register the default Platform by default. A static global
167+
* instance is meant to be used in the plugin and the unit
168+
* tests leverage the protected scope.
169+
*/
170+
PlatformManager() {
171+
register_platform(std::make_unique<Default>());
172+
}
173+
174+
private:
175+
std::map<int, PlatformPtr> platforms_;
176+
};
51177

52-
/*
53-
* Platform-specific guid property setter
54-
*
55-
* This overrides the default guid setter (nccl_net_ofi_device_set_guid()) which
56-
* is based on network device index and IP address. Platforms can set
57-
* device->guid to be any 64-bit value as they seem fit to uniquely identify the
58-
* network device.
59-
*/
60-
void platform_device_set_guid(struct fi_info *info, nccl_net_ofi_device_t *device) __attribute__((weak));
61178

62179
#endif // End NCCL_OFI_PLATFORM_H_

include/platform-aws.h

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,51 +11,80 @@
1111
#define PLATFORM_AWS_H_
1212

1313
#include <map>
14+
#include <mutex>
1415
#include <string>
16+
#include <unordered_map>
1517

1618
#include "nccl_ofi_param.h"
19+
#include "nccl_ofi_platform.h"
1720

1821
#define PLATFORM_NAME_P6E_GB200 "p6e-gb200"
1922

20-
struct ec2_platform_data {
21-
const char* name;
22-
const char* regex;
23-
const char* topology;
24-
int default_dup_conns;
25-
float latency;
26-
bool gdr_required;
27-
PROTOCOL default_protocol;
28-
bool domain_per_thread;
29-
std::map<std::string, std::string> env;
30-
};
23+
class PlatformAWS : public Platform {
24+
public:
25+
const char* get_name() const override { return "AWS"; }
26+
int get_priority() override { return 100; }
27+
int init(const char **provider_filter) override;
28+
int config_endpoint(struct fi_info *info, struct fid_ep *ep) override;
29+
void sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) override;
30+
void device_set_guid(struct fi_info *info, nccl_net_ofi_device_t *device) override;
3131

32+
protected:
33+
struct ec2_platform_data {
34+
const char* name;
35+
const char* regex;
36+
const char* topology;
37+
int default_dup_conns;
38+
float latency;
39+
bool gdr_required;
40+
PROTOCOL default_protocol;
41+
bool domain_per_thread;
42+
std::map<std::string, std::string> env;
43+
};
3244

33-
struct platform_aws_node_guid {
34-
uint8_t func_idx;
35-
uint8_t per_card_pci_bus;
36-
uint16_t per_card_pci_domain;
37-
uint32_t func_mac_low_bytes;
38-
};
45+
struct platform_aws_node_guid {
46+
uint8_t func_idx;
47+
uint8_t per_card_pci_bus;
48+
uint16_t per_card_pci_domain;
49+
uint32_t func_mac_low_bytes;
50+
};
3951

40-
/*
41-
* @brief Get the platform data map
42-
*
43-
* This function exists solely to test
44-
* platform_aws_get_platform_entry() against the production data map.
45-
*/
46-
struct ec2_platform_data *platform_aws_get_platform_map(size_t *len);
52+
static const ec2_platform_data platform_data_map[];
4753

54+
// Platform data functions
55+
const ec2_platform_data *get_platform_data();
56+
const ec2_platform_data *get_platform_map(size_t *len) const;
57+
static const ec2_platform_data *get_platform_entry(const char *platform_type,
58+
const ec2_platform_data *platform_data_list,
59+
size_t platform_data_len);
4860

49-
/*
50-
* @brief Returns platform data for current platform type, if found
51-
*
52-
* @input Platform type
53-
*
54-
* @return NULL, if no topology found
55-
* platform data, if match found
56-
*/
57-
struct ec2_platform_data *platform_aws_get_platform_entry(const char *platform_type,
58-
struct ec2_platform_data *platform_data_list,
59-
size_t platform_data_len);
61+
// Endpoint configuration functions
62+
int validate_rdma_write(struct fid_ep *ep);
63+
int configure_ep_inorder(struct fid_ep *ep, int optname, const char* optname_name, bool *have_ordering);
64+
int configure_ep_max_msg_size(struct fid_ep *ep);
65+
int configure_nvls_option();
66+
int configure_tuner();
67+
68+
// GUID and rail functions
69+
const platform_aws_node_guid* get_node_guid_fields(struct fi_info *info);
70+
inline int get_rail_vf_idx(struct fi_info *info) {
71+
const auto* fields = get_node_guid_fields(info);
72+
return fields ? fields->func_idx : -EIO;
73+
}
74+
75+
private:
76+
std::mutex mutex_;
77+
78+
// Cache for GUID fields to avoid repeated sysfs reads
79+
std::unordered_map<std::string, platform_aws_node_guid> guid_cache_;
80+
81+
// Platform data state
82+
bool platform_data_init_ = false;
83+
const ec2_platform_data *cached_platform_data_ = nullptr;
84+
85+
// Endpoint config state
86+
bool nccl_proto_configured_ = false;
87+
bool need_ordering_ = false;
88+
};
6089

6190
#endif // End NCCL_OFI_H_

m4/ax_platform_aws.m4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ AC_DEFUN([AX_CHECK_PLATFORM_AWS],[
2121
AM_CONDITIONAL([WANT_PLATFORM_AWS], [test "${want_platform_aws}" = "yes"])
2222
AS_IF([test "${want_platform_aws}" = "yes"],
2323
[NCCL_OFI_PLATFORM="AWS"
24+
AC_DEFINE([WANT_AWS_PLATFORM], [1], [Define to 1 if AWS platform optimizations are enabled])
2425
AC_MSG_CHECKING([for Libfabric 1.22.0 or greater])
2526
AC_COMPILE_IFELSE([AC_LANG_PROGRAM(
2627
[[#include <rdma/fabric.h>

src/Makefile.am

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ sources = \
3333
nccl_ofi_dmabuf.cpp \
3434
nccl_ofi_ep_addr_list.cpp \
3535
nccl_ofi_param.cpp \
36+
nccl_ofi_platform.cpp \
3637
tracepoint.cpp
3738

3839
if WANT_PLATFORM_AWS

src/nccl_ofi_net.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
#include "nccl_ofi_idpool.h"
3131
#include "nccl_ofi_dmabuf.h"
3232
#include "nccl_ofi_platform.h"
33+
#ifdef WANT_AWS_PLATFORM
34+
#include "platform-aws.h"
35+
#endif
3336
#include "nccl_ofi_ofiutils.h"
3437
#include "nccl_ofi_system.h"
3538

@@ -175,11 +178,13 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
175178
nic_dup_conns = ofi_nccl_nic_dup_conns();
176179
cq_read_count = ofi_nccl_cq_read_count();
177180

178-
if (platform_init) {
179-
ret = platform_init(&provider_filter);
180-
if (ret != 0)
181-
goto exit;
182-
}
181+
#ifdef WANT_AWS_PLATFORM
182+
PlatformManager::get_global().register_platform(std::make_unique<PlatformAWS>());
183+
#endif
184+
185+
ret = PlatformManager::get_global().get_platform().init(&provider_filter);
186+
if (ret != 0)
187+
goto exit;
183188

184189
if (ofi_nccl_progress_model.get_source() != ParamSource::DEFAULT) {
185190
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Requesting progress model %s",
@@ -794,11 +799,7 @@ nccl_net_ofi_device_t::nccl_net_ofi_device_t(nccl_net_ofi_plugin_t *plugin_arg,
794799
throw std::runtime_error("Base device constructor: device name alloc failed");
795800
}
796801

797-
if (platform_device_set_guid) {
798-
platform_device_set_guid(info, this);
799-
} else {
800-
nccl_net_ofi_device_set_guid(info, this);
801-
}
802+
PlatformManager::get_global().get_platform().device_set_guid(info, this);
802803

803804
/* Intiaialize mutex for endpoint access */
804805
ret = nccl_net_ofi_mutex_init(&this->device_lock, nullptr);

0 commit comments

Comments
 (0)