Source code for hfai.datasets.clue
from .base import (
BaseDataset,
get_data_dir,
register_dataset,
)
import numpy as np
import torch
import json
from ffrecord import FileReader
[docs]@register_dataset
class CLUEForMLM(BaseDataset):
"""
这是一个用于掩蔽语言模型预训练的CLUE数据集
该数据集包含了三个 CLUE 子数据集,为训练 BERT 模型所定制,它们分别是:news, comment, web。该数据集共有 43926640 条语句。语句中有12%的单词会被 mask 掉,3% 的单词会随机替换成其他单词。
Args:
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
Returns:
text, pad_mask, label (torch.LongTensor, torch.BoolTensor, torch.LongTensor): 返回的每条样本是一个三元组,包括每个单词已经通过词汇表转成 id 的训练文本,单词掩蔽信息和对应的标签
Examples:
.. code-block:: python
from hfai.datasets import CLUEForMLM
dataset = CLUEForMLM()
loader = dataset.loader(batch_size=32, num_workers=4)
for text, pad_mask, label in loader:
# training model
"""
def __init__(self, check_data=True):
super(CLUEForMLM, self).__init__()
self.select_prob = 0.15
self.mask_prob = 0.8
self.random_prob = 0.1
root_dir = get_data_dir() / "CLUE"
data_dir = root_dir / "mlm"
with open(root_dir / "vocab.txt", encoding='utf-8') as f:
self.vocab = {w.strip(): i for i, w in enumerate(f.readlines())}
self.vocab_size = len(self.vocab)
assert self.vocab_size == 8021
self.special = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
self.reader = FileReader(data_dir, check_data)
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
bytes_ = self.reader.read(indices)
seqs = []
for b in bytes_:
seq = np.frombuffer(b, dtype=np.int16)
seqs.append(seq)
seqs = torch.from_numpy(np.stack(seqs, axis=0)) # [N, L]
pad_masks = torch.eq(seqs, self.vocab["[PAD]"])
labels = seqs.clone()
# 15% 选取
prob = torch.full(labels.shape, self.select_prob)
special = torch.zeros_like(labels)
for word in self.special:
special |= torch.eq(labels, self.vocab[word])
prob.masked_fill_(special, 0)
mask = torch.bernoulli(prob).bool()
labels[~mask] = -100
# 80% 遮盖
masked = torch.bernoulli(torch.full(labels.shape, self.mask_prob)).bool() & mask
seqs[masked] = self.vocab["[MASK]"]
# 10% 随机
random_prob = self.random_prob / (1.0 - self.mask_prob)
random = torch.bernoulli(torch.full(labels.shape, random_prob)).bool() & mask & ~masked
seqs[random] = torch.randint_like(labels, self.vocab_size)[random]
return seqs.long(), pad_masks, labels.long()
[docs]@register_dataset
class CLUEForCLS(BaseDataset):
"""
这是用于用于文本分类的 CLUE 数据集。
该数据集包含了多个子数据集:``afqmc``, ``cmnli``, ``csl``, ``iflytek``, ``ocnli``, ``tnews``, ``wsc``,可用来训练文本分类模型。
Args:
dataset_name (str): 子数据集的名字,比如 ``afqmc``
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
Returns:
text, pad_mask, label (torch.LongTensor, torch.BoolTensor, torch.LongTensor): 返回的每条样本是一个三元组,包括每个单词已经通过词汇表转成 id 的训练文本,单词掩蔽信息和对应的标签
Examples:
.. code-block:: python
from hfai.datasets import CLUEForMLM
dataset = CLUEForCLS("afqmc", split="train")
loader = dataset.loader(batch_size=32, num_workers=4)
for text, pad_mask, label in loader:
# training model
"""
def __init__(self, dataset_name, split):
super(CLUEForCLS, self).__init__()
self.max_length = 128
self.dataset_name = dataset_name
self.data_dir = get_data_dir() / "CLUE"
with open(self.data_dir / "vocab.txt", encoding='utf-8') as f:
self.vocab = {w.strip(): i for i, w in enumerate(f.readlines())}
self.vocab_size = len(self.vocab)
assert self.vocab_size == 8021
self.seq, self.label = self.preprocess(self.vocab, split)
self.mask = torch.eq(self.seq, self.vocab["[PAD]"])
def __getitem__(self, i):
return self.seq[i], self.mask[i], self.label[i]
def __len__(self):
return len(self.seq)
def loader(self, *args, **kwargs) -> torch.utils.data.DataLoader:
return torch.utils.data.DataLoader(self, *args, **kwargs)
def preprocess(self, vocab: dict, split: str):
with open(self.data_dir / "clue-cls.json", encoding='utf-8') as fp:
meta = json.load(fp)[self.dataset_name]
keys = meta["keys"]
if split != "test":
keys.pop()
self.classes = meta["classes"]
vecs = []
labels = []
filename = self.data_dir / self.dataset_name / f"{split}.json"
with open(filename, encoding='utf-8') as f:
for line in f.readlines():
data = json.loads(line.strip())
if data["label"] == "-":
continue
labels.append(self.classes.index(data["label"]) if split != "test" else -1)
if self.dataset_name != "wsc":
vec = [vocab["[CLS]"]]
for key in keys:
vec.extend([vocab.get(word, vocab["[UNK]"]) for word in data[key]])
vec.append(vocab["[SEP]"])
vecs.append(torch.tensor(vec[: self.max_length]))
vecs = torch.nn.utils.rnn.pad_sequence(vecs, batch_first=True)
labels = torch.tensor(labels)
return vecs, labels