VOC Segmentation
Loading Dataset
from torchvision.datasets import VOCSegmentation
from torchvision import transforms as T
# Transforms to make data consistent
img_transform = T.Compose([
T.Resize((256, 256)), # or (512, 512)
T.ToTensor(),
])
mask_transform = T.Compose([
T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST), # IMPORTANT!
T.PILToTensor(),
T.Lambda(lambda x: x.squeeze(0).long())
])
# Load the dataset - note root should be './data', not './data/VOCdevkit'
train_dataset = VOCSegmentation(root='./data', year='2012', image_set='train', download=False, transform=img_transform, target_transform=mask_transform)
val_dataset = VOCSegmentation(root='./data', year='2012', image_set='val', download=False, transform=img_transform, target_transform=mask_transform)
print(f"Successfully loaded {len(train_dataset)} training images!")
print(f"Testing first sample...")
# Test loading one sample
img, mask = train_dataset[0]
img_v, mask_v = val_dataset[0]
img.shape, mask.shape, img_v.shape, mask_v.shape,Output:
Successfully loaded 1464 training images!
Testing first sample...
(torch.Size([3, 256, 256]),
torch.Size([256, 256]),
torch.Size([3, 256, 256]),
torch.Size([256, 256]))
Visualizing Dataset
import numpy as np
import matplotlib.pyplot as plt
# Show a single image and values
img, mask = train_dataset[0]
img_v, mask_v = val_dataset[0]
fig, axes = plt.subplots(2, 2, figsize=(12, 5))
axes[0, 0].imshow(img.permute(1, 2, 0))
axes[0, 1].imshow(mask)
axes[1, 0].imshow(img_v.permute(1, 2, 0))
axes[1, 1].imshow(mask_v)
plt.tight_layout()
plt.show()U-Net Style Segmentation Model
import torch
from torch import nn
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32
NUM_CLASSES = 21
LEARNING_RATE = 0.01
EPOCHS = 100
EVAL_INTERVAL = 10
train_dl = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
class VOCSegmentation(nn.Module):
def __init__(self):
super().__init__()
# Encoder blocks
self.enc0_block = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2) # 256 -> 128
)
self.enc1_block = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2) # 128 -> 64
)
self.enc2_block = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2) # 64 -> 32
)
# Decoder blocks
self.dec0_block = nn.Sequential(
nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2), # 32 -> 64
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.dec1_block = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2), # 64 -> 128
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.dec2_block = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2), # 128 -> 256
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.final = nn.Conv2d(in_channels=32, out_channels=NUM_CLASSES, kernel_size=1)
def forward(self, x):
x = self.enc0_block(x)
x = self.enc1_block(x)
x = self.enc2_block(x)
x = self.dec0_block(x)
x = self.dec1_block(x)
x = self.dec2_block(x)
return self.final(x)
model = VOCSegmentation()
model.cuda()
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=255)
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
def train(model, loss_fn, opt, train_dl):
model.train()
losses = []
for img, mask in train_dl:
img = img.to(device)
mask = mask.to(device)
# TEMPORARY FIX: Force all values into valid range
mask = torch.where(mask == 255, 255, torch.clamp(mask, 0, 20))
pred = model(img)
loss = loss_fn(pred, mask)
losses.append(loss)
opt.zero_grad()
loss.backward()
opt.step()
net_loss = sum(losses) / len(losses)
print(f"TRAIN: NET LOSS {net_loss.detach().cpu()}")
def test(model, val_dl):
model.eval()
acc = []
for img, mask in val_dl:
img = img.to(device)
mask = mask.to(device)
# TEMPORARY FIX: Force all values into valid range
mask = torch.where(mask == 255, 255, torch.clamp(mask, 0, 20))
pred = model(img)
loss = loss_fn(pred, mask)
acc.append(loss)
net_acc = sum(acc) / len(acc)
print(f"TEST NET LOSS {net_acc.detach().cpu()}")
for e in tqdm(range(EPOCHS)):
train(model, loss_fn, opt, train_dl)
if e % EVAL_INTERVAL == 0:
with torch.inference_mode():
test(model, val_dl)Note: Training outputs omitted for brevity
Architecture Notes:
- Uses a U-Net style encoder-decoder architecture
- Encoder: 3 blocks with Conv → BatchNorm → ReLU → MaxPool, progressively downsampling (256 → 128 → 64 → 32)
- Decoder: 3 blocks with ConvTranspose (upsampling) → ReLU → Conv → ReLU, progressively upsampling back to 256
- Final 1x1 convolution to produce 21 class predictions (VOC dataset has 21 classes)
- Uses CrossEntropyLoss with
ignore_index=255for unlabeled pixels
