Skip to content

This the official implementation of Nearest Neighbor Speculative Decoding (https://arxiv.org/abs/2405.19325), an inference-time revision approach to enhance LLM factuality and generation attribution.

License

Notifications You must be signed in to change notification settings

facebookresearch/NEST

Repository files navigation

Nearest Neighbor Speculative Decoding for LLM Generation and Attribution

License: CC BY-NC 4.0 Arxiv Tweet

This is the official implementation of Nearest Neighbor Speculative Decoding for LLM Generation and Attribution, NeurIPS 2024 using Huggingface. Screenshot 2025-04-22 at 10 35 41 PM

Dependencies

First, make sure you have Anaconda3 installed. Then use conda to create a new environment and activate it:

conda create -n nest python=3.10
conda activate nest

Now let's install the packages. First, follow the instructions here to install PyTorch on your machine. Then install faiss:

conda install faiss

Finally install the packages in requirements.txt. Remember to comment out the packages in the .txt file that you've already installed to avoid conflicts.

pip install -r requirements.txt

For flash attention, run

pip install flash-attn --no-build-isolation

Demo

To get started, we provide an example script of using NEST for generation:

python demo.py

You can play with the hyperparameters in the demo to understand how they affect the NEST generation.

Models

The default models we use are Llama-2-chat-7b, 13b, and 70b. You can switch to other CausalLM models on Huggingface. Remember to change the start_tag and end_tag value for your models. We provide the links of the Llama-2-chat models we used in the following:

Models Huggingface Tag
Llama-2-chat-7b meta-llama/Llama-2-7b-chat-hf
Llama-2-chat-13b meta-llama/Llama-2-13b-chat-hf
Llama-2-chat-70b meta-llama/Llama-2-70b-chat-hf

Remeber to get the authorization from Meta before using the above models.

Data Prep

Corpus

We use the Wikipedia 2021 dump from the Atlas repo. Download the corpus following the instructions in the repo.

Indexes

The default embedders/retrievers we used are DRAGON+ and BM25 (Pyserini). To build the sparse index using BM25, run:

python data/convert_atlas_corpus_to_pyserini_format.py your/path/to/downloaded/corpus collections/enwiki-dec2021
bash data/pyserini_index.sh

For the dense index, we use FAISS and the index string "IVF65536,PQ256" to build the index. Please see the DRAGON+ repo for more detailed index building instructions. We also open-source the dense index we used.

Indexes Size
DRAGON+ 8.96GB
BM25 (Pyserini) 3.48GB

By default, both BM25 and DRAGON+ run on CPUs during retrieval of which the latency is controlled by the number of threads. However, for the dense index, it is also possible to use GPUs to accelerate which we leave for other custom implementations.

Tasks

Task Description Size
WikiText-103 Text completion 2357 (test)
NQ Question Answering 3610 (test)
TriviaQA Question Answering 11313 (test)
HotpotQA Question Answering 2500 (dev, sub-sample)
MedMCQA Question Answering 2500 (dev, sub-sample)
TruthfulQA Question Answering 817 (test)
FactScore Fact Verification 500 (unlabeled)

We evaluate NEST on the above tasks based on the Wikipedia corpus. To preprocess the data, we provide a script in the data/ folder for each task.

To evaluate the NEST on the above tasks, run

bash scripts/eval_gen_task.sh

Change the task argument separated by "," and other arguments before running the evaluation. Note the Llama-2-chat models use special tags ("[INST] and [/INST]") for instructions fine-tuning. These tags might not be necessary if you use other CausalLMs.

Due to the compatibility issues between the CC-BY-NC license of NEST's code and the GPL3 license of MAUVE, we do not include this metric in the evaluation on WikiText-103 and Pile-of-Law.

For FactScore, please take the generation results and follow the instruction in the FactScore repo for evaluation. In the paper, we use the internal LLM finetuned for fact decomposition and fact checking, which may not be released in the future.

Pre-retrieved

You can also pre-fectch the supporting documents (as we provided in the task data) to avoid passage retrieval during generation by running:

bash scripts/eval_retrieval_task.sh

The results will be saved in the same format as the input data with an extra field "support". Remeber to move the results to the data input path before running evaluation. You need to add the --pre_retrieved argument for document pre-fetching.

Pre-tokenized

You can also pre-encode the corpus into tokens to save time during generation. You need to add the --pre_tokenized argument for corpus pre-encoding. See the encode function in knn_transformer.py for more details of pre-encoding.

Citation

@misc{li2024nearestneighborspeculativedecoding,
      title={Nearest Neighbor Speculative Decoding for LLM Generation and Attribution}, 
      author={Minghan Li and Xilun Chen and Ari Holtzman and Beidi Chen and Jimmy Lin and Wen-tau Yih and Xi Victoria Lin},
      year={2024},
      eprint={2405.19325},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2405.19325}, 
}

License

The code of NEST is licensed under CC-BY-NC.

About

This the official implementation of Nearest Neighbor Speculative Decoding (https://arxiv.org/abs/2405.19325), an inference-time revision approach to enhance LLM factuality and generation attribution.

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •