Deep Learning Project: Training and Fine-Tuning a Language Model with Unsloth

Bayram EKER
10 min readJust now

--

In this article, we will walk through the process of training and fine-tuning a language model using the Unsloth library. We will break down each section of the provided code, explaining its functionality and purpose. Additionally, we will offer tips to enhance and optimize the project further.

1. Installing and Upgrading Required Libraries

Before starting the project, it’s essential to install and update the necessary libraries.

!pip install unsloth
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

Explanation

  • Installing Unsloth Library: The first command installs the unsloth library.
  • Upgrading Unsloth Library: The second command uninstalls the existing unsloth installation and reinstalls the latest version directly from the GitHub repository, ensuring you have the most recent updates and features.

Tips

  • Version Control: Regularly check and manage library versions to avoid compatibility issues.
  • Dependency Management: While the --no-deps option speeds up installation by skipping dependencies, ensure all required dependencies are manually managed if necessary.

2. Importing Necessary Modules

We import the modules that will be used throughout the project.

from unsloth import FastLanguageModel
import torch
from datasets import load_dataset
import os
import json
import re
import random
from sklearn.model_selection import train_test_split

Explanation

  • FastLanguageModel: A class from the Unsloth library for handling language models efficiently.
  • torch: PyTorch library for deep learning operations.
  • datasets: Library to load and manage datasets.
  • os, json, re, random: Standard Python libraries for system operations, JSON handling, regular expressions, and random operations.
  • train_test_split: Function from scikit-learn to split datasets into training and validation sets.

Tips

  • Modular Code: Importing all necessary modules at the beginning enhances readability and simplifies debugging.

3. Setting Configuration Parameters

We define the configuration settings that will be used during the training process.

# Configuration
max_seq_length = 2048
load_in_4bit = True # Efficient memory usage
dtype = None

# Directory to store checkpoints
checkpoint_dir = "/content/drive/MyDrive/Defense/outputs_Meta-Llama-3.1-8B-bnb-4bit"

Explanation

  • max_seq_length: Maximum sequence length the model can process.
  • load_in_4bit: Enables 4-bit loading for efficient memory usage.
  • dtype: Data type configuration (currently set to None).
  • checkpoint_dir: Directory path where model checkpoints will be saved.

Tips

  • Memory Management: Using 4-bit loading significantly reduces memory consumption, allowing larger models to be trained on limited hardware.
  • Dynamic Configuration: Consider loading configuration parameters from an external file (e.g., JSON or YAML) for greater flexibility.

4. Checkpoint Management Functions

We define functions to manage model checkpoints during training.

def is_model_processed(model_name):
checkpoint_path = os.path.join(checkpoint_dir, model_name.split("/")[-1], "checkpoint-500", "trainer_state.json")
print(f"Checking checkpoint path: {checkpoint_path}")
return os.path.exists(checkpoint_path)

def mark_model_as_processed(model_name):
checkpoint_file = os.path.join(checkpoint_dir, f"{model_name.replace('/', '_')}.done")
print(f"Marking model as processed: {checkpoint_file}")
with open(checkpoint_file, 'w') as f:
f.write("")

Explanation

  • is_model_processed: Checks if a checkpoint exists for a specific model by verifying the existence of the trainer_state.json file.
  • mark_model_as_processed: Marks a model as processed by creating a .done file, indicating that the model has been successfully trained and saved.

Tips

  • State Tracking: Using checkpoint files helps in resuming training seamlessly in case of interruptions.
  • File Naming: Replace special characters in model names to prevent issues with file paths.

5. Data Preprocessing, Validation, and Augmentation

We preprocess the dataset to ensure it is clean, valid, and augmented for better model performance.

def preprocess_dataset(input_path, output_path, train_path, val_path, augmentation_factor=3):
print("Preprocessing, validating, and augmenting dataset...")
valid_entries = 0

def clean_text(text):
"""Normalize and clean text."""
text = re.sub(r"[^a-zA-Z0-9ğüşıöçĞÜŞİÖÇ.,!?\\-]", " ", text) # Remove unwanted characters
text = re.sub(r"\s+", " ", text).strip() # Remove extra spaces
return text.lower() # Normalize to lowercase

def augment_text(text):
"""Create variations of text for augmentation."""
synonyms = {
"highlight": ["emphasize", "focus on", "spotlight"],
"identify": ["detect", "recognize", "pinpoint"],
"discuss": ["elaborate on", "examine", "analyze"],
"important": ["crucial", "key", "essential"]
}
for word, replacements in synonyms.items():
if word in text:
text = text.replace(word, random.choice(replacements))
return text

augmented_data = []
with open(input_path, 'r', encoding='utf-8') as infile:
for line in infile:
try:
data = json.loads(line)
if 'instruction' in data and 'input' in data and 'output' in data:
cleaned_data = {
"instruction": clean_text(data.get("instruction", "")),
"input": clean_text(data.get("input", "")),
"output": clean_text(data.get("output", ""))
}
augmented_data.append(cleaned_data)
valid_entries += 1

for _ in range(augmentation_factor):
augmented_entry = {
"instruction": augment_text(cleaned_data['instruction']),
"input": augment_text(cleaned_data['input']),
"output": augment_text(cleaned_data['output'])
}
augmented_data.append(augmented_entry)
except json.JSONDecodeError:
print("Skipping invalid JSON line.")

print(f"Dataset preprocessing complete. Valid entries: {valid_entries}")

# Split into train and validation
train_data, val_data = train_test_split(augmented_data, test_size=0.2, random_state=42)

# Save datasets
with open(output_path, 'w', encoding='utf-8') as outfile:
for entry in augmented_data:
outfile.write(json.dumps(entry, ensure_ascii=False) + '\n')
with open(train_path, 'w', encoding='utf-8') as trainfile:
for entry in train_data:
trainfile.write(json.dumps(entry, ensure_ascii=False) + '\n')
with open(val_path, 'w', encoding='utf-8') as valfile:
for entry in val_data:
valfile.write(json.dumps(entry, ensure_ascii=False) + '\n')

print(f"Enhanced dataset saved to {output_path}. Train and validation sets saved to {train_path} and {val_path}.")

Explanation

  • clean_text: Cleans and normalizes text by removing unwanted characters, extra spaces, and converting text to lowercase.
  • augment_text: Enhances the dataset by replacing specific words with their synonyms to create variations.
  • preprocess_dataset: Reads the input dataset, cleans and augments the data, splits it into training and validation sets, and saves the processed data.

Tips

  • Data Cleaning: Adjust regular expressions to accommodate different languages or specific dataset requirements.
  • Data Augmentation: Incorporate more sophisticated augmentation techniques, such as paraphrasing or back-translation, to increase data diversity.
  • Error Handling: Enhance error logging to capture more details about problematic data entries.

6. Defining Dataset Paths and Preprocessing

We specify the paths to the datasets and execute the preprocessing function.

# Paths to dataset
dataset_input_path = "/content/drive/MyDrive/output.jsonl"
dataset_cleaned_path = "/content/drive/MyDrive/cleaned_dataset.jsonl"
train_dataset_path = "/content/drive/MyDrive/train_dataset.jsonl"
val_dataset_path = "/content/drive/MyDrive/val_dataset.jsonl"

# Preprocess the dataset
preprocess_dataset(dataset_input_path, dataset_cleaned_path, train_dataset_path, val_dataset_path, augmentation_factor=3)

Explanation

  • dataset_input_path: Path to the original dataset.
  • dataset_cleaned_path: Path to save the cleaned and augmented dataset.
  • train_dataset_path & val_dataset_path: Paths to save the training and validation subsets.
  • preprocess_dataset: Calls the preprocessing function with specified paths and augmentation factor.

Tips

  • Data Management: Utilize cloud storage solutions like Google Drive for handling large datasets efficiently.
  • File Naming: Use descriptive file names to simplify dataset tracking and management.

7. Listing Models for Fine-Tuning

We define a list of models that will undergo fine-tuning.

# List of models to try for fine-tuning
fourbit_models = [
"unsloth/Meta-Llama-3.1-8B-bnb-4bit",
"unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
"unsloth/Meta-Llama-3.1-70B-bnb-4bit",
"unsloth/Meta-Llama-3.1-405B-bnb-4bit",
"unsloth/Mistral-Nemo-Base-2407-bnb-4bit",
"unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
"unsloth/mistral-7b-v0.3-bnb-4bit",
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
"unsloth/Phi-3.5-mini-instruct",
"unsloth/Phi-3-medium-4k-instruct",
"unsloth/gemma-2-9b-bnb-4bit",
"unsloth/gemma-2-27b-bnb-4bit",
]

Explanation

This list contains various 4-bit models from Unsloth that will be fine-tuned. The models vary in size and configuration, optimizing them for different tasks.

Tips

  • Model Selection: Choose models based on the specific requirements of your task, balancing between performance and computational resources.
  • Diversity: Experiment with different architectures to identify which performs best for your use case.

8. Loading and Testing Models Sequentially

We iterate through each model in the list, load it, and check for existing checkpoints to resume training if available.

# Load and test models sequentially
for model_name in fourbit_models:
print(f"Processing model: {model_name}")
model_dir = os.path.join(checkpoint_dir)
print(f"Model directory: {model_dir}")

checkpoint_path = os.path.join(model_dir, "checkpoint-500", "trainer_state.json")
print(f"Checkpoint path: {checkpoint_path}")

if os.path.exists(checkpoint_path):
print(f"Resuming from checkpoint: {checkpoint_path}")
print(f"Files in checkpoint directory: {os.listdir(os.path.dirname(checkpoint_path))}")
model, tokenizer = FastLanguageModel.from_pretrained(
os.path.dirname(checkpoint_path),
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
else:
print(f"Starting training for model: {model_name}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)

print(f"Loaded model: {model_name}")

Explanation

  • Model Processing Loop: Iterates through each model in the fourbit_models list.
  • Checkpoint Check: Verifies if a checkpoint exists for the current model to resume training; otherwise, starts training from scratch.
  • Model and Tokenizer Loading: Uses FastLanguageModel.from_pretrained to load the model and tokenizer, either from a checkpoint or directly from the model name.

Tips

  • Parallel Processing: Consider processing multiple models in parallel to save time, provided sufficient computational resources are available.
  • Checkpoint Strategy: Implement a robust checkpointing strategy to minimize data loss and facilitate seamless training resumption.

9. Configuring LoRA for Fine-Tuning

We apply Low-Rank Adaptation (LoRA) to the model for efficient fine-tuning.

# LoRA Configuration for fine-tuning
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
print(f"LoRA configuration done for model: {model_name}")

Explanation

  • PEFT Model: Applies LoRA to selected modules of the model to enable efficient fine-tuning.
  • Parameters:
  • r: LoRA rank value, lower values save memory.
  • target_modules: Specifies which layers to adapt.
  • lora_alpha: Scaling factor for the learning rate.
  • lora_dropout: Dropout rate for regularization.
  • bias: Bias configuration settings.
  • use_gradient_checkpointing: Enables gradient checkpointing for memory efficiency.
  • random_state: Ensures reproducibility by setting a random seed.

Tips

  • LoRA Hyperparameters: Experiment with different r and lora_alpha values to find the optimal balance between performance and resource usage.
  • Layer Selection: Carefully choose which layers to adapt, as this can significantly impact both performance and training time.

10. Defining the Formatting Prompt

We create a template to format the dataset examples in a way that the model can understand.

    # Define the formatting prompt
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token # Ensures proper sequence termination
def formatting_prompts_func(examples):
print("Formatting dataset prompts...")
instructions = examples.get("instruction", "")
inputs = examples.get("input", "")
outputs = examples.get("output", "")
texts = []
for instruction, input, output in zip(instructions, inputs, outputs):
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
texts.append(text)
return {"text": texts}

Explanation

  • alpaca_prompt: A template that structures each dataset example with an instruction, input, and expected response.
  • formatting_prompts_func: Applies the alpaca_prompt to each example in the dataset, appending an End-Of-Sequence (EOS) token to ensure proper termination.

Tips

  • Prompt Engineering: Experiment with different prompt structures to determine which format yields the best model performance.
  • Language Consistency: Ensure the prompt language matches the dataset language to maintain consistency and improve model understanding.

11. Loading and Formatting Training and Validation Datasets

We load the training and validation datasets and apply the formatting function.

# Load the train and validation datasets
print(f"Loading train dataset from: {train_dataset_path}")
train_dataset = load_dataset("json", data_files=train_dataset_path, split="train")
print(f"Loading validation dataset from: {val_dataset_path}")
val_dataset = load_dataset("json", data_files=val_dataset_path, split="train")

train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
val_dataset = val_dataset.map(formatting_prompts_func, batched=True)
print("Datasets loaded and formatted.")

Explanation

  • load_dataset: Loads the training and validation datasets from JSON files.
  • map: Applies the formatting_prompts_func to each dataset in batches, formatting the data as per the defined prompt.

Tips

  • Batch Processing: Using batched=True speeds up the processing of large datasets.
  • Data Verification: After formatting, perform a quick check to ensure the prompts are correctly structured.

12. Configuring SFTTrainer for Training

We set up the training configuration using SFTTrainer to manage the training process.

    from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

resume_checkpoint_dir = os.path.join(model_dir, "checkpoint-500")
print(f"Before Resuming from checkpoint: {resume_checkpoint_dir}")
if os.path.exists(os.path.join(resume_checkpoint_dir, "trainer_state.json")):
print(f"Resuming from checkpoint: {resume_checkpoint_dir}")
print(f"Checkpoint Files: {os.listdir(resume_checkpoint_dir)}")
resume_from_checkpoint = resume_checkpoint_dir
else:
print("No valid checkpoint found, starting from scratch.")
resume_from_checkpoint = None

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=val_dataset,
dataset_text_field="text",
max_seq_length=max_seq_length,
dataset_num_proc=2,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
warmup_steps=50,
max_steps=1000,
save_steps=500,
save_total_limit=2,
learning_rate=3e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=10,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=42,
output_dir=model_dir,
report_to="none",
resume_from_checkpoint=resume_from_checkpoint, # Pass checkpoint path
),
)

Explanation

  • SFTTrainer: A trainer class from the trl library designed for Supervisory Fine-Tuning.
  • TrainingArguments: Configures various training parameters such as batch size, learning rate, optimizer, and more.
  • per_device_train_batch_size: Batch size per device (GPU/CPU).
  • gradient_accumulation_steps: Number of steps to accumulate gradients before updating.
  • warmup_steps: Number of warmup steps for learning rate scheduler.
  • max_steps: Total number of training steps.
  • save_steps: Frequency of saving checkpoints.
  • learning_rate: Learning rate for the optimizer.
  • fp16 and bf16: Mixed precision training options for faster computation and reduced memory usage.
  • optim: Optimizer type (adamw_8bit for memory efficiency).
  • weight_decay: Weight decay for regularization.
  • lr_scheduler_type: Type of learning rate scheduler.
  • seed: Random seed for reproducibility.
  • output_dir: Directory to save training outputs.
  • resume_from_checkpoint: Path to resume training from a checkpoint if available.

Tips

  • Hyperparameter Tuning: Adjusting hyperparameters like learning rate, batch size, and gradient accumulation steps can significantly impact model performance.
  • Mixed Precision Training: Utilizing FP16 or BF16 can accelerate training and reduce memory usage without sacrificing model performance.
  • Checkpoint Management: Regularly saving checkpoints allows for resuming training in case of interruptions and facilitates experimentation.

13. Starting and Completing Model Training

We initiate the training process and monitor its completion.

print("Starting training...")
trainer_stats = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
print("Training completed.")

Explanation

  • trainer.train: Begins the training process. If a checkpoint is available, training resumes from that point.
  • trainer_stats: Contains statistics and logs related to the training process.

Tips

  • Monitoring: Use tools like TensorBoard or Weights & Biases to monitor training metrics in real-time.
  • Early Stopping: Implement early stopping to prevent overfitting by halting training when performance on the validation set stops improving.

14. Saving the Fine-Tuned Model

After training, we save the fine-tuned model and tokenizer for future use.

# Save the fine-tuned model
print(f"Saving model to: {model_dir}")
model.save_pretrained(model_dir)
tokenizer.save_pretrained(model_dir)
print(f"Model {model_name} fine-tuned and saved to {model_dir}")
mark_model_as_processed(model_name)
print("----------------------------------------")

Explanation

  • save_pretrained: Saves the fine-tuned model and tokenizer to the specified directory.
  • mark_model_as_processed: Creates a .done file to indicate that the model has been successfully processed and saved.

Tips

  • Model Versioning: Implement a versioning system to keep track of different fine-tuned model versions.
  • Storage Optimization: Use compression techniques to save storage space, especially for large models.

15. Tips for Enhancing the Project

To further improve your project, consider the following tips:

a. Expand Data Augmentation Techniques

Enhance the dataset by incorporating more diverse augmentation methods, such as altering sentence structures or using advanced synonym replacement.

b. Hyperparameter Optimization

Use techniques like grid search or random search to find the optimal combination of hyperparameters, which can significantly boost model performance.

c. Model Parallelization

For training large models, implement model parallelization techniques to reduce training time. This can include distributed training or model sharding.

d. Performance Monitoring and Analysis

Integrate tools like TensorBoard or Weights & Biases to monitor and analyze model performance during training. This helps in understanding training dynamics and identifying issues early.

e. Experiment with More Models

Try different model architectures and sizes to determine which best suits your specific tasks. Stay updated with the latest models to leverage new advancements.

f. Automate Fine-Tuning Processes

Create scripts or pipelines to automate the training and fine-tuning processes. Automation reduces repetitive tasks and minimizes the risk of errors.

Conclusion

In this article, we thoroughly explored how to train and fine-tune a language model using the Unsloth library. We covered data preprocessing, model loading, LoRA configuration, and the training process. Additionally, we provided various tips to enhance and optimize your deep learning projects. We hope this guide serves as a valuable resource for your own projects.

Closing Remarks

Deep learning projects can be complex, but by taking a step-by-step approach and leveraging the right tools, you can achieve successful outcomes. Libraries like Unsloth simplify the processes of model training and fine-tuning, making them more accessible. We wish you success in your projects!

--

--

No responses yet