@@ -191,11 +191,12 @@ def scan(f, init, xs, length=None):
191191 reverse: optional boolean specifying whether to run the scan iteration
192192 forward (the default) or in reverse, equivalent to reversing the leading
193193 axes of the arrays in both ``xs`` and in ``ys``.
194- unroll: optional positive int or bool specifying, in the underlying
194+ unroll: optional non-negative int or bool specifying, in the underlying
195195 operation of the scan primitive, how many scan iterations to unroll within
196196 a single iteration of a loop. If an integer is provided, it determines how
197197 many unrolled loop iterations to run within a single rolled iteration of
198- the loop. If a boolean is provided, it will determine if the loop is
198+ the loop. `unroll=0` unrolls the entire loop.
199+ If a boolean is provided, it will determine if the loop is
199200 completely unrolled (i.e. `unroll=True`) or left completely rolled (i.e.
200201 `unroll=False`).
201202 _split_transpose: experimental optional bool specifying whether to further
@@ -320,8 +321,8 @@ def _create_jaxpr(init):
320321 "value." )
321322 if isinstance (unroll , bool ):
322323 unroll = max (length , 1 ) if unroll else 1
323- if unroll < 1 :
324- raise ValueError ("`unroll` must be a `bool` or a positive `int`." )
324+ if unroll < 0 :
325+ raise ValueError ("`unroll` must be a `bool` or a non-negative `int`." )
325326
326327 # If the body forwards an input carry to an output carry, that input is
327328 # read-only and can be moved to be a const. Doing so can lead to efficiency
@@ -465,7 +466,10 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
465466 del _split_transpose
466467 consts , carry , xs_ = split_list (args , [num_consts , num_carry ])
467468 _ , y_avals = split_list (jaxpr .out_avals , [num_carry ])
468- num_trips , remainder = divmod (length , unroll )
469+ if unroll == 0 :
470+ num_trips , remainder = 0 , length
471+ else :
472+ num_trips , remainder = divmod (length , unroll )
469473
470474 if unroll != 1 and num_trips == 1 and remainder == 0 :
471475 # In that case, we explicitly want to fully unroll the loop. Put everything
@@ -1452,7 +1456,7 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
14521456 tc (jaxpr , 'jaxpr' , 'ClosedJaxpr' , type (jaxpr ) is ClosedJaxpr )
14531457 tc (linear , 'linear' , 'tuple of bool' ,
14541458 type (linear ) is tuple and all (type (x ) is bool for x in linear ))
1455- tc (unroll , 'unroll' , 'positive int' , type (unroll ) is int and unroll > 0 )
1459+ tc (unroll , 'unroll' , 'non-negative int' , type (unroll ) is int and unroll >= 0 )
14561460
14571461 tc (length , 'length' , 'non-negative int' , length >= 0 )
14581462
0 commit comments