Skip to content

Commit 00cb220

Browse files
committed
enh: Select platform optimizations at runtime
Replace compile-time AWS platform detection with runtime EFA device detection using hwloc. This allows a single binary to work on both AWS and non-AWS environments, automatically enabling optimizations when EFA hardware is present. Removes autotools platform checks and always builds AWS platform code. Signed-off-by: Hershel Shah <[email protected]>
1 parent 6f074cb commit 00cb220

12 files changed

+92
-78
lines changed

configure.ac

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ CHECK_PKG_HWLOC([],
137137
CHECK_PKG_VALGRIND()
138138
CHECK_VAR_REDZONE()
139139

140-
NCCL_OFI_PLATFORM="none"
141-
AS_IF([test "${NCCL_OFI_PLATFORM}" = "none"], [AX_CHECK_PLATFORM_AWS()])
142-
143140
AS_IF([test "${valgrind_enabled}" = "1" -a "${enable_asan}" = "yes"],
144141
[AC_MSG_ERROR([Enabling ASAN and valgrind at the same time is not permitted])])
145142

@@ -278,5 +275,5 @@ AC_OUTPUT
278275
echo "*"
279276
echo "* AWS OFI NCCL plugin has been configured."
280277
echo "*"
281-
echo "* Platform-specific optimizations: ${NCCL_OFI_PLATFORM}"
278+
echo "* Platform optimizations: Runtime detection enabled"
282279
echo "*"

include/nccl_ofi_param.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,4 +374,10 @@ OFI_NCCL_PARAM_UINT(cm_num_rx_buffers, "CM_NUM_RX_BUFFERS", 32);
374374
OFI_NCCL_PARAM_VALUE_SET(PROGRESS_MODEL, (UNSPEC)(AUTO)(MANUAL))
375375
OFI_NCCL_PARAM(PROGRESS_MODEL, progress_model, "PROGRESS_MODEL", PROGRESS_MODEL::UNSPEC)
376376

377+
/*
378+
* Force non-AWS platform detection for testing. When set to true,
379+
* the AWS platform will report negative priority even on AWS hardware.
380+
*/
381+
OFI_NCCL_PARAM(bool, force_non_aws, "FORCE_NON_AWS", false);
382+
377383
#endif // End NCCL_OFI_PARAM_H_

include/nccl_ofi_platform.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,12 @@ class PlatformManager {
152152
*
153153
* @return Reference to highest priority platform
154154
*/
155-
inline Platform& get_platform() { return *platforms_.rbegin()->second; }
155+
inline Platform& get_platform() {
156+
Platform& selected = *platforms_.rbegin()->second;
157+
NCCL_OFI_INFO(NCCL_INIT, "Selected platform: %s with priority %d",
158+
selected.get_name(), platforms_.rbegin()->first);
159+
return selected;
160+
}
156161

157162
/**
158163
* @brief Get number of registered platforms (for testing)

include/platform-aws.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
class PlatformAWS : public Platform {
2424
public:
2525
const char* get_name() const override { return "AWS"; }
26-
int get_priority() override { return 100; }
26+
int get_priority() override { return BASE_PRIORITY * (is_aws() ? 1 : -1); }
2727
int init(const char **provider_filter) override;
2828
int config_endpoint(struct fi_info *info, struct fid_ep *ep) override;
2929
void sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) override;
@@ -72,7 +72,15 @@ class PlatformAWS : public Platform {
7272
return fields ? fields->func_idx : -EIO;
7373
}
7474

75+
// Determine AWS at runtime
76+
bool is_aws() const;
77+
7578
private:
79+
// Constants
80+
static constexpr int BASE_PRIORITY = 100;
81+
static constexpr unsigned short AWS_VENDOR_ID = 0x1D0F;
82+
static constexpr unsigned short EFA_DEV = 0xEFA0;
83+
7684
std::mutex mutex_;
7785

7886
// Cache for GUID fields to avoid repeated sysfs reads

m4/ax_platform_aws.m4

Lines changed: 0 additions & 38 deletions
This file was deleted.

src/Makefile.am

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,9 @@ sources = \
3434
nccl_ofi_ep_addr_list.cpp \
3535
nccl_ofi_param.cpp \
3636
nccl_ofi_platform.cpp \
37+
platform-aws.cpp \
3738
tracepoint.cpp
3839

39-
if WANT_PLATFORM_AWS
40-
sources += platform-aws.cpp
41-
endif
42-
4340
if ENABLE_NEURON
4441
sources += nccl_ofi_interface_neuron.cpp
4542
else
@@ -48,13 +45,11 @@ else
4845
nccl_ofi_interface_nvidia.cpp
4946

5047
# add the tuner sources into the library
51-
if WANT_PLATFORM_AWS
52-
sources += \
48+
sources += \
5349
tuner/nccl_ofi_regions.cpp \
5450
tuner/nccl_ofi_tuner.cpp \
5551
tuner/nccl_ofi_model.cpp
5652
endif
57-
endif
5853

5954
# Build an internal-only library that can be used by unit tests as
6055
# well as the actual nccl_net.so / nccom_net.so libraries. This saves
@@ -104,20 +99,18 @@ if ENABLE_NCCL_NET_LIBRARY
10499
libnccl_net_la_LIBTOOLFLAGS = --tag=CXX
105100
libnccl_net_la_LDFLAGS = -module -avoid-version
106101
endif
107-
if WANT_PLATFORM_AWS
102+
108103
# NCCL standardized on the libnccl-tuner-<interface> format after we released a
109104
# plugin with the tuner named libnccl-ofi-tuner.so. Create separate libraries
110-
# for each name.
111-
lib_LTLIBRARIES += libnccl-ofi-tuner.la libnccl-tuner-ofi.la
112-
libnccl_ofi_tuner_la_SOURCES =
113-
libnccl_ofi_tuner_la_LIBADD = libinternal_plugin.la
114-
libnccl_ofi_tuner_la_LIBTOOLFLAGS = --tag=CXX
115-
libnccl_ofi_tuner_la_LDFLAGS = -module -avoid-version
105+
lib_LTLIBRARIES += libnccl-ofi-tuner.la libnccl-tuner-ofi.la
106+
libnccl_ofi_tuner_la_SOURCES =
107+
libnccl_ofi_tuner_la_LIBADD = libinternal_plugin.la
108+
libnccl_ofi_tuner_la_LIBTOOLFLAGS = --tag=CXX
109+
libnccl_ofi_tuner_la_LDFLAGS = -module -avoid-version
116110

117-
libnccl_tuner_ofi_la_SOURCES =
118-
libnccl_tuner_ofi_la_LIBADD = libinternal_plugin.la
119-
libnccl_tuner_ofi_la_LIBTOOLFLAGS = --tag=CXX
120-
libnccl_tuner_ofi_la_LDFLAGS = -module -avoid-version
121-
endif
111+
libnccl_tuner_ofi_la_SOURCES =
112+
libnccl_tuner_ofi_la_LIBADD = libinternal_plugin.la
113+
libnccl_tuner_ofi_la_LIBTOOLFLAGS = --tag=CXX
114+
libnccl_tuner_ofi_la_LDFLAGS = -module -avoid-version
122115

123116
endif

src/nccl_ofi_net.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030
#include "nccl_ofi_idpool.h"
3131
#include "nccl_ofi_dmabuf.h"
3232
#include "nccl_ofi_platform.h"
33-
#ifdef WANT_AWS_PLATFORM
3433
#include "platform-aws.h"
35-
#endif
3634
#include "nccl_ofi_ofiutils.h"
3735
#include "nccl_ofi_system.h"
3836

@@ -178,9 +176,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
178176
nic_dup_conns = ofi_nccl_nic_dup_conns();
179177
cq_read_count = ofi_nccl_cq_read_count();
180178

181-
#ifdef WANT_AWS_PLATFORM
182179
PlatformManager::get_global().register_platform(std::make_unique<PlatformAWS>());
183-
#endif
184180

185181
ret = PlatformManager::get_global().get_platform().init(&provider_filter);
186182
if (ret != 0)
@@ -1058,7 +1054,7 @@ int nccl_net_ofi_ep_t::release_ep(bool skip_lock, bool force_cleanup)
10581054

10591055
/* Store ref_cnt in local variable in case the endpoint gets deleted */
10601056
int local_ref_cnt = this->ref_cnt;
1061-
1057+
10621058
if (local_ref_cnt == 0 || force_cleanup) {
10631059
/* If this was the endpoint we stored in domain for connection
10641060
management, remove that reference as well */

src/nccl_ofi_platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ void PlatformManager::register_platform(PlatformPtr&& platform) {
2121
// TODO: Add proper resolution mechanism for competing priorities
2222
priority++;
2323
}
24-
24+
NCCL_OFI_INFO(NCCL_INIT, "Adding %s platform with %d priority", name, priority);
2525
platforms_[priority] = std::move(platform);
2626
}

src/nccl_ofi_topo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ static hwloc_obj_t get_numa_mem_child(hwloc_obj_t node)
10721072
return child;
10731073
}
10741074

1075-
/*
1075+
/*
10761076
* @brief Return PCI device property of PCI device
10771077
*
10781078
* This function reads first `MAX_DEV_PROPERTY_LENGTH` characters from
@@ -1093,7 +1093,7 @@ static hwloc_obj_t get_numa_mem_child(hwloc_obj_t node)
10931093
* File name of the device property
10941094
* @return Pointer to an element of a char array to write device property to.
10951095
* The array has to be allocated by the caller of this function.
1096-
* There must be space for at least `MAX_DEV_PROPERTY_LENGTH`
1096+
* There must be space for at least `MAX_DEV_PROPERTY_LENGTH`
10971097
* characters in addition to the delimiting `\0`.
10981098
* @return 0, on sucess
10991099
* non-zero, on error

src/platform-aws.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include <map>
2222
#include <mutex>
2323
#include <string>
24+
#include <memory>
25+
#include <hwloc.h>
26+
#include <unistd.h>
2427

2528
#ifdef HAVE_RDMA_FI_EXT_H
2629
#include <rdma/fi_ext.h>
@@ -1047,3 +1050,26 @@ void PlatformAWS::sort_rails(struct fi_info **info_list, size_t num_rails, size_
10471050

10481051
return;
10491052
}
1053+
1054+
bool PlatformAWS::is_aws() const {
1055+
if (ofi_nccl_force_non_aws.get()) {
1056+
NCCL_OFI_INFO(NCCL_INIT, "Disabling PlatformAWS optimizations");
1057+
return false;
1058+
}
1059+
1060+
hwloc_topology_t topo;
1061+
if (hwloc_topology_init(&topo) != 0) return false;
1062+
1063+
auto topology = std::shared_ptr<hwloc_topology>(topo, hwloc_topology_destroy);
1064+
hwloc_topology_set_io_types_filter(topo, HWLOC_TYPE_FILTER_KEEP_ALL);
1065+
if (hwloc_topology_load(topo) != 0) return false;
1066+
1067+
hwloc_obj_t obj = nullptr;
1068+
while ((obj = hwloc_get_next_pcidev(topo, obj)) != nullptr) {
1069+
if (obj->attr->pcidev.vendor_id == AWS_VENDOR_ID &&
1070+
(obj->attr->pcidev.device_id & 0xFFF0) == EFA_DEV) {
1071+
return true;
1072+
}
1073+
}
1074+
return false;
1075+
}

0 commit comments

Comments
 (0)