@@ -18,6 +18,7 @@ def simulate_genotype_call_dataset(
18
18
seed : Optional [int ] = 0 ,
19
19
missing_pct : Optional [float ] = None ,
20
20
phased : Optional [bool ] = None ,
21
+ additional_variant_fields : Optional [dict ] = None ,
21
22
) -> Dataset :
22
23
"""Simulate genotype calls and variant/sample data.
23
24
@@ -50,6 +51,9 @@ def simulate_genotype_call_dataset(
50
51
The percentage of missing calls, must be within [0.0, 1.0], optional
51
52
phased
52
53
Whether genotypes are phased, default is unphased, optional
54
+ additional_variant_fields
55
+ Additional variant fields to add to the dataset as a dictionary of
56
+ {field_name: field_dtype}, optional
53
57
54
58
Returns
55
59
-------
@@ -62,6 +66,7 @@ def simulate_genotype_call_dataset(
62
66
- :data:`sgkit.variables.call_genotype_spec` (variants, samples, ploidy)
63
67
- :data:`sgkit.variables.call_genotype_mask_spec` (variants, samples, ploidy)
64
68
- :data:`sgkit.variables.call_genotype_phased_spec` (variants, samples), if ``phased`` is not None
69
+ - Those specified in ``additional_variant_fields``, if provided
65
70
"""
66
71
if missing_pct and (missing_pct < 0.0 or missing_pct > 1.0 ):
67
72
raise ValueError ("missing_pct must be within [0.0, 1.0]" )
@@ -87,7 +92,7 @@ def simulate_genotype_call_dataset(
87
92
["A" , "C" , "G" , "T" ], size = (n_variant , n_allele )
88
93
).astype ("S" )
89
94
sample_id = np .array ([f"S{ i } " for i in range (n_sample )])
90
- return create_genotype_call_dataset (
95
+ ds = create_genotype_call_dataset (
91
96
variant_contig_names = contig_names ,
92
97
variant_contig = contig ,
93
98
variant_position = position ,
@@ -96,3 +101,18 @@ def simulate_genotype_call_dataset(
96
101
call_genotype = call_genotype ,
97
102
call_genotype_phased = call_genotype_phased ,
98
103
)
104
+ # Add in each of the additional variant fields, if provided with random data
105
+ if additional_variant_fields is not None :
106
+ for field_name , field_dtype in additional_variant_fields .items ():
107
+ if field_dtype in (np .float32 , np .float64 ):
108
+ field = rs .rand (n_variant ).astype (field_dtype )
109
+ elif field_dtype in (np .int8 , np .int16 , np .int32 , np .int64 ):
110
+ field = rs .randint (0 , 100 , n_variant , dtype = field_dtype )
111
+ elif field_dtype is np .bool :
112
+ field = rs .rand (n_variant ) > 0.5
113
+ elif field_dtype is np .str :
114
+ field = np .arange (n_variant ).astype ("S" )
115
+ else :
116
+ raise ValueError (f"Unrecognized dtype { field_dtype } " )
117
+ ds [field_name ] = (("variants" ,), field )
118
+ return ds
0 commit comments