|
5 | 5 | #ifndef NCCL_OFI_PLATFORM_H_
|
6 | 6 | #define NCCL_OFI_PLATFORM_H_
|
7 | 7 |
|
| 8 | +#include <memory> |
| 9 | +#include <map> |
| 10 | + |
8 | 11 | #include <rdma/fabric.h>
|
9 | 12 | #include <rdma/fi_endpoint.h>
|
10 | 13 |
|
11 |
| -/* Declare platform-specific hooks that can be provided by platform-specific |
12 |
| - * source files (such as the optionally compiled platform_aws.c). The functions |
13 |
| - * here are declared as weak symbols so that linkage will not break if no |
14 |
| - * platform specific hook is provided; in that case the hook will be NULL at |
15 |
| - * runtime. |
16 |
| - */ |
17 |
| - |
18 |
| -/* Platform-specific initialization hook. |
19 |
| - */ |
20 |
| -int platform_init(const char **provider_filter) __attribute__((weak)); |
21 |
| - |
22 |
| -/* Platform-specific endpoint configuration hook |
23 |
| - */ |
24 |
| -int platform_config_endpoint(struct fi_info *info, struct fid_ep *ep) __attribute__((weak)); |
| 14 | +#include "nccl_ofi_system.h" |
25 | 15 |
|
26 |
| -/* Platform-specific hook to sort in the multi-rail protocol of the |
27 |
| - * plugin |
| 16 | +/** |
| 17 | + * @brief Abstract base class representing a platform implementation for NCCL OFI plugin |
28 | 18 | *
|
29 |
| - * Rail-oriented networks or traffic flows are a common performance |
30 |
| - * optimization for ML networks. Generally, Libfabric providers sort |
31 |
| - * their provider list by BDFs, which are indicitive of physical |
32 |
| - * ordering and good enough. However, on some platforms (especially |
33 |
| - * virtualized platforms), this might not actually be sufficient and |
34 |
| - * another sorting mechanism may be required to properly group NICs. |
| 19 | + * The Platform class provides an interface for platform-specific operations and configurations |
| 20 | + * in the NCCL OFI plugin. It defines virtual methods that must be implemented by concrete |
| 21 | + * platform implementations to handle platform-specific initialization, endpoint configuration, |
| 22 | + * and rail sorting operations. |
35 | 23 | *
|
36 |
| - * This interface is called in the topology initialization code to |
37 |
| - * order NICs that are behind the same PCIe root complex / switch. |
38 |
| - * The info_list will have num_rails providers listed, and will later |
39 |
| - * be split into num_groups groups (based on the number of |
40 |
| - * accelerators that are also behind the PCIe switch). |
| 24 | + * Each platform implementation can specify its priority level for selection, with higher |
| 25 | + * priority platforms being preferred over lower priority ones. |
41 | 26 | *
|
42 |
| - * Providers of this interface should sort the provided info_list such |
43 |
| - * that the Nth provider on this node will be assumed to talk to the |
44 |
| - * Nth provider on remote nodes (ie, identify the "rail id" and sort |
45 |
| - * by that). |
| 27 | + * Future platform are to be implemented by inheriting this class and overriding the |
| 28 | + * given functions. Look at PlatformAWS or Default as an example. |
46 | 29 | *
|
47 |
| - * @param info_list: pointer to list of `num_rails` info objects |
48 |
| - * @param num_rails: number of rails |
| 30 | + * @see Default |
| 31 | + * @see PlatformAWS |
| 32 | + * @see PlatformManager |
49 | 33 | */
|
50 |
| -void platform_sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) __attribute__((weak)); |
| 34 | +class Platform { |
| 35 | +public: |
| 36 | + virtual ~Platform() = default; |
| 37 | + |
| 38 | + /** |
| 39 | + * @brief Get platform name |
| 40 | + * |
| 41 | + * @return Platform name string |
| 42 | + */ |
| 43 | + virtual const char* get_name() const = 0; |
| 44 | + |
| 45 | + /** |
| 46 | + * @brief Get platform priority for selection. |
| 47 | + * |
| 48 | + * @return Priority value (higher values have higher priority) |
| 49 | + */ |
| 50 | + virtual int get_priority() = 0; |
| 51 | + |
| 52 | + /** |
| 53 | + * @brief Platform-specific initialization hook |
| 54 | + * |
| 55 | + * @param provider_filter Pointer to provider filter string |
| 56 | + * |
| 57 | + * @return 0 on success, error code on failure |
| 58 | + */ |
| 59 | + virtual int init(const char **provider_filter) = 0; |
| 60 | + |
| 61 | + /** |
| 62 | + * @brief Platform-specific endpoint configuration hook |
| 63 | + * |
| 64 | + * @param info Fabric info structure |
| 65 | + * @param ep Fabric endpoint |
| 66 | + * |
| 67 | + * @return 0 on success, error code on failure |
| 68 | + */ |
| 69 | + virtual int config_endpoint(struct fi_info *info, struct fid_ep *ep) = 0; |
| 70 | + |
| 71 | + /** |
| 72 | + * @brief Platform-specific hook to sort in the multi-rail protocol of the plugin |
| 73 | + * |
| 74 | + * Rail-oriented networks or traffic flows are a common performance |
| 75 | + * optimization for ML networks. Generally, Libfabric providers sort |
| 76 | + * their provider list by BDFs, which are indicitive of physical |
| 77 | + * ordering and good enough. However, on some platforms (especially |
| 78 | + * virtualized platforms), this might not actually be sufficient and |
| 79 | + * another sorting mechanism may be required to properly group NICs. |
| 80 | + * |
| 81 | + * This interface is called in the topology initialization code to |
| 82 | + * order NICs that are behind the same PCIe root complex / switch. |
| 83 | + * The info_list will have num_rails providers listed, and will later |
| 84 | + * be split into num_groups groups (based on the number of |
| 85 | + * accelerators that are also behind the PCIe switch). |
| 86 | + * |
| 87 | + * Providers of this interface should sort the provided info_list such |
| 88 | + * that the Nth provider on this node will be assumed to talk to the |
| 89 | + * Nth provider on remote nodes (ie, identify the "rail id" and sort |
| 90 | + * by that). |
| 91 | + * |
| 92 | + * @param info_list Array of fabric info pointers to sort |
| 93 | + * @param num_rails Number of rails in the list |
| 94 | + * @param num_groups Number of groups to split rails into |
| 95 | + */ |
| 96 | + virtual void sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) = 0; |
| 97 | + |
| 98 | + /** |
| 99 | + * @brief Platform-specific device GUID setter |
| 100 | + * |
| 101 | + * Sets device GUID to uniquely identify the network device |
| 102 | + * |
| 103 | + * @param info Fabric info structure |
| 104 | + * @param device Network device to set GUID for |
| 105 | + */ |
| 106 | + virtual void device_set_guid(struct fi_info *info, nccl_net_ofi_device_t *device) = 0; |
| 107 | +}; |
| 108 | + |
| 109 | +using PlatformPtr = std::unique_ptr<Platform>; |
| 110 | + |
| 111 | +class Default : public Platform { |
| 112 | +public: |
| 113 | + const char* get_name() const override { return "Default"; } |
| 114 | + int get_priority() override { return 0; } |
| 115 | + int init(const char **provider_filter) override { return 0; } |
| 116 | + int config_endpoint(struct fi_info *info, struct fid_ep *ep) override { return 0; } |
| 117 | + void sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups) override {} |
| 118 | + void device_set_guid(struct fi_info *info, nccl_net_ofi_device_t *device) override { |
| 119 | + uint32_t node_id = nccl_ofi_get_unique_node_id(); |
| 120 | + /* |
| 121 | + * Use device_index as lower 8 bits |
| 122 | + * Use node_id as next 32 bits (bits 8-39) |
| 123 | + * Upper 24 bits remain 0 |
| 124 | + */ |
| 125 | + device->guid = (static_cast<uint64_t>(node_id) << 8) | device->dev_id; |
| 126 | + } |
| 127 | +}; |
| 128 | + |
| 129 | +class PlatformManager { |
| 130 | +public: |
| 131 | + /** |
| 132 | + * @brief Get global instance |
| 133 | + * |
| 134 | + * @return Reference to global instance |
| 135 | + */ |
| 136 | + static PlatformManager& get_global(); |
| 137 | + |
| 138 | + /** |
| 139 | + * @brief Register a platform with the manager |
| 140 | + * |
| 141 | + * Platforms are automatically sorted by priority in the internal map. |
| 142 | + * Higher priority values take precedence and duplicates are dropped. |
| 143 | + * |
| 144 | + * @param platform Platform instance to register (moved) |
| 145 | + */ |
| 146 | + void register_platform(PlatformPtr&& platform); |
| 147 | + |
| 148 | + /** |
| 149 | + * @brief Get the highest priority platform instance |
| 150 | + * |
| 151 | + * Returns the platform with the highest priority value. |
| 152 | + * |
| 153 | + * @return Reference to highest priority platform |
| 154 | + */ |
| 155 | + inline Platform& get_platform() { return *platforms_.rbegin()->second; } |
| 156 | + |
| 157 | + /** |
| 158 | + * @brief Get number of registered platforms (for testing) |
| 159 | + * |
| 160 | + * @return Number of platforms in the manager |
| 161 | + */ |
| 162 | + inline size_t get_platform_count() { return platforms_.size(); } |
| 163 | +protected: |
| 164 | + /** |
| 165 | + * @brief Default constructor |
| 166 | + * Register the default Platform by default. A static global |
| 167 | + * instance is meant to be used in the plugin and the unit |
| 168 | + * tests leverage the protected scope. |
| 169 | + */ |
| 170 | + PlatformManager() { |
| 171 | + register_platform(std::make_unique<Default>()); |
| 172 | + } |
| 173 | + |
| 174 | +private: |
| 175 | + std::map<int, PlatformPtr> platforms_; |
| 176 | +}; |
51 | 177 |
|
52 |
| -/* |
53 |
| - * Platform-specific guid property setter |
54 |
| - * |
55 |
| - * This overrides the default guid setter (nccl_net_ofi_device_set_guid()) which |
56 |
| - * is based on network device index and IP address. Platforms can set |
57 |
| - * device->guid to be any 64-bit value as they seem fit to uniquely identify the |
58 |
| - * network device. |
59 |
| - */ |
60 |
| -void platform_device_set_guid(struct fi_info *info, nccl_net_ofi_device_t *device) __attribute__((weak)); |
61 | 178 |
|
62 | 179 | #endif // End NCCL_OFI_PLATFORM_H_
|
0 commit comments