Skip to content

Commit 85df81c

Browse files
committed
Corrected recursively retaining a connection and lazy connection handling
* The variable storing the initializing thread id must be atomic. * Introduced a connection pointer in order to be able to handle both, connected and unconnected states.
1 parent 76309d1 commit 85df81c

File tree

5 files changed

+322
-190
lines changed

5 files changed

+322
-190
lines changed

dev/connection_holder.h

Lines changed: 92 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
#include <atomic> // std::atomic_int, memory order flags
66
#include <mutex> // std::mutex, std::lock_guard
77
#include <thread> // std::thread::id
8+
#include <utility> // std::swap, std::exchange
89
#include <functional> // std::function
910
#include <string> // std::string
1011
#endif
1112

1213
#include "functional/cxx_new.h"
14+
#include "functional/cxx_scope_guard.h"
1315
#include "functional/gsl.h"
1416
#include "error_code.h"
1517
#include "vfs_name.h"
@@ -45,7 +47,7 @@ namespace sqlite_orm {
4547
explicit connection_holder(const connection_holder& other, std::function<void(sqlite3*)> didOpenDb) :
4648
_control{other._control.openedForeverHint}, dbArgs{other.dbArgs}, _didOpenDb{std::move(didOpenDb)} {}
4749

48-
explicit connection_holder(const connection_holder& other, std::true_type /*openForever*/) :
50+
explicit connection_holder(const connection_holder& other, std::true_type /*openedForeverHint*/) :
4951
_control{true}, dbArgs{other.dbArgs}, _didOpenDb{other._didOpenDb} {}
5052

5153
/*
@@ -107,30 +109,20 @@ namespace sqlite_orm {
107109
_do_close();
108110
}
109111

110-
sqlite3* retain() {
111-
// optimize for permanently opened connections
112+
sqlite3* retain_if_open() {
113+
// optional marginal optimization for permanently opened connections;
112114
if (_control.openedForeverHint) {
113115
#ifdef SQLITE_ORM_CONTRACTS_SUPPORTED
114116
contract_assert(_control.db);
115117
#endif
116118
return _control.db;
117119
}
118120

119-
// recursive and fast path
120-
{
121-
int currentCount = _control.retainCount.load(std::memory_order_acquire);
122-
123-
// test for recursion from the same thread
124-
if (currentCount == 0) {
125-
if (_control.lockingThread == std::this_thread::get_id()) SQLITE_ORM_CPP_UNLIKELY {
126-
return _control.db;
127-
}
128-
}
129-
130-
// optional fast path: if connection is already open, just increment counter;
131-
// this can make a difference while a transaction is active where all things happen in memory only;
132-
// it makes a difference if the `_didOpenDb` callback has a lot of work to do.
133-
while (currentCount > 0) {
121+
// optional fast path: if connection is already open, just increment counter;
122+
// this can make a difference while a transaction is active where all things happen in memory only;
123+
// it makes a difference if the `_didOpenDb` callback has a lot of work to do.
124+
if (int currentCount = _control.retainCount.load(std::memory_order_acquire)) {
125+
do {
134126
if (_control.retainCount.compare_exchange_weak(currentCount,
135127
currentCount + 1,
136128
std::memory_order_release,
@@ -139,33 +131,50 @@ namespace sqlite_orm {
139131
return _control.db;
140132
}
141133
// CAS failed - retry
134+
} while (currentCount > 0);
135+
}
136+
// test for recursion from the same thread
137+
else {
138+
const std::thread::id threadId = _control.initializingThreadId.load(std::memory_order_acquire);
139+
if (threadId != std::thread::id{} && std::this_thread::get_id() == threadId)
140+
SQLITE_ORM_CPP_UNLIKELY {
141+
return _control.db;
142142
}
143143
}
144144

145+
return nullptr;
146+
}
147+
148+
sqlite3* retain() {
149+
if (sqlite3* db = retain_if_open()) {
150+
return db;
151+
}
152+
145153
// slow path: need to open connection or wait for it
146-
{
147-
const std::lock_guard _{_sync};
148154

149-
// double-check: another thread might have opened it
150-
const bool needsToBeOpened = _control.retainCount == 0;
151-
if (needsToBeOpened) {
152-
_do_open();
153-
if (_didOpenDb) {
154-
_control.lockingThread = std::this_thread::get_id();
155-
// note: may incur recursion in user-provided `on_open` callback
156-
_didOpenDb(_control.db);
157-
_control.lockingThread = std::thread::id{};
158-
}
155+
const std::lock_guard _{_sync};
156+
157+
// double-check: another thread might have opened it
158+
const bool needsToBeOpened = _control.retainCount == 0;
159+
if (needsToBeOpened) {
160+
_do_open();
161+
if (_didOpenDb) {
162+
_control.initializingThreadId.store(std::this_thread::get_id(), std::memory_order_release);
163+
const scope_guard threadIdGuard{[&threadId = _control.initializingThreadId] {
164+
threadId.store(std::thread::id{}, std::memory_order_release);
165+
}};
166+
// note: may incur recursion in user-provided `on_open` callback
167+
_didOpenDb(_control.db);
159168
}
160-
161-
// attention: only increase the reference count after successful open in order to propagate a fully setup connection to other threads
162-
_control.retainCount.fetch_add(1, std::memory_order_release);
163-
return _control.db;
164169
}
170+
171+
// attention: only increase the reference count after successful open in order to propagate a fully setup connection to other threads
172+
_control.retainCount.fetch_add(1, std::memory_order_release);
173+
return _control.db;
165174
}
166175

167176
void release() {
168-
// optimize for permanently opened connections
177+
// optional marginal optimization for permanently opened connections;
169178
if (_control.openedForeverHint) {
170179
#ifdef SQLITE_ORM_CONTRACTS_SUPPORTED
171180
contract_assert(_control.db);
@@ -175,12 +184,13 @@ namespace sqlite_orm {
175184

176185
// test for recursion from the same thread;
177186
// testing against an empty thread id is sufficient because recursion is only possible while calling the `_didOpenDb` callback in `retain()`
178-
if (_control.lockingThread != std::thread::id{}) SQLITE_ORM_CPP_UNLIKELY {
187+
if (_control.initializingThreadId.load(std::memory_order_acquire) != std::thread::id{})
188+
SQLITE_ORM_CPP_UNLIKELY {
179189
return;
180190
}
181191

182-
const int previous = _control.retainCount.fetch_sub(1, std::memory_order_release);
183-
if (previous == 1) {
192+
const int previousCount = _control.retainCount.fetch_sub(1, std::memory_order_release);
193+
if (previousCount == 1) {
184194
// last one closes the connection
185195

186196
const std::lock_guard _{_sync};
@@ -192,21 +202,19 @@ namespace sqlite_orm {
192202
}
193203
}
194204

195-
/*
196-
Precondition: Call from a single-threaded context or after `retain()`.
197-
*/
198-
sqlite3* get() const {
199-
return _control.db;
200-
}
201-
202205
// note: members of the `control_block` are deliberately put on the same cache-line
203206
SQLITE_ORM_MSVC_SUPPRESS_OVERALIGNMENT(alignas(polyfill::hardware_destructive_interference_size))
204207
struct control_block {
208+
// the optimization gain is very small;
209+
// at some design point it served as a flag to not use a mutex at all;
210+
// now it merely saves all the atomic operations, which actually perform without noticeable difference;
211+
// however it may be kept for conveying logic or future optimizations.
205212
const bool openedForeverHint = false;
206213
std::atomic_int retainCount{};
214+
// `db` synchronizes with `retainCount`
207215
orm_gsl::owner<sqlite3*> db = nullptr;
208216
// we don't know what the user-provided `on_open` callback might do, so we need to track recursion;
209-
std::thread::id lockingThread;
217+
std::atomic<std::thread::id> initializingThreadId{};
210218
} _control;
211219

212220
SQLITE_ORM_MSVC_SUPPRESS_OVERALIGNMENT(alignas(polyfill::hardware_destructive_interference_size))
@@ -223,13 +231,10 @@ namespace sqlite_orm {
223231
/*
224232
Rebind connection reference;
225233
This function is actually unused in the library, but required for concepts compliance (moveable type).
226-
Unfortunately it is not `noexcept` because of the `release()` call.
227234
*/
228-
connection_ref& operator=(connection_ref&& other) {
229-
this->holder->release();
230-
this->holder = other.holder;
231-
this->db = other.db;
232-
this->holder->retain();
235+
connection_ref& operator=(connection_ref&& other) noexcept {
236+
std::swap(this->holder, other.holder);
237+
std::swap(this->db, other.db);
233238
return *this;
234239
}
235240

@@ -243,7 +248,41 @@ namespace sqlite_orm {
243248

244249
private:
245250
connection_holder* holder;
246-
sqlite3* db = nullptr;
251+
sqlite3* db;
252+
};
253+
254+
struct connection_ptr {
255+
connection_ptr(connection_holder& holder) : holder{&holder}, db{holder.retain_if_open()} {}
256+
257+
connection_ptr(connection_ptr&& other) noexcept :
258+
holder{other.holder}, db{std::exchange(other.db, nullptr)} {}
259+
260+
/*
261+
Rebind connection pointer;
262+
*/
263+
connection_ptr& operator=(connection_ptr&& other) noexcept {
264+
std::swap(this->holder, other.holder);
265+
std::swap(this->db, other.db);
266+
return *this;
267+
}
268+
269+
~connection_ptr() {
270+
if (this->db) {
271+
this->holder->release();
272+
}
273+
}
274+
275+
explicit operator bool() const {
276+
return this->db || false;
277+
}
278+
279+
sqlite3* get() const {
280+
return this->db;
281+
}
282+
283+
private:
284+
connection_holder* holder;
285+
sqlite3* db;
247286
};
248287
}
249288
}

dev/functional/cxx_scope_guard.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
namespace sqlite_orm::internal {
4+
/*
5+
Poor-man's scope (exit) guard.
6+
*/
7+
template<class F>
8+
struct scope_guard {
9+
explicit scope_guard(F f) : f{std::move(f)} {}
10+
~scope_guard() {
11+
f();
12+
}
13+
14+
F f;
15+
};
16+
}

0 commit comments

Comments
 (0)