Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Unreleased
- Add column comment to get_columns method (#253), thanks to @dotan-mor
- Fix autogenerate with ON UPDATE / DELETE (#258, #262), thanks to @idumitrescu-dn
- Improve support for table/column comments (via SQLA 2.0.36)

- Add nested transaction support (#267), thanks to @mfmarche

# Version 2.0.2
January 10, 2023
Expand Down
22 changes: 22 additions & 0 deletions cockroach_helper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash
COCKROACHDB=cockroach-v23.1.13.linux-amd64
CROACHDB=~/.cache/$COCKROACHDB/cockroach

quit_cockroachdb() {
OLDPIDNS=$(ps -o pidns -C cockroach | awk 'NR==2 {print $0}')
if [ -n "$OLDPIDNS" ]; then
pkill --ns $$ $OLDPIDNS
fi
return 0
}

[ -n "$HOST" ] || HOST=localhost
mkdir -p $(dirname $CROACHDB)
[[ -f "$CROACHDB" ]] || wget -qO- https://binaries.cockroachdb.com/$COCKROACHDB.tgz | tar xvz --directory ~/.cache
if [ $1 == "start" ]; then
quit_cockroachdb
$CROACHDB start-single-node --background --insecure --store=type=mem,size=10% --log-dir /tmp/ --listen-addr=$HOST:26257 --http-addr=$HOST:26301
#$CROACHDB sql --host=$HOST:26257 --insecure -e "set sql_safe_updates=false; drop database if exists apibuilder; create database if not exists apibuilder; create user if not exists apibuilder; grant all on database apibuilder to apibuilder;"
else
quit_cockroachdb
fi
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[tool.black]
line-length = 100

[tool.pytest.ini_options]
addopts = "--tb native -v -r sfxX --maxfail=250 -p warnings -p logging --strict-markers"
markers = [
Expand Down
118 changes: 76 additions & 42 deletions sqlalchemy_cockroachdb/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@
from .base import savepoint_state


def run_transaction(transactor, callback, max_retries=None, max_backoff=0):
class ChainTransaction:
def __init__(self, transactions=None):
self.results = []
self.transactions = transactions or []

def add_result(self, result):
self.results.append(result)


def run_transaction(transactor, callback, max_retries=None, max_backoff=0, **kwargs):
"""Run a transaction with retries.

``callback()`` will be called with one argument to execute the
Expand All @@ -26,15 +35,18 @@ def run_transaction(transactor, callback, max_retries=None, max_backoff=0):
transaction should be retried before giving up.
``max_backoff`` is an optional integer that specifies the capped number of seconds
for the exponential back-off.
``inject_error`` forces retry loop to run via SET inject_retry_errors_enabled = 'true'
``use_cockroach_restart``, default true, utilizes the special cockroach_restart protocol,
as outlined in: https://www.cockroachlabs.com/blog/nested-transactions-in-cockroachdb-20-1/
"""
if isinstance(transactor, (sqlalchemy.engine.Connection, sqlalchemy.orm.Session)):
return _txn_retry_loop(transactor, callback, max_retries, max_backoff)
return _txn_retry_loop(transactor, callback, max_retries, max_backoff, **kwargs)
elif isinstance(transactor, sqlalchemy.engine.Engine):
with transactor.connect() as connection:
return _txn_retry_loop(connection, callback, max_retries, max_backoff)
return _txn_retry_loop(connection, callback, max_retries, max_backoff, **kwargs)
elif isinstance(transactor, sqlalchemy.orm.sessionmaker):
session = transactor()
return _txn_retry_loop(session, callback, max_retries, max_backoff)
return _txn_retry_loop(session, callback, max_retries, max_backoff, **kwargs)
else:
raise TypeError("don't know how to run a transaction on %s", type(transactor))

Expand All @@ -46,27 +58,32 @@ class _NestedTransaction:
loop to be rewritten by the dialect.
"""

def __init__(self, conn):
def __init__(self, conn, use_cockroach_restart=True):
self.conn = conn
self.use_cockroach_restart = use_cockroach_restart

def __enter__(self):
try:
savepoint_state.cockroach_restart = True
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = True
self.txn = self.conn.begin_nested()
if isinstance(self.conn, sqlalchemy.orm.Session):
if self.use_cockroach_restart and isinstance(self.conn, sqlalchemy.orm.Session):
# Sessions are lazy and don't execute the savepoint
# query until you ask for the connection.
self.conn.connection()
finally:
savepoint_state.cockroach_restart = False
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = False
return self

def __exit__(self, typ, value, tb):
try:
savepoint_state.cockroach_restart = True
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = True
self.txn.__exit__(typ, value, tb)
finally:
savepoint_state.cockroach_restart = False
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = False


def retry_exponential_backoff(retry_count: int, max_backoff: int = 0) -> None:
Expand All @@ -81,45 +98,62 @@ def retry_exponential_backoff(retry_count: int, max_backoff: int = 0) -> None:
:return: None
"""

sleep_secs = uniform(0, min(max_backoff, 0.1 * (2 ** retry_count)))
sleep_secs = uniform(0, min(max_backoff, 0.1 * (2**retry_count)))
sleep(sleep_secs)


def _txn_retry_loop(conn, callback, max_retries, max_backoff):
"""Inner transaction retry loop.

``conn`` may be either a Connection or a Session, but they both
have compatible ``begin()`` and ``begin_nested()`` methods.
"""
def run_in_nested_transaction(
conn, callback, max_retries, max_backoff, inject_error=False, **kwargs
):
if isinstance(conn, sqlalchemy.orm.Session):
dbapi_name = conn.bind.driver
else:
dbapi_name = conn.engine.driver

retry_count = 0
with conn.begin():
while True:
try:
with _NestedTransaction(conn):
ret = callback(conn)
return ret
except sqlalchemy.exc.DatabaseError as e:
if max_retries is not None and retry_count >= max_retries:
raise
do_retry = False
if dbapi_name == "psycopg2":
import psycopg2
import psycopg2.errorcodes
if isinstance(e.orig, psycopg2.OperationalError):
if e.orig.pgcode == psycopg2.errorcodes.SERIALIZATION_FAILURE:
do_retry = True
else:
import psycopg
if isinstance(e.orig, psycopg.errors.SerializationFailure):
do_retry = True
if do_retry:
retry_count += 1
if max_backoff > 0:
retry_exponential_backoff(retry_count, max_backoff)
continue
while True:
if inject_error and retry_count == 0:
conn.execute(sqlalchemy.text("SET inject_retry_errors_enabled = 'true'"))
elif inject_error:
conn.execute(sqlalchemy.text("SET inject_retry_errors_enabled = 'false'"))
try:
with _NestedTransaction(conn, **kwargs):
return callback(conn)
except sqlalchemy.exc.DatabaseError as e:
if max_retries is not None and retry_count >= max_retries:
raise
do_retry = False
if dbapi_name == "psycopg2":
import psycopg2
import psycopg2.errorcodes

if isinstance(e.orig, psycopg2.OperationalError):
if e.orig.pgcode == psycopg2.errorcodes.SERIALIZATION_FAILURE:
do_retry = True
else:
import psycopg

if isinstance(e.orig, psycopg.errors.SerializationFailure):
do_retry = True
if do_retry:
retry_count += 1
if max_backoff > 0:
retry_exponential_backoff(retry_count, max_backoff)
continue
raise


def _txn_retry_loop(conn, callback, max_retries, max_backoff, **kwargs):
"""Inner transaction retry loop.

``conn`` may be either a Connection or a Session, but they both
have compatible ``begin()`` and ``begin_nested()`` methods.
"""
with conn.begin():
result = run_in_nested_transaction(conn, callback, max_retries, max_backoff, **kwargs)
if isinstance(result, ChainTransaction):
for transaction in result.transactions:
result.add_result(
run_in_nested_transaction(conn, transaction, max_retries, max_backoff, **kwargs)
)
return result
37 changes: 36 additions & 1 deletion test/test_run_transaction_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from sqlalchemy.testing import fixtures
from sqlalchemy.types import Integer
import threading
from sqlalchemy.orm import sessionmaker, scoped_session


from sqlalchemy_cockroachdb import run_transaction
from sqlalchemy_cockroachdb.transaction import ChainTransaction

meta = MetaData()

Expand All @@ -25,7 +28,9 @@ def setup_method(self, method):
)

def teardown_method(self, method):
meta.drop_all(testing.db)
session = scoped_session(sessionmaker(bind=testing.db))
session.query(account_table).delete()
session.commit()

def get_balances(self, conn):
"""Returns the balances of the two accounts as a list."""
Expand Down Expand Up @@ -134,3 +139,33 @@ def txn_body(conn):
with testing.db.connect() as conn:
rs = run_transaction(conn, txn_body)
assert rs[0] == (1, 100)

def test_run_transaction_retry_with_nested(self):
def txn_body(conn):
rs = conn.execute(text("select acct, balance from account where acct = 1"))
conn.execute(text("select crdb_internal.force_retry('1s')"))
return [r for r in rs]

with testing.db.connect() as conn:
rs = run_transaction(conn, txn_body, use_cockroach_restart=False)
assert rs[0] == (1, 100)

def test_run_chained_transaction(self):
def txn_body(conn):
# first transaction inserts
conn.execute(account_table.insert(), [dict(acct=99, balance=100)])
conn.execute(text("select crdb_internal.force_retry('1s')"))

def _get_val(s):
rs = s.execute(text("select acct, balance from account where acct = 99"))
return [r for r in rs]

# chain the get into a separate nested transaction, so that the value
# in the previous nested transaction is flushed and available
return ChainTransaction([lambda s: _get_val(s), lambda s: _get_val(s)])

with testing.db.connect() as conn:
rs = run_transaction(conn, txn_body, use_cockroach_restart=False)
assert len(rs.results) == 2
assert rs.results[0][0] == (99, 100)
assert rs.results[1][0] == (99, 100)