Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support visual dialogue dataset #1678

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .refcoco import RefCOCO
from .scienceqa import ScienceQA
from .textvqa import TextVQA
from .visdial import VisDial
from .visual_genome import VisualGenomeQA
from .vizwiz import VizWiz
from .vsr import VSR
Expand All @@ -54,5 +55,5 @@
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
'VSR', 'VizWiz', 'OCRVQA'
'VSR', 'VizWiz', 'OCRVQA', 'VisDial'
])
96 changes: 96 additions & 0 deletions mmpretrain/datasets/visdial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class VisDial(BaseDataset):
"""VisDial dataset.

Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)['data']

dialogs = annotations['dialogs']
answers = annotations['answers']
questions = annotations['questions']

data_list = []

for dialog in dialogs:
image_id = dialog['image_id']
caption = dialog['caption']

historys = ['Caption:' + caption + '.']

for i in range(1, len(dialog['dialog'])):
historys.append('')

previous_idx = i - 1
# for j in range(i):
question_id = dialog['dialog'][previous_idx]['question']
answer_id = dialog['dialog'][previous_idx]['answer']

history = ' Question:{question}? Answer:{answer}.' \
.format(question=questions[question_id],
answer=answers[answer_id])

historys[i] = historys[previous_idx] + history

# get question and answer options for each dialog round
for dialog_id, dialog_round in enumerate(dialog['dialog']):
question_id = dialog_round['question']
answer_id = dialog_round['answer']
answer_options = [
answers[answer_id]
for answer_id in dialog_round['answer_options']
]

data_info = dict(image_id=image_id)

img_prefix = self.data_prefix['img_path']
file_backend = get_file_backend(img_prefix)

data_info['img_path'] = file_backend.join_path(
img_prefix,
img_prefix.split('/')[-1] + '_' + str(image_id).zfill(12) +
'.jpg')

data_info['dialog_history'] = historys[dialog_id]

data_info['question'] = questions[question_id] + '?'
data_info['answer'] = answers[answer_id]
data_info['answer_options'] = answer_options
data_info['gt_answer_index'] = data_info[
'answer_options'].index(data_info['answer'])

data_list.append(data_info)

return data_list
3 changes: 2 additions & 1 deletion mmpretrain/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .retrieval import RetrievalAveragePrecision, RetrievalRecall
from .scienceqa import ScienceQAMetric
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
from .visual_dialog import SparseGTMetrics
from .visual_grounding_eval import VisualGroundingMetric
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
from .vqa import ReportVQA, VQAAcc
Expand All @@ -16,5 +17,5 @@
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
'RetrievalAveragePrecision'
'RetrievalAveragePrecision', 'SparseGTMetrics'
]
92 changes: 92 additions & 0 deletions mmpretrain/evaluation/metrics/visual_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch
from mmengine.evaluator import BaseMetric

from mmpretrain.evaluation.metrics.vqa import (_process_digit_article,
_process_punctuation)
from mmpretrain.registry import METRICS


@METRICS.register_module()
class SparseGTMetrics(BaseMetric):
"""Visual Dialog Acc metric.

Compute Visual Dialogaccuracy.

Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix = 'Visual Dialog'

def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)

def process(self, data_batch, data_samples) -> None:
"""Process one batch of data samples.

The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.

Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
answer_options = sample.get('answer_options')

G = torch.Generator()
G.manual_seed(0)
rank = 1 + torch.randperm(len(answer_options), generator=G)

pred_answer = sample.get('pred_answer')

if pred_answer in answer_options:
answer_index = answer_options.index(pred_answer)
rank[answer_index] = 1

gt_index = sample.get('gt_answer_index')
gt_rank = rank[gt_index]

self.results.append(gt_rank)

def compute_metrics(self, results: List) -> dict:
"""Compute the metrics from processed results.

Args:
results (dict): The processed results of each batch.

Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

R1 = (torch.tensor(results) <= 1).float().mean()
R5 = (torch.tensor(results) <= 5).float().mean()
R10 = (torch.tensor(results) <= 10).float().mean()
Mean = torch.tensor(results).float().mean()
MRR = torch.tensor(results).reciprocal().mean()

metrics = {
'R@1': R1.item(),
'R@5': R5.item(),
'R@10': R10.item(),
'Mean': Mean.item(),
'MRR': MRR.item()
}
return metrics

def _process_answer(self, answer) -> str:
answer = _process_punctuation(answer)
answer = _process_digit_article(answer)
return answer