This is the official implementation of Nearest Neighbor Speculative Decoding for LLM Generation and Attribution, NeurIPS 2024 using Huggingface.
First, make sure you have Anaconda3 installed. Then use conda to create a new environment and activate it:
conda create -n nest python=3.10
conda activate nest
Now let's install the packages. First, follow the instructions here to install PyTorch on your machine. Then install faiss:
conda install faiss
Finally install the packages in requirements.txt
. Remember to comment out the packages in the .txt file that you've already installed to avoid conflicts.
pip install -r requirements.txt
For flash attention, run
pip install flash-attn --no-build-isolation
To get started, we provide an example script of using NEST for generation:
python demo.py
You can play with the hyperparameters in the demo to understand how they affect the NEST generation.
The default models we use are Llama-2-chat-7b, 13b, and 70b. You can switch to other CausalLM models on Huggingface. Remember to change the start_tag
and end_tag
value for your models. We provide the links of the Llama-2-chat models we used in the following:
Models | Huggingface Tag |
---|---|
Llama-2-chat-7b | meta-llama/Llama-2-7b-chat-hf |
Llama-2-chat-13b | meta-llama/Llama-2-13b-chat-hf |
Llama-2-chat-70b | meta-llama/Llama-2-70b-chat-hf |
Remeber to get the authorization from Meta before using the above models.
We use the Wikipedia 2021 dump from the Atlas repo. Download the corpus following the instructions in the repo.
The default embedders/retrievers we used are DRAGON+ and BM25 (Pyserini). To build the sparse index using BM25, run:
python data/convert_atlas_corpus_to_pyserini_format.py your/path/to/downloaded/corpus collections/enwiki-dec2021
bash data/pyserini_index.sh
For the dense index, we use FAISS and the index string "IVF65536,PQ256" to build the index. Please see the DRAGON+ repo for more detailed index building instructions. We also open-source the dense index we used.
Indexes | Size |
---|---|
DRAGON+ | 8.96GB |
BM25 (Pyserini) | 3.48GB |
By default, both BM25 and DRAGON+ run on CPUs during retrieval of which the latency is controlled by the number of threads. However, for the dense index, it is also possible to use GPUs to accelerate which we leave for other custom implementations.
Task | Description | Size |
---|---|---|
WikiText-103 | Text completion | 2357 (test) |
NQ | Question Answering | 3610 (test) |
TriviaQA | Question Answering | 11313 (test) |
HotpotQA | Question Answering | 2500 (dev, sub-sample) |
MedMCQA | Question Answering | 2500 (dev, sub-sample) |
TruthfulQA | Question Answering | 817 (test) |
FactScore | Fact Verification | 500 (unlabeled) |
We evaluate NEST on the above tasks based on the Wikipedia corpus.
To preprocess the data, we provide a script in the data/
folder for each task.
To evaluate the NEST on the above tasks, run
bash scripts/eval_gen_task.sh
Change the task
argument separated by "," and other arguments before running the evaluation. Note the Llama-2-chat models use special tags ("[INST] and [/INST]") for instructions fine-tuning. These tags might not be necessary if you use other CausalLMs.
Due to the compatibility issues between the CC-BY-NC license of NEST's code and the GPL3 license of MAUVE, we do not include this metric in the evaluation on WikiText-103 and Pile-of-Law.
For FactScore, please take the generation results and follow the instruction in the FactScore repo for evaluation. In the paper, we use the internal LLM finetuned for fact decomposition and fact checking, which may not be released in the future.
You can also pre-fectch the supporting documents (as we provided in the task data) to avoid passage retrieval during generation by running:
bash scripts/eval_retrieval_task.sh
The results will be saved in the same format as the input data with an extra field "support". Remeber to move the results to the data input path before running evaluation. You need to add the --pre_retrieved
argument for document pre-fetching.
You can also pre-encode the corpus into tokens to save time during generation. You need to add the --pre_tokenized
argument for corpus pre-encoding. See the encode
function in knn_transformer.py
for more details of pre-encoding.
@misc{li2024nearestneighborspeculativedecoding,
title={Nearest Neighbor Speculative Decoding for LLM Generation and Attribution},
author={Minghan Li and Xilun Chen and Ari Holtzman and Beidi Chen and Jimmy Lin and Wen-tau Yih and Xi Victoria Lin},
year={2024},
eprint={2405.19325},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2405.19325},
}
The code of NEST is licensed under CC-BY-NC.