Learned Kalman Filter
Data Generation
import torch
def generate_kalman_data(batch_size=16, seq_len=50, dt=1.0):
"""
Generate data for learning Kalman filter
Returns:
observations: [B, T, 2] - Noisy position measurements
true_states: [B, T, 4] - Ground truth [x, y, vx, vy]
"""
# Initialize random starting states
states = torch.randn(batch_size, 4) # [x, y, vx, vy]
states[:, 2:] *= 0.5 # Smaller velocities
# Motion model
F = torch.tensor([
[1, 0, dt, 0],
[0, 1, 0, dt],
[0, 0, 1, 0],
[0, 0, 0, 1]
], dtype=torch.float32)
# Observation model (observe position only)
H = torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0]
], dtype=torch.float32)
# Process noise
process_noise_std = 0.1
# Observation noise
obs_noise_std = 0.5
true_states_list = []
observations_list = []
for t in range(seq_len):
# Store current state
true_states_list.append(states.clone())
# Generate observation (position + noise)
obs = torch.matmul(states, H.T) + torch.randn(batch_size, 2) * obs_noise_std
observations_list.append(obs)
# Propagate state
states = torch.matmul(states, F.T) + torch.randn(batch_size, 4) * process_noise_std
true_states = torch.stack(true_states_list, dim=1) # [B, T, 4]
observations = torch.stack(observations_list, dim=1) # [B, T, 2]
return observations, true_states
# Generate data
observations, true_states = generate_kalman_data(batch_size=16, seq_len=50)
print("INPUT:")
print(" observations:", observations.shape) # [16, 50, 2]
print(" obs noise level:", observations.std().item())
print("\nOUTPUT:")
print(" true_states:", true_states.shape) # [16, 50, 4]
print("\nVisualize trajectory:")
print(" Position:", true_states[0, :5, :2])
print(" Velocity:", true_states[0, :5, 2:])
# YOUR TASK: Build a learnable Kalman filter
# model = LearnedKalmanFilter()
# predicted_states = model(observations)
# loss = F.mse_loss(predicted_states, true_states)Output:
INPUT:
observations: torch.Size([16, 50, 2])
obs noise level: 15.132975578308105
OUTPUT:
true_states: torch.Size([16, 50, 4])
Visualize trajectory:
Position: tensor([[ 0.5475, 0.6218],
[ 0.6527, -0.1476],
[ 1.0890, -0.6857],
[ 1.2830, -1.0682],
[ 1.3929, -1.3671]])
Velocity: tensor([[ 0.1766, -0.6363],
[ 0.3616, -0.4346],
[ 0.1175, -0.4036],
[ 0.1457, -0.3459],
[ 0.1434, -0.2056]])
RNN-based Kalman Filter Model
import torch
import torch.nn as nn
EPOCHS = 100
LEARNING_RATE = 0.01
EVAL_INT = 10
class RNNKalmannFilter(nn.Module):
def __init__(self):
super().__init__()
self.gru = nn.GRU(input_size=2, hidden_size=16, num_layers=2, batch_first=True)
self.lin = nn.Linear(in_features=16, out_features=4)
def forward(self, x):
x, _ = self.gru(x)
x = self.lin(x)
return x
model = RNNKalmannFilter()
loss_fn = torch.nn.MSELoss()
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
def train(model, loss_fn, opt, observations, true_states):
pred = model(observations)
loss = loss_fn(pred, true_states)
opt.zero_grad()
loss.backward()
opt.step()
print(f"loss {loss}")
for e in range(EPOCHS):
train(model, loss_fn, opt, observations, true_states)Output:
loss 114.74261474609375
loss 112.75377655029297
loss 110.9104232788086
loss 109.04251861572266
...
(loss values decrease over training)
...
loss 36.4130744934082
loss 36.129634857177734
