diff --git a/jukebox/hparams.py b/jukebox/hparams.py index eb74584aa1..3870509d18 100644 --- a/jukebox/hparams.py +++ b/jukebox/hparams.py @@ -207,6 +207,13 @@ def setup_hparams(hparam_set_names, kwargs): ) HPARAMS_REGISTRY["small_vqvae"] = small_vqvae +custom_vqvae = Hyperparams( + restore_vqvae="https://genxx.s3.us-east-1.amazonaws.com/small_vqvae/checkpoint_step_200001.pth.tar", +) +custom_vqvae.update(small_vqvae) +HPARAMS_REGISTRY["custom_vqvae"] = custom_vqvae + + small_prior = Hyperparams( n_ctx=8192, prior_width=1024, @@ -219,6 +226,18 @@ def setup_hparams(hparam_set_names, kwargs): ) HPARAMS_REGISTRY["small_prior"] = small_prior +custom_prior = Hyperparams( + restore_prior="https://genxx.s3.us-east-1.amazonaws.com/small_prior/checkpoint_latest.pth.tar", + level=2, + labels=False, + alignment_layer=None, + alignment_head=None, +) +custom_prior.update(small_prior) +HPARAMS_REGISTRY["custom_prior"] = custom_prior + + + small_labelled_prior = Hyperparams( labels=True, labels_v3=True, @@ -231,6 +250,8 @@ def setup_hparams(hparam_set_names, kwargs): small_labelled_prior.update(small_prior) HPARAMS_REGISTRY["small_labelled_prior"] = small_labelled_prior + + small_single_enc_dec_prior = Hyperparams( n_ctx=6144, prior_width=1024, @@ -303,6 +324,15 @@ def setup_hparams(hparam_set_names, kwargs): HPARAMS_REGISTRY["small_upsampler"] = small_upsampler + +custom_upsampler = Hyperparams( + restore_prior="https://genxx.s3.us-east-1.amazonaws.com/small_upsampler/checkpoint_latest.pth.tar", + level=0, + labels=False, +) +custom_upsampler.update(small_upsampler) +HPARAMS_REGISTRY["custom_upsampler"] = custom_upsampler + all_fp16 = Hyperparams( fp16=True, fp16_params=True, @@ -484,7 +514,7 @@ def setup_hparams(hparam_set_names, kwargs): ) DEFAULTS["opt"] = Hyperparams( - epochs=10000, + epochs=50, lr=0.0003, clip=1.0, beta1=0.9,