Skip to content

Commit f9528b0

Browse files
committed
Fix #514 & minor updates
1 parent a464b9f commit f9528b0

File tree

4 files changed

+61
-8
lines changed

4 files changed

+61
-8
lines changed

python/mscclpp/comm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def make_connection(
107107
return connections
108108

109109
def register_tensor_with_connections(
110-
self, tensor: Type[cp.ndarray] or Type[np.ndarray], connections: dict[int, Connection]
110+
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
111111
) -> dict[int, RegisteredMemory]:
112112
transport_flags = TransportFlags()
113113
for rank in connections:
@@ -135,7 +135,7 @@ def register_tensor_with_connections(
135135
def make_semaphore(
136136
self,
137137
connections: dict[int, Connection],
138-
semaphore_type: Type[Host2HostSemaphore] or Type[Host2DeviceSemaphore] or Type[MemoryDevice2DeviceSemaphore],
138+
semaphore_type: Type[Host2HostSemaphore] | Type[Host2DeviceSemaphore] | Type[MemoryDevice2DeviceSemaphore],
139139
) -> dict[int, Host2HostSemaphore]:
140140
semaphores = {}
141141
for rank in connections:

src/bootstrap/bootstrap.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ std::shared_ptr<Socket> TcpBootstrap::Impl::getPeerRecvSocket(int peer, int tag)
528528
if (recvPeer == peer && recvTag == tag) {
529529
return sock;
530530
}
531+
// TODO(chhwang): set an exit condition or timeout
531532
}
532533
}
533534

src/communicator.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,41 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
3131
}
3232

3333
MSCCLPP_API_CPP void Communicator::sendMemory(RegisteredMemory memory, int remoteRank, int tag) {
34-
pimpl_->bootstrap_->send(memory.serialize(), remoteRank, tag);
34+
bootstrap()->send(memory.serialize(), remoteRank, tag);
3535
}
3636

3737
MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(int remoteRank, int tag) {
38-
return std::async(std::launch::deferred, [this, remoteRank, tag]() {
38+
size_t numRecvItemsAhead = pimpl_->recvQueues_.getNumRecvItems(remoteRank, tag);
39+
auto future = std::async(std::launch::deferred, [this, remoteRank, tag, numRecvItemsAhead]() {
40+
pimpl_->recvQueues_.waitN(remoteRank, tag, numRecvItemsAhead);
3941
std::vector<char> data;
4042
bootstrap()->recv(data, remoteRank, tag);
4143
return RegisteredMemory::deserialize(data);
4244
});
45+
auto shared_future = std::shared_future<RegisteredMemory>(std::move(future));
46+
pimpl_->recvQueues_.addRecvItem(remoteRank, tag, shared_future);
47+
return shared_future;
4348
}
4449

4550
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
4651
EndpointConfig localConfig) {
47-
auto localEndpoint = pimpl_->context_->createEndpoint(localConfig);
48-
pimpl_->bootstrap_->send(localEndpoint.serialize(), remoteRank, tag);
52+
auto localEndpoint = context()->createEndpoint(localConfig);
53+
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
4954

50-
return std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint = std::move(localEndpoint)]() mutable {
55+
size_t numRecvItemsAhead = pimpl_->recvQueues_.getNumRecvItems(remoteRank, tag);
56+
auto future = std::async(std::launch::deferred, [this, remoteRank, tag, numRecvItemsAhead,
57+
localEndpoint = std::move(localEndpoint)]() mutable {
58+
pimpl_->recvQueues_.waitN(remoteRank, tag, numRecvItemsAhead);
5159
std::vector<char> data;
5260
bootstrap()->recv(data, remoteRank, tag);
5361
auto remoteEndpoint = Endpoint::deserialize(data);
5462
auto connection = context()->connect(localEndpoint, remoteEndpoint);
5563
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
5664
return connection;
5765
});
66+
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
67+
pimpl_->recvQueues_.addRecvItem(remoteRank, tag, shared_future);
68+
return shared_future;
5869
}
5970

6071
MSCCLPP_API_CPP int Communicator::remoteRankOf(const Connection& connection) {

src/include/communicator.hpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,49 @@
99
#include <unordered_map>
1010
#include <vector>
1111

12+
#include "utils_internal.hpp"
13+
1214
namespace mscclpp {
1315

14-
class ConnectionBase;
16+
class BaseRecvItem {
17+
public:
18+
virtual ~BaseRecvItem() = default;
19+
virtual void wait() = 0;
20+
};
21+
22+
template <typename T>
23+
class RecvItem : public BaseRecvItem {
24+
public:
25+
RecvItem(std::shared_future<T> future) : future_(future) {}
26+
27+
void wait() { future_.wait(); }
28+
29+
private:
30+
std::shared_future<T> future_;
31+
};
32+
33+
class RecvQueues {
34+
public:
35+
RecvQueues() = default;
36+
37+
template <typename T>
38+
void addRecvItem(int remoteRank, int tag, std::shared_future<T> future) {
39+
auto& queue = queues_[std::make_pair(remoteRank, tag)];
40+
queue.emplace_back(std::make_shared<RecvItem<T>>(future));
41+
}
42+
43+
size_t getNumRecvItems(int remoteRank, int tag) { return queues_[std::make_pair(remoteRank, tag)].size(); }
44+
45+
void waitN(int remoteRank, int tag, size_t n) {
46+
auto& queue = queues_[std::make_pair(remoteRank, tag)];
47+
for (size_t i = 0; i < n; ++i) {
48+
queue[i]->wait();
49+
}
50+
}
51+
52+
private:
53+
std::unordered_map<std::pair<int, int>, std::vector<std::shared_ptr<BaseRecvItem>>, PairHash> queues_;
54+
};
1555

1656
struct ConnectionInfo {
1757
int remoteRank;
@@ -22,6 +62,7 @@ struct Communicator::Impl {
2262
std::shared_ptr<Bootstrap> bootstrap_;
2363
std::shared_ptr<Context> context_;
2464
std::unordered_map<const Connection*, ConnectionInfo> connectionInfos_;
65+
RecvQueues recvQueues_;
2566

2667
Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context);
2768

0 commit comments

Comments
 (0)