Model Training

1 The SegFormer Architecture

In the previous section we downloaded and prepared Sentinel-2 multispectral images together with their CLC+ Backbone land-cover labels. We now have patches of shape (14, H, W) — 14 spectral bands, height, width — each paired with a pixel-wise class mask. The goal of this section is to understand the model that will turn those patches into segmentation maps.

We use SegFormer, a Vision Transformer architecture designed specifically for semantic segmentation. SegFormer was introduced by Xie et al. (2021) in “SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers”. Its two key design choices make it both powerful and practical:

  1. A hierarchical Mix Transformer encoder (MiT) that produces multi-scale feature representations without positional encodings.
  2. A lightweight all-MLP decoder that aggregates information from all encoder stages with almost no parameters.

1.1 Encoder: the Mix Transformer (MiT)

A classic ViT cuts the image into non-overlapping 16×16 patches and processes them as a flat sequence. SegFormer instead uses a hierarchical encoder inspired by convolutional networks: the image passes through four successive stages, each one halving the spatial resolution while doubling the channel depth, very much like the stages in a ResNet.

Input (H × W × C)
    │
Stage 1 → feature map H/4  × W/4  × C₁     (high resolution, local details)
Stage 2 → feature map H/8  × W/8  × C₂
Stage 3 → feature map H/16 × W/16 × C₃
Stage 4 → feature map H/32 × W/32 × C₄     (low resolution, global context)

Two design choices within each stage make the MiT especially efficient:

  • Efficient Self-Attention with Sequence Reduction (SR): standard self-attention is O(N²) in the sequence length N. MiT reduces the key/value sequence by a factor R before computing attention, cutting the cost to O(N²/R). The reduction ratio decreases across stages (R = 64, 16, 4, 1) so that deeper stages, which already have a shorter sequence, can afford more attention.
  • Mix-FFN: the feed-forward network in each transformer block includes a depth-wise 3×3 convolution, which implicitly encodes local positional information. This removes the need for fixed positional encodings and allows SegFormer to handle any input resolution at inference time.

The largest variant, MiT-B5, has ~82 M parameters in the encoder alone and was pre-trained on ImageNet-1k.

1.2 Decoder: lightweight all-MLP

Once the encoder has produced four feature maps at different scales, a traditional approach (like FPN or UPerNet) would use complex convolutional blocks to fuse them. SegFormer’s decoder does something strikingly simple:

  1. Project each stage output to a common embedding dimension with a linear layer.
  2. Upsample all projected maps to the same resolution (H/4 × W/4).
  3. Concatenate them along the channel dimension.
  4. Predict class logits with a single linear layer.

This all-MLP decoder has ~4 M parameters — about 5 % of the total model size — and yet consistently matches or outperforms much heavier decoder heads. The key insight is that multi-scale context is already baked into the encoder features; the decoder just needs to combine them.

Stage 1 features  ──┐
Stage 2 features  ──┤  Linear proj → upsample → H/4 × W/4
Stage 3 features  ──┤                ↓
Stage 4 features  ──┘             Concat → Linear → logits (H/4 × W/4 × num_classes)

The final logits are at quarter resolution (H/4 × W/4). During training they are bilinearly upsampled to the label resolution before computing the loss; during inference they are upsampled to the original image resolution.

1.3 Why SegFormer for satellite imagery?

Satellite imagery presents specific challenges compared to natural images:

Challenge How SegFormer helps
14 spectral bands (not 3) Encoder’s first patch embedding is replaced — num_channels is a config parameter
Large images (512 × 512 patches) No positional encodings → resolution-agnostic; SR attention → memory-efficient
Small objects (buildings, roads) Stage-1 features at H/4 preserve fine spatial details
Large context (land-cover categories span hundreds of pixels) Deep stages capture global context at H/32

2 Instantiating the Model

The project wraps SegFormer in two classes defined in intermediate_solutions/solution_step2/models/model.py:

  • SemanticSegmentationSegformer: a thin subclass of HuggingFace’s SegformerPreTrainedModel that adds freeze/unfreeze helpers and a cleaner forward method.
  • SegformerB5: a factory that fetches the label mapping from an S3 bucket, builds a SegformerConfig, and downloads the pre-trained nvidia/mit-b5 weights.
import sys
sys.path.insert(0, "intermediate_solutions/solution_step2")

from models.model import SegformerB5

# This downloads ~330 MB of weights from HuggingFace on first run
model = SegformerB5(
    n_bands=14,              # 14 Sentinel-2 bands
    logits=True,             # return raw logits (not probabilities)
    freeze_encoder=False,    # keep encoder trainable
    type_labeler="CLCplus-Backbone",
)
NoteWhat happens inside SegformerB5.__new__?

SegformerB5 overrides __new__ (not __init__) because it returns an instance of the parent class SemanticSegmentationSegformer rather than of SegformerB5 itself. This is a Python trick that lets a plain function call SegformerB5(...) act as a factory.

Inside, three things happen:

  1. The id2label / label2id dicts are fetched from S3 (a JSON file mapping integer class ids to human-readable names).
  2. A SegformerConfig is built from the nvidia/mit-b5 preset and patched with num_channels=14 and the correct number of classes.
  3. SemanticSegmentationSegformer.from_pretrained("nvidia/mit-b5", ...) is called with ignore_mismatched_sizes=True so that the first patch-embedding layer (whose weight shape changes when num_channels differs from 3) is re-initialised from scratch while all other layers keep their ImageNet weights.

3 Exploring the Model

3.1 Printing the architecture

The model object is a standard torch.nn.Module. You can print it to see every sub-module:

print(model)

The output is long, but the top-level structure is clear:

SemanticSegmentationSegformer(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(...)   # 4 patch-embedding layers
      (block): ModuleList(...)              # 4 lists of transformer blocks
      (layer_norm): ModuleList(...)         # 4 layer norms
    )
  )
  (decode_head): SegformerDecodeHead(
    (linear_c): ModuleList(...)             # 4 projection layers
    (linear_fuse): Conv2d(...)              # channel fusion
    (batch_norm): BatchNorm2d(...)
    (classifier): Conv2d(...)              # final class prediction
  )
)

3.2 Counting parameters

A quick way to understand a model’s capacity is to count how many parameters it has, and how many are currently trainable (i.e. not frozen).

def count_params(module):
    total = sum(p.numel() for p in module.parameters())
    trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
    return total, trainable

total, trainable = count_params(model)
enc_total, enc_trainable = count_params(model.segformer)
dec_total, dec_trainable = count_params(model.decode_head)

print(f"{'Component':<20} {'Total params':>15} {'Trainable params':>18}")
print("-" * 55)
print(f"{'Encoder (MiT-B5)':<20} {enc_total:>15,} {enc_trainable:>18,}")
print(f"{'Decoder (MLP)':<20} {dec_total:>15,} {dec_trainable:>18,}")
print(f"{'Full model':<20} {total:>15,} {trainable:>18,}")

Expected output (approximate):

Component            Total params   Trainable params
-------------------------------------------------------
Encoder (MiT-B5)       82,030,720         82,030,720
Decoder (MLP)           4,217,248          4,217,248
Full model             86,247,968         86,247,968

Notice how lopsided the split is: the encoder carries ~95 % of the parameters, yet the decoder is what actually learns to combine multi-scale features for this specific task. This asymmetry is central to the SegFormer design philosophy.

3.3 Freezing the encoder

When fine-tuning on a new domain (e.g. switching from RGB natural images to 14-band satellite imagery), a common strategy is to freeze the encoder at first and only train the decoder. This is much faster and prevents overfitting when labelled data is scarce.

# Freeze all encoder parameters
model.freeze()

total, trainable = count_params(model)
print(f"After freezing encoder — trainable: {trainable:,} / {total:,}")

# Unfreeze for full fine-tuning
model.unfreeze()
total, trainable = count_params(model)
print(f"After unfreezing — trainable: {trainable:,} / {total:,}")
CautionExercise 1 — Parameter exploration

Implement the count_params helper from scratch, then:

  1. Count the number of parameters in each of the 4 encoder stages individually (model.segformer.encoder.block[i]).
  2. Which stage has the most parameters? Does this match your intuition about deeper stages being wider?
  3. Count the parameters in the decoder’s four projection layers (model.decode_head.linear_c) and in the final classifier (model.decode_head.classifier). What fraction of decoder parameters does the classifier represent?
# 1. Parameters per encoder stage
for i, stage in enumerate(model.segformer.encoder.block):
    n = ___  # TODO: count parameters in this stage
    print(f"Stage {i+1}: {n:,} parameters")

# 2. Decoder breakdown
for i, proj in enumerate(model.decode_head.linear_c):
    n = ___  # TODO
    print(f"Decoder projection {i+1}: {n:,} parameters")

clf_params = ___  # TODO: classifier layer
print(f"Classifier: {clf_params:,} parameters")

Use sum(p.numel() for p in module.parameters()) to count parameters for any nn.Module. Each encoder stage is a ModuleList of transformer blocks accessible at model.segformer.encoder.block[i].

# 1. Parameters per encoder stage
for i, stage in enumerate(model.segformer.encoder.block):
    n = sum(p.numel() for p in stage.parameters())
    print(f"Stage {i+1}: {n:,} parameters")

# 2. Decoder breakdown
for i, proj in enumerate(model.decode_head.linear_c):
    n = sum(p.numel() for p in proj.parameters())
    print(f"Decoder projection {i+1}: {n:,} parameters")

clf_params = sum(p.numel() for p in model.decode_head.classifier.parameters())
dec_total = sum(p.numel() for p in model.decode_head.parameters())
print(f"Classifier: {clf_params:,} / {dec_total:,} decoder params "
      f"({100*clf_params/dec_total:.1f}%)")

4 Running a Forward Pass

Before training, it is useful to trace a dummy forward pass to verify that all shapes are consistent.

4.1 Input shape

Our dataset produces patches of shape (14, H, W) — 14 spectral bands. The DataLoader stacks them into batches (B, 14, H, W). Let us create a random dummy batch:

import torch

B, C, H, W = 2, 14, 512, 512   # batch size, channels, height, width
dummy_input = torch.randn(B, C, H, W)
dummy_labels = torch.randint(0, model.config.num_labels, (B, H, W))

print(f"Input  shape: {tuple(dummy_input.shape)}")
print(f"Labels shape: {tuple(dummy_labels.shape)}")

4.2 Output shapes

The model has two calling modes depending on whether you pass labels:

model.eval()
with torch.no_grad():
    # Without labels → raw logits at H/4 × W/4
    logits = model(dummy_input)
    print(f"Logits shape (no labels):   {tuple(logits.shape)}")
    # Expected: (2, num_classes, 128, 128)  — quarter resolution

    # With labels → logits upsampled to label resolution
    upsampled = model(dummy_input, dummy_labels)
    print(f"Logits shape (with labels): {tuple(upsampled.shape)}")
    # Expected: (2, num_classes, 512, 512)
NoteWhy quarter resolution?

After 4 stages of patch merging, the spatial dimensions are reduced by 2 at each stage: 512 → 256 → 128 → 64 → 32. Wait — that would give 1/16, not 1/4! The key is that the first patch embedding in SegFormer uses stride 4 (not stride 2), which brings the resolution immediately to H/4. Subsequent stages then halve once each: H/4 → H/8 → H/16 → H/32. The decoder upsamples all four feature maps to the finest scale (H/4) before fusing them, so the output is at H/4 × W/4.

4.3 Inspecting intermediate hidden states

The encoder returns all four intermediate feature maps when output_hidden_states=True. This is exactly how the forward method feeds the decoder:

outputs = model.segformer(
    dummy_input,
    output_hidden_states=True,
    return_dict=True,
)

for i, hs in enumerate(outputs.hidden_states):
    print(f"Stage {i+1} hidden state: {tuple(hs.shape)}")

You should see something like:

Stage 1 hidden state: (2, 64,  128, 128)   # H/4  × W/4
Stage 2 hidden state: (2, 128,  64,  64)   # H/8  × W/8
Stage 3 hidden state: (2, 320,  32,  32)   # H/16 × W/16
Stage 4 hidden state: (2, 512,  16,  16)   # H/32 × W/32
CautionExercise 2 — Forward pass anatomy
  1. Run the dummy forward pass without labels, then manually upsample the logits to the original (512, 512) resolution using torch.nn.functional.interpolate with mode="bilinear". Verify the shape matches what you get when you pass labels directly.
  2. Apply torch.softmax(logits, dim=1) to the upsampled logits and call torch.argmax(..., dim=1) to obtain a predicted class map. What shape does it have? What do its values represent?
  3. (Bonus) For each of the 4 hidden states, compute the ratio of spatial area to the original image: (H_stage * W_stage) / (H * W). Does the pattern match the quarter-resolution rule?
import torch.nn.functional as F

# 1. Manual upsampling
logits_full = F.interpolate(
    ___,                      # TODO: the raw logits tensor
    size=___,                 # TODO: target (H, W)
    mode="bilinear",
    align_corners=False,
)
print(f"Manually upsampled: {tuple(logits_full.shape)}")

# 2. Predicted class map
probs = ___   # TODO: softmax over class dimension
pred  = ___   # TODO: argmax → (B, H, W)
print(f"Predicted map shape: {tuple(pred.shape)}")
print(f"Unique predicted classes: {pred.unique().tolist()}")

# 3. Spatial area ratios
for i, hs in enumerate(outputs.hidden_states):
    ratio = ___  # TODO
    print(f"Stage {i+1}: {ratio:.4f}")
import torch.nn.functional as F

model.eval()
with torch.no_grad():
    logits = model(dummy_input)           # (B, num_classes, H/4, W/4)
    outputs = model.segformer(
        dummy_input, output_hidden_states=True, return_dict=True
    )

# 1. Manual upsampling
logits_full = F.interpolate(
    logits,
    size=(H, W),
    mode="bilinear",
    align_corners=False,
)
print(f"Manually upsampled: {tuple(logits_full.shape)}")

# 2. Predicted class map
probs = torch.softmax(logits_full, dim=1)
pred  = torch.argmax(probs, dim=1)
print(f"Predicted map shape: {tuple(pred.shape)}")
print(f"Unique predicted classes: {pred.unique().tolist()}")

# 3. Spatial area ratios
for i, hs in enumerate(outputs.hidden_states):
    ratio = (hs.shape[-2] * hs.shape[-1]) / (H * W)
    print(f"Stage {i+1}: {ratio:.4f}")

5 Wrapping the Model in a Lightning Module

PyTorch Lightning separates the model definition from the training logic. The SegmentationModule class in intermediate_solutions/solution_step2/models/module.py wraps SemanticSegmentationSegformer and implements the standard Lightning hooks.

5.1 Structure of SegmentationModule

from models.module import SegmentationModule
from torch import nn, optim

module = SegmentationModule(
    model=model,
    loss=nn.CrossEntropyLoss(ignore_index=255),
    optimizer=optim.AdamW,
    optimizer_params={"lr": 1e-3, "weight_decay": 1e-2},
    scheduler=optim.lr_scheduler.OneCycleLR,
    scheduler_params={},
    scheduler_interval="step",
)

The module’s key methods are:

Method What it does
forward(batch, labels) Delegates to model.forward; upsamples logits when labels are given
training_step Computes loss + logs IoU and building-rate metrics
validation_step Same as training, on the validation set
test_step Evaluates on test dataloaders
configure_optimizers Returns [optimizer], [scheduler] in the Lightning format

5.2 Loss function: Cross-Entropy

For multi-class segmentation, the standard loss is pixel-wise cross-entropy:

\[ \mathcal{L} = -\frac{1}{HW} \sum_{h,w} \log \hat{p}_{y_{h,w},h,w} \]

where \(\hat{p}_{c,h,w}\) is the softmax probability assigned to class \(c\) at pixel \((h, w)\), and \(y_{h,w}\) is the true class label. The ignore_index=255 argument tells PyTorch to skip pixels labelled 255 (used here as a “no-data” sentinel).

5.3 Metrics: Intersection over Union (IoU)

Raw accuracy is misleading in segmentation because classes are often heavily imbalanced (e.g. buildings occupy a small fraction of pixels). Intersection over Union (IoU) for a class \(c\) is:

\[ \text{IoU}_c = \frac{|\hat{M}_c \cap M_c|}{|\hat{M}_c \cup M_c|} \]

where \(\hat{M}_c\) is the set of pixels predicted as class \(c\) and \(M_c\) is the set of true pixels of class \(c\). An IoU of 1.0 means perfect overlap; 0.0 means no overlap at all. Mean IoU (mIoU) averages over all classes.

import sys
sys.path.insert(0, "intermediate_solutions/solution_step2")

from training.metrics import IOU, positive_rate

# Simulate model output and labels
B, num_classes, H, W = 2, 11, 512, 512
dummy_logits = torch.randn(B, num_classes, H, W)
dummy_labels = torch.randint(0, num_classes, (B, H, W))

iou_mean, iou_building = IOU(dummy_logits, dummy_labels, logits=True)
building_rate = positive_rate(dummy_logits, logits=True)

print(f"Mean IoU:      {iou_mean:.4f}")
print(f"Building IoU:  {iou_building:.4f}")
print(f"Building rate: {building_rate:.4f}  (fraction of pixels predicted as building)")
NoteWhy track building rate?

During early training, the model often collapses to predicting the majority class everywhere. Monitoring the fraction of pixels predicted as “building” helps detect this: if building_rate ≈ 0 after a few epochs, the model is ignoring the minority class and something is wrong (learning rate too high, class imbalance, etc.).

6 The Full Training Pipeline

The training script intermediate_solutions/solution_step2/main.py wires everything together:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

callbacks = [
    EarlyStopping(monitor="val_loss", patience=5, mode="min"),
    ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        filename="best-model",
    ),
]

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",   # uses GPU if available, CPU otherwise
    devices="auto",
    callbacks=callbacks,
    log_every_n_steps=10,
)

trainer.fit(module, train_loader, val_loader)

Two callbacks are used:

  • EarlyStopping: halts training if the validation loss does not improve for 5 consecutive epochs, preventing over-fitting.
  • ModelCheckpoint: saves only the checkpoint with the best validation loss, keeping storage requirements low.
CautionExercise 3 — Training a toy model

To avoid downloading the full pre-trained weights, we can build a small SegFormer from a minimal config and run a few training steps on random data.

  1. Create a SegformerConfig with num_channels=14, num_labels=3, and depths=[1, 1, 1, 1] (one transformer block per stage instead of the usual [3, 6, 40, 3] for B5).
  2. Instantiate a SemanticSegmentationSegformer from this config (no pre-trained weights — just call the constructor directly).
  3. Wrap it in a SegmentationModule with CrossEntropyLoss and AdamW.
  4. Run 3 training steps manually (without a Trainer) by calling module.training_step(batch, 0) on a random batch.
from transformers import SegformerConfig
from models.model import SemanticSegmentationSegformer
from models.module import SegmentationModule
from torch import nn, optim

# 1. Minimal config
config = SegformerConfig(
    num_channels=___,   # TODO
    num_labels=___,     # TODO
    depths=___,         # TODO: [1, 1, 1, 1]
    id2label={0: "background", 1: "building", 2: "vegetation"},
    label2id={"background": 0, "building": 1, "vegetation": 2},
)

# 2. Model without pre-trained weights
tiny_model = SemanticSegmentationSegformer(config)

# 3. Lightning module
tiny_module = SegmentationModule(
    model=tiny_model,
    loss=___,               # TODO: CrossEntropyLoss
    optimizer=optim.AdamW,
    optimizer_params={"lr": 1e-3},
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
    scheduler_params={"mode": "min", "patience": 3, "monitor": "val_loss"},
    scheduler_interval="epoch",
)
import torch
from transformers import SegformerConfig
from models.model import SemanticSegmentationSegformer
from models.module import SegmentationModule
from torch import nn, optim

config = SegformerConfig(
    num_channels=14,
    num_labels=3,
    depths=[1, 1, 1, 1],
    id2label={0: "background", 1: "building", 2: "vegetation"},
    label2id={"background": 0, "building": 1, "vegetation": 2},
)

tiny_model = SemanticSegmentationSegformer(config)

tiny_module = SegmentationModule(
    model=tiny_model,
    loss=nn.CrossEntropyLoss(ignore_index=255),
    optimizer=optim.AdamW,
    optimizer_params={"lr": 1e-3},
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
    scheduler_params={"mode": "min", "patience": 3, "monitor": "val_loss"},
    scheduler_interval="epoch",
)

7 Summary

In this section you have:

  • Understood the two-component design of SegFormer: hierarchical MiT encoder (multi-scale features, efficient SR attention) and all-MLP decoder (lightweight fusion).
  • Learned how to instantiate and inspect the model, count parameters, and see the encoder/decoder split.
  • Traced a forward pass from a raw (B, 14, H, W) tensor through the four encoder stages to the final (B, num_classes, H/4, W/4) logits, and upsampled them to label resolution.
  • Seen how freezing the encoder reduces trainable parameters from ~86 M to ~4 M, which is useful when fine-tuning with limited data.
  • Wrapped the model in a PyTorch Lightning module with cross-entropy loss and IoU metrics, ready for a full training loop.

The next section covers inference and statistics: using a pre-trained model checkpoint to produce land-cover maps over new regions and computing summary statistics with administrative boundaries.