55#include < atomic> // std::atomic_int, memory order flags
66#include < mutex> // std::mutex, std::lock_guard
77#include < thread> // std::thread::id
8+ #include < utility> // std::swap, std::exchange
89#include < functional> // std::function
910#include < string> // std::string
1011#endif
1112
1213#include " functional/cxx_new.h"
14+ #include " functional/cxx_scope_guard.h"
1315#include " functional/gsl.h"
1416#include " error_code.h"
1517#include " vfs_name.h"
@@ -45,7 +47,7 @@ namespace sqlite_orm {
4547 explicit connection_holder (const connection_holder& other, std::function<void (sqlite3*)> didOpenDb) :
4648 _control{other._control .openedForeverHint }, dbArgs{other.dbArgs }, _didOpenDb{std::move (didOpenDb)} {}
4749
48- explicit connection_holder (const connection_holder& other, std::true_type /* openForever */ ) :
50+ explicit connection_holder (const connection_holder& other, std::true_type /* openedForeverHint */ ) :
4951 _control{true }, dbArgs{other.dbArgs }, _didOpenDb{other._didOpenDb } {}
5052
5153 /*
@@ -107,30 +109,20 @@ namespace sqlite_orm {
107109 _do_close ();
108110 }
109111
110- sqlite3* retain () {
111- // optimize for permanently opened connections
112+ sqlite3* retain_if_open () {
113+ // optional marginal optimization for permanently opened connections;
112114 if (_control.openedForeverHint ) {
113115#ifdef SQLITE_ORM_CONTRACTS_SUPPORTED
114116 contract_assert (_control.db );
115117#endif
116118 return _control.db ;
117119 }
118120
119- // recursive and fast path
120- {
121- int currentCount = _control.retainCount .load (std::memory_order_acquire);
122-
123- // test for recursion from the same thread
124- if (currentCount == 0 ) {
125- if (_control.lockingThread == std::this_thread::get_id ()) SQLITE_ORM_CPP_UNLIKELY {
126- return _control.db ;
127- }
128- }
129-
130- // optional fast path: if connection is already open, just increment counter;
131- // this can make a difference while a transaction is active where all things happen in memory only;
132- // it makes a difference if the `_didOpenDb` callback has a lot of work to do.
133- while (currentCount > 0 ) {
121+ // optional fast path: if connection is already open, just increment counter;
122+ // this can make a difference while a transaction is active where all things happen in memory only;
123+ // it makes a difference if the `_didOpenDb` callback has a lot of work to do.
124+ if (int currentCount = _control.retainCount .load (std::memory_order_acquire)) {
125+ do {
134126 if (_control.retainCount .compare_exchange_weak (currentCount,
135127 currentCount + 1 ,
136128 std::memory_order_release,
@@ -139,33 +131,50 @@ namespace sqlite_orm {
139131 return _control.db ;
140132 }
141133 // CAS failed - retry
134+ } while (currentCount > 0 );
135+ }
136+ // test for recursion from the same thread
137+ else {
138+ const std::thread::id threadId = _control.initializingThreadId .load (std::memory_order_acquire);
139+ if (threadId != std::thread::id{} && std::this_thread::get_id () == threadId)
140+ SQLITE_ORM_CPP_UNLIKELY {
141+ return _control.db ;
142142 }
143143 }
144144
145+ return nullptr ;
146+ }
147+
148+ sqlite3* retain () {
149+ if (sqlite3* db = retain_if_open ()) {
150+ return db;
151+ }
152+
145153 // slow path: need to open connection or wait for it
146- {
147- const std::lock_guard _{_sync};
148154
149- // double-check: another thread might have opened it
150- const bool needsToBeOpened = _control.retainCount == 0 ;
151- if (needsToBeOpened) {
152- _do_open ();
153- if (_didOpenDb) {
154- _control.lockingThread = std::this_thread::get_id ();
155- // note: may incur recursion in user-provided `on_open` callback
156- _didOpenDb (_control.db );
157- _control.lockingThread = std::thread::id{};
158- }
155+ const std::lock_guard _{_sync};
156+
157+ // double-check: another thread might have opened it
158+ const bool needsToBeOpened = _control.retainCount == 0 ;
159+ if (needsToBeOpened) {
160+ _do_open ();
161+ if (_didOpenDb) {
162+ _control.initializingThreadId .store (std::this_thread::get_id (), std::memory_order_release);
163+ const scope_guard threadIdGuard{[&threadId = _control.initializingThreadId ] {
164+ threadId.store (std::thread::id{}, std::memory_order_release);
165+ }};
166+ // note: may incur recursion in user-provided `on_open` callback
167+ _didOpenDb (_control.db );
159168 }
160-
161- // attention: only increase the reference count after successful open in order to propagate a fully setup connection to other threads
162- _control.retainCount .fetch_add (1 , std::memory_order_release);
163- return _control.db ;
164169 }
170+
171+ // attention: only increase the reference count after successful open in order to propagate a fully setup connection to other threads
172+ _control.retainCount .fetch_add (1 , std::memory_order_release);
173+ return _control.db ;
165174 }
166175
167176 void release () {
168- // optimize for permanently opened connections
177+ // optional marginal optimization for permanently opened connections;
169178 if (_control.openedForeverHint ) {
170179#ifdef SQLITE_ORM_CONTRACTS_SUPPORTED
171180 contract_assert (_control.db );
@@ -175,12 +184,13 @@ namespace sqlite_orm {
175184
176185 // test for recursion from the same thread;
177186 // testing against an empty thread id is sufficient because recursion is only possible while calling the `_didOpenDb` callback in `retain()`
178- if (_control.lockingThread != std::thread::id{}) SQLITE_ORM_CPP_UNLIKELY {
187+ if (_control.initializingThreadId .load (std::memory_order_acquire) != std::thread::id{})
188+ SQLITE_ORM_CPP_UNLIKELY {
179189 return ;
180190 }
181191
182- const int previous = _control.retainCount .fetch_sub (1 , std::memory_order_release);
183- if (previous == 1 ) {
192+ const int previousCount = _control.retainCount .fetch_sub (1 , std::memory_order_release);
193+ if (previousCount == 1 ) {
184194 // last one closes the connection
185195
186196 const std::lock_guard _{_sync};
@@ -192,21 +202,19 @@ namespace sqlite_orm {
192202 }
193203 }
194204
195- /*
196- Precondition: Call from a single-threaded context or after `retain()`.
197- */
198- sqlite3* get () const {
199- return _control.db ;
200- }
201-
202205 // note: members of the `control_block` are deliberately put on the same cache-line
203206 SQLITE_ORM_MSVC_SUPPRESS_OVERALIGNMENT (alignas (polyfill::hardware_destructive_interference_size))
204207 struct control_block {
208+ // the optimization gain is very small;
209+ // at some design point it served as a flag to not use a mutex at all;
210+ // now it merely saves all the atomic operations, which actually perform without noticeable difference;
211+ // however it may be kept for conveying logic or future optimizations.
205212 const bool openedForeverHint = false ;
206213 std::atomic_int retainCount{};
214+ // `db` synchronizes with `retainCount`
207215 orm_gsl::owner<sqlite3*> db = nullptr ;
208216 // we don't know what the user-provided `on_open` callback might do, so we need to track recursion;
209- std::thread::id lockingThread ;
217+ std::atomic<std:: thread::id> initializingThreadId{} ;
210218 } _control;
211219
212220 SQLITE_ORM_MSVC_SUPPRESS_OVERALIGNMENT (alignas (polyfill::hardware_destructive_interference_size))
@@ -223,13 +231,10 @@ namespace sqlite_orm {
223231 /*
224232 Rebind connection reference;
225233 This function is actually unused in the library, but required for concepts compliance (moveable type).
226- Unfortunately it is not `noexcept` because of the `release()` call.
227234 */
228- connection_ref& operator =(connection_ref&& other) {
229- this ->holder ->release ();
230- this ->holder = other.holder ;
231- this ->db = other.db ;
232- this ->holder ->retain ();
235+ connection_ref& operator =(connection_ref&& other) noexcept {
236+ std::swap (this ->holder , other.holder );
237+ std::swap (this ->db , other.db );
233238 return *this ;
234239 }
235240
@@ -243,7 +248,41 @@ namespace sqlite_orm {
243248
244249 private:
245250 connection_holder* holder;
246- sqlite3* db = nullptr ;
251+ sqlite3* db;
252+ };
253+
254+ struct connection_ptr {
255+ connection_ptr (connection_holder& holder) : holder{&holder}, db{holder.retain_if_open ()} {}
256+
257+ connection_ptr (connection_ptr&& other) noexcept :
258+ holder{other.holder }, db{std::exchange (other.db , nullptr )} {}
259+
260+ /*
261+ Rebind connection pointer;
262+ */
263+ connection_ptr& operator =(connection_ptr&& other) noexcept {
264+ std::swap (this ->holder , other.holder );
265+ std::swap (this ->db , other.db );
266+ return *this ;
267+ }
268+
269+ ~connection_ptr () {
270+ if (this ->db ) {
271+ this ->holder ->release ();
272+ }
273+ }
274+
275+ explicit operator bool () const {
276+ return this ->db || false ;
277+ }
278+
279+ sqlite3* get () const {
280+ return this ->db ;
281+ }
282+
283+ private:
284+ connection_holder* holder;
285+ sqlite3* db;
247286 };
248287 }
249288}
0 commit comments