How to Train BERT from Scratch using Transformers in Python

Learn how you can pretrain BERT and other transformers on the Masked Language Modeling (MLM) task on your custom dataset using Huggingface Transformers library in Python
  · 15 min read · Updated jun 2022 · Machine Learning · Natural Language Processing

Open In Colab

A pre-trained model is a model that was previously trained on a large dataset and saved for direct use or fine-tuning. In this tutorial, you will learn how you can train BERT (or any other transformer model) from scratch on your custom raw text dataset with the help of the Huggingface transformers library in Python.

Pre-training on transformers can be done with self-supervised tasks, below are some of the popular tasks done on BERT:

  • Masked Language Modeling (MLM): This task consists of masking a certain percentage of the tokens in the sentence, and the model is trained to predict those masked words. We'll be using this one in this tutorial.
  • Next Sentence Prediction (NSP): The model receives pairs of sentences as input and learns to predict whether the second sentence in the pair is the subsequent sentence in the original document.

To get started, we need to install 3 libraries:

$ pip install datasets transformers==4.18.0 sentencepiece

If you want to follow along, open up a new notebook, or Python file and import the necessary libraries:

from datasets import *
from transformers import *
from tokenizers import *
import os
import json

Picking a Dataset

If you're willing to pre-train a transformer, then you most likely have a custom dataset. But for demonstration purposes in this tutorial, we're going to use the cc_news dataset, we'll be using huggingface datasets library for that. As a result, make sure to follow this link to get your custom dataset to be loaded into the library.

CC-News dataset contains news articles from news sites all over the world. It contains 708,241 news articles in English published between January 2017 and December 2019.

Downloading and preparing the dataset:

# download and prepare cc_news dataset
dataset = load_dataset("cc_news", split="train")

There is only one split in the dataset, so we need to split it into training and testing sets:

# split the dataset into training (90%) and testing (10%)
d = dataset.train_test_split(test_size=0.1)
d["train"], d["test"]

You can also pass the seed parameter to the train_test_split() method so it'll be the same sets after running multiple times.


     features: ['title', 'text', 'domain', 'date', 'description', 'url', 'image_url'],
     num_rows: 637416
 }), Dataset({
     features: ['title', 'text', 'domain', 'date', 'description', 'url', 'image_url'],
     num_rows: 70825

Let's see how it looks like:

for t in d["train"]["text"][:3]:

Output (stripped):

Pretty sure women wish men did this better too!!
Q: A recent survey showed that 1/3 of men wish they did THIS better. What is...<STRIPPED>
× GoDaddy boots neo-Nazi site after a derogatory story on the Charlottesville victim
The Daily Stormer, a white supremacist and neo-Nazi website,...<STRIPPED>
French bank Natixis under investigation over subprime losses
PARIS, Feb 15 Natixis has been placed under formal investigation...<STRIPPED>

As mentioned previously, if you have your custom dataset, you can either follow the link of setting up your dataset to be loaded as above, or you can use the LineByLineTextDataset class if your custom dataset is a text file where all sentences are separated by a new line.

However, a better way to set up your custom dataset is to split your text file into several chunk files using the split command or any other Python code, and load them using load_dataset() as we did above, like this:

# if you have huge custom dataset separated into files
# load the splitted files
files = ["train1.txt", "train2.txt"] # train3.txt, etc.
dataset = load_dataset("text", data_files=files, split="train")

If you have your custom data as one massive file, then you should divide it into a handful of text files (such as using the split command on Linux or Colab) before loading them using the load_dataset() function, as the runtime will crash if it exceeds the memory.

Training the Tokenizer

Next, we need to train our tokenizer. To do that, we need to write our dataset into text files, as that's what the tokenizers library requires the input to be:

# if you want to train the tokenizer from scratch (especially if you have custom
# dataset loaded as datasets object), then run this cell to save it as files
# but if you already have your custom data as text files, there is no point using this
def dataset_to_text(dataset, output_filename="data.txt"):
  """Utility function to save dataset text to disk,
  useful for using the texts to train the tokenizer 
  (as the tokenizer accepts files)"""
  with open(output_filename, "w") as f:
    for t in dataset["text"]:
      print(t, file=f)

# save the training set to train.txt
dataset_to_text(d["train"], "train.txt")
# save the testing set to test.txt
dataset_to_text(d["test"], "test.txt")

The main purpose of the above code cell is to save the dataset object as text files. If you already have your dataset as text files, then you should skip this step. Next, let's define some parameters:

special_tokens = [
  "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "<S>", "<T>"
# if you want to train the tokenizer on both sets
# files = ["train.txt", "test.txt"]
# training the tokenizer on the training set
files = ["train.txt"]
# 30,522 vocab is BERT's default vocab size, feel free to tweak
vocab_size = 30_522
# maximum sequence length, lowering will result to faster training (when increasing batch size)
max_length = 512
# whether to truncate
truncate_longer_samples = False

The files list is the list of files to pass to the tokenizer for training. vocab_size is the vocabulary size of tokens. max_length is the maximum sequence length.

truncate_longer_samples is a boolean indicating whether we truncate sentences longer than the length of max_length, if it's set to False, we won't truncate the sentences, we group them together and split them by max_length, so all the resulting sentences will have the length of max_length.

Let's train the tokenizer now:

# initialize the WordPiece tokenizer
tokenizer = BertWordPieceTokenizer()
# train the tokenizer
tokenizer.train(files=files, vocab_size=vocab_size, special_tokens=special_tokens)
# enable truncation up to the maximum 512 tokens

Since this is BERT, the default tokenizer is WordPiece. As a result, we initialize the BertWordPieceTokenizer() tokenizer class from the tokenizers library and use the train() method to train it, it will take several minutes to finish. Let's save it now:

model_path = "pretrained-bert"
# make the directory if not already there
if not os.path.isdir(model_path):
# save the tokenizer  
# dumping some of the tokenizer config to config file, 
# including special tokens, whether to lower case and the maximum sequence length
with open(os.path.join(model_path, "config.json"), "w") as f:
  tokenizer_cfg = {
      "do_lower_case": True,
      "unk_token": "[UNK]",
      "sep_token": "[SEP]",
      "pad_token": "[PAD]",
      "cls_token": "[CLS]",
      "mask_token": "[MASK]",
      "model_max_length": max_length,
      "max_len": max_length,
  json.dump(tokenizer_cfg, f)

The tokenizer.save_model() method saves the vocabulary file into that path, we also manually save some tokenizer configurations, such as special tokens:

  • unk_token: A special token that represents an out-of-vocabulary token, even though the tokenizer is a WordPiece tokenizer, the unk tokens are not impossible, but rare.
  • sep_token: A special token that separates two different sentences in the same input.
  • pad_token: A special token that is used to fill sentences that do not reach the maximum sequence length (since the arrays of tokens must be the same size).
  • cls_token: A special token representing the class of the input.
  • mask_token: This is the mask token we use for the Masked Language Modeling (MLM) pretraining task.

After the training of the tokenizer is completed, let's load it now:

# when the tokenizer is trained and configured, load it as BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(model_path)

Of course, if you want to use the tokenizer multiple times, you don't have to train it again, simply load it using the above cell.

Tokenizing the Dataset

Now that we have the tokenizer ready, the below code is responsible for tokenizing the dataset:

def encode_with_truncation(examples):
  """Mapping function to tokenize the sentences passed with truncation"""
  return tokenizer(examples["text"], truncation=True, padding="max_length",
                   max_length=max_length, return_special_tokens_mask=True)

def encode_without_truncation(examples):
  """Mapping function to tokenize the sentences passed without truncation"""
  return tokenizer(examples["text"], return_special_tokens_mask=True)

# the encode function will depend on the truncate_longer_samples variable
encode = encode_with_truncation if truncate_longer_samples else encode_without_truncation
# tokenizing the train dataset
train_dataset = d["train"].map(encode, batched=True)
# tokenizing the testing dataset
test_dataset = d["test"].map(encode, batched=True)
if truncate_longer_samples:
  # remove other columns and set input_ids and attention_mask as PyTorch tensors
  train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
  test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
  # remove other columns, and remain them as Python lists
  test_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])
  train_dataset.set_format(columns=["input_ids", "attention_mask", "special_tokens_mask"])

The encode() callback that we use to tokenize our dataset depends on the truncate_longer_samples boolean variable. If set to True, then we truncate sentences that exceed the maximum sequence length (max_length parameter). Otherwise, we don't.

Next, in the case of setting truncate_longer_samples to False, we need to join our untruncated samples together and cut them into fixed-size vectors since the model expects a fixed-sized sequence during training:

from itertools import chain
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
# grabbed from:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= max_length:
        total_length = (total_length // max_length) * max_length
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + max_length] for i in range(0, total_length, max_length)]
        for k, t in concatenated_examples.items()
    return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
if not truncate_longer_samples:
  train_dataset =, batched=True,
                                    desc=f"Grouping texts in chunks of {max_length}")
  test_dataset =, batched=True,
                                  desc=f"Grouping texts in chunks of {max_length}")
  # convert them from lists to torch tensors

Most of the above code was brought from the script from the huggingface transformers examples, so this is actually used by the library itself.

If you don't want to concatenate all texts and then split them into chunks of 512 tokens, then make sure you set truncate_longer_samples to True, so it will treat each line as an individual sample regardless of its length. If you set truncate_longer_samples to True, the above code cell won't be executed at all.

len(train_dataset), len(test_dataset)


(643843, 71357)

Loading the Model

For this tutorial, we're picking BERT, but feel free to pick any of the transformer models supported by huggingface transformers library, such as RobertaForMaskedLM or DistilBertForMaskedLM:

# initialize the model with the config
model_config = BertConfig(vocab_size=vocab_size, max_position_embeddings=max_length)
model = BertForMaskedLM(config=model_config)

We initialize the model config using BertConfig, and pass the vocabulary size as well as the maximum sequence length. We then pass the config to BertForMaskedLM to initialize the model itself.


Before we start pre-training our model, we need a way to randomly mask tokens in our dataset for the Masked Language Model (MLM) task. Luckily, the library makes this easy for us by simply constructing a DataCollatorForLanguageModeling object:

# initialize the data collator, randomly masking 20% (default is 15%) of the tokens for the Masked Language
# Modeling (MLM) task
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.2

We pass the tokenizer and set mlm to True, and also set the mlm_probability to 0.2 to randomly replace each token with [MASK] token by 20% probability.

Next, let's initialize our training arguments:

training_args = TrainingArguments(
    output_dir=model_path,          # output directory to where save model checkpoint
    evaluation_strategy="steps",    # evaluate each `logging_steps` steps
    num_train_epochs=10,            # number of training epochs, feel free to tweak
    per_device_train_batch_size=10, # the training batch size, put it as high as your GPU memory fits
    gradient_accumulation_steps=8,  # accumulating the gradients before updating the weights
    per_device_eval_batch_size=64,  # evaluation batch size
    logging_steps=1000,             # evaluate, log and save model checkpoints every 1000 step
    # load_best_model_at_end=True,  # whether to load the best model (in terms of loss) at the end of training
    # save_total_limit=3,           # whether you don't have much space so you let only 3 model weights saved in the disk

Each argument is explained in the comments, refer to the TrainingArguments docs for more details. Let's make our trainer now:

# initialize the trainer and pass everything to it
trainer = Trainer(

We pass our training arguments to the Trainer, as well as the model, data collator, and the training sets. We simply call train() now to start training:

# train the model
[10135/79670 18:53:08 < 129:35:53, 0.15 it/s, Epoch 1.27/10]
Step	Training Loss	Validation Loss
1000	6.904000	6.558231
2000	6.498800	6.401168
3000	6.362600	6.277831
4000	6.251000	6.172856
5000	6.155800	6.071129
6000	6.052800	5.942584
7000	5.834900	5.546123
8000	5.537200	5.248503
9000	5.272700	4.934949
10000	4.915900	4.549236

The training will take several hours to several days, depending on the dataset size, training batch size (i.e increase it as much as your GPU memory fits), and GPU speed.

As you can see in the output, the model is still improving and the validation loss is still decreasing. You usually have to cancel the training once the validation loss stops decreasing or decreasing very slowly.

Since we have set logging_steps and save_steps to 1000, then the trainer will evaluate and save the model after every 1000 steps (i.e trained on steps x gradient_accumulation_step x per_device_train_size1000x8x10 = 80,000 samples). As a result, I have canceled the training after about 19 hours of training, or 10000 steps (that is about 1.27 epochs, or trained on 800,000 samples), and started to use the model. In the next section, we'll see how we can use the model for inference.

Using the Model

Before we use the model, let's assume we don't have model and tokenizer variables in the current runtime. Therefore, we need to load them again:

# load the model checkpoint
model = BertForMaskedLM.from_pretrained(os.path.join(model_path, "checkpoint-10000"))
# load the tokenizer
tokenizer = BertTokenizerFast.from_pretrained(model_path)

If you're on Google Colab, then you have to save your checkpoints in Google Drive for later use, you can do that by setting model_path to a drive path instead of a local path like we did here, just make sure you have enough space there.

Alternatively, you can push your model and tokenizer into the huggingface hub, check this useful guide to do it.

Let's use our model now:

fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)

We use the simple pipeline API, and pass both the model and the tokenizer. Let's predict some examples:

# perform predictions
examples = [
  "Today's most trending hashtags on [MASK] is Donald Trump",
  "The [MASK] was cloudy yesterday, but today it's rainy.",
for example in examples:
  for prediction in fill_mask(example):
    print(f"{prediction['sequence']}, confidence: {prediction['score']}")


today's most trending hashtags on twitter is donald trump, confidence: 0.1027069091796875
today's most trending hashtags on monday is donald trump, confidence: 0.09271949529647827
today's most trending hashtags on tuesday is donald trump, confidence: 0.08099588006734848
today's most trending hashtags on facebook is donald trump, confidence: 0.04266013577580452
today's most trending hashtags on wednesday is donald trump, confidence: 0.04120611026883125
the weather was cloudy yesterday, but today it's rainy., confidence: 0.04445931687951088
the day was cloudy yesterday, but today it's rainy., confidence: 0.037249673157930374
the morning was cloudy yesterday, but today it's rainy., confidence: 0.023775646463036537
the weekend was cloudy yesterday, but today it's rainy., confidence: 0.022554103285074234
the storm was cloudy yesterday, but today it's rainy., confidence: 0.019406016916036606

That's impressive, I have canceled the training and the model is still producing interesting results! If your model does not make good predictions, then that's a good indicator that it wasn't trained enough.


And there you have a complete code for pretraining BERT or other transformers using Huggingface libraries, below are some tips:

  • As mentioned above, the training speed will depend on the GPU speed, the number of samples in the dataset, and batch size. I have set the training batch size to 10, as that's the maximum it can fit my GPU memory on Colab. If you have more memory, make sure to increase it so you increase the training speed significantly.
  • During training, if you see the validation loss starts to increase, make sure to remember the checkpoint where the lowest validation loss occurs so you can load that checkpoint later for use. You can also set load_best_model_at_end to True if you don't want to keep track of the loss, as it will load the best weights in terms of loss when the training ends.
  • The vocabulary size was chosen based on the original BERT configuration, as it had the size of 30,522, feel free to increase it if you feel the language of your dataset has a large vocabulary, or you can experiment with this.
  • If you set truncate_longer_samples to False, then the code assumes you have larger text on one sentence (i.e line), you will notice that it takes much longer to process, especially if you set a large batch_size on the map() method. If it takes a lot of hours to process, then you can either set truncate_longer_samples to True so you truncate sentences that exceed max_length tokens or you can save the dataset after processing using the save_to_disk() method, so you process it once and load it several times.
  • In a newer version of the transformers library, there is a new parameter called auto_find_batch_size in the TrainingArguments() class, you can pass it as True so it'll find the optimal batch size for your GPU, avoiding Out-of-Memory errors. Make sure you have accelerate library installed: pip install accelerate.

If you're interested in fine-tuning BERT for a downstream task such as text classification, then this tutorial guides you through it.

Other related tutorials:

Check the full code here.

Open In Colab

Happy learning ♥

View Full Code
Sharing is caring!

Read Also

Comment panel