Retrieval-Augmented Generation — Deep Dive into its Components

Subir Verma
11 min readJan 30, 2024

Introduction

When Large Language Models are pre-trained on a huge text corpus they learn and store knowledge as their parameters. To access that knowledge we use pre-trained models for downstream tasks which depict a much better performance based on task to task. Those Pre-trained Large Neural Models cannot still manipulate certain knowledge. Its memory can't be expanded or revised and it might lead to “hallucinations”.

This is where we explore Retrieval-Augmented Generation RAG models, which use the input sequence x to retrieve text documents z and use them as additional context when generating the target sequence y.

We endow pre-trained, parametric-memory generation models with a non-parametric memory through a general-purpose fine-tuning approach which we refer to as retrieval-augmented generation (RAG). We introduce RAG models where the parametric memory is a pre-trained seq2seq model and the non-parametric memory is a dense vector index, accessed with a pre-trained neural retriever. RAG models generate responses that are more factual, specific, and diverse.

This blog will take time.

Photo by DJ Johnson on Unsplash

Overview of approach

We combine a pre-trained retriever (Query Encoder + Document Index) with a pre-trained seq2seq model (Generator) and fine-tune end-to-end. For query x, we use Maximum Inner Product Search (MIPS) to find the top-K documents zi. For the final prediction y, we treat z as a latent variable and marginalize over seq2seq predictions given different documents.

About MIPS: Maximum Inner Product Search (MIPS) is a problem in information retrieval and similarity search where the goal is to find the vector in a given database that maximizes the inner product with a query vector. MIPS has various applications, including:

Information Retrieval: Given a large dataset of documents represented as vectors, MIPS can be used to find the document that is most relevant to a query.

Recommendation Systems: In collaborative filtering, where user preferences are represented as vectors, MIPS can help find users or items with similar preferences.

Computer Vision: MIPS is applicable in image retrieval tasks, where visual features of images are represented as vectors.

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —

About Marginalizing

Marginalization: In probability theory, marginalization involves summing or integrating over one or more variables to obtain the distribution of interest. In this case, the model is interested in the distribution over generated text.

RAG’s internal knowledge can be easily altered or even supplemented on the fly, enabling researchers and engineers to control what RAG knows and doesn’t know without wasting time or computing power retraining the entire model.

RAG looks and acts like a standard seq2seq model, meaning it takes in one sequence and outputs a corresponding sequence. An intermediary step, though, differentiates and elevates RAG above the usual seq2seq methods. Rather than passing the input directly to the generator, RAG instead uses the input to retrieve a set of relevant documents.

Parametric and Non-parametric Memory

The RAG model combines a generative model, often based on a sequence-to-sequence (seq2seq) architecture, with a retrieval component that utilizes external memory. Let’s break down the components:

  1. Parametric Memory (Pre-trained Seq2Seq Model):
  • Parametric: “Parametric” means that the memory structure has a fixed number of parameters. In this case, the memory is represented by a pre-trained seq2seq model. The parameters of the seq2seq model are fixed during the retrieval and generation processes.
  • Pre-trained Seq2Seq Model: This refers to a sequence-to-sequence model that has been trained on a large dataset for a specific natural language processing (NLP) task. Seq2seq models are often used for tasks like machine translation, summarization, or question answering.
    Non-parametric Memory (Dense Vector Index of Wikipedia, accessed with a Pre-trained Neural Retriever):

2. Non-parametric: “Non-parametric” implies that the memory structure is not fixed but can adapt its complexity based on the data or task. In this case, the non-parametric memory is represented by a dense vector index of Wikipedia.

  • Dense Vector Index of Wikipedia: This refers to a representation of the content in Wikipedia using dense vectors, often created through techniques like word embeddings or other vectorization methods. Each document or section in Wikipedia is represented by a dense vector.
  • Pre-trained Neural Retriever: This is a neural network model trained to retrieve relevant information from the dense vector index. The retriever takes a query or context as input and outputs the most relevant documents or vectors from the index.

RAG Components

Figure 1: Overview of the RAG Approach

We build RAG models where the parametric memory is a pre-trained seq2seq transformer, and the non-parametric memory is a dense vector index, accessed with a pre-trained neural retriever. We combine these components in a probabilistic model trained end-to-end.

The retriever (Dense Passage Retriever, henceforth DPR) provides latent documents conditioned on the input, and the seq2seq model (BART) then conditions these latent documents together with the input to generate the output.

What do we mean when we say “latent documents conditioned on the input, and the seq2seq model then conditions on these latent documents together with the input to generate the output?”

  1. Latent Documents:
  • Generation: In the first step, a model generates latent documents based on some input. These latent documents are representations of information, context, or features that the model deems relevant for the subsequent generation task. The term “latent” often implies that these documents are not directly observed in the training data but are learned by the model.

2. Seq2Seq Model:

  • Conditioning on Latent Documents: After generating the latent documents, a sequence-to-sequence (seq2seq) model is employed. Seq2seq models are neural architectures designed for sequence generation tasks. They consist of an encoder-decoder structure where the encoder processes input sequences and the decoder generates output sequences.
  • Joint Conditioning: In this specific scenario, the seq2seq model is conditioned not only on the original input but also on the generated latent documents. This means that the latent documents become additional contextual information for the seq2seq model during the generation process.
  • Output Generation: The seq2seq model takes the input and the latent documents as input and generates an output sequence based on this joint information. The output could be in the form of text, translation, summarization, or any other sequence-based task, depending on the specific application.

This two-step approach, involving generating latent documents followed by the seq2seq model conditioned on these documents, allows the model to capture complex relationships and contextual information that may not be directly present in the input data. It leverages the latent space to encode relevant information before the final sequence generation, potentially improving the model’s performance on tasks requiring a deeper understanding of context and semantics.

As we observe in Figure 1, there are broadly 2 components

  1. Retriever: A retriever pη(z|x) with parameters η that returns (top-K truncated) distributions over text passages given a query x. The retrieval component pη(z|x) is based on DPR. From a paper titled “Dense Passage Retrieval for Open-Domain Question Answering”.
  • Our dense passage retriever (DPR) uses a dense Query encoder EP (·) which maps any text passage to a “d” dimensional real-valued vector and builds an index for all the “z” passages that we will use for retrieval.
  • At run-time, DPR applies a different encoder EQ(·) that maps the input question to a d-dimensional vector, and retrieves k passages of which vectors are the closest to the question vector

where;

  • d(z) is a dense representation of a document produced by a BERT BASE document encoder and
  • q(x) is a query representation produced by a query encoder, also based on BERT BASE.

Calculating top-k(pη(·|x)), the list of k documents z with the highest prior probability pη(z|x), is a Maximum Inner Product Search (MIPS) problem. MIPS can be solved using Approximate Nearest Neighbour ANN Search.

2. Generator: A generator pθ(yi|x, z, y1:i−1) parametrized by θ that generates a current token based on a context of the previous i − 1 tokens y1:i−1, the original input x and a retrieved passage z.

  • The generator component pθ(yi|x, z, y1:i−1) could be modeled using any encoder-decoder. We use BART-large.
  • To combine the input x with the retrieved content z when generating from BART, we concatenate them.
  • BART was pre-trained using a denoising objective and a variety of different noising functions.
  • We refer to the BART generator parameters θ as the parametric memory henceforth.

BART (Bidirectional and Auto-Regressive Transformers) is a natural language processing (NLP) model introduced by Facebook AI Research. It belongs to the family of transformer-based models and is designed for various sequence-to-sequence tasks in NLP.

BART utilizes a transformer-based architecture, similar to models like GPT (Generative Pre-trained Transformer) and T5 (Text-to-Text Transfer Transformer). The transformer architecture allows BART to capture contextual information and relationships within sequences of text.

RAG Model

To train the retriever and generator end-to-end, we treat the retrieved document as a latent variable. We propose two models that marginalize the latent documents in different ways to produce a distribution over generated text.

Treating the Retrieved Document as a Latent Variable:

Latent Variable: In the context of generative models, a latent variable is an unobserved or hidden variable that influences the generation process. It captures information that is not explicitly observed in the training data but is essential for generating diverse and contextually relevant outputs.

Retrieved Document as Latent Variable: In this training approach, the document retrieved by the retriever is considered a latent variable. This means that during training, the model treats the specific document chosen by the retriever as an unobserved variable that affects the generation of the final output.

Marginalize the latent documents: The statement mentions that two models are proposed for marginalizing the latent documents in different ways. This implies that during training, the model considers multiple possible retrieved documents and integrates them to obtain a distribution over the generated text. The two models likely differ in how they handle and incorporate information from the latent variable (the retrieved document) during the marginalization process.

What are those two components?

  1. RAG-Sequence
  • The retriever selects relevant sequences (documents) from a larger corpus, and the generator generates text based on both the input query/context and the retrieved sequences.

The equation says that

  • Top-K Approximation: The model only considers the top k most relevant documents retrieved by the neural retriever, as indicated by the ‘z ∈ top-k(p(.| x))’ notation
  • Generate Output Sequence for Each Document: For each retrieved document z, the neural generator calculates the probability of the target sequence y given the input sequence x and the document z.
  • Average Probabilities: The model averages the probabilities of the target sequence y for all retrieved documents to obtain a final probability score for the entire sequence.

2. RAG-Token

  • the retriever and generator operate at the level of individual tokens rather than entire sequences. The retriever may select specific tokens from a larger corpus, and the generator then generates text based on the input query/context and the retrieved tokens.
  • Retrieve Top K Documents: The model retrieves the top k most relevant documents (Z) to the input sequence X.
  • Iterate through Target Tokens: The model processes each token Y_i in the target sequence Y:

a. Generate Distribution for Each Retrieved Document: For each retrieved document Z, the generator calculates the probability of the next token Y_i given the input X, the document Z, and the previously generated tokens.

b. Marginalize Probabilities: The probabilities for Y_i from all retrieved documents are averaged to obtain a single probability for the token.

  • Repeat for All Tokens: The process of generating a distribution for each document and marginalizing is repeated for each token in the target sequence.

RAG Training

The approach involves the joint training of retriever and generator components without explicit supervision on document retrieval. The training data comprises input/output pairs (xj, yj), and the objective is to minimize the negative marginal log-likelihood of each target, Σj — log p(yj | xj), utilizing stochastic gradient descent with the Adam optimizer.

Unlike some methods that involve periodic updates to the document encoder (BERTd) during training, which can be computationally expensive, the proposed approach opts to keep the document encoder (and its index) fixed. In contrast, only the query encoder (BERTq) and the BART generator undergo fine-tuning. This deviation from updating the document encoder is motivated by the belief that such a step is not essential for achieving robust performance, in contrast to the REALM method that employs such updates during pre-training.

RAG Decoding

At test time, RAG-Sequence and RAG-Token require different ways to approximate arg max.

The RAG-Token model can be seen as a standard, autoregressive seq2seq generator

For RAG-Sequence, the likelihood p(y|x) does not break into a conventional per-token likelihood, hence we cannot solve it with a single beam search.

Key Challenges for Decoding:

  • Non-Decomposable Likelihood: In RAG-Sequence, the likelihood of the target sequence y given the input x (p(y|x)) doesn’t break down into per-token likelihoods, making it incompatible with conventional single-pass beam search algorithms.
  • Document-Specific Generation: The model generates output sequences based on specific retrieved documents, requiring a decoding process that considers multiple documents and their relevance.

Here’s a detailed explanation of “Thorough Decoding” in the context of RAG-Sequence decoding:

Key Challenges for Decoding:

  • Non-Decomposable Likelihood: In RAG-Sequence, the likelihood of the target sequence y given the input x (p(y|x)) doesn’t break down into per-token likelihoods, making it incompatible with conventional single-pass beam search algorithms.
  • Document-Specific Generation: The model generates output sequences based on specific retrieved documents, requiring a decoding process that considers multiple documents and their relevance.

Thorough Decoding Process:

  1. Beam Search per Document:
  • The model runs a separate beam search for each retrieved document z.
  • Each beam search generates a set of candidate hypotheses (partial or complete output sequences).
  • Hypotheses are scored using pθ(yi|x, z, y1:i-1) (probability of generating the i-th token given the input x, retrieved document z, and previously generated tokens).

2. Collect Hypotheses:

  • All generated hypotheses from all beam searches are pooled together into a set Y.

3. Additional Forward Passes:

  • For each hypothesis y in Y that didn’t appear in the beam of a particular document z:
  • Run an additional forward pass of the model using that document z and the hypothesis y.
  • Calculate the probability of hypothesis y given the document z and the input x.

4. Probability Calculation:

  • Multiply the generator probability (from step 3) with the retriever probability pη(z|x) (probability of retrieving document z given the input x).
  • Sum the probabilities across all beams and documents to obtain a final, marginalized probability score for each hypothesis y.

5. Select Best Hypothesis:

  • The hypothesis with the highest final probability is selected as the most likely output sequence.

Purpose of Thorough Decoding in Generator

  • Ensures comprehensive consideration of all possible hypotheses, even those that might not have initially appeared in the beam searches of certain documents.
  • Enhances accuracy and diversity of generated outputs by exploring a broader range of options and incorporating information from multiple retrieved documents.
  • Mitigates potential biases or limitations of individual beam searches, leading to more robust and reliable results.

--

--