Image captioning is the process of generating caption i.e. description from input image. It requires both Natural language processing as well as computer vision to generate the caption.

The popular benchmarking dataset which has images and its caption are:

  • Common Objects in Context (COCO). A collection of more than 120 thousand images with descriptions.
  • Flickr 8K: A collection of 8 thousand described images taken from flickr.com.
  • Flickr 30K: A collection of 30 thousand described images taken from flickr.com.

Try trained model: https://huggingface.co/nlpconnect/vit-gpt2-image-captioning


Vision Encoder Decoder Architecture


The Vision Encoder Decoder Model can be used to initialize an image-to-text model with any pre-trained Transformer-based vision model as the encoder (e.g. ViT, BEiT, DeiT, Swin) and any pre-trained language model as the decoder (e.g. RoBERTa, GPT2, BERT, DistilBERT).

Image captioning is an example, in which the encoder model is used to encode the image, after which an autoregressive language model i.e. the decoder model generates the caption.

import os
import datasets
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor,AutoTokenizer
os.environ["WANDB_DISABLED"] = "true"
import nltk
except (LookupError, OSError):
    nltk.download("punkt", quiet=True)

Initialize VisionEncoderDecoderModel

from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor

image_encoder_model = "google/vit-base-patch16-224-in21k"
text_decode_model = "gpt2"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    image_encoder_model, text_decode_model)
  • FeatureExtractor is used to extract features i.e. image patch resolution of 16x16.

  • Tokenizer is used to tokenize and encode text features.

# image feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model)
# text tokenizer
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token

# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
output_dir = "vit-gpt-model"

Data Loading and Preparation

We are going to use sample dataset from ydshieh/coco_dataset_script.

For using Full COCO dataset (2017), you need to download it manually first:

wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/zips/test2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
wget http://images.cocodataset.org/annotations/image_info_test2017.zip Then to load the dataset:

COCO_DIR = ...(path to the downloaded dataset directory)...
ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir=COCO_DIR)
import datasets
ds = datasets.load_dataset("ydshieh/coco_dataset_script", "2017", data_dir="./dummy_data/")
    train: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 80
    validation: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 80
    test: Dataset({
        features: ['image_id', 'caption_id', 'caption', 'height', 'width', 'file_name', 'coco_url', 'image_path'],
        num_rows: 16
# print single example
{'image_id': 74,
 'caption_id': 145996,
 'caption': 'A picture of a dog laying on the ground.',
 'height': 426,
 'width': 640,
 'file_name': '000000000074.jpg',
 'coco_url': 'http://images.cocodataset.org/train2017/000000000074.jpg',
 'image_path': '/.cache/huggingface/datasets/downloads/extracted/f1122be5b6fbdb4a45c67365345f5639d2e11a25094285db1348c3872189a0f6/train2017/000000000074.jpg'}
from PIL import Image

# text preprocessing step
def tokenization_fn(captions, max_target_length):
    """Run tokenization on captions."""
    labels = tokenizer(captions, 

    return labels

# image preprocessing step
def feature_extraction_fn(image_paths, check_image=True):
    Run feature extraction on images
    If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
    Otherwise, an exception will be thrown.

    model_inputs = {}

    if check_image:
        images = []
        to_keep = []
        for image_file in image_paths:
                img = Image.open(image_file)
            except Exception:
        images = [Image.open(image_file) for image_file in image_paths]

    encoder_inputs = feature_extractor(images=images, return_tensors="np")

    return encoder_inputs.pixel_values

def preprocess_fn(examples, max_target_length, check_image = True):
    """Run tokenization + image feature extraction"""
    image_paths = examples['image_path']
    captions = examples['caption']    
    model_inputs = {}
    # This contains image path column
    model_inputs['labels'] = tokenization_fn(captions, max_target_length)
    model_inputs['pixel_values'] = feature_extraction_fn(image_paths, check_image=check_image)

    return model_inputs
processed_dataset = ds.map(
    fn_kwargs={"max_target_length": 128},
    train: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 80
    validation: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 80
    test: Dataset({
        features: ['labels', 'pixel_values'],
        num_rows: 16

Define seq2seq training arguments

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).

Define metric

import evaluate
metric = evaluate.load("rouge")
import numpy as np

ignore_pad_token_for_loss = True

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,

    result = metric.compute(predictions=decoded_preds,
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    result["gen_len"] = np.mean(prediction_lens)
    return result


from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
***** Running training *****
  Num examples = 80
  Num Epochs = 3
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 60
[60/60 05:39, Epoch 3/3]
Epoch Training Loss Validation Loss Rouge1 Rouge2 Rougel Rougelsum Gen Len
1 No log 0.341840 18.242900 0.548200 18.205800 18.229100 7.000000
2 No log 0.333408 17.483000 2.052800 16.658800 16.677000 15.750000
3 No log 0.333202 19.999400 3.316300 18.790000 18.768300 13.500000
***** Running Evaluation *****
  Num examples = 80
  Batch size = 4
***** Running Evaluation *****
  Num examples = 80
  Batch size = 4
***** Running Evaluation *****
  Num examples = 80
  Batch size = 4

Training completed. Do not forget to share your model on huggingface.co/models =)

TrainOutput(global_step=60, training_loss=0.6188589096069336, metrics={'train_runtime': 344.1556, 'train_samples_per_second': 0.697, 'train_steps_per_second': 0.174, 'total_flos': 4.331133386883072e+16, 'train_loss': 0.6188589096069336, 'epoch': 3.0})
Saving model checkpoint to ./image-captioning-output
Configuration saved in ./image-captioning-output/config.json
Model weights saved in ./image-captioning-output/pytorch_model.bin
Feature extractor saved in ./image-captioning-output/preprocessor_config.json
tokenizer config file saved in ./image-captioning-output/tokenizer_config.json
Special tokens file saved in ./image-captioning-output/special_tokens_map.json



from transformers import pipeline
# full dataset trained model can be found at https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
image_captioner = pipeline("image-to-text", model="./image-captioning-output")
loading weights file ./image-captioning-output/pytorch_model.bin
All model checkpoint weights were used when initializing VisionEncoderDecoderModel.


