-
Couldn't load subscription status.
- Fork 72
Description
Would like to extend the pytensor backend of Pathfinder to compile using JAX by setting compile_kwargs=dict(mode="JAX") inpmx.fit. Not yet entirely sure what the speed advantage (if any) there is. However, I think the solution to the problem below might not be too difficult.
A required fix may be to implement JAX conversion for the LogLike operator below. (The reason for having the LogLike Op was to vectorise an existing compiled model.logp() function which takes in a flattened array of the model parameters).
pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py
Lines 693 to 716 in 00a4ca3
| class LogLike(Op): | |
| """ | |
| Op that computes the densities using vectorised operations. | |
| """ | |
| __props__ = ("logp_func",) | |
| def __init__(self, logp_func: Callable): | |
| self.logp_func = logp_func | |
| super().__init__() | |
| def make_node(self, inputs): | |
| inputs = pt.as_tensor(inputs) | |
| outputs = pt.tensor(dtype="float64", shape=(None, None)) | |
| return Apply(self, [inputs], [outputs]) | |
| def perform(self, node: Apply, inputs, outputs) -> None: | |
| phi = inputs[0] | |
| logP = np.apply_along_axis(self.logp_func, axis=-1, arr=phi) | |
| # replace nan with -inf since np.argmax will return the first index at nan | |
| mask = np.isnan(logP) | np.isinf(logP) | |
| if np.all(mask): | |
| raise PathInvalidLogP() | |
| outputs[0][0] = np.where(mask, -np.inf, logP) |
Minimum working example:
def eight_schools_model():
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
with pm.Model() as model:
mu = pm.Normal("mu", mu=0.0, sigma=10.0)
tau = pm.HalfCauchy("tau", 5.0)
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
return model
model = eight_schools_model()
with model:
idata = pmx.fit(
method="pathfinder",
num_paths=20,
jitter=12.0,
random_seed=41,
inference_backend="pymc",
compile_kwargs=dict(mode="JAX"), # <--- enable JAX mode
)Output:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[4], line 2
1 with model:
----> 2 idata = pmx.fit(
3 method="pathfinder",
4 num_paths=20,
5 jitter=12.0,
6 random_seed=41,
7 inference_backend="pymc",
8 compile_kwargs=dict(mode="JAX"),
9 )
File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/fit.py:35, in fit(method, **kwargs)
32 if method == "pathfinder":
33 from pymc_extras.inference.pathfinder import fit_pathfinder
---> 35 return fit_pathfinder(**kwargs)
37 if method == "laplace":
38 from pymc_extras.inference.laplace import fit_laplace
File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:1685, in fit_pathfinder(model, num_paths, num_draws, num_draws_per_path, maxcor, maxiter, ftol, gtol, maxls, num_elbo_draws, jitter, epsilon, importance_sampling, progressbar, concurrent, random_seed, postprocessing_backend, inference_backend, pathfinder_kwargs, compile_kwargs)
1682 maxcor = max(maxcor, 5)
1684 if inference_backend == "pymc":
-> 1685 mp_result = multipath_pathfinder(
1686 model,
1687 num_paths=num_paths,
1688 num_draws=num_draws,
1689 num_draws_per_path=num_draws_per_path,
1690 maxcor=maxcor,
1691 maxiter=maxiter,
1692 ftol=ftol,
1693 gtol=gtol,
1694 maxls=maxls,
1695 num_elbo_draws=num_elbo_draws,
1696 jitter=jitter,
1697 epsilon=epsilon,
1698 importance_sampling=importance_sampling,
1699 progressbar=progressbar,
1700 concurrent=concurrent,
1701 random_seed=random_seed,
1702 pathfinder_kwargs=pathfinder_kwargs,
1703 compile_kwargs=compile_kwargs,
1704 )
1705 pathfinder_samples = mp_result.samples
1706 elif inference_backend == "blackjax":
File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:1506, in multipath_pathfinder(model, num_paths, num_draws, num_draws_per_path, maxcor, maxiter, ftol, gtol, maxls, num_elbo_draws, jitter, epsilon, importance_sampling, progressbar, concurrent, random_seed, pathfinder_kwargs, compile_kwargs)
1493 pathfinder_config = PathfinderConfig(
1494 num_draws=num_draws_per_path,
1495 maxcor=maxcor,
(...)
1502 epsilon=epsilon,
1503 )
1505 compile_start = time.time()
-> 1506 single_pathfinder_fn = make_single_pathfinder_fn(
1507 model,
1508 **asdict(pathfinder_config),
1509 pathfinder_kwargs=pathfinder_kwargs,
1510 compile_kwargs=compile_kwargs,
1511 )
1512 compile_end = time.time()
1514 # NOTE: from limited tests, no concurrency is faster than thread, and thread is faster than process. But I suspect this also depends on the model size and maxcor setting.
File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:939, in make_single_pathfinder_fn(model, num_draws, maxcor, maxiter, ftol, gtol, maxls, num_elbo_draws, jitter, epsilon, pathfinder_kwargs, compile_kwargs)
936 lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls)
938 # pathfinder body
--> 939 pathfinder_body_fn = make_pathfinder_body(
940 logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs
941 )
942 rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs)
944 def single_pathfinder_fn(random_seed: int) -> PathfinderResult:
File ~/projects/pymc-devs/pymc-extras/pymc_extras/inference/pathfinder/pathfinder.py:857, in make_pathfinder_body(logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs)
853 logP_psi = loglike(psi)
855 # return psi, logP_psi, logQ_psi, elbo_argmax
--> 857 pathfinder_body_fn = compile_pymc(
858 [x_full, g_full],
859 [psi, logP_psi, logQ_psi, elbo_argmax],
860 **compile_kwargs,
861 )
862 pathfinder_body_fn.trust_input = True
863 return pathfinder_body_fn
File ~/projects/pymc-devs/pymc/pymc/pytensorf.py:956, in compile_pymc(*args, **kwargs)
951 def compile_pymc(*args, **kwargs):
952 warnings.warn(
953 "compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC",
954 FutureWarning,
955 )
--> 956 return compile(*args, **kwargs)
File ~/projects/pymc-devs/pymc/pymc/pytensorf.py:941, in compile(inputs, outputs, random_seed, mode, **kwargs)
939 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
940 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 941 pytensor_function = pytensor.function(
942 inputs,
943 outputs,
944 updates={**rng_updates, **kwargs.pop("updates", {})},
945 mode=mode,
946 **kwargs,
947 )
948 return pytensor_function
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/__init__.py:318, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
312 fn = orig_function(
313 inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
314 )
315 else:
316 # note: pfunc will also call orig_function -- orig_function is
317 # a choke point that all compilation must pass through
--> 318 fn = pfunc(
319 params=inputs,
320 outputs=outputs,
321 mode=mode,
322 updates=updates,
323 givens=givens,
324 no_default_updates=no_default_updates,
325 accept_inplace=accept_inplace,
326 name=name,
327 rebuild_strict=rebuild_strict,
328 allow_input_downcast=allow_input_downcast,
329 on_unused_input=on_unused_input,
330 profile=profile,
331 output_keys=output_keys,
332 )
333 return fn
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/pfunc.py:465, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
451 profile = ProfileStats(message=profile)
453 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
454 params,
455 outputs,
(...)
462 fgraph=fgraph,
463 )
--> 465 return orig_function(
466 inputs,
467 cloned_outputs,
468 mode,
469 accept_inplace=accept_inplace,
470 name=name,
471 profile=profile,
472 on_unused_input=on_unused_input,
473 output_keys=output_keys,
474 fgraph=fgraph,
475 )
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/types.py:1769, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
1757 m = Maker(
1758 inputs,
1759 outputs,
(...)
1766 fgraph=fgraph,
1767 )
1768 with config.change_flags(compute_test_value="off"):
-> 1769 fn = m.create(defaults)
1770 finally:
1771 if profile and fn:
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/compile/function/types.py:1661, in FunctionMaker.create(self, input_storage, storage_map)
1658 start_import_time = pytensor.link.c.cmodule.import_time
1660 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1661 _fn, _i, _o = self.linker.make_thunk(
1662 input_storage=input_storage_lists, storage_map=storage_map
1663 )
1665 end_linker = time.perf_counter()
1667 linker_time = end_linker - start_linker
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
238 def make_thunk(
239 self,
240 input_storage: Optional["InputStorageType"] = None,
(...)
243 **kwargs,
244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245 return self.make_all(
246 input_storage=input_storage,
247 output_storage=output_storage,
248 storage_map=storage_map,
249 )[:3]
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
692 for k in storage_map:
693 compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
696 compute_map, nodes, input_storage, output_storage, storage_map
697 )
699 [fn] = thunks
700 fn.jit_fn = jit_fn
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
644 # This is a bit hackish, but we only return one of the output nodes
645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
648 self.fgraph,
649 order=order,
650 input_storage=input_storage,
651 output_storage=output_storage,
652 storage_map=storage_map,
653 )
655 thunk_inputs = self.create_thunk_inputs(storage_map)
656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/jax/linker.py:67, in JAXLinker.fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs)
64 fgraph.inputs.remove(new_inp)
65 fgraph.inputs.insert(old_inp_fgrap_index, new_inp)
---> 67 return jax_funcify(
68 fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
69 )
File ~/miniconda3/envs/python-3.10/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:54, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
47 @jax_funcify.register(FunctionGraph)
48 def jax_funcify_FunctionGraph(
49 fgraph,
(...)
52 **kwargs,
53 ):
---> 54 return fgraph_to_python(
55 fgraph,
56 jax_funcify,
57 type_conversion_fn=jax_typify,
58 fgraph_name=fgraph_name,
59 **kwargs,
60 )
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
734 body_assigns = []
735 for node in order:
--> 736 compiled_func = op_conversion_fn(
737 node.op, node=node, storage_map=storage_map, **kwargs
738 )
740 # Create a local alias with a unique name
741 local_compiled_func_name = unique_name(compiled_func)
File ~/miniconda3/envs/python-3.10/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/miniconda3/envs/python-3.10/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:44, in jax_funcify(op, node, storage_map, **kwargs)
41 @singledispatch
42 def jax_funcify(op, node=None, storage_map=None, **kwargs):
43 """Create a JAX compatible function from an PyTensor `Op`."""
---> 44 raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: LogLike{logp_func=<function make_single_pathfinder_fn.<locals>.logp_func at 0x7f7cac13f010>}