import io
import json
import logging
import os
import pickle
import tempfile
from collections import Counter, OrderedDict, defaultdict
import matplotlib.pyplot as plt
import nltk
import numpy as np
import PIL
import torch
import torch.nn as nn
from nltk.tokenize import sent_tokenize, word_tokenize
from pythae.data.datasets import DatasetOutput
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.datasets.utils import download_and_extract_archive
from multivae.data.datasets.base import DatasetOutput, MultimodalBaseDataset
logger = logging.getLogger(__name__)
# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
class OrderedCounter(Counter, OrderedDict):
"""Counter that remembers the order elements are first encountered."""
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, OrderedDict(self))
def __reduce__(self):
return self.__class__, (OrderedDict(self),)
class CUBSentences(Dataset): # pragma: no cover
"""
Dataset for the CUB captions only.
Args:
- root_data_dir (str): The path where to find the data.
- split (str): The split to use. Either 'trainval' or 'test'.
- output_type (str): The output type of the text. Either 'one_hot','tokens', 'text'. If text, the output is the list of words.
- transform (torchvision.transforms): The transformations to apply to the text data. Default: None.
"""
def __init__(
self, root_data_dir, split, output_type="one_hot", transform=None, **kwargs
):
super().__init__()
self.data_dir = os.path.join(root_data_dir, "cub")
self.split = split
self.max_sequence_length = kwargs.get("max_sequence_length", 32)
self.min_occ = kwargs.get("min_occ", 3)
self.transform = transform
self.output_type = output_type
self.gen_dir = os.path.join(
self.data_dir, "oc:{}_msl:{}".format(self.min_occ, self.max_sequence_length)
)
if split == "train":
self.raw_data_path = os.path.join(self.data_dir, "text_trainvalclasses.txt")
elif split == "test":
self.raw_data_path = os.path.join(self.data_dir, "text_testclasses.txt")
else:
raise Exception("Only train or test split is available")
os.makedirs(self.gen_dir, exist_ok=True)
self.data_file = "cub.{}.s{}".format(split, self.max_sequence_length)
self.vocab_file = "cub.vocab"
if not os.path.exists(os.path.join(self.gen_dir, self.data_file)):
print(
"Data file not found for {} split at {}. Creating new... (this may take a while)".format(
split.upper(), os.path.join(self.gen_dir, self.data_file)
)
)
self._create_data()
else:
self._load_data()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sent = torch.LongTensor(self.data[str(idx)]["idx"])
length = self.data[str(idx)]["length"]
padding_mask = torch.FloatTensor(
[1.0] * length + [0.0] * (self.max_sequence_length - length)
)
if self.output_type == "one_hot":
sent = nn.functional.one_hot(
torch.Tensor(sent).long(), self.vocab_size
).float()
if self.transform is not None:
sent = self.transform(sent)
return dict(one_hot=sent, padding_mask=padding_mask)
elif self.output_type == "tokens":
if self.transform is not None:
sent = self.transform(sent)
return dict(tokens=sent, padding_mask=padding_mask)
else:
raise AttributeError(
"The output type should be either 'one_hot' or 'tokens' but is neither."
)
@property
def vocab_size(self):
return len(self.w2i)
@property
def pad_idx(self):
return self.w2i["<pad>"]
@property
def eos_idx(self):
return self.w2i["<eos>"]
@property
def unk_idx(self):
return self.w2i["<unk>"]
def get_w2i(self):
return self.w2i
def get_i2w(self):
return self.i2w
def _load_data(self, vocab=True):
try:
with open(os.path.join(self.gen_dir, self.data_file), "rb") as file:
self.data = json.load(file)
except TypeError:
with open(os.path.join(self.gen_dir, self.data_file), "r") as file:
self.data = json.load(file)
if vocab:
self._load_vocab()
def _load_vocab(self):
if not os.path.exists(os.path.join(self.gen_dir, self.vocab_file)):
self._create_vocab()
with open(os.path.join(self.gen_dir, self.vocab_file), "r") as vocab_file:
vocab = json.load(vocab_file)
self.w2i, self.i2w = vocab["w2i"], vocab["i2w"]
def _create_data(self):
if self.split == "train" and not os.path.exists(
os.path.join(self.gen_dir, self.vocab_file)
):
self._create_vocab()
else:
self._load_vocab()
with open(self.raw_data_path, "r") as file:
text = file.read()
sentences = sent_tokenize(text)
data = defaultdict(dict)
pad_count = 0
for i, line in enumerate(sentences):
words = word_tokenize(line)
tok = words[: self.max_sequence_length - 1]
tok = tok + ["<eos>"]
length = len(tok)
if self.max_sequence_length > length:
tok.extend(["<pad>"] * (self.max_sequence_length - length))
pad_count += 1
else:
length = self.max_sequence_length
idx = [self.w2i.get(w, self.w2i["<exc>"]) for w in tok]
id = len(data)
data[id]["tok"] = tok
data[id]["idx"] = idx
data[id]["length"] = length
print(
"{} out of {} sentences are truncated with max sentence length {}.".format(
len(sentences) - pad_count, len(sentences), self.max_sequence_length
)
)
with io.open(os.path.join(self.gen_dir, self.data_file), "wb") as data_file:
data = json.dumps(data, ensure_ascii=False)
data_file.write(data.encode("utf8", "replace"))
self._load_data(vocab=False)
def _create_vocab(self):
import nltk
nltk.download("punkt_tab")
nltk.download("punkt")
assert (
self.split == "train"
), "Vocablurary can only be created for training file."
with open(self.raw_data_path, "r") as file:
text = file.read()
sentences = sent_tokenize(text)
occ_register = OrderedCounter()
w2i = dict()
i2w = dict()
special_tokens = ["<exc>", "<pad>", "<eos>"]
for st in special_tokens:
i2w[len(w2i)] = st
w2i[st] = len(w2i)
texts = []
unq_words = []
for i, line in enumerate(sentences):
words = word_tokenize(line)
occ_register.update(words)
texts.append(words)
for w, occ in occ_register.items():
if occ > self.min_occ and w not in special_tokens:
i2w[len(w2i)] = w
w2i[w] = len(w2i)
else:
unq_words.append(w)
assert len(w2i) == len(i2w)
print(
"Vocablurary of {} keys created, {} words are excluded (occurrence <= {}).".format(
len(w2i), len(unq_words), self.min_occ
)
)
vocab = dict(w2i=w2i, i2w=i2w)
with io.open(os.path.join(self.gen_dir, self.vocab_file), "wb") as vocab_file:
data = json.dumps(vocab, ensure_ascii=False)
vocab_file.write(data.encode("utf8", "replace"))
with open(os.path.join(self.gen_dir, "cub.unique"), "wb") as unq_file:
pickle.dump(np.array(unq_words), unq_file)
with open(os.path.join(self.gen_dir, "cub.all"), "wb") as a_file:
pickle.dump(occ_register, a_file)
self._load_vocab()
def one_hot_to_string(self, data):
ret_list = [self._to_string(i) for i in data]
return ret_list
def _to_string(self, matrix):
words = []
for i in range(matrix.shape[0]):
idx = np.argmax(matrix[i, :])
words.append(self.i2w[str(idx)])
ret_str = " ".join(words)
return ret_str
[docs]
class CUB(MultimodalBaseDataset): # pragma: no cover
"""
A paired text image CUB dataset.
Args:
path (str) : The path where the data is saved.
split (str) : Either 'train', 'eval' or 'test'. Default: 'train'.
max_words_in_caption (int): The number of words in the captions. Default: 32.
im_size (Tuple[int]): The desired size of the images. Default: (64, 64)
img_transform (Transforms): The transformations to be applied to the images. If
None, nothing is done. Default: None.
output_type (Literal['one_hot','tokens']) : Default to 'one_hot'.
download (bool): Whether to download the data if it is not found in the path. Default: False.
"""
def __init__(
self,
path: str,
split="train",
max_words_in_caption=32,
im_size=(64, 64),
img_transform=None,
output_type="one_hot",
download=False,
):
self.split = split
self.check_or_download_data(path, download)
self.max_words_in_caption = max_words_in_caption
transform_text = lambda data: torch.Tensor(data)
transform_img = [transforms.Resize([*im_size]), transforms.ToTensor()]
if img_transform is not None:
transform_img.append(img_transform)
transform_img = transforms.Compose(transform_img)
if split == "eval":
self.text_data = CUBSentences(
path,
"train",
output_type=output_type,
transform=transform_text,
max_sequence_length=max_words_in_caption,
)
self.image_data = datasets.ImageFolder(
os.path.join(path, "cub", "train"), transform=transform_img
)
else:
self.text_data = CUBSentences(
path,
split,
output_type=output_type,
transform=transform_text,
max_sequence_length=max_words_in_caption,
)
self.image_data = datasets.ImageFolder(
os.path.join(path, "cub", split), transform=transform_img
)
# Split the training data into train and validation
if self.split == "train" or "eval":
self.train_idx, self.val_idx = train_test_split(
np.arange(len(self.text_data)),
test_size=0.1,
random_state=0,
shuffle=True,
)
self.vocab_size = self.text_data.vocab_size
def check_or_download_data(self, data_path, download):
if not os.path.exists(os.path.join(data_path, "cub")):
if download:
os.makedirs(os.path.join(data_path, "cub"))
tempdir = tempfile.mkdtemp()
logger.info(f"Downloading the CUB dataset into {data_path}")
download_and_extract_archive(
url="http://www.robots.ox.ac.uk/~yshi/mmdgm/datasets/cub.zip",
download_root=tempdir,
extract_root=data_path,
)
else:
raise AttributeError(
"The CUB dataset is not available at the"
" given datapath and download is set to False."
"Set download to True or place the dataset"
" in the data_path folder."
)
def __getitem__(self, index):
if self.split == "train":
index = self.train_idx[index]
elif self.split == "eval":
index = self.val_idx[index]
image = self.image_data[index // 10][0]
text = self.text_data[index]
return DatasetOutput(data=dict(image=image, text=text))
def __len__(self):
if self.split == "train":
return len(self.train_idx)
if self.split == "eval":
return len(self.val_idx)
else:
return len(self.text_data)
def plot_text(self, input_tensor, fig_size=(2, 1.5)):
# input_tensor is of shape (max_sequence_lenght,vocab_size)
device = input_tensor.device
array = input_tensor.detach().cpu().numpy()
sentence = self.text_data._to_string(array)
fig = plt.figure(figsize=fig_size)
plt.text(
x=0.5,
y=0.5,
s="{}".format(
" ".join(
i + "\n" if (n + 1) % 3 == 0 else i
for n, i in enumerate(
[word for word in sentence.split() if word != "<eos>"]
)
)
),
fontsize=7,
verticalalignment="center_baseline",
horizontalalignment="center",
)
plt.axis("off")
fig.tight_layout()
# Draw the canvas and retrieve the image as a NumPy array
fig.canvas.draw()
img_buf = io.BytesIO()
plt.savefig(img_buf, format="png")
image = PIL.Image.open(img_buf)
image = np.array(image).transpose(2, 0, 1) / 255
plt.close(fig=fig)
return torch.from_numpy(image).float().to(device)