|
7 | 7 |
|
8 | 8 | #include <memory> |
9 | 9 | #include <map> |
| 10 | +#include <limits> |
10 | 11 |
|
11 | 12 | #include <rdma/fabric.h> |
12 | 13 | #include <rdma/fi_endpoint.h> |
13 | 14 |
|
| 15 | +#include "nccl_ofi_param.h" |
14 | 16 | #include "nccl_ofi_system.h" |
15 | 17 |
|
16 | 18 | /** |
@@ -135,44 +137,57 @@ class PlatformManager { |
135 | 137 | */ |
136 | 138 | static PlatformManager& get_global(); |
137 | 139 |
|
138 | | - /** |
139 | | - * @brief Register a platform with the manager |
140 | | - * |
141 | | - * Platforms are automatically sorted by priority in the internal map. |
142 | | - * Higher priority values take precedence and duplicates are dropped. |
143 | | - * |
144 | | - * @param platform Platform instance to register (moved) |
145 | | - */ |
146 | | - void register_platform(PlatformPtr&& platform); |
147 | | - |
148 | 140 | /** |
149 | 141 | * @brief Get the highest priority platform instance |
150 | 142 | * |
151 | 143 | * Returns the platform with the highest priority value. |
152 | 144 | * |
153 | 145 | * @return Reference to highest priority platform |
154 | 146 | */ |
155 | | - inline Platform& get_platform() { return *platforms_.rbegin()->second; } |
| 147 | + inline Platform& get_platform() { |
| 148 | + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Selected platform: %s", |
| 149 | + platform_->get_name()); |
| 150 | + return *platform_; |
| 151 | + } |
156 | 152 |
|
157 | | - /** |
158 | | - * @brief Get number of registered platforms (for testing) |
159 | | - * |
160 | | - * @return Number of platforms in the manager |
161 | | - */ |
162 | | - inline size_t get_platform_count() { return platforms_.size(); } |
163 | 153 | protected: |
164 | 154 | /** |
165 | 155 | * @brief Default constructor |
166 | 156 | * Register the default Platform by default. A static global |
167 | 157 | * instance is meant to be used in the plugin and the unit |
168 | 158 | * tests leverage the protected scope. |
169 | 159 | */ |
170 | | - PlatformManager() { |
171 | | - register_platform(std::make_unique<Default>()); |
| 160 | + PlatformManager(); |
| 161 | + |
| 162 | + /** |
| 163 | + * @brief Register a platform with the manager |
| 164 | + * |
| 165 | + * Platforms are selected by priority. Higher priority values take |
| 166 | + * precedence. This can only be done in the constructor as all platforms |
| 167 | + * must be added during object creation to allow the tuner and plugin |
| 168 | + * to operate consistently. |
| 169 | + * |
| 170 | + * @param platform Platform instance to register (moved) |
| 171 | + */ |
| 172 | + void register_platform(PlatformPtr&& platform) { |
| 173 | + int priority = platform->get_priority(); |
| 174 | + // Replace if no current platform or higher priority |
| 175 | + if (!platform_ || priority > current_priority_) { |
| 176 | + platform_ = std::move(platform); |
| 177 | + current_priority_ = priority; |
| 178 | + |
| 179 | + // Set max priority if this matches the override so no other platform can override |
| 180 | + if (!override_.empty() && override_ == platform_->get_name()) { |
| 181 | + current_priority_ = std::numeric_limits<int>::max(); |
| 182 | + } |
| 183 | + } |
172 | 184 | } |
173 | 185 |
|
| 186 | + |
174 | 187 | private: |
175 | | - std::map<int, PlatformPtr> platforms_; |
| 188 | + std::string override_ = ""; |
| 189 | + int current_priority_ = -1; |
| 190 | + PlatformPtr platform_ = nullptr; |
176 | 191 | }; |
177 | 192 |
|
178 | 193 |
|
|
0 commit comments