@@ -31,30 +31,41 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
31
31
}
32
32
33
33
MSCCLPP_API_CPP void Communicator::sendMemory (RegisteredMemory memory, int remoteRank, int tag) {
34
- pimpl_-> bootstrap_ ->send (memory.serialize (), remoteRank, tag);
34
+ bootstrap () ->send (memory.serialize (), remoteRank, tag);
35
35
}
36
36
37
37
MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory (int remoteRank, int tag) {
38
- return std::async (std::launch::deferred, [this , remoteRank, tag]() {
38
+ size_t numRecvItemsAhead = pimpl_->recvQueues_ .getNumRecvItems (remoteRank, tag);
39
+ auto future = std::async (std::launch::deferred, [this , remoteRank, tag, numRecvItemsAhead]() {
40
+ pimpl_->recvQueues_ .waitN (remoteRank, tag, numRecvItemsAhead);
39
41
std::vector<char > data;
40
42
bootstrap ()->recv (data, remoteRank, tag);
41
43
return RegisteredMemory::deserialize (data);
42
44
});
45
+ auto shared_future = std::shared_future<RegisteredMemory>(std::move (future));
46
+ pimpl_->recvQueues_ .addRecvItem (remoteRank, tag, shared_future);
47
+ return shared_future;
43
48
}
44
49
45
50
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect (int remoteRank, int tag,
46
51
EndpointConfig localConfig) {
47
- auto localEndpoint = pimpl_-> context_ ->createEndpoint (localConfig);
48
- pimpl_-> bootstrap_ ->send (localEndpoint.serialize (), remoteRank, tag);
52
+ auto localEndpoint = context () ->createEndpoint (localConfig);
53
+ bootstrap () ->send (localEndpoint.serialize (), remoteRank, tag);
49
54
50
- return std::async (std::launch::deferred, [this , remoteRank, tag, localEndpoint = std::move (localEndpoint)]() mutable {
55
+ size_t numRecvItemsAhead = pimpl_->recvQueues_ .getNumRecvItems (remoteRank, tag);
56
+ auto future = std::async (std::launch::deferred, [this , remoteRank, tag, numRecvItemsAhead,
57
+ localEndpoint = std::move (localEndpoint)]() mutable {
58
+ pimpl_->recvQueues_ .waitN (remoteRank, tag, numRecvItemsAhead);
51
59
std::vector<char > data;
52
60
bootstrap ()->recv (data, remoteRank, tag);
53
61
auto remoteEndpoint = Endpoint::deserialize (data);
54
62
auto connection = context ()->connect (localEndpoint, remoteEndpoint);
55
63
pimpl_->connectionInfos_ [connection.get ()] = {remoteRank, tag};
56
64
return connection;
57
65
});
66
+ auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move (future));
67
+ pimpl_->recvQueues_ .addRecvItem (remoteRank, tag, shared_future);
68
+ return shared_future;
58
69
}
59
70
60
71
MSCCLPP_API_CPP int Communicator::remoteRankOf (const Connection& connection) {
0 commit comments