Model Training (Optional)

Note

This section is optional. The rest of the project (Inference and Statistics) uses a pre-trained checkpoint you can download directly. Complete this section if you want to train the model yourself and experiment with hyperparameters.

1 The Full Training Pipeline

The training script src/main.py wires together the data pipeline, the SegmentationModule, and a PyTorch Lightning Trainer. This section walks through the key pieces.

1.1 Callbacks: early stopping and checkpointing

Two callbacks guard the training run:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

callbacks = [
    EarlyStopping(monitor="validation_loss", patience=5, mode="min"),
    ModelCheckpoint(
        monitor="validation_loss",
        mode="min",
        save_top_k=1,
        filename="best-model",
    ),
]
  • EarlyStopping: halts training if the validation loss does not improve for 5 consecutive epochs, preventing overfitting.
  • ModelCheckpoint: saves only the checkpoint with the best validation loss, keeping storage usage low.

1.2 Trainer and training loop

# Random placeholder loaders so the snippet runs standalone — replace with
# real Sentinel-2 tile loaders for actual training (see Exercise 4 / src/train.py).
import torch
from torch.utils.data import DataLoader

B, C, H, W = 4, 14, 64, 64
dummy_dataset = [
    {"pixel_values": torch.randn(C, H, W), "labels": torch.randint(0, 10, (H, W))}
    for _ in range(16)
]

def collate(batch):
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

train_loader = DataLoader(dummy_dataset[:12], batch_size=B, collate_fn=collate)
val_loader   = DataLoader(dummy_dataset[12:], batch_size=B, collate_fn=collate)

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",   # uses GPU if available, CPU otherwise
    devices="auto",
    callbacks=callbacks,
    log_every_n_steps=10,
)

trainer.fit(module, train_loader, val_loader)

Calling trainer.fit runs the full loop: for each epoch it iterates the training dataloader, calls training_step, then runs the validation dataloader and calls validation_step. Metrics logged inside those steps (via self.log(...)) are accumulated and flushed according to log_every_n_steps.

CautionExercise 3 — Training a toy model

To avoid downloading the full pre-trained weights (~330 MB), we can build a small SegFormer from a minimal config and run a few training steps on random data.

  1. Create a SegformerConfig with num_channels=14, num_labels=3, and depths=[1, 1, 1, 1] (one transformer block per stage instead of the usual [3, 6, 40, 3] for MiT-B5).
  2. Instantiate a SemanticSegmentationSegformer from this config (no pre-trained weights — just call the constructor directly).
  3. Wrap it in a SegmentationModule with CrossEntropyLoss and AdamW.
  4. Run 3 training steps manually (without a Trainer) by calling module.training_step(batch, batch_idx) on a random batch.
import torch
from transformers import SegformerConfig
from src.models.model import SemanticSegmentationSegformer
from src.models.module import SegmentationModule
from torch import nn, optim

# 1. Minimal config
config = SegformerConfig(
    num_channels=___,   # TODO
    num_labels=___,     # TODO
    depths=___,         # TODO: [1, 1, 1, 1]
    id2label={0: "background", 1: "building", 2: "vegetation"},
    label2id={"background": 0, "building": 1, "vegetation": 2},
)

# 2. Model without pre-trained weights
tiny_model = SemanticSegmentationSegformer(config)

# 3. Lightning module
tiny_module = SegmentationModule(
    model=tiny_model,
    loss=___,               # TODO: CrossEntropyLoss(ignore_index=255)
    optimizer=optim.AdamW,
    optimizer_params={"lr": 1e-3},
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
    scheduler_params={"mode": "min", "patience": 3, "monitor": "validation_loss"},
    scheduler_interval="epoch",
)

# 4. Three manual training steps on random data
B, C, H, W = 2, 14, 128, 128
dummy_batch = {
    "pixel_values": torch.randn(B, C, H, W),
    "labels": torch.randint(0, 3, (B, H, W)),
}

tiny_module.train()
optimizer = optim.AdamW(tiny_module.parameters(), lr=1e-3)
for step in range(3):
    optimizer.zero_grad()
    loss = ___  # TODO: call tiny_module.training_step(dummy_batch, step)
    loss.backward()
    optimizer.step()
    print(f"Step {step + 1} loss: {loss.item():.4f}")
import torch
from transformers import SegformerConfig
from src.models.model import SemanticSegmentationSegformer
from src.models.module import SegmentationModule
from torch import nn, optim

config = SegformerConfig(
    num_channels=14,
    num_labels=3,
    depths=[1, 1, 1, 1],
    id2label={0: "background", 1: "building", 2: "vegetation"},
    label2id={"background": 0, "building": 1, "vegetation": 2},
)

tiny_model = SemanticSegmentationSegformer(config)

tiny_module = SegmentationModule(
    model=tiny_model,
    loss=nn.CrossEntropyLoss(ignore_index=255),
    optimizer=optim.AdamW,
    # weight_decay=1e-2 is the AdamW L2-regularisation strength used throughout
    # the project. On 14-channel inputs this helps the encoder generalise instead of
    # memorising channel-specific quirks of the training tiles.
    optimizer_params={"lr": 1e-3, "weight_decay": 1e-2},
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
    scheduler_params={"mode": "min", "patience": 3, "monitor": "validation_loss"},
    scheduler_interval="epoch",
)

B, C, H, W = 2, 14, 128, 128
dummy_batch = {
    "pixel_values": torch.randn(B, C, H, W),
    "labels": torch.randint(0, 3, (B, H, W)),
}

# In a normal run, pl.Trainer.fit drives the four-step loop below for every batch:
# zero_grad → forward (training_step computes loss) → backward → optimizer.step.
# We unroll it by hand here just to see the mechanics. The loss won't drop much in
# three steps on random data — the point is to confirm the wiring, not to converge.
tiny_module.train()
optimizer = optim.AdamW(tiny_module.parameters(), lr=1e-3, weight_decay=1e-2)
for step in range(3):
    optimizer.zero_grad()
    loss = tiny_module.training_step(dummy_batch, step)
    loss.backward()
    optimizer.step()
    print(f"Step {step + 1} loss: {loss.item():.4f}")

2 Tracking Experiments with MLflow

Training a deep learning model involves experimenting with many hyperparameters: learning rate, weight decay, number of epochs, whether to freeze the encoder, etc. MLflow is an open-source platform that tracks, compares, and stores training runs so that you can reproduce any past experiment.

2.1 Setting up the MLflow logger

PyTorch Lightning integrates with MLflow through MLFlowLogger. On SSP Cloud, an MLflow tracking server is available as a separate service you can launch from the catalogue.

NoteFinding your MLflow URI on SSP Cloud

Launch an MLflow service from the SSP Cloud service catalogue. Once it starts, its public URL appears in the service details. It typically looks like https://user-<username>-mlflow.user.lab.sspcloud.fr.

If you opened your MLflow service after your VSCode service, create a .env file at the root of your project and fill in your MLflow credentials:

MLFLOW_TRACKING_URI=""
MLFLOW_TRACKING_USERNAME=""
MLFLOW_TRACKING_PASSWORD=""

Those info can be found in the panel that opens when you launch your MLflow service on the SSPCloud.

import os
from dotenv import load_dotenv
from pytorch_lightning.loggers import MLFlowLogger

load_dotenv()

mlf_logger = MLFlowLogger(
    experiment_name="segformer-satellite",
    tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
    log_model=True,   # upload the best checkpoint as an MLflow artifact
)
NoteFinding your MLflow URI on SSP Cloud

Launch an MLflow service from the SSP Cloud service catalogue. Once it starts, its public URL appears in the service details. It typically looks like https://user-<username>-mlflow.user.lab.sspcloud.fr.

MLFLOW_TRACKING_USERNAME and MLFLOW_TRACKING_PASSWORD are picked up automatically by the MLflow client from the environment — load_dotenv() makes them available without passing them explicitly.

2.2 Passing the logger to the Trainer

# Same random placeholder loaders as in "Trainer and training loop" above —
# replace with real loaders for actual training (see Exercise 4 / src/train.py).
import torch
from torch.utils.data import DataLoader

B, C, H, W = 4, 14, 64, 64
dummy_dataset = [
    {"pixel_values": torch.randn(C, H, W), "labels": torch.randint(0, 10, (H, W))}
    for _ in range(16)
]

def collate(batch):
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

train_loader = DataLoader(dummy_dataset[:12], batch_size=B, collate_fn=collate)
val_loader   = DataLoader(dummy_dataset[12:], batch_size=B, collate_fn=collate)

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    devices="auto",
    callbacks=callbacks,
    logger=mlf_logger,          # <-- plug in here
    log_every_n_steps=10,
)

trainer.fit(module, train_loader, val_loader)

That single change is enough: every self.log(...) call inside SegmentationModule now goes to MLflow automatically.

2.3 What gets logged

The SegmentationModule logs the following metrics at every step and at the end of each epoch:

Metric Phase Meaning
train_loss train Cross-entropy loss on the training batch
train_iou_all train Mean IoU across all classes
train_iou_building train IoU for the building class
train_building_rate train Fraction of pixels predicted as building
validation_loss val Cross-entropy loss on the validation batch
validation_IOU_all val Mean IoU across all classes
validation_IOU_building val IoU for the building class
validation_building_rate val Fraction of pixels predicted as building

2.4 Logging hyperparameters manually

You can log any scalar or string value as an MLflow parameter before training starts:

mlf_logger.log_hyperparams({
    "lr": CONFIG["lr"],
    "batch_size": CONFIG["batch_size"],
    "epochs": CONFIG["epochs"],
    "freeze_encoder": False,
    "n_bands": CONFIG["n_bands"],
})
NoteWhy watch building_rate?

During early training, models often collapse to predicting the majority class everywhere. Monitoring the fraction of pixels predicted as “building” helps detect this: if building_rate ≈ 0 after a few epochs, the model is ignoring the minority class — likely a sign of too high a learning rate or severe class imbalance.

NoteClass imbalance in CLC+

CLC+ classes are heavily imbalanced — Permanent herbaceous and Woody together cover the vast majority of European pixels, while Sealed, Lichens and mosses and Water are rare. The plain CrossEntropyLoss used here weighs every pixel equally, so the optimiser is implicitly rewarded for ignoring rare classes. If you find the minority classes are never predicted, common remedies are:

  • Weighted cross-entropy: pass weight=class_weights to nn.CrossEntropyLoss, with one weight per class inversely proportional to its frequency.
  • Focal loss: down-weights easy / majority pixels so gradients focus on hard ones.
  • Oversampling: build a sampler that draws tiles containing rare classes more often.

2.5 Registering and loading the best model

Once training is complete, MLflow lets you version and compare runs, then promote the best model to a Model Registry. Follow these steps:

  1. In the top left, select Model Training.
  2. Go to Experiments and open "funathon-2026-project3".
  3. In the Runs section, compare runs using the logged metrics and select the best-performing one.
  4. Click Register model, name it "segmentation-sentinel2-model", and confirm.
  5. Verify it appears in the Model Registry.

Once registered, you can load the model directly from the registry:

import mlflow.pyfunc

model = mlflow.pyfunc.load_model(
    model_uri="models:/segmentation-sentinel2-model/latest"
)
CautionExercise 4 — Instrument a run with MLflow

Starting from the toy model you built in Exercise 3, add MLflow tracking to a short training run.

  1. Instantiate MLFlowLogger with experiment_name="segformer-toy" and tracking_uri="mlruns" (a local directory — no server needed).
  2. Create a pl.Trainer with max_epochs=3, the MLflow logger, and no GPU (accelerator="cpu").
  3. Build minimal DataLoader objects from the random dummy_batch and call trainer.fit.
  4. Open the local MLflow UI (mlflow ui --backend-store-uri mlruns) and inspect the logged metrics.
import torch
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger

# 1. Logger writing to a local directory
mlf_logger = MLFlowLogger(
    experiment_name=___,       # TODO
    tracking_uri=___,          # TODO: "mlruns"
)

# 2. Tiny data loaders (repeat the same batch for simplicity)
B, C, H, W = 4, 14, 64, 64
images = torch.randn(16, C, H, W)
labels = torch.randint(0, 3, (16, H, W))
dataset = [{"pixel_values": images[i], "labels": labels[i]} for i in range(16)]

def collate(batch):
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

train_dl = DataLoader(dataset[:12], batch_size=4, collate_fn=collate)
val_dl   = DataLoader(dataset[12:], batch_size=4, collate_fn=collate)

# 3. Trainer
trainer = pl.Trainer(
    max_epochs=___,            # TODO: 3
    accelerator="cpu",
    logger=___,                # TODO: mlf_logger
    log_every_n_steps=1,
    enable_progress_bar=False,
)

trainer.fit(tiny_module, train_dl, val_dl)
print("Run logged to:", mlf_logger.run_id)
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from transformers import SegformerConfig
from src.models.model import SemanticSegmentationSegformer
from src.models.module import SegmentationModule
from torch import nn, optim

# Rebuild the tiny model (or reuse from Exercise 3)
config = SegformerConfig(
    num_channels=14, num_labels=3, depths=[1, 1, 1, 1],
    id2label={0: "background", 1: "building", 2: "vegetation"},
    label2id={"background": 0, "building": 1, "vegetation": 2},
)
tiny_model = SemanticSegmentationSegformer(config)
tiny_module = SegmentationModule(
    model=tiny_model,
    loss=nn.CrossEntropyLoss(ignore_index=255),
    optimizer=optim.AdamW,
    optimizer_params={"lr": 1e-3, "weight_decay": 1e-2},
    # ReduceLROnPlateau divides the learning rate when a metric stops improving.
    # It needs to know *which* metric to watch (monitor="validation_loss"), and
    # because validation runs once per epoch, the check has to happen at the same
    # cadence (scheduler_interval="epoch") — checking every step would compare
    # against stale values and trigger spurious LR drops.
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
    scheduler_params={"mode": "min", "patience": 3, "monitor": "validation_loss"},
    scheduler_interval="epoch",
)

mlf_logger = MLFlowLogger(
    experiment_name="segformer-toy",
    tracking_uri="mlruns",
)

B, C, H, W = 4, 14, 64, 64
images = torch.randn(16, C, H, W)
labels = torch.randint(0, 3, (16, H, W))
dataset = [{"pixel_values": images[i], "labels": labels[i]} for i in range(16)]

def collate(batch):
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.stack([x["labels"] for x in batch]),
    }

train_dl = DataLoader(dataset[:12], batch_size=4, collate_fn=collate)
val_dl   = DataLoader(dataset[12:], batch_size=4, collate_fn=collate)

trainer = pl.Trainer(
    max_epochs=3,
    accelerator="cpu",
    logger=mlf_logger,
    log_every_n_steps=1,
    enable_progress_bar=False,
)

trainer.fit(tiny_module, train_dl, val_dl)
print("Run logged to:", mlf_logger.run_id)

To inspect the results:

mlflow ui --backend-store-uri mlruns

Open http://localhost:5000 in your browser to see the experiment.

3 The Complete Training Script

The full production script is intermediate_solutions/solution_step2/main.py. It follows the same structure as the exercises above but adds:

  • Reproducibility: set_seed(42) ensures identical results across runs.
  • Data normalization: band-wise mean and standard deviation are computed over the training set before building the transforms.
  • Multiple test regions: trainer.test evaluates on four held-out NUTS-3 regions to measure generalisation.

To run it (requires a GPU ):


uv run python -m src/train.py

To add MLflow tracking, pass a logger to the Trainer as shown in the previous section:

import os
from dotenv import load_dotenv
from pytorch_lightning.loggers import MLFlowLogger

load_dotenv()

mlf_logger = MLFlowLogger(
    experiment_name="segformer-satellite",
    tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
    log_model=True,
)

trainer = pl.Trainer(
    max_epochs=CONFIG["epochs"],
    accelerator="auto",
    devices="auto",
    callbacks=callbacks,
    logger=mlf_logger,
    log_every_n_steps=10,
)

4 Summary

In this section you have:

  • Set up a full PyTorch Lightning training loop with early stopping and model checkpointing.
  • Trained a toy SegFormer from scratch on random data to verify the pipeline end-to-end.
  • Integrated MLflow via MLFlowLogger to track metrics, hyperparameters, and model artifacts across runs.
  • Learned what to watch during training — especially building_rate as an early indicator of class collapse.

The next section covers inference: using a pre-trained model checkpoint to produce land-cover maps over new regions.