import sys
sys.path.insert(0, "intermediate_solutions/solution_step2")
from models.model import SegformerB5
# This downloads ~330 MB of weights from HuggingFace on first run
model = SegformerB5(
n_bands=14, # 14 Sentinel-2 bands
logits=True, # return raw logits (not probabilities)
freeze_encoder=False, # keep encoder trainable
type_labeler="CLCplus-Backbone",
)Model Training
1 The SegFormer Architecture
In the previous section we downloaded and prepared Sentinel-2 multispectral images together with their CLC+ Backbone land-cover labels. We now have patches of shape (14, H, W) — 14 spectral bands, height, width — each paired with a pixel-wise class mask. The goal of this section is to understand the model that will turn those patches into segmentation maps.
We use SegFormer, a Vision Transformer architecture designed specifically for semantic segmentation. SegFormer was introduced by Xie et al. (2021) in “SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers”. Its two key design choices make it both powerful and practical:
- A hierarchical Mix Transformer encoder (MiT) that produces multi-scale feature representations without positional encodings.
- A lightweight all-MLP decoder that aggregates information from all encoder stages with almost no parameters.
1.1 Encoder: the Mix Transformer (MiT)
A classic ViT cuts the image into non-overlapping 16×16 patches and processes them as a flat sequence. SegFormer instead uses a hierarchical encoder inspired by convolutional networks: the image passes through four successive stages, each one halving the spatial resolution while doubling the channel depth, very much like the stages in a ResNet.
Input (H × W × C)
│
Stage 1 → feature map H/4 × W/4 × C₁ (high resolution, local details)
Stage 2 → feature map H/8 × W/8 × C₂
Stage 3 → feature map H/16 × W/16 × C₃
Stage 4 → feature map H/32 × W/32 × C₄ (low resolution, global context)
Two design choices within each stage make the MiT especially efficient:
- Efficient Self-Attention with Sequence Reduction (SR): standard self-attention is O(N²) in the sequence length N. MiT reduces the key/value sequence by a factor R before computing attention, cutting the cost to O(N²/R). The reduction ratio decreases across stages (R = 64, 16, 4, 1) so that deeper stages, which already have a shorter sequence, can afford more attention.
- Mix-FFN: the feed-forward network in each transformer block includes a depth-wise 3×3 convolution, which implicitly encodes local positional information. This removes the need for fixed positional encodings and allows SegFormer to handle any input resolution at inference time.
The largest variant, MiT-B5, has ~82 M parameters in the encoder alone and was pre-trained on ImageNet-1k.
1.2 Decoder: lightweight all-MLP
Once the encoder has produced four feature maps at different scales, a traditional approach (like FPN or UPerNet) would use complex convolutional blocks to fuse them. SegFormer’s decoder does something strikingly simple:
- Project each stage output to a common embedding dimension with a linear layer.
- Upsample all projected maps to the same resolution (H/4 × W/4).
- Concatenate them along the channel dimension.
- Predict class logits with a single linear layer.
This all-MLP decoder has ~4 M parameters — about 5 % of the total model size — and yet consistently matches or outperforms much heavier decoder heads. The key insight is that multi-scale context is already baked into the encoder features; the decoder just needs to combine them.
Stage 1 features ──┐
Stage 2 features ──┤ Linear proj → upsample → H/4 × W/4
Stage 3 features ──┤ ↓
Stage 4 features ──┘ Concat → Linear → logits (H/4 × W/4 × num_classes)
The final logits are at quarter resolution (H/4 × W/4). During training they are bilinearly upsampled to the label resolution before computing the loss; during inference they are upsampled to the original image resolution.
1.3 Why SegFormer for satellite imagery?
Satellite imagery presents specific challenges compared to natural images:
| Challenge | How SegFormer helps |
|---|---|
| 14 spectral bands (not 3) | Encoder’s first patch embedding is replaced — num_channels is a config parameter |
| Large images (512 × 512 patches) | No positional encodings → resolution-agnostic; SR attention → memory-efficient |
| Small objects (buildings, roads) | Stage-1 features at H/4 preserve fine spatial details |
| Large context (land-cover categories span hundreds of pixels) | Deep stages capture global context at H/32 |
2 Instantiating the Model
The project wraps SegFormer in two classes defined in intermediate_solutions/solution_step2/models/model.py:
SemanticSegmentationSegformer: a thin subclass of HuggingFace’sSegformerPreTrainedModelthat adds freeze/unfreeze helpers and a cleaner forward method.SegformerB5: a factory that fetches the label mapping from an S3 bucket, builds aSegformerConfig, and downloads the pre-trainednvidia/mit-b5weights.
SegformerB5.__new__?
SegformerB5 overrides __new__ (not __init__) because it returns an instance of the parent class SemanticSegmentationSegformer rather than of SegformerB5 itself. This is a Python trick that lets a plain function call SegformerB5(...) act as a factory.
Inside, three things happen:
- The
id2label/label2iddicts are fetched from S3 (a JSON file mapping integer class ids to human-readable names). - A
SegformerConfigis built from thenvidia/mit-b5preset and patched withnum_channels=14and the correct number of classes. SemanticSegmentationSegformer.from_pretrained("nvidia/mit-b5", ...)is called withignore_mismatched_sizes=Trueso that the first patch-embedding layer (whose weight shape changes whennum_channelsdiffers from 3) is re-initialised from scratch while all other layers keep their ImageNet weights.
3 Exploring the Model
3.1 Printing the architecture
The model object is a standard torch.nn.Module. You can print it to see every sub-module:
print(model)The output is long, but the top-level structure is clear:
SemanticSegmentationSegformer(
(segformer): SegformerModel(
(encoder): SegformerEncoder(
(patch_embeddings): ModuleList(...) # 4 patch-embedding layers
(block): ModuleList(...) # 4 lists of transformer blocks
(layer_norm): ModuleList(...) # 4 layer norms
)
)
(decode_head): SegformerDecodeHead(
(linear_c): ModuleList(...) # 4 projection layers
(linear_fuse): Conv2d(...) # channel fusion
(batch_norm): BatchNorm2d(...)
(classifier): Conv2d(...) # final class prediction
)
)
3.2 Counting parameters
A quick way to understand a model’s capacity is to count how many parameters it has, and how many are currently trainable (i.e. not frozen).
def count_params(module):
total = sum(p.numel() for p in module.parameters())
trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
return total, trainable
total, trainable = count_params(model)
enc_total, enc_trainable = count_params(model.segformer)
dec_total, dec_trainable = count_params(model.decode_head)
print(f"{'Component':<20} {'Total params':>15} {'Trainable params':>18}")
print("-" * 55)
print(f"{'Encoder (MiT-B5)':<20} {enc_total:>15,} {enc_trainable:>18,}")
print(f"{'Decoder (MLP)':<20} {dec_total:>15,} {dec_trainable:>18,}")
print(f"{'Full model':<20} {total:>15,} {trainable:>18,}")Expected output (approximate):
Component Total params Trainable params
-------------------------------------------------------
Encoder (MiT-B5) 82,030,720 82,030,720
Decoder (MLP) 4,217,248 4,217,248
Full model 86,247,968 86,247,968
Notice how lopsided the split is: the encoder carries ~95 % of the parameters, yet the decoder is what actually learns to combine multi-scale features for this specific task. This asymmetry is central to the SegFormer design philosophy.
3.3 Freezing the encoder
When fine-tuning on a new domain (e.g. switching from RGB natural images to 14-band satellite imagery), a common strategy is to freeze the encoder at first and only train the decoder. This is much faster and prevents overfitting when labelled data is scarce.
# Freeze all encoder parameters
model.freeze()
total, trainable = count_params(model)
print(f"After freezing encoder — trainable: {trainable:,} / {total:,}")
# Unfreeze for full fine-tuning
model.unfreeze()
total, trainable = count_params(model)
print(f"After unfreezing — trainable: {trainable:,} / {total:,}")Implement the count_params helper from scratch, then:
- Count the number of parameters in each of the 4 encoder stages individually (
model.segformer.encoder.block[i]). - Which stage has the most parameters? Does this match your intuition about deeper stages being wider?
- Count the parameters in the decoder’s four projection layers (
model.decode_head.linear_c) and in the final classifier (model.decode_head.classifier). What fraction of decoder parameters does the classifier represent?
# 1. Parameters per encoder stage
for i, stage in enumerate(model.segformer.encoder.block):
n = ___ # TODO: count parameters in this stage
print(f"Stage {i+1}: {n:,} parameters")
# 2. Decoder breakdown
for i, proj in enumerate(model.decode_head.linear_c):
n = ___ # TODO
print(f"Decoder projection {i+1}: {n:,} parameters")
clf_params = ___ # TODO: classifier layer
print(f"Classifier: {clf_params:,} parameters")Use sum(p.numel() for p in module.parameters()) to count parameters for any nn.Module. Each encoder stage is a ModuleList of transformer blocks accessible at model.segformer.encoder.block[i].
# 1. Parameters per encoder stage
for i, stage in enumerate(model.segformer.encoder.block):
n = sum(p.numel() for p in stage.parameters())
print(f"Stage {i+1}: {n:,} parameters")
# 2. Decoder breakdown
for i, proj in enumerate(model.decode_head.linear_c):
n = sum(p.numel() for p in proj.parameters())
print(f"Decoder projection {i+1}: {n:,} parameters")
clf_params = sum(p.numel() for p in model.decode_head.classifier.parameters())
dec_total = sum(p.numel() for p in model.decode_head.parameters())
print(f"Classifier: {clf_params:,} / {dec_total:,} decoder params "
f"({100*clf_params/dec_total:.1f}%)")4 Running a Forward Pass
Before training, it is useful to trace a dummy forward pass to verify that all shapes are consistent.
4.1 Input shape
Our dataset produces patches of shape (14, H, W) — 14 spectral bands. The DataLoader stacks them into batches (B, 14, H, W). Let us create a random dummy batch:
import torch
B, C, H, W = 2, 14, 512, 512 # batch size, channels, height, width
dummy_input = torch.randn(B, C, H, W)
dummy_labels = torch.randint(0, model.config.num_labels, (B, H, W))
print(f"Input shape: {tuple(dummy_input.shape)}")
print(f"Labels shape: {tuple(dummy_labels.shape)}")4.2 Output shapes
The model has two calling modes depending on whether you pass labels:
model.eval()
with torch.no_grad():
# Without labels → raw logits at H/4 × W/4
logits = model(dummy_input)
print(f"Logits shape (no labels): {tuple(logits.shape)}")
# Expected: (2, num_classes, 128, 128) — quarter resolution
# With labels → logits upsampled to label resolution
upsampled = model(dummy_input, dummy_labels)
print(f"Logits shape (with labels): {tuple(upsampled.shape)}")
# Expected: (2, num_classes, 512, 512)After 4 stages of patch merging, the spatial dimensions are reduced by 2 at each stage: 512 → 256 → 128 → 64 → 32. Wait — that would give 1/16, not 1/4! The key is that the first patch embedding in SegFormer uses stride 4 (not stride 2), which brings the resolution immediately to H/4. Subsequent stages then halve once each: H/4 → H/8 → H/16 → H/32. The decoder upsamples all four feature maps to the finest scale (H/4) before fusing them, so the output is at H/4 × W/4.
5 Wrapping the Model in a Lightning Module
PyTorch Lightning separates the model definition from the training logic. The SegmentationModule class in intermediate_solutions/solution_step2/models/module.py wraps SemanticSegmentationSegformer and implements the standard Lightning hooks.
5.1 Structure of SegmentationModule
from models.module import SegmentationModule
from torch import nn, optim
module = SegmentationModule(
model=model,
loss=nn.CrossEntropyLoss(ignore_index=255),
optimizer=optim.AdamW,
optimizer_params={"lr": 1e-3, "weight_decay": 1e-2},
scheduler=optim.lr_scheduler.OneCycleLR,
scheduler_params={},
scheduler_interval="step",
)The module’s key methods are:
| Method | What it does |
|---|---|
forward(batch, labels) |
Delegates to model.forward; upsamples logits when labels are given |
training_step |
Computes loss + logs IoU and building-rate metrics |
validation_step |
Same as training, on the validation set |
test_step |
Evaluates on test dataloaders |
configure_optimizers |
Returns [optimizer], [scheduler] in the Lightning format |
5.2 Loss function: Cross-Entropy
For multi-class segmentation, the standard loss is pixel-wise cross-entropy:
\[ \mathcal{L} = -\frac{1}{HW} \sum_{h,w} \log \hat{p}_{y_{h,w},h,w} \]
where \(\hat{p}_{c,h,w}\) is the softmax probability assigned to class \(c\) at pixel \((h, w)\), and \(y_{h,w}\) is the true class label. The ignore_index=255 argument tells PyTorch to skip pixels labelled 255 (used here as a “no-data” sentinel).
5.3 Metrics: Intersection over Union (IoU)
Raw accuracy is misleading in segmentation because classes are often heavily imbalanced (e.g. buildings occupy a small fraction of pixels). Intersection over Union (IoU) for a class \(c\) is:
\[ \text{IoU}_c = \frac{|\hat{M}_c \cap M_c|}{|\hat{M}_c \cup M_c|} \]
where \(\hat{M}_c\) is the set of pixels predicted as class \(c\) and \(M_c\) is the set of true pixels of class \(c\). An IoU of 1.0 means perfect overlap; 0.0 means no overlap at all. Mean IoU (mIoU) averages over all classes.
import sys
sys.path.insert(0, "intermediate_solutions/solution_step2")
from training.metrics import IOU, positive_rate
# Simulate model output and labels
B, num_classes, H, W = 2, 11, 512, 512
dummy_logits = torch.randn(B, num_classes, H, W)
dummy_labels = torch.randint(0, num_classes, (B, H, W))
iou_mean, iou_building = IOU(dummy_logits, dummy_labels, logits=True)
building_rate = positive_rate(dummy_logits, logits=True)
print(f"Mean IoU: {iou_mean:.4f}")
print(f"Building IoU: {iou_building:.4f}")
print(f"Building rate: {building_rate:.4f} (fraction of pixels predicted as building)")During early training, the model often collapses to predicting the majority class everywhere. Monitoring the fraction of pixels predicted as “building” helps detect this: if building_rate ≈ 0 after a few epochs, the model is ignoring the minority class and something is wrong (learning rate too high, class imbalance, etc.).
6 The Full Training Pipeline
The training script intermediate_solutions/solution_step2/main.py wires everything together:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
callbacks = [
EarlyStopping(monitor="val_loss", patience=5, mode="min"),
ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=1,
filename="best-model",
),
]
trainer = pl.Trainer(
max_epochs=20,
accelerator="auto", # uses GPU if available, CPU otherwise
devices="auto",
callbacks=callbacks,
log_every_n_steps=10,
)
trainer.fit(module, train_loader, val_loader)Two callbacks are used:
EarlyStopping: halts training if the validation loss does not improve for 5 consecutive epochs, preventing over-fitting.ModelCheckpoint: saves only the checkpoint with the best validation loss, keeping storage requirements low.
To avoid downloading the full pre-trained weights, we can build a small SegFormer from a minimal config and run a few training steps on random data.
- Create a
SegformerConfigwithnum_channels=14,num_labels=3, anddepths=[1, 1, 1, 1](one transformer block per stage instead of the usual[3, 6, 40, 3]for B5). - Instantiate a
SemanticSegmentationSegformerfrom this config (no pre-trained weights — just call the constructor directly). - Wrap it in a
SegmentationModulewithCrossEntropyLossandAdamW. - Run 3 training steps manually (without a
Trainer) by callingmodule.training_step(batch, 0)on a random batch.
from transformers import SegformerConfig
from models.model import SemanticSegmentationSegformer
from models.module import SegmentationModule
from torch import nn, optim
# 1. Minimal config
config = SegformerConfig(
num_channels=___, # TODO
num_labels=___, # TODO
depths=___, # TODO: [1, 1, 1, 1]
id2label={0: "background", 1: "building", 2: "vegetation"},
label2id={"background": 0, "building": 1, "vegetation": 2},
)
# 2. Model without pre-trained weights
tiny_model = SemanticSegmentationSegformer(config)
# 3. Lightning module
tiny_module = SegmentationModule(
model=tiny_model,
loss=___, # TODO: CrossEntropyLoss
optimizer=optim.AdamW,
optimizer_params={"lr": 1e-3},
scheduler=optim.lr_scheduler.ReduceLROnPlateau,
scheduler_params={"mode": "min", "patience": 3, "monitor": "val_loss"},
scheduler_interval="epoch",
)import torch
from transformers import SegformerConfig
from models.model import SemanticSegmentationSegformer
from models.module import SegmentationModule
from torch import nn, optim
config = SegformerConfig(
num_channels=14,
num_labels=3,
depths=[1, 1, 1, 1],
id2label={0: "background", 1: "building", 2: "vegetation"},
label2id={"background": 0, "building": 1, "vegetation": 2},
)
tiny_model = SemanticSegmentationSegformer(config)
tiny_module = SegmentationModule(
model=tiny_model,
loss=nn.CrossEntropyLoss(ignore_index=255),
optimizer=optim.AdamW,
optimizer_params={"lr": 1e-3},
scheduler=optim.lr_scheduler.ReduceLROnPlateau,
scheduler_params={"mode": "min", "patience": 3, "monitor": "val_loss"},
scheduler_interval="epoch",
)7 Summary
In this section you have:
- Understood the two-component design of SegFormer: hierarchical MiT encoder (multi-scale features, efficient SR attention) and all-MLP decoder (lightweight fusion).
- Learned how to instantiate and inspect the model, count parameters, and see the encoder/decoder split.
- Traced a forward pass from a raw
(B, 14, H, W)tensor through the four encoder stages to the final(B, num_classes, H/4, W/4)logits, and upsampled them to label resolution. - Seen how freezing the encoder reduces trainable parameters from ~86 M to ~4 M, which is useful when fine-tuning with limited data.
- Wrapped the model in a PyTorch Lightning module with cross-entropy loss and IoU metrics, ready for a full training loop.
The next section covers inference and statistics: using a pre-trained model checkpoint to produce land-cover maps over new regions and computing summary statistics with administrative boundaries.