Skip to content

Commit a61ed59

Browse files
committed
Switch from None to raising exceptions in ASI2.
1 parent b85bab7 commit a61ed59

File tree

6 files changed

+80
-103
lines changed

6 files changed

+80
-103
lines changed

pyvdrm/asi2.py

+28-45
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,10 @@
55
from functools import reduce, total_ordering
66
from pyparsing import (Literal, nums, Word, Forward, Optional, Regex,
77
infixNotation, delimitedList, opAssoc, ParseException)
8-
from pyvdrm.drm import AsiExpr, AsiBinaryExpr, DRMParser
8+
from pyvdrm.drm import AsiExpr, AsiBinaryExpr, DRMParser, MissingPositionError
99
from pyvdrm.vcf import MutationSet
1010

1111

12-
def maybe_foldl(func, noneable):
13-
"""Safely fold a function over a potentially empty list of
14-
potentially null values"""
15-
if noneable is None:
16-
return None
17-
clean = [x for x in noneable if x is not None]
18-
if not clean:
19-
return None
20-
return reduce(func, clean)
21-
22-
23-
def maybe_map(func, noneable):
24-
if noneable is None:
25-
return None
26-
r_list = []
27-
for x in noneable:
28-
if x is None:
29-
continue
30-
result = func(x)
31-
if result is None:
32-
continue
33-
r_list.append(result)
34-
if not r_list:
35-
return None
36-
return r_list
37-
38-
3912
@total_ordering
4013
class Score(object):
4114
"""Encapsulate a score and the residues that support it"""
@@ -169,10 +142,17 @@ class ScoreList(AsiExpr):
169142
def __call__(self, mutations):
170143
operation, *rest = self.children
171144
if operation == 'MAX':
172-
return maybe_foldl(max, [f(mutations) for f in rest])
173-
174-
# the default operation is sum
175-
return maybe_foldl(lambda x, y: x+y, [f(mutations) for f in self.children])
145+
terms = rest
146+
func = max
147+
else:
148+
# the default operation is sum
149+
terms = self.children
150+
func = sum
151+
scores = [f(mutations) for f in terms]
152+
matched_scores = [score.score for score in scores if score.score]
153+
residues = reduce(lambda x, y: x | y,
154+
(score.residues for score in scores))
155+
return Score(bool(matched_scores) and func(matched_scores), residues)
176156

177157

178158
class SelectFrom(AsiExpr):
@@ -186,15 +166,13 @@ def typecheck(self, tokens):
186166
def __call__(self, mutations):
187167
operation, *rest = self.children
188168
# the head of the arg list must be an equality expression
189-
190-
scored = list(maybe_map(lambda f: f(mutations), rest))
191-
passing = len(scored)
192169

193-
if operation(passing):
194-
return Score(True, maybe_foldl(
195-
lambda x, y: x.residues.union(y.residues), scored))
196-
else:
197-
return None
170+
scored = [f(mutations) for f in rest]
171+
passing = sum(bool(score.score) for score in scored)
172+
173+
return Score(operation(passing),
174+
reduce(lambda x, y: x | y,
175+
(item.residues for item in scored)))
198176

199177

200178
class AsiScoreCond(AsiExpr):
@@ -204,7 +182,7 @@ class AsiScoreCond(AsiExpr):
204182

205183
def __call__(self, args):
206184
"""Score conditions evaluate a list of expressions and sum scores"""
207-
return maybe_foldl(lambda x, y: x+y, map(lambda x: x(args), self.children))
185+
return sum((f(args) for f in self.children), Score(False, set()))
208186

209187

210188
class AsiMutations(object):
@@ -213,19 +191,24 @@ class AsiMutations(object):
213191
def __init__(self, _label=None, _pos=None, args=None):
214192
"""Initialize set of mutations from a potentially ambiguous residue
215193
"""
216-
self.mutations = args and MutationSet(''.join(args))
194+
self.mutations = MutationSet(''.join(args))
217195

218196
def __repr__(self):
219-
if self.mutations is None:
220-
return "AsiMutations()"
221197
return "AsiMutations(args={!r})".format(str(self.mutations))
222198

223199
def __call__(self, env):
200+
is_found = False
224201
for mutation_set in env:
202+
is_found |= mutation_set.pos == self.mutations.pos
225203
intersection = self.mutations.mutations & mutation_set.mutations
226204
if len(intersection) > 0:
227205
return Score(True, intersection)
228-
return None
206+
207+
if not is_found:
208+
# Some required positions were not found in the environment.
209+
raise MissingPositionError('Missing position {}.'.format(
210+
self.mutations.pos))
211+
return Score(False, set())
229212

230213

231214
class ASI2(DRMParser):

pyvdrm/hcvr.py

-2
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,6 @@ def __init__(self, _label=None, _pos=None, args=None):
232232
self.mutations = MutationSet(''.join(args))
233233

234234
def __repr__(self):
235-
if self.mutations is None:
236-
return "AsiMutations()"
237235
return "AsiMutations(args={!r})".format(str(self.mutations))
238236

239237
def __call__(self, env):

pyvdrm/tests/test_asi2.py

+32-39
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,31 @@
44
from pyparsing import ParseException
55

66
from pyvdrm.asi2 import ASI2, AsiMutations, Score
7+
from pyvdrm.drm import MissingPositionError
78
from pyvdrm.vcf import Mutation, MutationSet, VariantCalls
89

10+
from pyvdrm.tests.test_vcf import add_mutations
11+
912

1013
# noinspection SqlNoDataSourceInspection,SqlDialectInspection
1114
class TestRuleParser(unittest.TestCase):
1215

1316
def test_stanford_ex1(self):
1417
ASI2("151M OR 69i")
1518

16-
def test_stanford_ex2(self):
19+
def test_atleast_true(self):
20+
rule = ASI2("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
21+
self.assertTrue(rule(VariantCalls('41L 67N 70d 210d 215d 219d')))
22+
23+
def test_atleast_false(self):
1724
rule = ASI2("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
18-
m1 = MutationSet('41L')
19-
m2 = MutationSet('67N')
20-
m3 = MutationSet('70N')
21-
self.assertTrue(rule([m1, m2]))
22-
self.assertFalse(rule([m1, m3]))
25+
self.assertFalse(rule(VariantCalls('41L 67d 70d 210d 215d 219d')))
26+
27+
def test_atleast_missing(self):
28+
rule = ASI2("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
29+
with self.assertRaisesRegex(MissingPositionError,
30+
r'Missing position 70.'):
31+
rule(VariantCalls('41L 67N'))
2332

2433
def test_stanford_ex3(self):
2534
ASI2("SELECT ATLEAST 2 AND NOTMORETHAN 2 FROM (41L, 67N, 70R, 210W, 215FY, 219QE)")
@@ -53,50 +62,47 @@ def test_asi2_compat(self):
5362
class TestRuleSemantics(unittest.TestCase):
5463
def test_score_from(self):
5564
rule = ASI2("SCORE FROM ( 100G => 10, 101D => 20 )")
56-
self.assertEqual(rule(VariantCalls("100G 102G")), 10)
65+
self.assertEqual(rule(VariantCalls("100G 101d")), 10)
5766

5867
def test_score_negate(self):
5968
rule = ASI2("SCORE FROM ( NOT 100G => 10, NOT 101SD => 20 )")
60-
self.assertEqual(rule(VariantCalls("100G 102G")), 20)
69+
self.assertEqual(rule(VariantCalls("100G 101d")), 20)
6170
self.assertEqual(rule(VariantCalls("100S 101S")), 10)
6271

6372
def test_score_residues(self):
6473
rule = ASI2("SCORE FROM ( 100G => 10, 101D => 20 )")
6574
expected_residue = repr({Mutation('S100G')})
6675

67-
result = rule.dtree(VariantCalls("S100G R102G"))
76+
result = rule.dtree(VariantCalls("S100G R101d"))
6877

6978
self.assertEqual(expected_residue, repr(result.residues))
7079

7180
def test_score_from_max(self):
7281
rule = ASI2("SCORE FROM (MAX (100G => 10, 101D => 20, 102D => 30))")
73-
self.assertEqual(rule(VariantCalls("100G 101D")), 20)
74-
self.assertEqual(rule(VariantCalls("10G 11D")), False)
82+
self.assertEqual(rule(VariantCalls("100G 101D 102d")), 20)
83+
self.assertEqual(rule(VariantCalls("100d 101d 102d")), False)
7584

7685
def test_score_from_max_neg(self):
7786
rule = ASI2("SCORE FROM (MAX (100G => -10, 101D => -20, 102D => 30))")
78-
self.assertEqual(rule(VariantCalls("100G 101D")), -10)
79-
self.assertEqual(rule(VariantCalls("10G 11D")), False)
87+
self.assertEqual(rule(VariantCalls("100G 101D 102d")), -10)
8088

8189
def test_bool_and(self):
8290
rule = ASI2("1G AND (2T AND 7Y)")
8391
self.assertEqual(rule(VariantCalls("2T 7Y 1G")), True)
84-
self.assertEqual(rule(VariantCalls("2T 3Y 1G")), False)
92+
self.assertEqual(rule(VariantCalls("2T 7d 1G")), False)
8593
self.assertEqual(rule(VariantCalls("7Y 1G 2T")), True)
86-
self.assertEqual(rule([]), False)
8794

8895
def test_bool_or(self):
8996
rule = ASI2("1G OR (2T OR 7Y)")
90-
self.assertTrue(rule(VariantCalls("2T")))
91-
self.assertFalse(rule(VariantCalls("3T")))
92-
self.assertTrue(rule(VariantCalls("1G")))
93-
self.assertFalse(rule([]))
97+
self.assertTrue(rule(VariantCalls("1d 2T 7d")))
98+
self.assertFalse(rule(VariantCalls("1d 2d 7d")))
99+
self.assertTrue(rule(VariantCalls("1G 2d 7d")))
94100

95101
def test_select_from_atleast(self):
96102
rule = ASI2("SELECT ATLEAST 2 FROM (2T, 7Y, 3G)")
97-
self.assertTrue(rule(VariantCalls("2T 7Y 1G")))
98-
self.assertFalse(rule(VariantCalls("2T 4Y 5G")))
99-
self.assertTrue(rule(VariantCalls("3G 9Y 2T")))
103+
self.assertTrue(rule(VariantCalls("2T 7Y 3d")))
104+
self.assertFalse(rule(VariantCalls("2T 7d 3d")))
105+
self.assertTrue(rule(VariantCalls("3G 7d 2T")))
100106

101107
def test_score_from_exactly(self):
102108
rule = ASI2("SELECT EXACTLY 1 FROM (2T, 7Y)")
@@ -155,10 +161,10 @@ def test_chained_and(self):
155161
215FY) => 10), MAX ((41L AND 215ACDEILNSV) => 5, (41L AND 215FY) =>
156162
15))
157163
""")
158-
self.assertEqual(rule(VariantCalls("40F 41L 210W 215Y")), 65)
159-
self.assertEqual(rule(VariantCalls("41L 210W 215F")), 60)
160-
self.assertEqual(rule(VariantCalls("40F 210W 215Y")), 25)
161-
self.assertEqual(rule(VariantCalls("40F 67G 215Y")), 15)
164+
self.assertEqual(rule(add_mutations("40F 41L 210W 215Y")), 65)
165+
self.assertEqual(rule(add_mutations("41L 210W 215F")), 60)
166+
self.assertEqual(rule(add_mutations("40F 210W 215Y")), 25)
167+
self.assertEqual(rule(add_mutations("40F 67G 215Y")), 15)
162168

163169

164170
class TestAsiMutations(unittest.TestCase):
@@ -169,11 +175,6 @@ def test_init_args(self):
169175
self.assertEqual(expected_mutation_set, m.mutations)
170176
self.assertEqual(expected_mutation_set.wildtype, m.mutations.wildtype)
171177

172-
def test_init_none(self):
173-
m = AsiMutations()
174-
175-
self.assertIsNone(m.mutations)
176-
177178
def test_repr(self):
178179
expected_repr = "AsiMutations(args='Q80KR')"
179180
m = AsiMutations(args='Q80KR')
@@ -182,14 +183,6 @@ def test_repr(self):
182183

183184
self.assertEqual(expected_repr, r)
184185

185-
def test_repr_none(self):
186-
expected_repr = "AsiMutations()"
187-
m = AsiMutations()
188-
189-
r = repr(m)
190-
191-
self.assertEqual(expected_repr, r)
192-
193186

194187
class TestScore(unittest.TestCase):
195188
def test_init(self):

pyvdrm/tests/test_hcvr.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from pyvdrm.hcvr import HCVR, AsiMutations, Score
88
from pyvdrm.vcf import Mutation, MutationSet, VariantCalls
99

10+
from pyvdrm.tests.test_vcf import add_mutations
11+
1012

1113
# noinspection SqlNoDataSourceInspection,SqlDialectInspection
1214
class TestRuleParser(unittest.TestCase):
1315

1416
def test_stanford_ex1(self):
1517
HCVR("151M OR 69i")
1618

17-
def test(self):
19+
def test_atleast_true(self):
1820
rule = HCVR("SELECT ATLEAST 2 FROM (41L, 67N, 70R, 210W, 215F, 219Q)")
1921
self.assertTrue(rule(VariantCalls('41L 67N 70d 210d 215d 219d')))
2022

@@ -158,21 +160,6 @@ def test_parse_exception_multiline(self):
158160
self.assertEqual(expected_error_message, str(context.exception))
159161

160162

161-
def add_mutations(text):
162-
""" Add a small set of mutations to an RT wild type. """
163-
164-
# Start of RT reference.
165-
ref = ("PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGKISKIGPENPYNTPVFA"
166-
"IKKKDSTKWRKLVDFRELNKRTQDFWEVQLGIPHPAGLKKKKSVTVLDVGDAYFSVPLDEDF"
167-
"RKYTAFTIPSINNETPGIRYQYNVLPQGWKGSPAIFQSSMTKILEPFRKQNPDIVIYQYMDD"
168-
"LYVGSDLEIGQHRTKIEELRQHLLRWGLTTPDKKHQK")
169-
seq = list(ref)
170-
changes = VariantCalls(text)
171-
for mutation_set in changes:
172-
seq[mutation_set.pos - 1] = [m.variant for m in mutation_set]
173-
return VariantCalls(reference=ref, sample=seq)
174-
175-
176163
class TestActualRules(unittest.TestCase):
177164
def test_hivdb_rules_parse(self):
178165
folder = os.path.dirname(__file__)

pyvdrm/tests/test_vcf.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
from pyvdrm.vcf import Mutation, MutationSet, VariantCalls
3+
from vcf import VariantCalls
34

45

56
class TestMutation(unittest.TestCase):
@@ -429,3 +430,18 @@ def test_immutable(self):
429430

430431
if __name__ == '__main__':
431432
unittest.main()
433+
434+
435+
def add_mutations(text):
436+
""" Add a small set of mutations to an RT wild type. """
437+
438+
# Start of RT reference.
439+
ref = ("PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTEMEKEGKISKIGPENPYNTPVFA"
440+
"IKKKDSTKWRKLVDFRELNKRTQDFWEVQLGIPHPAGLKKKKSVTVLDVGDAYFSVPLDEDF"
441+
"RKYTAFTIPSINNETPGIRYQYNVLPQGWKGSPAIFQSSMTKILEPFRKQNPDIVIYQYMDD"
442+
"LYVGSDLEIGQHRTKIEELRQHLLRWGLTTPDKKHQK")
443+
seq = list(ref)
444+
changes = VariantCalls(text)
445+
for mutation_set in changes:
446+
seq[mutation_set.pos - 1] = [m.variant for m in mutation_set]
447+
return VariantCalls(reference=ref, sample=seq)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='pyvdrm',
5-
version='0.2.0',
5+
version='0.3.0',
66
description='',
77

88
url='',

0 commit comments

Comments
 (0)