Login Sign Up

Transfer Learning and Multi-Task Fine-Tuning

Transfer learning is a fundamental concept in AI that allows a model trained on one task to be adapted to another, leveraging its pre-existing knowledge. Multi-task fine-tuning expands on this by enabling a model to handle multiple related tasks simultaneously. These approaches significantly reduce training time, improve generalization, and enhance sample efficiency.

1. What is Transfer Learning?

Transfer learning enables a model trained on a source task to be fine-tuned for a target task, reducing the need for extensive labeled data. This is particularly valuable for large language models (LLMs) and deep learning architectures.

Types of Transfer Learning

Feature-Based Transfer Learning

  • Extracts embeddings or features from a pretrained model and uses them in another model.
  • Example: Using BERT embeddings for text classification without fine-tuning the entire model.

Fine-Tuning-Based Transfer Learning

  • Adapts a pretrained model to a new task by updating its weights.
  • Example: Fine-tuning GPT-4 for medical question answering on a healthcare dataset.

Domain Adaptation

  • Transfers knowledge from a general dataset to a domain-specific dataset.
  • Example: Pretraining on Wikipedia and fine-tuning on legal documents for legal NLP tasks.

2. Multi-Task Fine-Tuning

Multi-task fine-tuning trains a single model on multiple related tasks, allowing it to generalize better across them. This is useful in low-data scenarios and reduces the need for training multiple models.

Key Methods for Multi-Task Learning

Soft Parameter Sharing

  • Different tasks share some parameters while maintaining task-specific layers.
  • Example: T5’s multi-task training for summarization, translation, and sentiment analysis.

Hard Parameter Sharing

  • A single model backbone is shared across tasks, with separate output layers.
  • Example: Multi-lingual models (mBERT, XLM-R) trained on multiple languages.

Adapter-Based Multi-Task Learning

  • AdapterFusion: Uses task-specific adapters while keeping the base model frozen, reducing computational cost.
  • Improves efficiency compared to full fine-tuning, particularly for large models like LLaMA.

Advantages of Multi-Task Fine-Tuning

  • Improves model generalization across tasks.
  • Reduces overfitting by training on diverse datasets.
  • Saves computational resources compared to training separate models for each task.

3. Implementing Transfer Learning and Multi-Task Fine-Tuning

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset, DatasetDict
import torch
import numpy as np
import os

# Step 1: Choose a Pretrained Model
model_name = "t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Step 2: Load and Prepare Multi-Task Data
# For sentiment analysis - use a smaller subset
print("Loading sentiment analysis data...")
sentiment_data = load_dataset("imdb", split="train[:500]")

# Create a small custom translation dataset to avoid disk space issues
print("Creating custom translation dataset...")
# Create German and English sentence lists
de_texts = [
    "Heute ist das Wetter sehr schön.",
    "Ich liebe Programmierung.",
    "Künstliche Intelligenz ist faszinierend.",
    "Die Sonne scheint heute.",
    "Wie heißt du?",
    "Ich möchte eine neue Sprache lernen.",
    "Das Buch ist sehr interessant.",
    "Kannst du mir helfen?",
    "Berlin ist die Hauptstadt von Deutschland.",
    "Ich trinke gerne Kaffee am Morgen.",
    "Mein Auto ist kaputt.",
    "Die Musik ist zu laut.",
    "Ich habe morgen einen wichtigen Termin.",
    "Das Essen schmeckt sehr gut.",
    "Wir treffen uns um 8 Uhr.",
    "Der Film war spannend.",
    "Ich wohne seit drei Jahren hier.",
    "Das ist ein schönes Haus.",
    "Ich arbeite von zu Hause aus.",
    "Meine Schwester kommt morgen zu Besuch."
]

en_texts = [
    "Today the weather is very nice.",
    "I love programming.",
    "Artificial intelligence is fascinating.",
    "The sun is shining today.",
    "What is your name?",
    "I want to learn a new language.",
    "The book is very interesting.",
    "Can you help me?",
    "Berlin is the capital of Germany.",
    "I like to drink coffee in the morning.",
    "My car is broken.",
    "The music is too loud.",
    "I have an important appointment tomorrow.",
    "The food tastes very good.",
    "We meet at 8 o'clock.",
    "The movie was exciting.",
    "I have been living here for three years.",
    "This is a beautiful house.",
    "I work from home.",
    "My sister is coming to visit tomorrow."
]

# Multiply the data to get more examples (about 500)
n_copies = 25  # 20 pairs × 25 = 500 examples
de_texts = de_texts * n_copies
en_texts = en_texts * n_copies

# Create dataset with proper structure
translation_data = Dataset.from_dict({
    "de": de_texts,
    "en": en_texts
})

# T5 expects specific input formats for different tasks
print("Preprocessing sentiment data...")
def preprocess_sentiment_data(examples):
    # T5 format for sentiment classification: "sentiment: {text}"
    inputs = ["sentiment: " + text for text in examples["text"]]
    targets = ["positive" if label == 1 else "negative" for label in examples["label"]]
    
    # Use dynamic padding with DataCollator instead of padding here
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    # Add labels without padding
    labels = tokenizer(text_target=targets, max_length=8, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    # Add task identifier for tracking
    model_inputs["task"] = ["sentiment"] * len(inputs)
    return model_inputs

print("Preprocessing translation data...")
def preprocess_translation_data(examples):
    # FIX: Access fields directly since we created a flat structure
    inputs = ["translate German to English: " + text for text in examples["de"]]
    targets = examples["en"]
    
    # Use dynamic padding with DataCollator instead of padding here
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    # Add labels without padding
    labels = tokenizer(text_target=targets, max_length=128, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    # Add task identifier for tracking
    model_inputs["task"] = ["translation"] * len(inputs)
    return model_inputs

# Process each dataset
sentiment_processed = sentiment_data.map(preprocess_sentiment_data, batched=True, remove_columns=sentiment_data.column_names)
translation_processed = translation_data.map(preprocess_translation_data, batched=True, remove_columns=translation_data.column_names)

print("Merging datasets...")
# Merge datasets
all_processed_data = Dataset.from_dict({
    "input_ids": sentiment_processed["input_ids"] + translation_processed["input_ids"],
    "attention_mask": sentiment_processed["attention_mask"] + translation_processed["attention_mask"],
    "labels": sentiment_processed["labels"] + translation_processed["labels"],
    "task": sentiment_processed["task"] + translation_processed["task"]
})

print("Creating train-validation split...")
# Create train-validation split
train_val = all_processed_data.train_test_split(test_size=0.1)
datasets = DatasetDict({
    "train": train_val["train"],
    "validation": train_val["test"]
})

# Use DataCollatorForSeq2Seq for dynamic padding
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
    return_tensors="pt"
)

# Step 3: Define improved evaluation metrics
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    # Create lists to store decoded sequences
    decoded_preds = []
    decoded_labels = []
    
    # Process each example in the batch
    for pred, label in zip(pred_ids, labels_ids):
        # Filter out padding tokens (-100)
        pred = [p for p in pred if p != -100]
        label = [l for l in label if l != -100]
        
        # Decode predictions and labels
        decoded_preds.append(tokenizer.decode(pred, skip_special_tokens=True))
        decoded_labels.append(tokenizer.decode(label, skip_special_tokens=True))
    
    # Simple accuracy for sentiment analysis and translation
    accuracy = sum([1 if p == l else 0 for p, l in zip(decoded_preds, decoded_labels)]) / len(decoded_preds)
    
    return {"accuracy": accuracy}

# Create output directory if it doesn't exist
output_dir = "./t5_multi_task_model"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Step 4: Configure training arguments
print("Setting up training arguments...")
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=4,  # Reduced batch size to save memory
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=1,  # Save only the last model to save disk space
    num_train_epochs=3,  # Increase epochs slightly for better performance
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),  # Only use fp16 if GPU is available
    report_to="none",  # Remove if you want to log to wandb or tensorboard
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=50
)

# Step 5: Initialize Trainer with data collator
print("Initializing trainer...")
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,  # Use data collator for dynamic padding
    compute_metrics=compute_metrics
)

# Step 6: Train the model
print("Starting training...")
trainer.train()

# Step 7: Save the fine-tuned model
print("Saving model...")
model_save_path = "./t5_multi_task_final"
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)
    
model.save_pretrained(model_save_path)
tokenizer.save_pretrained(model_save_path)

# Step 8: Example inference with task-specific evaluation
print("Testing model...")
def predict(text, task_prefix):
    input_text = f"{task_prefix}: {text}"
    inputs = tokenizer(input_text, return_tensors="pt", padding=True)
    
    # Move to GPU if available
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}
        model.cuda()
    
    outputs = model.generate(**inputs, max_length=128)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test sentiment analysis
sentiment_examples = [
    "This movie was amazing and I loved every minute of it",
    "The film was boring and I fell asleep halfway through"
]
print("Sentiment Analysis Results:")
for example in sentiment_examples:
    result = predict(example, "sentiment")
    print(f"Input: {example}\nResult: {result}\n")

# Test translation
translation_examples = [
    "Heute ist das Wetter sehr schön",
    "Ich habe morgen einen wichtigen Termin"
]
print("Translation Results:")
for example in translation_examples:
    result = predict(example, "translate German to English")
    print(f"Input: {example}\nResult: {result}\n")

print("Multi-task learning completed successfully!")

4. When to Use Transfer Learning vs. Multi-Task Fine-Tuning

Transfer learning and multi-task fine-tuning are powerful techniques for leveraging existing AI models in new applications. While transfer learning excels in adapting models to single specialized tasks, multi-task fine-tuning enhances generalization across multiple related tasks. Using methods like AdapterFusion, we can efficiently fine-tune large models without excessive computational cost.