# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset for gpt."""
import numpy as np
import oneflow as flow
from libai.data.data_utils import BlockIndexedDataset
from libai.data.structures import DistTensorData, Instance
[docs]class GPT2Dataset(flow.utils.data.Dataset):
"""Dataset containing sentences for GPT2 training.
Args:
tokenizer: Tokenizer to use.
data_prefix (str): Path to the training dataset.
indexed_dataset: Indexed dataset to use.
max_seq_length (int, optional): Maximum length of the sequence passing into encoder.
All values are padded to this length. Defaults to 512.
"""
def __init__(self, tokenizer, data_prefix, indexed_dataset, max_seq_length=512):
self.dataset = BlockIndexedDataset(
data_prefix, indexed_dataset, max_seq_length=max_seq_length
)
self.tokenizer = tokenizer
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
text = np.array(self.dataset[idx], dtype=np.long)
input_ids = flow.tensor(text[:-1], dtype=flow.long)
lm_labels = flow.tensor(text[1:], dtype=flow.long)
sample = Instance(
input_ids=DistTensorData(input_ids),
labels=DistTensorData(lm_labels, placement_idx=-1),
)
return sample
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch
def prefetch(self, indices):
self.dataset.prefetch(indices)