Return to blog

How To Fine-Tune FLAN-T5 For Question Answering

Greg Bizup
Aug 24, 2023

quora question answer dataset

Background on Flan-T5

Flan-T5 is a powerful text-to-text language model. It excels in a range of tasks including summarization, translation, and question answering. Google first introduced it in the paper Scaling Instruction-Finetuned Language Models.

Google created Flan-T5 by training the original T5 architecture on a wide variety of tasks. This finetuning method is called instruction finetuning, and has been shown to improve model performance significantly on previously unseen tasks.

Flan-T5 is free to use and relatively lightweight compared to models such as Llama 2 and GPT-NeoX. The base model can easily be fine-tuned on a free Colab GPU.

Closed book question answering

Today, we will fine tune Flan-T5 for closed-book question answering using question-answer pairs scraped from Quora. Closed book, also known as open-domain question answering refers to the task of answering a question from memory, without being given any background information.

You can find the code on our GitHub, and the model and quora dataset on our HuggingFace repository.

Quora question answer dataset

We will fine tune using our dataset containing 56.4k question/answer pairs scraped from the question board Quora. I collected this data by scraping Quora using a rotating proxy session from Bright Data.

For every question, there is one answer. Here are some sample rows from the dataset:

quora question answer dataset

Fine tune Flan-T5 for closed-book question answering

We will be using the HuggingFace transformers library to fine tune our model. The code presented here is based on the Seq2SeqTrainer documentation.

First, install the necessary libraries by entering the following command in the terminal.

pip install transformers[torch] tokenizers datasets evaluate rouge_score sentencepiece huggingface_hub --upgrade

Start by importing the following Python libraries:

# Import the necessary libraries
import nltk
from datasets import load_dataset
import evaluate
import numpy as np
from transformers import T5Tokenizer, DataCollatorForSeq2Seq
from transformers import T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer

Load the Quora dataset from the hub:

# Load and split the dataset
dataset = load_dataset("toughdata/quora-question-answer-dataset")
dataset = dataset["train"].train_test_split(test_size=0.2)

Next, load Flan-T5-Base from the hub.

# Load the tokenizer, model, and data collator
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

We use the following function to tokenize and preprocess the data. This will add the prefix "answer the question:" to all of the questions, then tokenize them. Then it will tokenize the answers. The "inputs" for training the model will be the tokenized and prefixed questions, and the "labels" will be the answers.

# We prefix our tasks with "answer the question"
prefix = "answer the question: "

# Define our preprocessing function
def preprocess_function(examples):
    """Add prefix to the sentences, tokenize the text, and set the labels"""
    # The "inputs" are the tokenized answer:
    inputs = [prefix + doc for doc in examples["question"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)
    
    # The "labels" are the tokenized outputs:
    labels = tokenizer(text_target=examples["answer"], max_length=512, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Map the preprocessing function across our dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True)

We will use the Rouge score to evaluate the training progress. The Rouge score basically compares the n-grams, or word combinations, in the generated text with the n-grams in the target text.

# Set up Rouge score for evaluation
nltk.download("punkt", quiet=True)
metric = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return result

The final step is to set up the trainer. The parameters here are optimized to maximize batch size and training speed on a GPU with 16GB VRAM. You can adjust them as needed for your hardware.

# Set up training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    push_to_hub=False
)

# Set up trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

Resource requirements to fine tune Flan-T5

You can run this script on a free Colab notebook with 16GB of RAM and 16GB of GPU memory. Larger versions such as Flan-T5-XXL have higher resource requirements, and you will need more computational power to run them.

Trying out the model

We trained the model and pushed it to the hub. You can visit our HuggingFace repo to see the model card, and experiment with the hosted inference API.

Here is a sample from the hosted inference API:

flan-t5 model inference question answer

Not bad. The Inference API cuts the outputs short, but it seems to answer the questions in a conversational manner. We can probably get some sensible responses by downloading it from the hub and increasing the output length. But is it imperative to be sensible on Quora?