@@ -68,28 +68,17 @@ def __init__(self, P, dt=1, random_state=None):
6868 self .P = np .array (P )
6969 self .n = self .P .shape [0 ]
7070
71- # initialize mu
72- self .mudist = None
73-
71+ if random_state is None :
72+ random_state = np .random .RandomState ()
7473 self .random_state = random_state
7574
76- # generate discrete random value generators for each line
77- self .rgs = np .ndarray (self .n , dtype = object )
78- from scipy .stats import rv_discrete
79- for i , row in enumerate (self .P ):
80- nz = row .nonzero ()[0 ]
81- self .rgs [i ] = rv_discrete (values = (nz , row [nz ]))
82-
8375 def _get_start_state (self ):
84- if self .mudist is None :
85- # compute mu, the stationary distribution of P
86- from ..analysis import stationary_distribution
87- from scipy .stats import rv_discrete
88-
89- mu = stationary_distribution (self .P )
90- self .mudist = rv_discrete (values = (np .arange (self .n ), mu ))
91- # sample starting point from mu
92- start = self .mudist .rvs (random_state = self .random_state )
76+ # compute mu, the stationary distribution of P
77+ from ..analysis import stationary_distribution
78+
79+ mu = stationary_distribution (self .P )
80+ start = self .random_state .choice (self .n , p = mu )
81+
9382 return start
9483
9584 def trajectory (self , N , start = None , stop = None ):
@@ -113,24 +102,19 @@ def trajectory(self, N, start=None, stop=None):
113102 if start is None :
114103 start = self ._get_start_state ()
115104
116- # evaluate stopping set
117- stopat = np .zeros (self .n , dtype = bool )
118- if stop is not None :
119- stopat [np .array (stop )] = True
120-
121105 # result
122106 traj = np .zeros (N , dtype = int )
123107 traj [0 ] = start
124108 # already at stopping state?
125- if stopat [ traj [0 ]] :
109+ if traj [0 ] == stop :
126110 return traj [:1 ]
127111 # else run until end or stopping state
128112 for t in range (1 , N ):
129- traj [t ] = self .rgs [traj [t - 1 ]]. rvs ( random_state = self . random_state )
130- if stopat [ traj [t ]] :
113+ traj [t ] = self .random_state . choice ( self . n , p = self . P [traj [t - 1 ]])
114+ if traj [t ] == stop :
131115 traj = np .resize (traj , t + 1 )
132116 break
133- # return
117+
134118 return traj
135119
136120 def trajectories (self , M , N , start = None , stop = None ):
0 commit comments