# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset for bert."""
import collections
import math
import numpy as np
import oneflow as flow
from libai.data.data_utils import SentenceIndexedDataset
from libai.data.structures import DistTensorData, Instance
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"])
def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return not piece.startswith("##")
[docs]class BertDataset(flow.utils.data.Dataset):
"""Dataset containing sentence pairs for BERT training.
Each index corresponds to a randomly generated sentence pair.
Args:
tokenizer: Tokenizer to use.
data_prefix: Path to the training dataset.
indexed_dataset: Indexed dataset to use.
max_seq_length: Maximum length of the sequence. All values are padded to
this length. Defaults to 512.
mask_lm_prob: Probability to mask tokens. Defaults to 0.15.
short_seq_prob: Probability of producing a short sequence. Defaults to 0.0.
max_preds_per_seq: Maximum number of mask tokens in each sentence. Defaults to None.
seed: Seed for random number generator for reproducibility. Defaults to 1234.
binary_head: Specifies whether the underlying dataset
generates a pair of blocks along with a sentence_target or not.
Setting it to True assumes that the underlying dataset generates a
label for the pair of sentences which is surfaced as
sentence_target. Defaults to True.
"""
def __init__(
self,
tokenizer,
data_prefix,
indexed_dataset,
max_seq_length=512,
mask_lm_prob=0.15,
short_seq_prob=0.0,
max_preds_per_seq=None,
seed=1234,
binary_head=True,
):
self.seed = seed
self.mask_lm_prob = mask_lm_prob
self.max_seq_length = max_seq_length
self.short_seq_prob = short_seq_prob
self.binary_head = binary_head
if max_preds_per_seq is None:
max_preds_per_seq = math.ceil(max_seq_length * mask_lm_prob / 10) * 10
self.max_preds_per_seq = max_preds_per_seq
self.dataset = SentenceIndexedDataset(
data_prefix,
indexed_dataset,
max_seq_length=self.max_seq_length - 3,
short_seq_prob=self.short_seq_prob,
binary_head=self.binary_head,
)
self.tokenizer = tokenizer
self.vocab_id_list = list(tokenizer.get_vocab().values())
self.cls_id = tokenizer.cls_token_id
self.sep_id = tokenizer.sep_token_id
self.mask_id = tokenizer.mask_token_id
self.pad_id = tokenizer.pad_token_id
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
sents = self.dataset[idx]
if self.binary_head:
tokens_a, tokens_b, is_next_random = self.create_random_sentence_pair(sents, np_rng)
else:
tokens_a = []
for j in range(len(sents)):
tokens_a.extend(sents[j])
tokens_b = []
is_next_random = False
tokens_a, tokens_b = self.truncate_seq_pair(
tokens_a, tokens_b, self.max_seq_length - 3, np_rng
)
tokens, token_types = self.create_tokens_and_token_types(tokens_a, tokens_b)
tokens, masked_positions, masked_labels = self.create_masked_lm_predictions(tokens, np_rng)
(
tokens,
token_types,
labels,
padding_mask,
loss_mask,
) = self.pad_and_convert_to_tensor(tokens, token_types, masked_positions, masked_labels)
sample = Instance(
input_ids=DistTensorData(tokens),
attention_mask=DistTensorData(padding_mask),
tokentype_ids=DistTensorData(token_types),
ns_labels=DistTensorData(
flow.tensor(int(is_next_random), dtype=flow.long), placement_idx=-1
),
lm_labels=DistTensorData(labels, placement_idx=-1),
loss_mask=DistTensorData(loss_mask, placement_idx=-1),
)
return sample
def create_random_sentence_pair(self, sample, np_rng):
num_sentences = len(sample)
assert num_sentences > 1, "make sure each sample has at least two sentences."
a_end = 1
if num_sentences >= 3:
a_end = np_rng.randint(1, num_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
tokens_b = []
for j in range(a_end, num_sentences):
tokens_b.extend(sample[j])
is_next_random = False
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
[docs] def truncate_seq_pair(self, tokens_a, tokens_b, max_num_tokens, np_rng):
"""truncate sequence pair to a maximum sequence length"""
len_a, len_b = len(tokens_a), len(tokens_b)
while True:
total_length = len_a + len_b
if total_length <= max_num_tokens:
break
if len_a > len_b:
trunc_tokens = tokens_a
len_a -= 1
else:
trunc_tokens = tokens_b
len_b -= 1
if np_rng.random() < 0.5:
trunc_tokens.pop(0) # remove the first element
else:
trunc_tokens.pop() # remove the last element
return tokens_a, tokens_b
[docs] def create_tokens_and_token_types(self, tokens_a, tokens_b):
"""merge segments A and B, add [CLS] and [SEP] and build token types."""
tokens = [self.cls_id] + tokens_a + [self.sep_id]
token_types = [0] * (len(tokens_a) + 2)
if len(tokens_b) > 0:
tokens = tokens + tokens_b + [self.sep_id]
token_types = token_types + [1] * (len(tokens_b) + 1)
return tokens, token_types
[docs] def mask_token(self, idx, tokens, np_rng):
"""
helper function to mask `idx` token from `tokens` according to
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label = tokens[idx]
if np_rng.random() < 0.8:
new_label = self.mask_id
else:
if np_rng.random() < 0.5:
new_label = label
else:
new_label = np_rng.choice(self.vocab_id_list)
tokens[idx] = new_label
return label
[docs] def create_masked_lm_predictions(
self,
tokens,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
geometric_dist=False,
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = []
token_boundary = [0] * len(tokens)
new_tokens = []
for (i, token) in enumerate(tokens):
new_tokens.append(token % len(self.tokenizer))
if token == self.cls_id or token == self.sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (
do_whole_word_mask
and len(cand_indexes) >= 1
and not is_start_piece(self.tokenizer._convert_id_to_token(token))
):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(self.tokenizer._convert_id_to_token(token)):
token_boundary[i] = 1
tokens = new_tokens
masked_positions = []
masked_labels = []
output_tokens = list(tokens)
if self.mask_lm_prob == 0:
return output_tokens, masked_positions, masked_labels
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == self.cls_id or token == self.sep_id:
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if do_whole_word_mask and len(cand_indexes) >= 1 and token_boundary[i] == 0:
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
num_to_predict = min(
self.max_preds_per_seq, max(1, int(round(len(tokens) * self.mask_lm_prob)))
)
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
if not geometric_dist:
# By default, we set the probilities to favor shorter ngram sequences.
pvals = 1.0 / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx : idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
masked_lms = []
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
if not geometric_dist:
n = np_rng.choice(
ngrams[: len(cand_index_set)],
p=pvals[: len(cand_index_set)]
/ pvals[: len(cand_index_set)].sum(keepdims=True),
)
else:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n = min(np_rng.geometric(0.2), max_ngrams)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
label = self.mask_token(index, output_tokens, np_rng)
masked_lms.append(MaskedLmInstance(index=index, label=label))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
for p in masked_lms:
masked_positions.append(p.index)
masked_labels.append(p.label)
return output_tokens, masked_positions, masked_labels
[docs] def pad_and_convert_to_tensor(self, tokens, token_types, masked_positions, masked_labels):
"""pad sequences and convert them to tensor"""
# check
num_tokens = len(tokens)
num_pad = self.max_seq_length - num_tokens
assert num_pad >= 0
assert len(token_types) == num_tokens
assert len(masked_positions) == len(masked_labels)
# tokens and token types
filler = [self.pad_id] * num_pad
tokens = flow.tensor(tokens + filler, dtype=flow.long)
token_types = flow.tensor(token_types + filler, dtype=flow.long)
# padding mask
padding_mask = flow.tensor([1] * num_tokens + [0] * num_pad, dtype=flow.long)
# labels and loss mask
labels = [-1] * self.max_seq_length
loss_mask = [0] * self.max_seq_length
for idx, label in zip(masked_positions, masked_labels):
assert idx < num_tokens
labels[idx] = label
loss_mask[idx] = 1
labels = flow.tensor(labels, dtype=flow.long)
loss_mask = flow.tensor(loss_mask, dtype=flow.long)
return tokens, token_types, labels, padding_mask, loss_mask
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch
def prefetch(self, indices):
self.dataset.prefetch(indices)