Skip to content

Commit 4124747

Browse files
committed
enh: Select platform optimizations at runtime
Replace compile-time AWS platform detection with runtime EFA device detection using hwloc. This allows a single binary to work on both AWS and non-AWS environments, automatically enabling optimizations when EFA hardware is present. Removes autotools platform checks and always builds AWS platform code. Signed-off-by: Hershel Shah <[email protected]>
1 parent 4decd69 commit 4124747

File tree

10 files changed

+135
-94
lines changed

10 files changed

+135
-94
lines changed

configure.ac

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,5 +278,7 @@ AC_OUTPUT
278278
echo "*"
279279
echo "* AWS OFI NCCL plugin has been configured."
280280
echo "*"
281-
echo "* Platform-specific optimizations: ${NCCL_OFI_PLATFORM}"
281+
AS_IF([test "${NCCL_OFI_PLATFORM}" = "none"],
282+
[echo "* Platform Optimizations: DISABLED"],
283+
[echo "* Platform Optimizations: ${NCCL_OFI_PLATFORM}"])
282284
echo "*"

include/nccl_ofi_param.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,4 +374,9 @@ OFI_NCCL_PARAM_UINT(cm_num_rx_buffers, "CM_NUM_RX_BUFFERS", 32);
374374
OFI_NCCL_PARAM_VALUE_SET(PROGRESS_MODEL, (UNSPEC)(AUTO)(MANUAL))
375375
OFI_NCCL_PARAM(PROGRESS_MODEL, progress_model, "PROGRESS_MODEL", PROGRESS_MODEL::UNSPEC)
376376

377+
/*
378+
* Override platform selection. Valid options: "AWS", "Default", or empty string for auto-detection.
379+
*/
380+
OFI_NCCL_PARAM(std::string, platform_override, "PLATFORM_OVERRIDE", "");
381+
377382
#endif // End NCCL_OFI_PARAM_H_

include/nccl_ofi_platform.h

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
#include <memory>
99
#include <map>
10+
#include <limits>
1011

1112
#include <rdma/fabric.h>
1213
#include <rdma/fi_endpoint.h>
1314

15+
#include "nccl_ofi_param.h"
1416
#include "nccl_ofi_system.h"
1517

1618
/**
@@ -135,44 +137,57 @@ class PlatformManager {
135137
*/
136138
static PlatformManager& get_global();
137139

138-
/**
139-
* @brief Register a platform with the manager
140-
*
141-
* Platforms are automatically sorted by priority in the internal map.
142-
* Higher priority values take precedence and duplicates are dropped.
143-
*
144-
* @param platform Platform instance to register (moved)
145-
*/
146-
void register_platform(PlatformPtr&& platform);
147-
148140
/**
149141
* @brief Get the highest priority platform instance
150142
*
151143
* Returns the platform with the highest priority value.
152144
*
153145
* @return Reference to highest priority platform
154146
*/
155-
inline Platform& get_platform() { return *platforms_.rbegin()->second; }
147+
inline Platform& get_platform() {
148+
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Selected platform: %s",
149+
platform_->get_name());
150+
return *platform_;
151+
}
156152

157-
/**
158-
* @brief Get number of registered platforms (for testing)
159-
*
160-
* @return Number of platforms in the manager
161-
*/
162-
inline size_t get_platform_count() { return platforms_.size(); }
163153
protected:
164154
/**
165155
* @brief Default constructor
166156
* Register the default Platform by default. A static global
167157
* instance is meant to be used in the plugin and the unit
168158
* tests leverage the protected scope.
169159
*/
170-
PlatformManager() {
171-
register_platform(std::make_unique<Default>());
160+
PlatformManager();
161+
162+
/**
163+
* @brief Register a platform with the manager
164+
*
165+
* Platforms are selected by priority. Higher priority values take
166+
* precedence. This can only be done in the constructor as all platforms
167+
* must be added during object creation to allow the tuner and plugin
168+
* to operate consistently.
169+
*
170+
* @param platform Platform instance to register (moved)
171+
*/
172+
void register_platform(PlatformPtr&& platform) {
173+
int priority = platform->get_priority();
174+
// Replace if no current platform or higher priority
175+
if (!platform_ || priority > current_priority_) {
176+
platform_ = std::move(platform);
177+
current_priority_ = priority;
178+
179+
// Set max priority if this matches the override so no other platform can override
180+
if (!override_.empty() && override_ == platform_->get_name()) {
181+
current_priority_ = std::numeric_limits<int>::max();
182+
}
183+
}
172184
}
173185

186+
174187
private:
175-
std::map<int, PlatformPtr> platforms_;
188+
std::string override_ = "";
189+
int current_priority_ = -1;
190+
PlatformPtr platform_ = nullptr;
176191
};
177192

178193

include/nccl_ofi_topo.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#ifndef NCCL_NET_OFI_TOPO_H_
66
#define NCCL_NET_OFI_TOPO_H_
77

8+
#include <memory>
9+
810
#include <hwloc.h>
911
#include <rdma/fabric.h>
1012

@@ -310,4 +312,49 @@ struct fi_info *nccl_ofi_topo_next_info_list(nccl_ofi_topo_data_iterator_t *iter
310312
*/
311313
int nccl_ofi_topo_write_nccl_topology(nccl_ofi_topo_t *topo, FILE *file);
312314

315+
class TopologyManager {
316+
private:
317+
static inline std::unique_ptr<nccl_ofi_topo_t, decltype(&nccl_ofi_topo_free)> instance{
318+
nullptr,
319+
nccl_ofi_topo_free
320+
};
321+
322+
public:
323+
static nccl_ofi_topo_t* get() {
324+
return instance.get();
325+
}
326+
327+
static nccl_ofi_topo_t* initialize(struct fi_info *provider_list) {
328+
instance.reset(nccl_ofi_topo_create(provider_list));
329+
if (instance && nccl_ofi_topo_group(instance.get()) != 0) {
330+
instance.reset();
331+
}
332+
return instance.get();
333+
}
334+
335+
static void reset() {
336+
instance.reset();
337+
}
338+
339+
// TODO: Refactor all topology functions to be part of this class
340+
static bool has_efa_ena_devices() {
341+
auto* topo_instance = get();
342+
if (!topo_instance) return false;
343+
344+
hwloc_obj_t obj = nullptr;
345+
while ((obj = hwloc_get_next_pcidev(topo_instance->topo, obj)) != nullptr) {
346+
// Check for Amazon vendor id and EFA device or ENA device
347+
if (obj->attr->pcidev.vendor_id == 0x1D0F &&
348+
(((obj->attr->pcidev.device_id & 0xFFF0) == 0xEFA0 || (obj->attr->pcidev.device_id & 0xFFF0) == 0xEC20) ||
349+
(obj->attr->pcidev.device_id & 0x0FFF) == 0x0EC2)) {
350+
return true;
351+
}
352+
}
353+
return false;
354+
}
355+
356+
// Prevent instantiation
357+
TopologyManager() = delete;
358+
};
359+
313360
#endif // End NCCL_NET_OFI_TOPO_H_

include/platform-aws.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class PlatformAWS : public Platform {
2424
public:
2525
const char* get_name() const override { return "AWS"; }
26-
int get_priority() override { return 100; }
26+
int get_priority() override { return TopologyManager::has_efa_ena_devices() ? 100 : -1; }
2727
int init(const char **provider_filter) override;
2828
int config_endpoint(struct fi_info *info, struct fid_ep *ep) override;
2929
void sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) override;
@@ -72,6 +72,8 @@ class PlatformAWS : public Platform {
7272
return fields ? fields->func_idx : -EIO;
7373
}
7474

75+
// Determine if running on Amazon EC2 instance
76+
7577
private:
7678
std::mutex mutex_;
7779

src/nccl_ofi_net.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@
3030
#include "nccl_ofi_idpool.h"
3131
#include "nccl_ofi_dmabuf.h"
3232
#include "nccl_ofi_platform.h"
33-
#ifdef WANT_AWS_PLATFORM
34-
#include "platform-aws.h"
35-
#endif
3633
#include "nccl_ofi_ofiutils.h"
3734
#include "nccl_ofi_system.h"
3835

@@ -178,10 +175,6 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
178175
nic_dup_conns = ofi_nccl_nic_dup_conns();
179176
cq_read_count = ofi_nccl_cq_read_count();
180177

181-
#ifdef WANT_AWS_PLATFORM
182-
PlatformManager::get_global().register_platform(std::make_unique<PlatformAWS>());
183-
#endif
184-
185178
ret = PlatformManager::get_global().get_platform().init(&provider_filter);
186179
if (ret != 0)
187180
goto exit;

src/nccl_ofi_platform.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,19 @@
33
*/
44

55
#include "nccl_ofi_platform.h"
6+
#ifdef WANT_AWS_PLATFORM
7+
#include "platform-aws.h"
8+
#endif
9+
10+
PlatformManager::PlatformManager()
11+
: override_(ofi_nccl_platform_override.get()) {
12+
register_platform(std::make_unique<Default>());
13+
#ifdef WANT_AWS_PLATFORM
14+
register_platform(std::make_unique<PlatformAWS>());
15+
#endif
16+
}
617

718
PlatformManager& PlatformManager::get_global() {
819
static PlatformManager manager;
920
return manager;
1021
}
11-
12-
void PlatformManager::register_platform(PlatformPtr&& platform) {
13-
int priority = platform->get_priority();
14-
const char* name = platform->get_name();
15-
16-
auto it = platforms_.find(priority);
17-
if (it != platforms_.end()) {
18-
if (strcmp(it->second->get_name(), name) == 0) {
19-
return;
20-
}
21-
// TODO: Add proper resolution mechanism for competing priorities
22-
priority++;
23-
}
24-
25-
platforms_[priority] = std::move(platform);
26-
}

src/nccl_ofi_rdma.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6934,10 +6934,8 @@ static void get_hints(struct fi_info *hints)
69346934

69356935
nccl_net_ofi_rdma_plugin_t::~nccl_net_ofi_rdma_plugin_t()
69366936
{
6937-
if (this->topo != nullptr) {
6938-
nccl_ofi_topo_free(this->topo);
6939-
this->topo = nullptr;
6940-
}
6937+
// TODO: Refactor so plugin holds reference to TopoploguManager object
6938+
this->topo = nullptr;
69416939

69426940
if (r_comm_cleanup_list != nullptr) {
69436941
delete r_comm_cleanup_list;
@@ -7013,17 +7011,11 @@ nccl_net_ofi_rdma_plugin_t::nccl_net_ofi_rdma_plugin_t(struct fi_info *provider_
70137011
int ret = 0;
70147012
int num_devices = 0;
70157013

7016-
/* Create NCCL OFI topology */
7017-
this->topo = nccl_ofi_topo_create(provider_list);
7014+
/* Use shared NCCL OFI topology */
7015+
this->topo = TopologyManager::get();
70187016
if (!this->topo) {
7019-
NCCL_OFI_WARN("Failed to create NCCL OFI topology");
7020-
throw std::runtime_error("rdma plugin constructor: topo creation failed");
7021-
}
7022-
7023-
ret = nccl_ofi_topo_group(this->topo);
7024-
if (ret != 0) {
7025-
NCCL_OFI_WARN("Failed to group NICs");
7026-
throw std::runtime_error("rdma plugin constructor: NIC grouping failed");
7017+
NCCL_OFI_WARN("Failed to get NCCL OFI topology");
7018+
throw std::runtime_error("rdma plugin constructor: topo access failed");
70277019
}
70287020

70297021
if (this->topo->max_group_size < 1 || this->topo->max_group_size > MAX_NUM_RAILS) {
@@ -7091,6 +7083,10 @@ int nccl_net_ofi_rdma_init(const char *provider_filter,
70917083
FI_MAJOR(api_version),
70927084
FI_MINOR(api_version),
70937085
FI_VERSION_GE(api_version, FI_VERSION(1, 20)) ? "DMA-BUF" : "GPUDirect RDMA");
7086+
7087+
/* Initialize topology manager with provider list */
7088+
TopologyManager::initialize(provider_list);
7089+
70947090
/* The 1.18 API allows providers to use CUDA to
70957091
* support HMEM pointers, so just having HMEM doesn't
70967092
* tell us anything about the usability of CUDA

src/tuner/nccl_ofi_tuner.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "nccl_ofi_pthread.h"
2020
#include "nccl_ofi_system.h"
2121
#include "nccl_ofi_param.h"
22+
#include "nccl_ofi_platform.h"
2223

2324
#include "tuner/nccl_ofi_tuner_region.h"
2425
#include "tuner/nccl_ofi_tuner_model.h"
@@ -63,7 +64,17 @@ static ncclResult_t nccl_ofi_tuner_init(size_t nRanks, size_t nNodes, ncclDebugL
6364
* Retrieve platform type and pass to Region and Model based tuner support check functions.
6465
* If both Region and Model based tuner are not supported, log a warning and exit.
6566
*/
66-
platform_type = nccl_net_ofi_get_product_name();
67+
auto platform_name = PlatformManager::get_global().get_platform().get_name();
68+
NCCL_OFI_INFO(NCCL_INIT | NCCL_TUNING, "Tuner init: Platform function returned: %s", platform_name);
69+
if (strcmp(platform_name, "AWS") == 0) {
70+
platform_type = nccl_net_ofi_get_product_name();
71+
} else {
72+
/* Default platform or other non-AWS platforms should use internal tuner */
73+
NCCL_OFI_INFO(NCCL_INIT | NCCL_TUNING,
74+
"Non-AWS platform detected (%s), falling back to NCCL's internal tuner", platform_name);
75+
goto exit;
76+
}
77+
6778
if (platform_type == NULL) {
6879
NCCL_OFI_WARN("NCCL_OFI_TUNER is not available because platform type is unavailable.");
6980
goto exit;

0 commit comments

Comments
 (0)