The Illustrated Dense Passage retriever

8 minute read

Introduction

Open-domain question answering relies on efficient passage retrieval to select candidate contexts, where traditional sparse vector space models, such as TF-IDF or BM25, are the defacto method. We can implement using dense representations, where embeddings are learned from a small number of questions and passages by a simple dual-encoder framework.

Original Paper link

Setup

Install transformers, faiss-cpu via pip install -q transformers faiss-cpu`.

import os
import json
import faiss
import random
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tqdm import tqdm_notebook
from dataclasses import dataclass
from tensorflow.keras import layers
from transformers import AutoTokenizer
from transformers import AutoTokenizer, TFAutoModel

Download Dataset

First we need to download train and test dataset from public data source.

wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-adv-hn-train.json.gz
gunzip biencoder-nq-adv-hn-train.json.gz
wget https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz
gunzip biencoder-nq-dev.json.gz

Data & Model Configuration Setup

Let’s configure data and model requirements. Given we have positive passages and hard negative passages for each query. So we need to select no of positives and hard negatives need to build good representation.

For Model requirements, we need to specify things like seq_len for passage and query, pre-trained base models, learning rate, epochs, etc.

# Configure dataset
@dataclass
class DataConfig:
    num_positives = 1  # No. of positive
    num_hard_negatives = 1  # No of hard negatives


data_config = DataConfig()

# Configure models
@dataclass
class ModelConfig:
    passage_max_seq_len = 156
    query_max_seq_len = 64
    batch_size_per_replica = 128
    epochs = 40
    learning_rate = 2e-5
    num_warmup_steps = 1234
    dropout = 0.1
    model_name = "google/bert_uncased_L-4_H-512_A-8"


model_config = ModelConfig()

Load and Preprocess Dataset

Now, load training data and return dict which has a query and set of passages (positive n hard negatives). For the simplicity of target labels, in the passage list, place the first positive passage followed by hard negative passages.

training-data

Then, tokenize and encode passages and queries. After encoding passage-queries dict keys shape will be like, for query: (no_of_sample, query_seq_len) and for passage: (no_of_sample, num_of_passages, seq_len)

def read_dpr_json(
    file,
    max_samples=None,
    num_hard_negatives=1,
    num_positives=1,
    shuffle_negatives=True,
    shuffle_positives=False,
):
    """Read Json file and reture list of dicts"""

    dicts = json.load(open(file, encoding="utf-8"))

    # Query key options
    query_json_keys = ["question", "questions", "query"]

    # Positive key options
    positive_context_json_keys = [
        "positive_contexts",
        "positive_ctxs",
        "positive_context",
        "positive_ctx",
    ]

    # Hard Negative key options
    hard_negative_json_keys = [
        "hard_negative_contexts",
        "hard_negative_ctxs",
        "hard_negative_context",
        "hard_negative_ctx",
    ]
    standard_dicts = []
    for i in tqdm_notebook(range(len(dicts))):
        dict = dicts[i]
        sample = {}
        positive_passages = []
        negative_passages = []
        for key, val in dict.items():
            if key in query_json_keys:
                sample["query"] = val
            elif key in positive_context_json_keys:
                if shuffle_positives:
                    random.shuffle(val)
                for passage in val[:num_positives]:
                    positive_passages.append(
                        {
                            "title": passage["title"],
                            "text": passage["text"],
                            "label": "positive",
                        }
                    )
            elif key in hard_negative_json_keys:
                if shuffle_negatives:
                    random.shuffle(val)
                for passage in val[:num_hard_negatives]:
                    negative_passages.append(
                        {
                            "title": passage["title"],
                            "text": passage["text"],
                            "label": "hard_negative",
                        }
                    )
        # Place Positive passage first and then negative passages
        # This will be used to make in-batch labels for loss calculation.
        sample["passages"] = positive_passages + negative_passages
        if len(sample["passages"]) == num_positives + num_hard_negatives:
            standard_dicts.append(sample)
        if max_samples:
            if len(standard_dicts) == max_samples:
                break
    return standard_dicts


# Read training json file
dicts = read_dpr_json(
    "biencoder-nq-adv-hn-train.json", max_samples=6400, num_hard_negatives=1
)


def encode_query_passage(tokenizer, dicts, model_config, data_config):
    """Encode Text i.e. queries and passages into token_ids"""

    passage_input_ids = []
    passage_token_type_ids = []
    passage_attention_mask = []

    queries = []
    for i in tqdm_notebook(range(len(dicts))):
        di = dicts[i]
        di_query = di["query"]
        di_passages = di["passages"]
        di_positives = [
            (pi["title"], pi["text"]) for pi in di_passages if pi["label"] == "positive"
        ]
        di_negatives = [
            (ni["title"], ni["text"])
            for ni in di_passages
            if ni["label"] == "hard_negative"
        ]

        if data_config.num_positives == len(
            di_positives
        ) and data_config.num_hard_negatives == len(di_negatives):

            queries.append(di_query)
            i_passages = di_positives + di_negatives
            i_passage_inputs = tokenizer.batch_encode_plus(
                i_passages,
                max_length=model_config.passage_max_seq_len,
                add_special_tokens=True,
                truncation=True,
                truncation_strategy="longest_first",
                padding="max_length",
                return_token_type_ids=True,
            )
            passage_input_ids.append(np.array(i_passage_inputs["input_ids"]))
            passage_token_type_ids.append(np.array(i_passage_inputs["token_type_ids"]))
            passage_attention_mask.append(np.array(i_passage_inputs["attention_mask"]))

    query_inputs = tokenizer.batch_encode_plus(
        queries,
        max_length=model_config.query_max_seq_len,
        add_special_tokens=True,
        truncation=True,
        truncation_strategy="longest_first",
        padding="max_length",
        return_token_type_ids=True,
        return_tensors="np",
    )

    return {
        "query_input_ids": query_inputs["input_ids"],
        "query_token_type_ids": query_inputs["token_type_ids"],
        "query_attention_mask": query_inputs["attention_mask"],
        "passage_input_ids": np.array(passage_input_ids),
        "passage_token_type_ids": np.array(passage_token_type_ids),
        "passage_attention_mask": np.array(passage_attention_mask),
    }


# Load Pretrained tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name)
# Encoder queries and passages
X = encode_query_passage(tokenizer, dicts, model_config, data_config)

Model Preparation

Now prepare Bi-Encoder model using pre-trained base models i.e. bert-base-uncased. We can select different models for passage and query models, but here we are using the same model base for both passage n query models.

Bi-Model-Architechture
class QueryModel(tf.keras.Model):
    """Query Model"""

    def __init__(self, model_config, **kwargs):
        super().__init__(**kwargs)
        # Load Pretrained models
        self.query_encoder = TFAutoModel.from_pretrained(model_config.model_name)
        # Add dropout layer
        self.dropout = layers.Dropout(model_config.dropout)

    def call(self, inputs, training=False, **kwargs):

        pooled_output = self.query_encoder(inputs, training=training, **kwargs)[1]
        pooled_output = self.dropout(pooled_output, training=training)
        return pooled_output


class PassageModel(tf.keras.Model):
    """Passage Model"""

    def __init__(self, model_config, **kwargs):
        super().__init__(**kwargs)
        # Load Pretrained models
        self.passage_encoder = TFAutoModel.from_pretrained(model_config.model_name)
        # Add dropout layer
        self.dropout = layers.Dropout(model_config.dropout)

    def call(self, inputs, training=False, **kwargs):

        pooled_output = self.passage_encoder(inputs, training=training, **kwargs)[1]
        pooled_output = self.dropout(pooled_output, training=training)
        return pooled_output


def cross_replica_concat(values):
    """Get concat values from all replica"""

    context = tf.distribute.get_replica_context()
    gathered = context.all_gather(values, axis=0)

    return tf.roll(
        gathered,
        -context.replica_id_in_sync_group * values.shape[0],
        axis=0,
    )


class BiEncoderModel(tf.keras.Model):
    """Bi-Encoder Query & Passage Model"""

    def __init__(
        self,
        query_encoder,
        passage_encoder,
        num_passages_per_question,
        model_config,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)

        # Query encoder model
        self.query_encoder = query_encoder
        # Passage encoder model
        self.passage_encoder = passage_encoder
        # No. positives plus No. of hard negatives
        self.num_passages_per_question = num_passages_per_question
        # Model configuration
        self.model_config = model_config

        # Loss tracker
        self.loss_tracker = keras.metrics.Mean(name="loss")
        # Define loss
        self.loss_fn = keras.losses.SparseCategoricalCrossentropy(
            reduction=keras.losses.Reduction.NONE, from_logits=True
        )

    def calculate_loss(self, logits):
        """Function to calculate in batch loss"""

        # Get no of queries from global batch size
        num_queries = tf.shape(logits)[0]
        # Get no of passages from global batch size
        num_candidates = tf.shape(logits)[1]

        # Make In-Batch Labels:
        # Given single quetion positives are placed first followed by negatives.
        labels = tf.convert_to_tensor(
            [
                i
                for i in range(
                    0,
                    (GLOBAL_BATCH_SIZE * self.num_passages_per_question),
                    self.num_passages_per_question,
                )
            ]
        )

        loss = self.loss_fn(labels, logits)
        scale_loss = tf.reduce_sum(loss) * (1.0 / GLOBAL_BATCH_SIZE)
        return scale_loss

    def passage_forward(self, X):

        # Reshape input (BS, num_passages_per_question, seq_len) -> (BS*num_passages_per_question, seq_len)
        input_shape = (
            self.model_config.batch_size_per_replica * self.num_passages_per_question,
            self.model_config.passage_max_seq_len,
        )
        input_ids = tf.reshape(X["passage_input_ids"], input_shape)
        attention_mask = tf.reshape(X["passage_attention_mask"], input_shape)
        token_type_ids = tf.reshape(X["passage_token_type_ids"], input_shape)
        # Call passage encoder model
        outputs = self.passage_encoder(
            [input_ids, attention_mask, token_type_ids], training=True
        )
        return outputs

    def query_forward(self, X):
        # Reshape input (BS, seq_len) -> (BS, seq_len)
        input_shape = (
            self.model_config.batch_size_per_replica,
            self.model_config.query_max_seq_len,
        )
        input_ids = tf.reshape(X["query_input_ids"], input_shape)
        attention_mask = tf.reshape(X["query_attention_mask"], input_shape)
        token_type_ids = tf.reshape(X["query_token_type_ids"], input_shape)
        outputs = self.query_encoder(
            [input_ids, attention_mask, token_type_ids], training=True
        )
        return outputs

    def train_step(self, X):

        with tf.GradientTape() as tape:
            # Call encoder models
            passage_embeddings = self.passage_forward(X)
            query_embeddings = self.query_forward(X)

            # Get all replica concat values for In-Batch loss calculation
            global_passage_embeddings = cross_replica_concat(passage_embeddings, 32)
            global_query_embeddings = cross_replica_concat(query_embeddings, 16)

            # Dot product similarity
            similarity_scores = tf.linalg.matmul(
                global_query_embeddings, global_passage_embeddings, transpose_b=True
            )

            loss = self.calculate_loss(similarity_scores)
            loss = loss / strategy.num_replicas_in_sync

        # Backward pass
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # Monitor loss
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    @property
    def metrics(self):
        return [self.loss_tracker]

Model Building and Training

Model training will be the interesting part of the whole DPR model because in this bi-encoder model training we use the In-Batch softmax loss function. What is In-Batch softmax loss?

In general, given distributed env training, each pod or node has a model copy with them. The global batch size data equally split into no of pods or nodes and each pod or node call forward pass and calculates loss separately and after that, it just aggregates each loss by some reduction methods (i.e. mean or sum), and that will be the final loss of one global-batch, this goes to each copy of models on the pod or node and the gradient update happens.

But in the In-Batch Loss method, loss calculated by concatenation values of each pod or node final output logits. We can see this in the below image.

Bi-Model-Training-TPU
BATCH_SIZE_PER_REPLICA = model_config.batch_size_per_replica
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
N_EPOCHS = model_config.epochs
one_epoch_steps = int(len(dicts) / GLOBAL_BATCH_SIZE)
num_train_steps = one_epoch_steps * N_EPOCHS
num_warmup_steps = num_train_steps // 10

# Define model under strategy scope
with strategy.scope():
    query_encoder = QueryModel(model_config)
    passage_encoder = PassageModel(model_config)
    bi_model = BiEncoderModel(
        query_encoder,
        passage_encoder,
        num_passages_per_question=data_config.num_positives
        + data_config.num_hard_negatives,
        model_config=model_config,
    )
    optimizer, lr_schedule = create_optimizer(
        init_lr=model_config.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
    )
    bi_model.compile(optimizer=optimizer)

with strategy.scope():
    train_ds = (
        tf.data.Dataset.from_tensor_slices(X)
        .shuffle(GLOBAL_BATCH_SIZE * 10)
        .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        .batch(GLOBAL_BATCH_SIZE, drop_remainder=True)
    )

# Train on TPU
bi_model.fit(train_ds, epochs=N_EPOCHS)

Model Evaluation

Now after training it’s time to evaluate the model on the dev dataset. We first preprocess the data into queries, passages, and answer indexes.

Queries will pass into query encoder to get query embeddings and Passages will pass into passage encoder to get passage embeddings.

Passage embedding will then stored into faiss for similarity matching.

For each query, embedding will pass into faiss.search to get the top-k index and then we will calculate the top-k acc.

# Read dev json for evaluation
eval_dicts = read_dpr_json(
    "biencoder-nq-dev.json", num_hard_negatives=30, shuffle_negatives=False
)


def combine_title_context(titles, texts):
    res = []
    for title, ctx in zip(titles, texts):
        if title is None:
            title = ""
        res.append(tuple((title, ctx)))
    return res


def process_single_example(passages):
    answer_index = -1
    titles = []
    texts = []
    for i in range(len(passages)):
        p = passages[i]
        titles.append(p["title"])
        texts.append(p["text"])
        if p["label"] == "positive":
            answer_index = i

    res = combine_title_context(titles, texts)

    return res, answer_index


def process_examples(dicts):
    processed_passages = []
    queries = []
    answer_indexes = []
    global_answer_index = 0

    for i in range(len(dicts)):
        dict_ = dicts[i]
        query = dict_["query"]
        queries.append(query)

        passages = dict_["passages"]
        res, answer_index = process_single_example(passages)

        i_answer_index = global_answer_index + answer_index

        processed_passages.extend(res)
        answer_indexes.append(i_answer_index)

        global_answer_index = global_answer_index + len(passages)
    return queries, answer_indexes, processed_passages


# Process examples for evaluation
queries, answer_indexes, processed_passages = process_examples(eval_dicts)
print(len(processed_passages), len(queries))


def extracted_passage_embeddings(processed_passages, model_config):
    """Extract Passage Embeddings"""
    passage_inputs = tokenizer.batch_encode_plus(
        processed_passages,
        add_special_tokens=True,
        truncation=True,
        padding="max_length",
        max_length=model_config.passage_max_seq_len,
        return_token_type_ids=True,
    )
    passage_embeddings = passage_encoder.predict(
        [
            np.array(passage_inputs["input_ids"]),
            np.array(passage_inputs["attention_mask"]),
            np.array(passage_inputs["token_type_ids"]),
        ],
        batch_size=512,
        verbose=1,
    )
    return passage_embeddings


passage_embeddings = extracted_passage_embeddings(processed_passages, model_config)


def extracted_query_embeddings(queries, model_config):
    """Extract Query Embeddings"""
    query_inputs = tokenizer.batch_encode_plus(
        queries,
        add_special_tokens=True,
        truncation=True,
        padding="max_length",
        max_length=model_config.query_max_seq_len,
        return_token_type_ids=True,
    )
    query_embeddings = query_encoder.predict(
        [
            np.array(query_inputs["input_ids"]),
            np.array(query_inputs["attention_mask"]),
            np.array(query_inputs["token_type_ids"]),
        ],
        batch_size=512,
        verbose=1,
    )
    return query_embeddings


query_embeddings = extracted_query_embeddings(queries, model_config)


# Load into Faiss
faiss_index = faiss.IndexFlatL2(768)
faiss_index.add(passage_embeddings)


def get_k_accuracy(faiss_index, query_embeddings, answer_indexes, k):

    prob, index = faiss_index.search(query_embeddings, k=k)

    corrects = []
    for i in tqdm_notebook(range(len(answer_indexes))):
        i_index = index[i]
        i_count = len(np.where(i_index == answer_indexes[i])[0])
        if i_count > 0:
            corrects.append((i, answer_indexes[i]))
    return corrects


# Calculate Top-k Acc.
top10_corrects = get_k_accuracy(faiss_index, query_embeddings, answer_indexes, k=10)
top20_corrects = get_k_accuracy(faiss_index, query_embeddings, answer_indexes, k=20)
top50_corrects = get_k_accuracy(faiss_index, query_embeddings, answer_indexes, k=50)
top100_corrects = get_k_accuracy(faiss_index, query_embeddings, answer_indexes, k=100)
top1000_corrects = get_k_accuracy(faiss_index, query_embeddings, answer_indexes, k=1000)


results = pd.DataFrame(
    {
        "topK": [10, 20, 50, 100, 1000],
        "total": [len(query_embeddings)] * 5,
        "correct_total": [
            len(top10_corrects),
            len(top20_corrects),
            len(top50_corrects),
            len(top100_corrects),
            len(top1000_corrects),
        ],
    }
)

# Show results
results["accuracy"] = (results["correct_total"] / results["total"]) * 100
print(results)

References

DPR Paper: https://arxiv.org/pdf/2004.04906.pdf
Blog: https://ankur3107.github.io/blogs/dense-passage-retriever/

Cited as

@article{kumar2021dprtpu,
title   = "The Illustrated Dense Passage Retreiver on TPU",
author  = "Kumar, Ankur",
journal = "ankur3107.github.io",
year    = "2021",
url     = "https://ankur3107.github.io/blogs/dense-passage-retriever/"
}

Comments