Skip to content

Commit ca9795a

Browse files
authored
[UR][Offload] Implement buffer map/unmap (#19054)
Implement buffer map/unmap for Offload. Pretty much just a copy of the CUDA/HIP adapters implementation, with some workarounds for current offload limitations.
1 parent a9fada6 commit ca9795a

File tree

5 files changed

+149
-7
lines changed

5 files changed

+149
-7
lines changed

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,85 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
165165

166166
return UR_RESULT_SUCCESS;
167167
}
168+
169+
ur_result_t enqueueNoOp(ur_queue_handle_t hQueue, ur_event_handle_t *phEvent) {
170+
// This path is a no-op, but we can't output a real event because
171+
// Offload doesn't currently support creating arbitrary events, and we
172+
// don't know the last real event in the queue. Instead we just have to
173+
// wait on the whole queue and then return an empty (implicitly
174+
// finished) event.
175+
*phEvent = ur_event_handle_t_::createEmptyEvent();
176+
return urQueueFinish(hQueue);
177+
}
178+
179+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
180+
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingMap,
181+
ur_map_flags_t mapFlags, size_t offset, size_t size,
182+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
183+
ur_event_handle_t *phEvent, void **ppRetMap) {
184+
185+
auto &BufferImpl = std::get<BufferMem>(hBuffer->Mem);
186+
auto MapPtr = BufferImpl.mapToPtr(size, offset, mapFlags);
187+
188+
if (!MapPtr) {
189+
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
190+
}
191+
192+
const bool IsPinned =
193+
BufferImpl.MemAllocMode == BufferMem::AllocMode::AllocHostPtr;
194+
195+
ur_result_t Result = UR_RESULT_SUCCESS;
196+
if (!IsPinned &&
197+
((mapFlags & UR_MAP_FLAG_READ) || (mapFlags & UR_MAP_FLAG_WRITE))) {
198+
// Pinned host memory is already on host so it doesn't need to be read.
199+
Result = urEnqueueMemBufferRead(hQueue, hBuffer, blockingMap, offset, size,
200+
MapPtr, numEventsInWaitList,
201+
phEventWaitList, phEvent);
202+
} else {
203+
if (IsPinned) {
204+
// TODO: Ignore the event waits list for now. When urEnqueueEventsWait is
205+
// implemented we can call it on the wait list.
206+
}
207+
208+
if (phEvent) {
209+
enqueueNoOp(hQueue, phEvent);
210+
}
211+
}
212+
*ppRetMap = MapPtr;
213+
214+
return Result;
215+
}
216+
217+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
218+
ur_queue_handle_t hQueue, ur_mem_handle_t hMem, void *pMappedPtr,
219+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
220+
ur_event_handle_t *phEvent) {
221+
auto &BufferImpl = std::get<BufferMem>(hMem->Mem);
222+
223+
auto *Map = BufferImpl.getMapDetails(pMappedPtr);
224+
UR_ASSERT(Map != nullptr, UR_RESULT_ERROR_INVALID_MEM_OBJECT);
225+
226+
const bool IsPinned =
227+
BufferImpl.MemAllocMode == BufferMem::AllocMode::AllocHostPtr;
228+
229+
ur_result_t Result = UR_RESULT_SUCCESS;
230+
if (!IsPinned && ((Map->MapFlags & UR_MAP_FLAG_WRITE) ||
231+
(Map->MapFlags & UR_MAP_FLAG_WRITE_INVALIDATE_REGION))) {
232+
// Pinned host memory is only on host so it doesn't need to be written to.
233+
Result = urEnqueueMemBufferWrite(
234+
hQueue, hMem, true, Map->MapOffset, Map->MapSize, pMappedPtr,
235+
numEventsInWaitList, phEventWaitList, phEvent);
236+
} else {
237+
if (IsPinned) {
238+
// TODO: Ignore the event waits list for now. When urEnqueueEventsWait is
239+
// implemented we can call it on the wait list.
240+
}
241+
242+
if (phEvent) {
243+
enqueueNoOp(hQueue, phEvent);
244+
}
245+
}
246+
BufferImpl.unmap(pMappedPtr);
247+
248+
return Result;
249+
}

unified-runtime/source/adapters/offload/event.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "event.hpp"
1515
#include "ur2offload.hpp"
1616

17-
UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hKernel,
17+
UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
1818
ur_event_info_t propName,
1919
size_t propSize,
2020
void *pPropValue,
@@ -23,7 +23,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hKernel,
2323

2424
switch (propName) {
2525
case UR_EVENT_INFO_REFERENCE_COUNT:
26-
return ReturnValue(hKernel->RefCount.load());
26+
return ReturnValue(hEvent->RefCount.load());
2727
default:
2828
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
2929
}
@@ -42,9 +42,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetProfilingInfo(ur_event_handle_t,
4242
UR_APIEXPORT ur_result_t UR_APICALL
4343
urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
4444
for (uint32_t i = 0; i < numEvents; i++) {
45-
auto Res = olWaitEvent(phEventWaitList[i]->OffloadEvent);
46-
if (Res) {
47-
return offloadResultToUR(Res);
45+
if (phEventWaitList[i]->OffloadEvent) {
46+
auto Res = olWaitEvent(phEventWaitList[i]->OffloadEvent);
47+
if (Res) {
48+
return offloadResultToUR(Res);
49+
}
4850
}
4951
}
5052
return UR_RESULT_SUCCESS;

unified-runtime/source/adapters/offload/event.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,13 @@
1717

1818
struct ur_event_handle_t_ : RefCounted {
1919
ol_event_handle_t OffloadEvent;
20+
ur_command_t Type;
21+
22+
static ur_event_handle_t createEmptyEvent() {
23+
auto *Event = new ur_event_handle_t_();
24+
// Null event represents an empty event. Waiting on it is a no-op.
25+
Event->OffloadEvent = nullptr;
26+
27+
return Event;
28+
}
2029
};

unified-runtime/source/adapters/offload/memory.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ struct BufferMem {
2222
AllocHostPtr,
2323
};
2424

25+
struct BufferMap {
26+
size_t MapSize;
27+
size_t MapOffset;
28+
ur_map_flags_t MapFlags;
29+
// Allocated host memory used exclusively for this map.
30+
std::unique_ptr<unsigned char[]> MapMem;
31+
32+
BufferMap(size_t MapSize, size_t MapOffset, ur_map_flags_t MapFlags)
33+
: MapSize(MapSize), MapOffset(MapOffset), MapFlags(MapFlags),
34+
MapMem(nullptr) {}
35+
36+
BufferMap(size_t MapSize, size_t MapOffset, ur_map_flags_t MapFlags,
37+
std::unique_ptr<unsigned char[]> &&MapMem)
38+
: MapSize(MapSize), MapOffset(MapOffset), MapFlags(MapFlags),
39+
MapMem(std::move(MapMem)) {}
40+
};
41+
2542
ur_mem_handle_t Parent;
2643
// Underlying device pointer
2744
void *Ptr;
@@ -30,6 +47,7 @@ struct BufferMem {
3047
size_t Size;
3148

3249
AllocMode MemAllocMode;
50+
std::unordered_map<void *, BufferMap> PtrToBufferMap;
3351

3452
BufferMem(ur_mem_handle_t Parent, BufferMem::AllocMode Mode, void *Ptr,
3553
void *HostPtr, size_t Size)
@@ -38,6 +56,37 @@ struct BufferMem {
3856

3957
void *get() const noexcept { return Ptr; }
4058
size_t getSize() const noexcept { return Size; }
59+
60+
BufferMap *getMapDetails(void *Map) {
61+
auto Details = PtrToBufferMap.find(Map);
62+
if (Details != PtrToBufferMap.end()) {
63+
return &Details->second;
64+
}
65+
return nullptr;
66+
}
67+
68+
void *mapToPtr(size_t MapSize, size_t MapOffset,
69+
ur_map_flags_t MapFlags) noexcept {
70+
71+
void *MapPtr = nullptr;
72+
// If the buffer already has a host pointer we can just use it, otherwise
73+
// create a new host allocation
74+
if (HostPtr == nullptr) {
75+
auto MapMem = std::make_unique<unsigned char[]>(MapSize);
76+
MapPtr = MapMem.get();
77+
PtrToBufferMap.insert(
78+
{MapPtr, BufferMap(MapSize, MapOffset, MapFlags, std::move(MapMem))});
79+
} else {
80+
MapPtr = static_cast<char *>(HostPtr) + MapOffset;
81+
PtrToBufferMap.insert({MapPtr, BufferMap(MapSize, MapOffset, MapFlags)});
82+
}
83+
return MapPtr;
84+
}
85+
86+
void unmap(void *MapPtr) noexcept {
87+
assert(MapPtr != nullptr);
88+
PtrToBufferMap.erase(MapPtr);
89+
}
4190
};
4291

4392
struct ur_mem_handle_t_ : RefCounted {

unified-runtime/source/adapters/offload/ur_interface_loader.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,15 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
176176
pDdiTable->pfnMemBufferCopy = nullptr;
177177
pDdiTable->pfnMemBufferCopyRect = nullptr;
178178
pDdiTable->pfnMemBufferFill = nullptr;
179-
pDdiTable->pfnMemBufferMap = nullptr;
179+
pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap;
180180
pDdiTable->pfnMemBufferRead = urEnqueueMemBufferRead;
181181
pDdiTable->pfnMemBufferReadRect = nullptr;
182182
pDdiTable->pfnMemBufferWrite = urEnqueueMemBufferWrite;
183183
pDdiTable->pfnMemBufferWriteRect = nullptr;
184184
pDdiTable->pfnMemImageCopy = nullptr;
185185
pDdiTable->pfnMemImageRead = nullptr;
186186
pDdiTable->pfnMemImageWrite = nullptr;
187-
pDdiTable->pfnMemUnmap = nullptr;
187+
pDdiTable->pfnMemUnmap = urEnqueueMemUnmap;
188188
pDdiTable->pfnUSMFill2D = urEnqueueUSMFill2D;
189189
pDdiTable->pfnUSMFill = nullptr;
190190
pDdiTable->pfnUSMAdvise = nullptr;

0 commit comments

Comments
 (0)