import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
callbacks = [
EarlyStopping(monitor="validation_loss", patience=5, mode="min"),
ModelCheckpoint(
monitor="validation_loss",
mode="min",
save_top_k=1,
filename="best-model",
),
]Model Training (Optional)
This section is optional. The rest of the project (Inference and Statistics) uses a pre-trained checkpoint you can download directly. Complete this section if you want to train the model yourself and experiment with hyperparameters.
1 The Full Training Pipeline
The training script src/main.py wires together the data pipeline, the SegmentationModule, and a PyTorch Lightning Trainer. This section walks through the key pieces.
1.1 Callbacks: early stopping and checkpointing
Two callbacks guard the training run:
EarlyStopping: halts training if the validation loss does not improve for 5 consecutive epochs, preventing overfitting.ModelCheckpoint: saves only the checkpoint with the best validation loss, keeping storage usage low.
1.2 Trainer and training loop
# Random placeholder loaders so the snippet runs standalone — replace with
# real Sentinel-2 tile loaders for actual training (see Exercise 4 / src/train.py).
import torch
from torch.utils.data import DataLoader
B, C, H, W = 4, 14, 64, 64
dummy_dataset = [
{"pixel_values": torch.randn(C, H, W), "labels": torch.randint(0, 10, (H, W))}
for _ in range(16)
]
def collate(batch):
return {
"pixel_values": torch.stack([x["pixel_values"] for x in batch]),
"labels": torch.stack([x["labels"] for x in batch]),
}
train_loader = DataLoader(dummy_dataset[:12], batch_size=B, collate_fn=collate)
val_loader = DataLoader(dummy_dataset[12:], batch_size=B, collate_fn=collate)
trainer = pl.Trainer(
max_epochs=20,
accelerator="auto", # uses GPU if available, CPU otherwise
devices="auto",
callbacks=callbacks,
log_every_n_steps=10,
)
trainer.fit(module, train_loader, val_loader)Calling trainer.fit runs the full loop: for each epoch it iterates the training dataloader, calls training_step, then runs the validation dataloader and calls validation_step. Metrics logged inside those steps (via self.log(...)) are accumulated and flushed according to log_every_n_steps.
To avoid downloading the full pre-trained weights (~330 MB), we can build a small SegFormer from a minimal config and run a few training steps on random data.
- Create a
SegformerConfigwithnum_channels=14,num_labels=3, anddepths=[1, 1, 1, 1](one transformer block per stage instead of the usual[3, 6, 40, 3]for MiT-B5). - Instantiate a
SemanticSegmentationSegformerfrom this config (no pre-trained weights — just call the constructor directly). - Wrap it in a
SegmentationModulewithCrossEntropyLossandAdamW. - Run 3 training steps manually (without a
Trainer) by callingmodule.training_step(batch, batch_idx)on a random batch.
import torch
from transformers import SegformerConfig
from src.models.model import SemanticSegmentationSegformer
from src.models.module import SegmentationModule
from torch import nn, optim
# 1. Minimal config
config = SegformerConfig(
num_channels=___, # TODO
num_labels=___, # TODO
depths=___, # TODO: [1, 1, 1, 1]
id2label={0: "background", 1: "building", 2: "vegetation"},
label2id={"background": 0, "building": 1, "vegetation": 2},
)
# 2. Model without pre-trained weights
tiny_model = SemanticSegmentationSegformer(config)
# 3. Lightning module
tiny_module = SegmentationModule(
model=tiny_model,
loss=___, # TODO: CrossEntropyLoss(ignore_index=255)
optimizer=optim.AdamW,
optimizer_params={"lr": 1e-3},
scheduler=optim.lr_scheduler.ReduceLROnPlateau,
scheduler_params={"mode": "min", "patience": 3, "monitor": "validation_loss"},
scheduler_interval="epoch",
)
# 4. Three manual training steps on random data
B, C, H, W = 2, 14, 128, 128
dummy_batch = {
"pixel_values": torch.randn(B, C, H, W),
"labels": torch.randint(0, 3, (B, H, W)),
}
tiny_module.train()
optimizer = optim.AdamW(tiny_module.parameters(), lr=1e-3)
for step in range(3):
optimizer.zero_grad()
loss = ___ # TODO: call tiny_module.training_step(dummy_batch, step)
loss.backward()
optimizer.step()
print(f"Step {step + 1} loss: {loss.item():.4f}")import torch
from transformers import SegformerConfig
from src.models.model import SemanticSegmentationSegformer
from src.models.module import SegmentationModule
from torch import nn, optim
config = SegformerConfig(
num_channels=14,
num_labels=3,
depths=[1, 1, 1, 1],
id2label={0: "background", 1: "building", 2: "vegetation"},
label2id={"background": 0, "building": 1, "vegetation": 2},
)
tiny_model = SemanticSegmentationSegformer(config)
tiny_module = SegmentationModule(
model=tiny_model,
loss=nn.CrossEntropyLoss(ignore_index=255),
optimizer=optim.AdamW,
# weight_decay=1e-2 is the AdamW L2-regularisation strength used throughout
# the project. On 14-channel inputs this helps the encoder generalise instead of
# memorising channel-specific quirks of the training tiles.
optimizer_params={"lr": 1e-3, "weight_decay": 1e-2},
scheduler=optim.lr_scheduler.ReduceLROnPlateau,
scheduler_params={"mode": "min", "patience": 3, "monitor": "validation_loss"},
scheduler_interval="epoch",
)
B, C, H, W = 2, 14, 128, 128
dummy_batch = {
"pixel_values": torch.randn(B, C, H, W),
"labels": torch.randint(0, 3, (B, H, W)),
}
# In a normal run, pl.Trainer.fit drives the four-step loop below for every batch:
# zero_grad → forward (training_step computes loss) → backward → optimizer.step.
# We unroll it by hand here just to see the mechanics. The loss won't drop much in
# three steps on random data — the point is to confirm the wiring, not to converge.
tiny_module.train()
optimizer = optim.AdamW(tiny_module.parameters(), lr=1e-3, weight_decay=1e-2)
for step in range(3):
optimizer.zero_grad()
loss = tiny_module.training_step(dummy_batch, step)
loss.backward()
optimizer.step()
print(f"Step {step + 1} loss: {loss.item():.4f}")2 Tracking Experiments with MLflow
Training a deep learning model involves experimenting with many hyperparameters: learning rate, weight decay, number of epochs, whether to freeze the encoder, etc. MLflow is an open-source platform that tracks, compares, and stores training runs so that you can reproduce any past experiment.
2.1 Setting up the MLflow logger
PyTorch Lightning integrates with MLflow through MLFlowLogger. On SSP Cloud, an MLflow tracking server is available as a separate service you can launch from the catalogue.
Launch an MLflow service from the SSP Cloud service catalogue. Once it starts, its public URL appears in the service details. It typically looks like https://user-<username>-mlflow.user.lab.sspcloud.fr.
If you opened your MLflow service after your VSCode service, create a .env file at the root of your project and fill in your MLflow credentials:
MLFLOW_TRACKING_URI=""
MLFLOW_TRACKING_USERNAME=""
MLFLOW_TRACKING_PASSWORD=""
Those info can be found in the panel that opens when you launch your MLflow service on the SSPCloud.
import os
from dotenv import load_dotenv
from pytorch_lightning.loggers import MLFlowLogger
load_dotenv()
mlf_logger = MLFlowLogger(
experiment_name="segformer-satellite",
tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
log_model=True, # upload the best checkpoint as an MLflow artifact
)Launch an MLflow service from the SSP Cloud service catalogue. Once it starts, its public URL appears in the service details. It typically looks like https://user-<username>-mlflow.user.lab.sspcloud.fr.
MLFLOW_TRACKING_USERNAME and MLFLOW_TRACKING_PASSWORD are picked up automatically by the MLflow client from the environment — load_dotenv() makes them available without passing them explicitly.
2.2 Passing the logger to the Trainer
# Same random placeholder loaders as in "Trainer and training loop" above —
# replace with real loaders for actual training (see Exercise 4 / src/train.py).
import torch
from torch.utils.data import DataLoader
B, C, H, W = 4, 14, 64, 64
dummy_dataset = [
{"pixel_values": torch.randn(C, H, W), "labels": torch.randint(0, 10, (H, W))}
for _ in range(16)
]
def collate(batch):
return {
"pixel_values": torch.stack([x["pixel_values"] for x in batch]),
"labels": torch.stack([x["labels"] for x in batch]),
}
train_loader = DataLoader(dummy_dataset[:12], batch_size=B, collate_fn=collate)
val_loader = DataLoader(dummy_dataset[12:], batch_size=B, collate_fn=collate)
trainer = pl.Trainer(
max_epochs=20,
accelerator="auto",
devices="auto",
callbacks=callbacks,
logger=mlf_logger, # <-- plug in here
log_every_n_steps=10,
)
trainer.fit(module, train_loader, val_loader)That single change is enough: every self.log(...) call inside SegmentationModule now goes to MLflow automatically.
2.3 What gets logged
The SegmentationModule logs the following metrics at every step and at the end of each epoch:
| Metric | Phase | Meaning |
|---|---|---|
train_loss |
train | Cross-entropy loss on the training batch |
train_iou_all |
train | Mean IoU across all classes |
train_iou_building |
train | IoU for the building class |
train_building_rate |
train | Fraction of pixels predicted as building |
validation_loss |
val | Cross-entropy loss on the validation batch |
validation_IOU_all |
val | Mean IoU across all classes |
validation_IOU_building |
val | IoU for the building class |
validation_building_rate |
val | Fraction of pixels predicted as building |
2.4 Logging hyperparameters manually
You can log any scalar or string value as an MLflow parameter before training starts:
mlf_logger.log_hyperparams({
"lr": CONFIG["lr"],
"batch_size": CONFIG["batch_size"],
"epochs": CONFIG["epochs"],
"freeze_encoder": False,
"n_bands": CONFIG["n_bands"],
})building_rate?
During early training, models often collapse to predicting the majority class everywhere. Monitoring the fraction of pixels predicted as “building” helps detect this: if building_rate ≈ 0 after a few epochs, the model is ignoring the minority class — likely a sign of too high a learning rate or severe class imbalance.
CLC+ classes are heavily imbalanced — Permanent herbaceous and Woody together cover the vast majority of European pixels, while Sealed, Lichens and mosses and Water are rare. The plain CrossEntropyLoss used here weighs every pixel equally, so the optimiser is implicitly rewarded for ignoring rare classes. If you find the minority classes are never predicted, common remedies are:
- Weighted cross-entropy: pass
weight=class_weightstonn.CrossEntropyLoss, with one weight per class inversely proportional to its frequency. - Focal loss: down-weights easy / majority pixels so gradients focus on hard ones.
- Oversampling: build a sampler that draws tiles containing rare classes more often.
2.5 Registering and loading the best model
Once training is complete, MLflow lets you version and compare runs, then promote the best model to a Model Registry. Follow these steps:
- In the top left, select Model Training.
- Go to Experiments and open
"funathon-2026-project3". - In the Runs section, compare runs using the logged metrics and select the best-performing one.
- Click Register model, name it
"segmentation-sentinel2-model", and confirm. - Verify it appears in the Model Registry.
Once registered, you can load the model directly from the registry:
import mlflow.pyfunc
model = mlflow.pyfunc.load_model(
model_uri="models:/segmentation-sentinel2-model/latest"
)Starting from the toy model you built in Exercise 3, add MLflow tracking to a short training run.
- Instantiate
MLFlowLoggerwithexperiment_name="segformer-toy"andtracking_uri="mlruns"(a local directory — no server needed). - Create a
pl.Trainerwithmax_epochs=3, the MLflow logger, and no GPU (accelerator="cpu"). - Build minimal
DataLoaderobjects from the randomdummy_batchand calltrainer.fit. - Open the local MLflow UI (
mlflow ui --backend-store-uri mlruns) and inspect the logged metrics.
import torch
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
# 1. Logger writing to a local directory
mlf_logger = MLFlowLogger(
experiment_name=___, # TODO
tracking_uri=___, # TODO: "mlruns"
)
# 2. Tiny data loaders (repeat the same batch for simplicity)
B, C, H, W = 4, 14, 64, 64
images = torch.randn(16, C, H, W)
labels = torch.randint(0, 3, (16, H, W))
dataset = [{"pixel_values": images[i], "labels": labels[i]} for i in range(16)]
def collate(batch):
return {
"pixel_values": torch.stack([x["pixel_values"] for x in batch]),
"labels": torch.stack([x["labels"] for x in batch]),
}
train_dl = DataLoader(dataset[:12], batch_size=4, collate_fn=collate)
val_dl = DataLoader(dataset[12:], batch_size=4, collate_fn=collate)
# 3. Trainer
trainer = pl.Trainer(
max_epochs=___, # TODO: 3
accelerator="cpu",
logger=___, # TODO: mlf_logger
log_every_n_steps=1,
enable_progress_bar=False,
)
trainer.fit(tiny_module, train_dl, val_dl)
print("Run logged to:", mlf_logger.run_id)import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from transformers import SegformerConfig
from src.models.model import SemanticSegmentationSegformer
from src.models.module import SegmentationModule
from torch import nn, optim
# Rebuild the tiny model (or reuse from Exercise 3)
config = SegformerConfig(
num_channels=14, num_labels=3, depths=[1, 1, 1, 1],
id2label={0: "background", 1: "building", 2: "vegetation"},
label2id={"background": 0, "building": 1, "vegetation": 2},
)
tiny_model = SemanticSegmentationSegformer(config)
tiny_module = SegmentationModule(
model=tiny_model,
loss=nn.CrossEntropyLoss(ignore_index=255),
optimizer=optim.AdamW,
optimizer_params={"lr": 1e-3, "weight_decay": 1e-2},
# ReduceLROnPlateau divides the learning rate when a metric stops improving.
# It needs to know *which* metric to watch (monitor="validation_loss"), and
# because validation runs once per epoch, the check has to happen at the same
# cadence (scheduler_interval="epoch") — checking every step would compare
# against stale values and trigger spurious LR drops.
scheduler=optim.lr_scheduler.ReduceLROnPlateau,
scheduler_params={"mode": "min", "patience": 3, "monitor": "validation_loss"},
scheduler_interval="epoch",
)
mlf_logger = MLFlowLogger(
experiment_name="segformer-toy",
tracking_uri="mlruns",
)
B, C, H, W = 4, 14, 64, 64
images = torch.randn(16, C, H, W)
labels = torch.randint(0, 3, (16, H, W))
dataset = [{"pixel_values": images[i], "labels": labels[i]} for i in range(16)]
def collate(batch):
return {
"pixel_values": torch.stack([x["pixel_values"] for x in batch]),
"labels": torch.stack([x["labels"] for x in batch]),
}
train_dl = DataLoader(dataset[:12], batch_size=4, collate_fn=collate)
val_dl = DataLoader(dataset[12:], batch_size=4, collate_fn=collate)
trainer = pl.Trainer(
max_epochs=3,
accelerator="cpu",
logger=mlf_logger,
log_every_n_steps=1,
enable_progress_bar=False,
)
trainer.fit(tiny_module, train_dl, val_dl)
print("Run logged to:", mlf_logger.run_id)To inspect the results:
mlflow ui --backend-store-uri mlrunsOpen http://localhost:5000 in your browser to see the experiment.
3 The Complete Training Script
The full production script is intermediate_solutions/solution_step2/main.py. It follows the same structure as the exercises above but adds:
- Reproducibility:
set_seed(42)ensures identical results across runs. - Data normalization: band-wise mean and standard deviation are computed over the training set before building the transforms.
- Multiple test regions:
trainer.testevaluates on four held-out NUTS-3 regions to measure generalisation.
To run it (requires a GPU ):
uv run python -m src/train.pyTo add MLflow tracking, pass a logger to the Trainer as shown in the previous section:
import os
from dotenv import load_dotenv
from pytorch_lightning.loggers import MLFlowLogger
load_dotenv()
mlf_logger = MLFlowLogger(
experiment_name="segformer-satellite",
tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
log_model=True,
)
trainer = pl.Trainer(
max_epochs=CONFIG["epochs"],
accelerator="auto",
devices="auto",
callbacks=callbacks,
logger=mlf_logger,
log_every_n_steps=10,
)4 Summary
In this section you have:
- Set up a full PyTorch Lightning training loop with early stopping and model checkpointing.
- Trained a toy SegFormer from scratch on random data to verify the pipeline end-to-end.
- Integrated MLflow via
MLFlowLoggerto track metrics, hyperparameters, and model artifacts across runs. - Learned what to watch during training — especially
building_rateas an early indicator of class collapse.
The next section covers inference: using a pre-trained model checkpoint to produce land-cover maps over new regions.