How to Finetune the Entire RAG Architecture (including DPR retriever)

Shamane Siriwardhana
8 min readMay 10, 2021

with Huggingface, Pytorch-Lightning, and Ray

In September 2020, Facebook open-sourced a new NLP model called Retrieval Augmented Generation (RAG) on the Hugging Face Transformer library. RAG is able to use a set of support documents from an external knowledge base as a latent variable to generate the final output. The RAG model consists of an Input Encoder, a Neural Retriever, and an Output Generator. All three components are initialized with pre-trained transformers. However, the original Hugging Face implementation only fine-tunes the Input Encoder, and the Output Generator in an end-to-end manner. To the best of our knowledge, an end-to-end RAG implementation that trains all three components does not exist. In this blog post, we introduce a novel approach to extending the RAG implementation which fulfills the task of training all three components end-to-end (PR to HugginFace). Although this appears straightforward, there are many engineering challenges that need to be dealt with.

So what is amazing about RAG-like architectures?

As the name suggests, RAG adds an intermediate information retrieval step before the final generation process. It combines the ideas of neural information retrieval with seq2seq generation in an end-to-end manner. There are several blogs that explain the core concepts of RAG (such as this, and this). So, after a quick review of RAG, we’ll walk through the additions we made to the original Hugging Face RAG implementation.

Figure 1: Original RAG model. Image courtesy RAG paper.

As illustrated in Figure 1, during the training phase, RAG takes an input and encodes it to a high-dimensional vector. This vector then gets used as a query to select the most relevant passages from an external database using Maximum Inner Product Search (MIPS). Finally, the input and the selected set of documents get fed into a generator that produces the final answer. It is important to emphasize that components of RAG are being initialized with pre-trained BERT and BART transformers. Finally, the probability of selecting documents given an input context (p(z|x)) as mentioned in RAG loss (Figure:2) is used to propagate gradients back to the question encoder.

Figure 2: RAG loss function

During the process of making the RAG end-to-end trainable, we mainly work on the Retriever part, which uses a neural retriever model called Dense Passage Retrieval (DPR).

RAG Retriever Model

Retriever basically consists of a pre-trained DPR model. As Figure 3 illustrates, the DPR model consists of two BERT models. One model encodes questions and the other model encodes the documents using the CLS token outputs. DPR model used in RAG has already been trained with passages and questions extracted from open domain Wikipedia-based datasets.

Dense Passage Retrieval (DPR) architecture.

RAG is known to be having a neural retriever, and a reader combined in an end-to-end manner. However, in practice, we freeze the passage encoder and only train the question encoder. (check Figure:1). But, in our approach, the passage encoder is trainable as well.

In RAG, prior to the fine-tuning, a frozen Passage Encoder is used to encode your external Knowledge Base (KB) and save it on disk. Then we do the dataset indexing. Imagine there are millions of documents that you need to search and retrieve during the training time. Usual vector similarity can be too slow. So we use some approximation methods by clustering the vector space into subregions. There are several libraries for this, such as ScANN and FAISS. In RAG implementation Huggingface uses the FAISS to make the retrieval phase faster (see this blog for more details on FAISS).

See use_own_knowledge_dataset.py script on how this is done in code. After this step, we start the training process with the indexed dataset where we only update the model parameters of the Question Encoder and Generator networks. We do not use the Passage Encoder or change the encoded embeddings for documents during the training.

Why is it important to fine-tune the entire retriever?

RAG authors illustrated it is ok not to fine-tune the Passage Encoder for tasks like question answering and fact-checking. But the authors have conducted their experiments mainly on open domain Wikipedia-like datasets. Since DPR also initially trained on Wiki-data, it really makes sense!!!

But what about other domains like financial, healthcare, and legislation? So we are interested in the following two facts.

  • Does the training of the entire RAG-retriever help domain adaptation?
  • We also motivated by similar retrieval augmented models like REALM, where the authors mentioned it would be advantageous to train the entire retriever.

So why is it hard to make the entire RAG-retriever end-to-end trainable?

Theoretically, there are no problems in propagating gradients to both the passage encoder and question encoder BERT model. As described in the RAG loss function (Figure 2), we can compute the document scores for latent documents (p(z|x)) by computing question encoder embeddings and document encoder embeddings during the training time. This has the following two steps,

  1. Combined a pre-trained Passage Encoder model to the RAG model prior to the training.
  2. During the training, first, retrieve relevant passages given a question by using the indexed dataset.
  3. Calculate the document scores by re-computing CLS embeddings for retrieved documents using the initialized passage BERT model and do the same thing for input using the question encoder.

Although the propagation of the gradients to the passage encoder is straightforward, we need to make sure the updated passage encoder gets used in the overall training architecture. So we have to somehow update the indexed dataset during the training process!

But then, there are a lot of engineering challenges that we need to solve…

Why updating the indexed KB during the training is time-consuming and computationally expensive?

There are two main things! Let’s say we have an external KB with 10 million passages.

  1. During training, we first need to send each passage in the external KB through the updated passage encoder to compute the CLS token. Let’s call this step re-encoding the KB.
  2. Secondly, with the updated embeddings for each passage, we need to re-index the external KB with FAISS. Let’s call this part re-indexing of the KB.

So, similar to the REALM training process, we can execute external KB re-encoding and re-indexing commands in every N-training step. The authors in REALM have mentioned that if we improve the frequency of KB updates, we can get better results (check section 3.3 in REALM paper).

But still, those two processes mentioned above can take lots of time….

we do not need to hold our training process for hours until we finish the above two processes. Which is really impractical! (read about the effectiveness of stale gradients in REALM paper in section 3.3).

Making the re-encoding and re-indexing practical during the training loop

We really can use the following three amazing libraries.

  1. Hugging Face Datasets library
  2. Pytorch-Lightning
  3. Python Multiprocessing

First, we need to make sure the re-encoding step is fast and doesn’t make the entire training loop wait until it finishes the computation of embeddings for every passage in the KB.

How to use a separate set of GPUs for the re-encoding process during the training loop?

Knowledge-base re-encoding with parallel GPU processes.

Usually, the training process uses Distributed Data-Parallel (DDP) mechanism to cope up with multiple GPUs and we do not want to run into OOM issues by trying to use GPUs in the DDP process to complete the re-encoding step.

So the question is how to use extra GPUs inside the training loop to complete the re-encoding step. Here comes awesome python parallel processing. Imagine, we start another parallel process from the training loop in every N-steps that only focus on re-encoding.

Now let’s go through the code. Pytorch-Lightning gives the ability to easily access the training loop with the def training_step hook.

Inlay terms, the logic is as follows (please refer to the gist file to get a better understanding):

  • Since there are several DDP processes in the multi-GPU training setup, we will only start re-encoding in the master DDP process.
  • Every N steps we need to start parallel processes that do the re-encoding.
  • A number of re-encoding parallel processes depend on the number of free GPUs, which can be found easily by the fast.ai’s GPU management library.
  • When we are using multiple GPUs in re-encoding, we need to shard the local KB into equal chunks, which can be easily done by the HuggingFace Datasets library.
  • Then we can pass a copy of the updated passage encoder (ctx_encoder in the code) along with KB shards to the re-encoding function.
  • We also need to make sure we do not start new re-encoding processes if there are old re-encoding processes running. Well, that is basic parallel processing.
Initialization of multiple processes that encodes the dataset shards.

So now one problem is solved… without obstructing the training process, we can run the re-encoding process.

Re-indexing the knowledge-base

The Next is to update the index... The original Huggingface RAG implementation uses the HNSW FAISS index. This index is very fast during the retrieval process.

But index building process takes some time. We work with a KB with 10 million passages. So it takes nearly two hours to build the index when using a 48 core computer. There is no point in holding up the training process. Similar to re-encoding, we again start a parallel process to complete re-indexing.

Initializing the re-indexing parallel process once re-encoding processes are finished

Yes… now the final part …

So how can we transform the updated index into the training process?

We can do this easily by playing around with the RAY-RAG distributed retrieval method mentioned in the Transformers library. First, we need to start from the master process training loop (please refer to the distributed_ray_retriever.py for more details).

Transforming the new index to ray workers from the master DDP process.

Again, here’s the simple logic :

  • First, check whether the re-indexing process has finished.
  • Update the KB with embeddings and the index files in the drive.
  • Then, access RAY workers, reset and re-initialized the index.
To avoid memory leaks in RAY workers, we delete previous objects and update them with an updated index.

Future work

We hope this implementation can be used to execute other models such as REALM and MARGE.

We also with this implementation we can explore Retrieval Augmented Models in depth.

Special Thanks

  • Amazing Hugging Face Crew!
  • Especially Patrick von Platen and Quentin Lhoest who helped by answering all our issues :)

Written By
Shamane Siri, Rivindu Weerasekera, Elliot Wen, Suranga Nanayakkara

--

--

Shamane Siriwardhana

Ph.D. Candidate — The University of Auckland, New Zealand