Model Explanation

1 The SegFormer Architecture

In the previous section we downloaded and prepared Sentinel-2 multispectral images together with their CLC+ Backbone land-cover labels. We now have patches of shape (14, H, W)14 input channels (12 L2A surface-reflectance bands plus NDVI and NDWI as derived layers), height, width — each paired with a pixel-wise class mask. The goal of this section is to understand the model that will turn those patches into segmentation maps.

We use SegFormer, a Vision Transformer architecture designed specifically for semantic segmentation. SegFormer was introduced by Xie et al. (2021) in “SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers”. Its two key design choices make it both powerful and practical:

  1. A hierarchical Mix Transformer encoder (MiT) that produces multi-scale feature representations without positional encodings.
  2. A lightweight all-MLP decoder that aggregates information from all encoder stages with almost no parameters.

1.1 Encoder: the Mix Transformer (MiT)

A classic ViT cuts the image into non-overlapping 16×16 patches and processes them as a flat sequence. SegFormer instead uses a hierarchical encoder inspired by convolutional networks: the image passes through four successive stages, each one halving the spatial resolution while doubling the channel depth, very much like the stages in a ResNet.

Input (H × W × C)
    │
Stage 1 → feature map H/4  × W/4  × C₁     (high resolution, local details)
Stage 2 → feature map H/8  × W/8  × C₂
Stage 3 → feature map H/16 × W/16 × C₃
Stage 4 → feature map H/32 × W/32 × C₄     (low resolution, global context)

Two design choices within each stage make the MiT especially efficient:

  • Efficient Self-Attention with Sequence Reduction (SR): standard self-attention is O(N²) in the sequence length N. MiT reduces the key/value sequence by a factor R before computing attention, cutting the cost to O(N²/R). The reduction ratio decreases across stages (R = 64, 16, 4, 1) so that deeper stages, which already have a shorter sequence, can afford more attention.
  • Mix-FFN: the feed-forward network in each transformer block includes a depth-wise 3×3 convolution, which implicitly encodes local positional information. This removes the need for fixed positional encodings and allows SegFormer to handle any input resolution at inference time.

The largest variant, MiT-B5, has ~82 M parameters in the encoder alone and was pre-trained on ImageNet-1k.

1.2 Decoder: lightweight all-MLP

Once the encoder has produced four feature maps at different scales, a traditional approach (like FPN or UPerNet) would use complex convolutional blocks to fuse them. SegFormer’s decoder does something strikingly simple:

  1. Project each stage output to a common embedding dimension with a linear layer.
  2. Upsample all projected maps to the same resolution (H/4 × W/4).
  3. Concatenate them along the channel dimension.
  4. Predict class logits with a single linear layer.

This all-MLP decoder has ~4 M parameters — about 5 % of the total model size — and yet consistently matches or outperforms much heavier decoder heads. The key insight is that multi-scale context is already baked into the encoder features; the decoder just needs to combine them.

Stage 1 features  ──┐
Stage 2 features  ──┤  Linear proj → upsample → H/4 × W/4
Stage 3 features  ──┤                ↓
Stage 4 features  ──┘             Concat → Linear → logits (H/4 × W/4 × num_classes)

The final logits are at quarter resolution (H/4 × W/4). During training they are bilinearly upsampled to the label resolution before computing the loss; during inference they are upsampled to the original image resolution.

1.3 Why SegFormer for satellite imagery?

Satellite imagery presents specific challenges compared to natural images:

Challenge How SegFormer helps
14 input channels (vs 3 for RGB) Encoder’s first patch embedding is replaced — num_channels is a config parameter
Large images (512 × 512 patches) No positional encodings → resolution-agnostic; SR attention → memory-efficient
Small objects (buildings, roads) Stage-1 features at H/4 preserve fine spatial details
Large context (land-cover categories span hundreds of pixels) Deep stages capture global context at H/32

2 Instantiating the Model

The project wraps SegFormer in two classes defined in src/models/model.py:

  • SemanticSegmentationSegformer: a thin subclass of HuggingFace’s SegformerPreTrainedModel that adds freeze/unfreeze helpers and a cleaner forward method.
  • SegformerB5: a factory that fetches the label mapping from an S3 bucket, builds a SegformerConfig, and downloads the pre-trained nvidia/mit-b5 weights.
from src.models.model import SegformerB5

# This downloads ~330 MB of weights from HuggingFace on first run
model = SegformerB5(
    # 14 matches the layout of the pre-baked GeoTIFFs: the 12 L2A spectral bands
    # (B10 is dropped by atmospheric correction) plus NDVI and NDWI as derived
    # channels. `src/download_region.py` produces the same layout. If you swap in
    # a different dataset (different band count or different derived layers), this
    # number, the normalisation statistics, and `num_channels` in the SegformerConfig
    # must all change together — otherwise the first patch-embedding layer
    # mis-shapes inputs.
    n_bands=14,
    logits=True,             # return raw logits (not probabilities)
    freeze_encoder=False,    # keep encoder trainable
    type_labeler="CLCplus-Backbone",
)
NoteWhat happens inside SegformerB5.__new__?

SegformerB5 overrides __new__ (not __init__) because it returns an instance of the parent class SemanticSegmentationSegformer rather than of SegformerB5 itself. This is a Python trick that lets a plain function call SegformerB5(...) act as a factory.

Inside, three things happen:

  1. The id2label / label2id dicts are fetched from S3 (a JSON file mapping integer class ids to human-readable names).
  2. A SegformerConfig is built from the nvidia/mit-b5 preset and patched with num_channels=14 and the correct number of classes.
  3. SemanticSegmentationSegformer.from_pretrained("nvidia/mit-b5", ...) is called with ignore_mismatched_sizes=True so that the first patch-embedding layer (whose weight shape changes when num_channels differs from 3) is re-initialised from scratch while all other layers keep their ImageNet weights.

3 Exploring the Model

3.1 Printing the architecture

The model object is a standard torch.nn.Module. You can print it to see every sub-module:

print(model)

The output is long, but the top-level structure is clear:

SemanticSegmentationSegformer(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(...)   # 4 patch-embedding layers
      (block): ModuleList(...)              # 4 lists of transformer blocks
      (layer_norm): ModuleList(...)         # 4 layer norms
    )
  )
  (decode_head): SegformerDecodeHead(
    (linear_c): ModuleList(...)             # 4 projection layers
    (linear_fuse): Conv2d(...)              # channel fusion
    (batch_norm): BatchNorm2d(...)
    (classifier): Conv2d(...)              # final class prediction
  )
)

3.2 Counting parameters

A quick way to understand a model’s capacity is to count how many parameters it has, and how many are currently trainable (i.e. not frozen).

def count_params(module):
    total = sum(p.numel() for p in module.parameters())
    trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
    return total, trainable

total, trainable = count_params(model)
enc_total, enc_trainable = count_params(model.segformer)
dec_total, dec_trainable = count_params(model.decode_head)

print(f"{'Component':<20} {'Total params':>15} {'Trainable params':>18}")
print("-" * 55)
print(f"{'Encoder (MiT-B5)':<20} {enc_total:>15,} {enc_trainable:>18,}")
print(f"{'Decoder (MLP)':<20} {dec_total:>15,} {dec_trainable:>18,}")
print(f"{'Full model':<20} {total:>15,} {trainable:>18,}")

Expected output (approximate):

Component            Total params   Trainable params
-------------------------------------------------------
Encoder (MiT-B5)       82,030,720         82,030,720
Decoder (MLP)           4,217,248          4,217,248
Full model             86,247,968         86,247,968

Notice how lopsided the split is: the encoder carries ~95 % of the parameters, yet the decoder is what actually learns to combine multi-scale features for this specific task. This asymmetry is central to the SegFormer design philosophy.

3.3 Freezing the encoder

When fine-tuning on a new domain (e.g. switching from RGB natural images to 14-channel satellite imagery), a common strategy is to freeze the encoder at first and only train the decoder. This is much faster and prevents overfitting when labelled data is scarce.

# Freeze all encoder parameters
model.freeze()

total, trainable = count_params(model)
print(f"After freezing encoder — trainable: {trainable:,} / {total:,}")

# Unfreeze for full fine-tuning
model.unfreeze()
total, trainable = count_params(model)
print(f"After unfreezing — trainable: {trainable:,} / {total:,}")
CautionExercise 1 — Parameter exploration

Implement the count_params helper from scratch, then:

  1. Count the number of parameters in each of the 4 encoder stages individually (model.segformer.encoder.block[i]).
  2. Which stage has the most parameters? Does this match your intuition about deeper stages being wider?
  3. Count the parameters in the decoder’s four projection layers (model.decode_head.linear_c) and in the final classifier (model.decode_head.classifier). What fraction of decoder parameters does the classifier represent?
# 1. Parameters per encoder stage
for i, stage in enumerate(model.segformer.encoder.block):
    n = ___  # TODO: count parameters in this stage
    print(f"Stage {i+1}: {n:,} parameters")

# 2. Decoder breakdown
for i, proj in enumerate(model.decode_head.linear_c):
    n = ___  # TODO
    print(f"Decoder projection {i+1}: {n:,} parameters")

clf_params = ___  # TODO: classifier layer
print(f"Classifier: {clf_params:,} parameters")

Use sum(p.numel() for p in module.parameters()) to count parameters for any nn.Module. Each encoder stage is a ModuleList of transformer blocks accessible at model.segformer.encoder.block[i].

# 1. Parameters per encoder stage.
# Deeper stages have more channels (C₁ < C₂ < C₃ < C₄), so even with fewer transformer
# blocks they end up holding most of the encoder's capacity. This matters during
# fine-tuning: freezing stages 3–4 alone already locks the bulk of the model's weights.
for i, stage in enumerate(model.segformer.encoder.block):
    n = sum(p.numel() for p in stage.parameters())
    print(f"Stage {i+1}: {n:,} parameters")

# 2. Decoder breakdown.
# The four projection layers are tiny linear maps (one per encoder stage); the real
# decoder cost lives in the linear_fuse + classifier. Seeing how light the head is
# explains why training the decoder alone is fast and resistant to overfitting.
for i, proj in enumerate(model.decode_head.linear_c):
    n = sum(p.numel() for p in proj.parameters())
    print(f"Decoder projection {i+1}: {n:,} parameters")

clf_params = sum(p.numel() for p in model.decode_head.classifier.parameters())
dec_total = sum(p.numel() for p in model.decode_head.parameters())
print(f"Classifier: {clf_params:,} / {dec_total:,} decoder params "
      f"({100*clf_params/dec_total:.1f}%)")

4 Running a Forward Pass

Before training, it is useful to trace a dummy forward pass to verify that all shapes are consistent.

4.1 Input shape

Our dataset produces patches of shape (14, H, W) — 14 input channels (12 L2A surface-reflectance bands plus NDVI and NDWI). The DataLoader stacks them into batches (B, 14, H, W). Let us create a random dummy batch:

import torch

B, C, H, W = 2, 14, 512, 512   # batch size, channels, height, width
dummy_input = torch.randn(B, C, H, W)
dummy_labels = torch.randint(0, model.config.num_labels, (B, H, W))

print(f"Input  shape: {tuple(dummy_input.shape)}")
print(f"Labels shape: {tuple(dummy_labels.shape)}")

4.2 Output shapes

The model has two calling modes depending on whether you pass labels:

model.eval()
with torch.no_grad():
    # Without labels → raw logits at H/4 × W/4
    logits = model(dummy_input)
    print(f"Logits shape (no labels):   {tuple(logits.shape)}")
    # Expected: (2, num_classes, 128, 128)  — quarter resolution

    # With labels → logits upsampled to label resolution
    upsampled = model(dummy_input, dummy_labels)
    print(f"Logits shape (with labels): {tuple(upsampled.shape)}")
    # Expected: (2, num_classes, 512, 512)
NoteWhy quarter resolution?

After 4 stages of patch merging, the spatial dimensions are reduced by 2 at each stage: 512 → 256 → 128 → 64 → 32. Wait — that would give 1/16, not 1/4! The key is that the first patch embedding in SegFormer uses stride 4 (not stride 2), which brings the resolution immediately to H/4. Subsequent stages then halve once each: H/4 → H/8 → H/16 → H/32. The decoder upsamples all four feature maps to the finest scale (H/4) before fusing them, so the output is at H/4 × W/4.

4.3 Inspecting intermediate hidden states

The encoder returns all four intermediate feature maps when output_hidden_states=True. This is exactly how the forward method feeds the decoder:

outputs = model.segformer(
    dummy_input,
    output_hidden_states=True,
    return_dict=True,
)

for i, hs in enumerate(outputs.hidden_states):
    print(f"Stage {i+1} hidden state: {tuple(hs.shape)}")

You should see something like:

Stage 1 hidden state: (2, 64,  128, 128)   # H/4  × W/4
Stage 2 hidden state: (2, 128,  64,  64)   # H/8  × W/8
Stage 3 hidden state: (2, 320,  32,  32)   # H/16 × W/16
Stage 4 hidden state: (2, 512,  16,  16)   # H/32 × W/32
CautionExercise 2 — Forward pass anatomy
  1. Run the dummy forward pass without labels, then manually upsample the logits to the original (512, 512) resolution using torch.nn.functional.interpolate with mode="bilinear". Verify the shape matches what you get when you pass labels directly.
  2. Apply torch.softmax(logits, dim=1) to the upsampled logits and call torch.argmax(..., dim=1) to obtain a predicted class map. What shape does it have? What do its values represent?
  3. (Bonus) For each of the 4 hidden states, compute the ratio of spatial area to the original image: (H_stage * W_stage) / (H * W). Does the pattern match the quarter-resolution rule?
import torch.nn.functional as F

# 1. Manual upsampling
logits_full = F.interpolate(
    ___,                      # TODO: the raw logits tensor
    size=___,                 # TODO: target (H, W)
    mode="bilinear",
    align_corners=False,
)
print(f"Manually upsampled: {tuple(logits_full.shape)}")

# 2. Predicted class map
probs = ___   # TODO: softmax over class dimension
pred  = ___   # TODO: argmax → (B, H, W)
print(f"Predicted map shape: {tuple(pred.shape)}")
print(f"Unique predicted classes: {pred.unique().tolist()}")

# 3. Spatial area ratios
for i, hs in enumerate(outputs.hidden_states):
    ratio = ___  # TODO
    print(f"Stage {i+1}: {ratio:.4f}")
import torch.nn.functional as F

model.eval()
with torch.no_grad():
    # Logits = raw, unbounded per-class scores. They have shape (B, num_classes, H/4, W/4)
    # because the all-MLP decoder fuses everything at the finest encoder scale (H/4).
    logits = model(dummy_input)
    # output_hidden_states=True surfaces the four encoder feature maps. The decoder
    # combines all four — early stages (high resolution) carry local detail, late
    # stages (low resolution) carry semantic context. Returning them is what makes
    # multi-scale fusion possible.
    outputs = model.segformer(
        dummy_input, output_hidden_states=True, return_dict=True
    )

# 1. Manual upsampling.
# Bilinear (not nearest-neighbour) because logits are continuous scores: interpolating
# between two scores is meaningful, whereas a discrete class label is not.
logits_full = F.interpolate(
    logits,
    size=(H, W),
    mode="bilinear",
    align_corners=False,
)
print(f"Manually upsampled: {tuple(logits_full.shape)}")

# 2. Predicted class map.
# softmax converts logits → probabilities; argmax picks the most likely class at each
# pixel. The class axis disappears: (B, num_classes, H, W) → (B, H, W), one int per pixel.
probs = torch.softmax(logits_full, dim=1)
pred  = torch.argmax(probs, dim=1)
print(f"Predicted map shape: {tuple(pred.shape)}")
print(f"Unique predicted classes: {pred.unique().tolist()}")

# 3. Spatial area ratios — stage 1 is at H/4 × W/4 → 1/16 of the input area, and each
# next stage divides that by 4 again (1/64, 1/256, 1/1024).
for i, hs in enumerate(outputs.hidden_states):
    ratio = (hs.shape[-2] * hs.shape[-1]) / (H * W)
    print(f"Stage {i+1}: {ratio:.4f}")

5 Wrapping the Model in a Lightning Module

PyTorch Lightning separates the model definition from the training logic. The SegmentationModule class in intermediate_solutions/solution_step2/models/module.py wraps SemanticSegmentationSegformer and implements the standard Lightning hooks.

5.1 Structure of SegmentationModule

from src.models.module import SegmentationModule
from torch import nn, optim

module = SegmentationModule(
    model=model,
    loss=nn.CrossEntropyLoss(ignore_index=255),
    optimizer=optim.AdamW,
    optimizer_params={"lr": 1e-3, "weight_decay": 1e-2},
    scheduler=optim.lr_scheduler.OneCycleLR,
    scheduler_params={},
    scheduler_interval="step",
)

The module’s key methods are:

Method What it does
forward(batch, labels) Delegates to model.forward; upsamples logits when labels are given
training_step Computes loss + logs IoU and building-rate metrics
validation_step Same as training, on the validation set
test_step Evaluates on test dataloaders
configure_optimizers Returns [optimizer], [scheduler] in the Lightning format

5.2 Loss function: Cross-Entropy

For multi-class segmentation, the standard loss is pixel-wise cross-entropy:

\[ \mathcal{L} = -\frac{1}{HW} \sum_{h,w} \log \hat{p}_{y_{h,w},h,w} \]

where \(\hat{p}_{c,h,w}\) is the softmax probability assigned to class \(c\) at pixel \((h, w)\), and \(y_{h,w}\) is the true class label.

The ignore_index=255 argument tells PyTorch to drop pixels with label 255 from the sum. Satellite labels routinely contain such “no-data” pixels — clouds, cloud shadows, image borders, missing observations — and 255 is the conventional sentinel used by the CLC+ preprocessing pipeline. Without ignore_index, the loss would treat those pixels as if they belonged to a real class 255 (which doesn’t exist), pushing gradients in a meaningless direction and contaminating the trained weights.

5.3 Metrics: Intersection over Union (IoU)

Raw accuracy is misleading in segmentation because classes are often heavily imbalanced (e.g. buildings occupy a small fraction of pixels). Intersection over Union (IoU) for a class \(c\) is:

\[ \text{IoU}_c = \frac{|\hat{M}_c \cap M_c|}{|\hat{M}_c \cup M_c|} \]

where \(\hat{M}_c\) is the set of pixels predicted as class \(c\) and \(M_c\) is the set of true pixels of class \(c\). An IoU of 1.0 means perfect overlap; 0.0 means no overlap at all. Mean IoU (mIoU) averages over all classes.

from src.training.metrics import IOU, positive_rate

# Simulate model output and labels
B, num_classes, H, W = 2, 11, 512, 512
dummy_logits = torch.randn(B, num_classes, H, W)
dummy_labels = torch.randint(0, num_classes, (B, H, W))

iou_mean, iou_building = IOU(dummy_logits, dummy_labels, logits=True)
building_rate = positive_rate(dummy_logits, logits=True)

print(f"Mean IoU:      {iou_mean:.4f}")
print(f"Building IoU:  {iou_building:.4f}")
print(f"Building rate: {building_rate:.4f}  (fraction of pixels predicted as building)")
NoteWhy track building rate?

During early training, the model often collapses to predicting the majority class everywhere. Monitoring the fraction of pixels predicted as “building” helps detect this: if building_rate ≈ 0 after a few epochs, the model is ignoring the minority class and something is wrong (learning rate too high, class imbalance, etc.).

6 Summary

In this section you have:

  • Understood the two-component design of SegFormer: hierarchical MiT encoder (multi-scale features, efficient SR attention) and all-MLP decoder (lightweight fusion).
  • Learned how to instantiate and inspect the model, count parameters, and see the encoder/decoder split.
  • Traced a forward pass from a raw (B, 14, H, W) tensor through the four encoder stages to the final (B, num_classes, H/4, W/4) logits, and upsampled them to label resolution.
  • Seen how freezing the encoder reduces trainable parameters from ~86 M to ~4 M, which is useful when fine-tuning with limited data.
  • Wrapped the model in a PyTorch Lightning module with cross-entropy loss and IoU metrics.

The next step is either the (optional) Model Training section — if you want to train the model yourself — or directly Inference, which uses a pre-trained checkpoint to produce land-cover maps.