Skip to content

Commit e28c0f7

Browse files
benjefferymergify[bot]
authored andcommitted
Add variant id to simulate dataset
1 parent 0bccf6a commit e28c0f7

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

docs/changelog.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,16 @@ New Features
3535
- Add :func:`sgkit.io.vcf.read_vcf` convenience function.
3636
(:user:`tomwhite`, :pr:`1052`, :issue:`1004`)
3737

38-
- Add :func:`sgkit.hybrid_relationship`, :func:`sgkit.hybrid_inverse_relationship`
38+
- Add :func:`sgkit.hybrid_relationship`, :func:`sgkit.hybrid_inverse_relationship`
3939
and :func:`invert_relationship_matrix` methods.
4040
(:user:`timothymillar`, :pr:`1053`, :issue:`993`)
4141

4242
- Add :func:`sgkit.io.vcf.zarr_array_sizes` for determining array sizes for storage in Zarr.
4343
(:user:`tomwhite`, :pr:`1073`, :issue:`734`)
4444

45+
- Add `additional_variant_fields` to :func:`sgkit.simulate_genotype_call_dataset` function.
46+
(:user:`benjeffery`, :pr:`1056`)
47+
4548
Bug fixes
4649
~~~~~~~~~
4750

sgkit/testing.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def simulate_genotype_call_dataset(
1818
seed: Optional[int] = 0,
1919
missing_pct: Optional[float] = None,
2020
phased: Optional[bool] = None,
21+
additional_variant_fields: Optional[dict] = None,
2122
) -> Dataset:
2223
"""Simulate genotype calls and variant/sample data.
2324
@@ -50,6 +51,9 @@ def simulate_genotype_call_dataset(
5051
The percentage of missing calls, must be within [0.0, 1.0], optional
5152
phased
5253
Whether genotypes are phased, default is unphased, optional
54+
additional_variant_fields
55+
Additional variant fields to add to the dataset as a dictionary of
56+
{field_name: field_dtype}, optional
5357
5458
Returns
5559
-------
@@ -62,6 +66,7 @@ def simulate_genotype_call_dataset(
6266
- :data:`sgkit.variables.call_genotype_spec` (variants, samples, ploidy)
6367
- :data:`sgkit.variables.call_genotype_mask_spec` (variants, samples, ploidy)
6468
- :data:`sgkit.variables.call_genotype_phased_spec` (variants, samples), if ``phased`` is not None
69+
- Those specified in ``additional_variant_fields``, if provided
6570
"""
6671
if missing_pct and (missing_pct < 0.0 or missing_pct > 1.0):
6772
raise ValueError("missing_pct must be within [0.0, 1.0]")
@@ -87,7 +92,7 @@ def simulate_genotype_call_dataset(
8792
["A", "C", "G", "T"], size=(n_variant, n_allele)
8893
).astype("S")
8994
sample_id = np.array([f"S{i}" for i in range(n_sample)])
90-
return create_genotype_call_dataset(
95+
ds = create_genotype_call_dataset(
9196
variant_contig_names=contig_names,
9297
variant_contig=contig,
9398
variant_position=position,
@@ -96,3 +101,18 @@ def simulate_genotype_call_dataset(
96101
call_genotype=call_genotype,
97102
call_genotype_phased=call_genotype_phased,
98103
)
104+
# Add in each of the additional variant fields, if provided with random data
105+
if additional_variant_fields is not None:
106+
for field_name, field_dtype in additional_variant_fields.items():
107+
if field_dtype in (np.float32, np.float64):
108+
field = rs.rand(n_variant).astype(field_dtype)
109+
elif field_dtype in (np.int8, np.int16, np.int32, np.int64):
110+
field = rs.randint(0, 100, n_variant, dtype=field_dtype)
111+
elif field_dtype is np.bool:
112+
field = rs.rand(n_variant) > 0.5
113+
elif field_dtype is np.str:
114+
field = np.arange(n_variant).astype("S")
115+
else:
116+
raise ValueError(f"Unrecognized dtype {field_dtype}")
117+
ds[field_name] = (("variants",), field)
118+
return ds

sgkit/tests/test_testing.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,35 @@ def test_simulate_genotype_call_dataset__phased(tmp_path):
2929
ds = simulate_genotype_call_dataset(n_variant=10, n_sample=10, phased=False)
3030
assert "call_genotype_phased" in ds
3131
assert not np.any(ds["call_genotype_phased"])
32+
33+
34+
def test_simulate_genotype_call_dataset__additional_variant_fields():
35+
ds = simulate_genotype_call_dataset(
36+
n_variant=10,
37+
n_sample=10,
38+
phased=True,
39+
additional_variant_fields={
40+
"variant_id": np.str,
41+
"variant_filter": np.bool,
42+
"variant_quality": np.int8,
43+
"variant_yummyness": np.float32,
44+
},
45+
)
46+
assert "variant_id" in ds
47+
assert np.all(ds["variant_id"] == np.arange(10).astype("S"))
48+
assert "variant_filter" in ds
49+
assert ds["variant_filter"].dtype == np.bool
50+
assert "variant_quality" in ds
51+
assert ds["variant_quality"].dtype == np.int8
52+
assert "variant_yummyness" in ds
53+
assert ds["variant_yummyness"].dtype == np.float32
54+
55+
with pytest.raises(ValueError, match="Unrecognized dtype"):
56+
simulate_genotype_call_dataset(
57+
n_variant=10,
58+
n_sample=10,
59+
phased=True,
60+
additional_variant_fields={
61+
"variant_id": None,
62+
},
63+
)

0 commit comments

Comments
 (0)