From be8be2c5c849d47e376f49a7c8ab0f9b23d6a2da Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Tue, 8 Nov 2022 17:36:27 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- UniprotDB/AsyncMongoDB.py | 13 +++---------- UniprotDB/BaseDatabase.py | 5 +---- UniprotDB/LMDB.py | 20 +++++++++++--------- UniprotDB/MongoDB.py | 10 ++-------- UniprotDB/SwissProtUtils.py | 5 ++--- UniprotDB/UniprotDB.py | 15 ++++++++++----- UniprotDB/_utils.py | 4 ++-- UniprotDB/data_loader.py | 7 ++----- UniprotDBTests.py | 13 +++++-------- 9 files changed, 38 insertions(+), 54 deletions(-) diff --git a/UniprotDB/AsyncMongoDB.py b/UniprotDB/AsyncMongoDB.py index 3c94f32..b91db6b 100644 --- a/UniprotDB/AsyncMongoDB.py +++ b/UniprotDB/AsyncMongoDB.py @@ -21,18 +21,13 @@ def __init__(self, database: str, host: Tuple[str] = ('localhost',), **kwargs): def get_item(self, item: str) -> Union[SeqRecord, None]: t = self.loop.run_until_complete(self.col.find_one({'$or': [{i: item} for i in self.ids]})) - if t is None: - return None - r = self._extract_seqrecord(t['raw_record']) - return r + return None if t is None else self._extract_seqrecord(t['raw_record']) def get_iter(self) -> Generator[SeqRecord, None, None]: q = asyncio.Queue() self.loop.create_task(self._get_iter(q)) - r = self.loop.run_until_complete(q.get()) - while r: + while r := self.loop.run_until_complete(q.get()): yield r - r = self.loop.run_until_complete(q.get()) async def _get_iter(self, q: asyncio.Queue) -> None: async for entry in self.col.find({'_id': {'$exists': True}}): @@ -42,10 +37,8 @@ async def _get_iter(self, q: asyncio.Queue) -> None: def get_iterkeys(self) -> Generator[str, None, None]: q = asyncio.Queue() self.loop.create_task(self._get_iterkeys(q)) - r = self.loop.run_until_complete(q.get()) - while r: + while r := self.loop.run_until_complete(q.get()): yield r - r = self.loop.run_until_complete(q.get()) async def _get_iterkeys(self, q: asyncio.Queue) -> None: async for i in self.col.find({'_id': {'$exists': True}}, {'_id': 1}): diff --git a/UniprotDB/BaseDatabase.py b/UniprotDB/BaseDatabase.py index 170e9e6..601bf0f 100644 --- a/UniprotDB/BaseDatabase.py +++ b/UniprotDB/BaseDatabase.py @@ -30,7 +30,6 @@ def __init__(self, database: str, host: Union[tuple, str], self.create_protein_func = partial(create_protein_func, compressor=self.compressor) from UniprotDB._utils import _extract_seqrecord self._extract_seqrecord = partial(_extract_seqrecord, decompressor=self.decompressor) - pass def initialize(self, seq_handles: Iterable, filter_fn: Callable[[bytes], bool] = None, @@ -91,9 +90,7 @@ def update(self, handles: Iterable, def add_record(self, raw_record: bytes, test: str = None, test_attr: str = None) -> bool: protein = self.create_protein_func(raw_record) if test: - good = False - if test == protein['_id']: - good = True + good = test == protein['_id'] if not good: for ref in ([test_attr] if test_attr else self.ids): if test in protein.get(ref, []): diff --git a/UniprotDB/LMDB.py b/UniprotDB/LMDB.py index 50469d7..5bcec1c 100644 --- a/UniprotDB/LMDB.py +++ b/UniprotDB/LMDB.py @@ -64,15 +64,20 @@ def _setup_dbs(self) -> None: self.db: Dict[str] = {} for i in range(self.db_splits): - self.db[str(i)] = lmdb.open(os.path.join(self.host, str(i) + '.lmdb'), - map_size=self.map_size / self.db_splits, - writemap=True, map_async=True, readahead=False) + self.db[str(i)] = lmdb.open( + os.path.join(self.host, f'{str(i)}.lmdb'), + map_size=self.map_size / self.db_splits, + writemap=True, + map_async=True, + readahead=False, + ) + if self.has_index: self.index_dbs: Dict[str] = {} for index in self.indices: for i in range(self.index_db_splits): self.index_dbs[index + str(i)] = \ - lmdb.open(os.path.join(self.host, index + str(i) + '.lmdb'), + lmdb.open(os.path.join(self.host, index + str(i) + '.lmdb'), map_size=self.map_size / self.index_db_splits, writemap=True, map_async=True, readahead=False) with open(os.path.join(self.host, 'db_info.json'), 'w') as o: @@ -109,9 +114,7 @@ def get_item(self, item: str) -> Union[SeqRecord, None]: with self.db[self._get_subdb(t.decode())].begin() as txn: t = txn.get(t) break - if t is None: - return None - return self._extract_seqrecord(t) + return None if t is None else self._extract_seqrecord(t) def get_iter(self) -> Generator[SeqRecord, None, None]: for i in range(self.db_splits): @@ -148,8 +151,7 @@ def get_by(self, attr: str, value: str) -> List[SeqRecord]: db = self.index_dbs[subdb].open_db(dupsort=True) cur = txn.cursor(db=db) if cur.set_key(value.encode()): - for i in cur.iternext_dup(): - ret.append(self.get_item(i.decode())) + ret.extend(self.get_item(i.decode()) for i in cur.iternext_dup()) return ret def _create_indices(self, background: bool = False) -> None: diff --git a/UniprotDB/MongoDB.py b/UniprotDB/MongoDB.py index 5fe4a6c..32f9627 100644 --- a/UniprotDB/MongoDB.py +++ b/UniprotDB/MongoDB.py @@ -18,10 +18,7 @@ def __init__(self, database: str, host: Union[str, tuple] = ('localhost',), **kw def get_item(self, item: str) -> Union[SeqRecord, None]: t = self.col.find_one({'$or': [{i: item} for i in self.ids]}, {'raw_record': True}) - if t is None: - return None - r = self._extract_seqrecord(t['raw_record']) - return r + return None if t is None else self._extract_seqrecord(t['raw_record']) def get_iter(self) -> Generator[SeqRecord, None, None]: for entry in self.col.find({}, {'raw_record': True}): @@ -38,11 +35,8 @@ def length(self) -> int: return self.col.count_documents({}) def get_by(self, attr: str, value: str) -> List[SeqRecord]: - ret = [] res = self.col.find({attr: value}, {'raw_record': True}) - for i in res: - ret.append(self._extract_seqrecord(i['raw_record'])) - return ret + return [self._extract_seqrecord(i['raw_record']) for i in res] def _reset(self) -> None: self.client[self.database].proteins.drop() diff --git a/UniprotDB/SwissProtUtils.py b/UniprotDB/SwissProtUtils.py index 1f21b74..6f7afd4 100644 --- a/UniprotDB/SwissProtUtils.py +++ b/UniprotDB/SwissProtUtils.py @@ -8,7 +8,7 @@ def _get_record(handle: BinaryIO, ignore: Collection[bytes] = (b'R', b'C')): """ lines = [] for line in handle: - if not line[0] in ignore: + if line[0] not in ignore: lines.append(line) if line.startswith(b'//'): yield b''.join(lines) @@ -22,8 +22,7 @@ def filter_proks(record: bytes): good_taxa = {b'Archaea', b'Bacteria', } taxa = re.search(b'OC.*\n', record).group()[5:] base_taxa = taxa.split(b'; ')[0] - good = base_taxa in good_taxa - return good + return base_taxa in good_taxa def parse_raw_swiss(handle: BinaryIO, filter_fn: Callable[[bytes], bool] = None): diff --git a/UniprotDB/UniprotDB.py b/UniprotDB/UniprotDB.py index 50908bb..11161f5 100644 --- a/UniprotDB/UniprotDB.py +++ b/UniprotDB/UniprotDB.py @@ -101,7 +101,7 @@ def update_trembl_taxa(self, taxa: Iterable, filter_fn: Callable[[bytes], bool] import urllib.request for taxon in taxa: taxon_handle = gzip.open(urllib.request.urlopen(trembl_taxa_prefix.format(taxon))) - print("Updating {}".format(taxon)) + print(f"Updating {taxon}") self.update([taxon_handle], filter_fn, loud, workers=workers) taxon_handle.close() @@ -130,7 +130,12 @@ def create_index(flatfiles: Iterable, host: Union[str, tuple] = (), host/filename + dbtype, fill the database with the protein entries and returns a SeqDB object. """ from .data_loader import process_main - s = process_main(flatfiles, host, - dbtype=dbtype, initialize=True, verbose=True, n_jobs=n_jobs, **kwargs) - - return s + return process_main( + flatfiles, + host, + dbtype=dbtype, + initialize=True, + verbose=True, + n_jobs=n_jobs, + **kwargs + ) diff --git a/UniprotDB/_utils.py b/UniprotDB/_utils.py index 6224af1..ed57389 100644 --- a/UniprotDB/_utils.py +++ b/UniprotDB/_utils.py @@ -78,7 +78,7 @@ def _extract_seqrecord(raw_record: bytes, decompressor: zstd.ZstdDecompressor) - def search_uniprot(value: str, retries: int = 3) -> Generator[bytes, None, None]: possible_ids = [] - for x in range(retries): + for _ in range(retries): try: possible_ids = requests.get(query_req.format(value)).content.split() break @@ -87,7 +87,7 @@ def search_uniprot(value: str, retries: int = 3) -> Generator[bytes, None, None] raw_record = None for pid in possible_ids[:5]: - for x in range(retries): + for _ in range(retries): try: raw_record = requests.get(fetch_req.format(pid.decode())).content break diff --git a/UniprotDB/data_loader.py b/UniprotDB/data_loader.py index f9c15f7..a69e142 100644 --- a/UniprotDB/data_loader.py +++ b/UniprotDB/data_loader.py @@ -162,10 +162,7 @@ def process_main(dats: Iterable[str], if verbose: from tqdm import tqdm - if num_seqs: - pbar = tqdm(total=num_seqs) - else: - pbar = tqdm() + pbar = tqdm(total=num_seqs) if num_seqs else tqdm() current = len(seqdb) with tempfile.TemporaryDirectory() as directory: @@ -180,7 +177,7 @@ def process_main(dats: Iterable[str], logging.debug('Started async processing') logging.debug(f'Opening fifos {fifos}') output_handles = [open(f, 'wb') for f in fifos] - logging.debug(f'Opening fifos') + logging.debug('Opening fifos') feed_steps = feed_files(fh, output_handles) all_fed = False logging.debug(f'Starting fifo feeding from {dat}') diff --git a/UniprotDBTests.py b/UniprotDBTests.py index f6f64d5..d62620f 100644 --- a/UniprotDBTests.py +++ b/UniprotDBTests.py @@ -62,14 +62,14 @@ def test_update(self): with gzip.open('TestFiles/testbig.dat.gz', 'rb') as h: self.db.update([h]) with gzip.open('TestFiles/testbig.dat.gz', 'rb') as h: - ids = set(line.split()[1].decode() for line in h if line.startswith(b'ID')) - inserted_ids = set(e.name for e in self.db) + ids = {line.split()[1].decode() for line in h if line.startswith(b'ID')} + inserted_ids = {e.name for e in self.db} self.assertEqual(inserted_ids, ids) def test_update_filtered(self): with gzip.open('TestFiles/testbig.dat.gz', 'rb') as h: self.db.update([h], filter_fn=filter_proks) - self.assertEqual(len(set(e.name for e in self.db)), 70) + self.assertEqual(len({e.name for e in self.db}), 70) @unittest.skipUnless(HAS_MONGO, "requires pymongo") @@ -79,8 +79,7 @@ def setUp(self): self.database = 'test_uni2' import os - db_host = os.environ.get('TEST_DB_HOST') - if db_host: + if db_host := os.environ.get('TEST_DB_HOST'): self.db = UniprotDB.create_index(['TestFiles/test.dat.bgz'], host=(db_host,), database=self.database, @@ -101,9 +100,7 @@ def setUp(self): self.database = 'test_uni2' import os - db_host = os.environ.get('TEST_DB_HOST') - - if db_host: + if db_host := os.environ.get('TEST_DB_HOST'): self.db = UniprotDB.create_index(['TestFiles/test.dat.bgz'], host=(db_host,), database=self.database, dbtype='mongoasync') else: