Fine-tuning pretrained models like BERT can yield excellent results on downstream tasks such as sentiment classification. However, training the entire model can be resource-intensive. A common solution is to freeze certain layers of the model and only fine-tune the ones that matter most.
In this section, we’ll demonstrate how to freeze layers in BERT while fine-tuning the last two encoder blocks and the classification head for a sentiment analysis task using the Amazon Polarity dataset.
We’ll use the Hugging Face transformers and datasets libraries to handle the entire training pipeline.
We load a small subset of the Amazon Polarity dataset, tokenize it, and prepare it for binary classification.
from datasets import load_dataset
# Load and split the dataset
dataset = load_dataset("amazon_polarity", split="train[:10000]")
dataset = dataset.train_test_split(test_size=0.2)
train_data = dataset["train"]
test_data = dataset["test"]
Next, we load the BERT model and tokenizer.
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)Tokenize the input data:
def tokenize_batch(batch):
return tokenizer(batch["content"], truncation=True)
tokenized_train = train_data.map(tokenize_batch, batched=True)
tokenized_test = test_data.map(tokenize_batch, batched=True)To reduce training time and memory usage, we freeze all layers except for:
trainable_layers = ["encoder.layer.10", "encoder.layer.11", "pooler", "classifier"]
for name, param in model.named_parameters():
param.requires_grad = any(layer in name for layer in trainable_layers)We use Hugging Face’s Trainer API to train the model.
# fine_tune_bert_amazon.py
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
DataCollatorWithPadding
)
from sklearn.metrics import f1_score
import numpy as np
import torch
# 1. Load and prepare the dataset
dataset = load_dataset("amazon_polarity", split="train[:10000]")
dataset = dataset.train_test_split(test_size=0.2)
train_data = dataset["train"]
test_data = dataset["test"]
# 2. Load tokenizer and model
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
# 3. Tokenization function
def tokenize_batch(batch):
return tokenizer(batch["content"], truncation=True)
tokenized_train = train_data.map(tokenize_batch, batched=True)
tokenized_test = test_data.map(tokenize_batch, batched=True)
# 4. Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# 5. Metric computation
def compute_metrics(pred):
preds = np.argmax(pred.predictions, axis=1)
return {"f1": f1_score(pred.label_ids, preds)}
# 6. Freeze all layers except the last two encoder blocks and classifier
trainable_layers = ["encoder.layer.10", "encoder.layer.11", "pooler", "classifier"]
for name, param in model.named_parameters():
if any(layer in name for layer in trainable_layers):
param.requires_grad = True
else:
param.requires_grad = False
# 7. Training arguments
training_args = TrainingArguments(
output_dir="./results",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=1,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_dir="./logs",
logging_steps=10,
report_to="none" # disable wandb or other reporters
)
# 8. Trainer setup
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_test,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)
# 9. Train and evaluate
trainer.train()
results = trainer.evaluate()
print("\n--- Evaluation Results ---")
for key, value in results.items():
print(f"{key}: {value:.4f}")After training for one epoch, the model yielded an F1 score of approximately 0.80, with significantly reduced training time compared to fine-tuning the full model.
— Evaluation Results — eval_loss: 0.42 eval_f1: 0.80 eval_runtime: 3.8 eval_samples_per_second: 263.15 |