Classification with PyTorch
A problem involving assigning classes to an input:
- Binary classification: one or the other
- Multi-class classification: one of the following classes
- Multi-label classification: one or more of the following classes
Binary Classification
from sklearn.datasets import make_circles
# Using the make_circles function from sklearn which literally makes a dataset of two classes of points in cocentric circles
n_samples = 1000
# Create circles
X, y = make_circles(n_samples, noise=0.03, random_state=42)
# Make DataFrame of circle data
import pandas as pd
circles = pd.DataFrame({"X1": X[:, 0],
"X2": X[:, 1],
"label": y
})
circles.head(10)
import matplotlib.pyplot as plt
plt.scatter(x=X[:, 0],
y=X[:, 1],
c=y,
cmap=plt.cm.RdYlBu)Output:
<matplotlib.collections.PathCollection at 0x7d6060121610>

print(X.shape, y.shape)Output:
(1000, 2) (1000,)
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class CircleBinary(nn.Module):
def __init__(self, device=None):
super().__init__()
self.lin = nn.Linear(2, 50, device=device)
self.lin4 = nn.Linear(50, 1, device=device)
self.relu = nn.ReLU()
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = self.relu(self.lin(x))
# x = self.relu(self.lin2(x))
# x = self.relu(self.lin3(x))
x = self.lin4(x)
return x.squeeze(dim=1)
model = CircleBinary(device)
BATCH_SIZE = 10
LEARNING_RATE = 0.01
EPOCHS = 50
def calc_accuracy(y_pred, y):
''' Calculate Accuracy (Percentage)
Calculates accuracy as a percentage.
'''
correct = torch.eq(y_pred, y).sum().item() # gets tensor of equivalence, sums, get value
acc = (correct) / len(y) * 100
return acc
# moving not copying!
X = torch.as_tensor(X, device=device, dtype=torch.float)
y = torch.as_tensor(y, device=device, dtype=torch.float)
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE)
# loss_func = torch.nn.BCELoss()
loss_func = torch.nn.BCEWithLogitsLoss() # <--- this is more stable than using BCELoss after a sigmoid layer
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(EPOCHS):
model.train()
for x_b, y_b in train_dl:
y_p = model(x_b)
train_loss = loss_func(y_p, y_b)
opt.zero_grad()
train_loss.backward()
opt.step()
train_acc = sum(calc_accuracy(torch.round(torch.sigmoid(model(x_b))), y_b) for x_b, y_b in train_dl) / len(train_dl)
model.eval()
with torch.inference_mode():
test_acc = sum(calc_accuracy(torch.round(torch.sigmoid(model(x_b))), y_b) for x_b, y_b in test_dl) / len(test_dl)
print(f"EPOCH {e} | BCELOSS TRAIN {train_loss} | TRAINING ACC {train_acc} | TESTING ACC {test_acc}")Output:
torch.Size([800, 2])
torch.Size([200, 2])
torch.Size([800])
torch.Size([200])
EPOCH 0 | BCELOSS TRAIN 0.6700426340103149 | TRAINING ACC 50.625 | TESTING ACC 48.5
EPOCH 1 | BCELOSS TRAIN 0.6890271306037903 | TRAINING ACC 54.375 | TESTING ACC 51.0
EPOCH 2 | BCELOSS TRAIN 0.6950142979621887 | TRAINING ACC 56.125 | TESTING ACC 52.5
EPOCH 3 | BCELOSS TRAIN 0.6931143999099731 | TRAINING ACC 72.625 | TESTING ACC 71.0
EPOCH 4 | BCELOSS TRAIN 0.7008362412452698 | TRAINING ACC 66.875 | TESTING ACC 69.0
EPOCH 5 | BCELOSS TRAIN 0.6703662276268005 | TRAINING ACC 71.75 | TESTING ACC 70.0
EPOCH 6 | BCELOSS TRAIN 0.6747851371765137 | TRAINING ACC 73.25 | TESTING ACC 71.0
EPOCH 7 | BCELOSS TRAIN 0.6851223707199097 | TRAINING ACC 70.375 | TESTING ACC 70.0
EPOCH 8 | BCELOSS TRAIN 0.6609817147254944 | TRAINING ACC 71.375 | TESTING ACC 70.5
EPOCH 9 | BCELOSS TRAIN 0.6603668928146362 | TRAINING ACC 63.625 | TESTING ACC 64.5
EPOCH 10 | BCELOSS TRAIN 0.6925440430641174 | TRAINING ACC 67.375 | TESTING ACC 68.0
EPOCH 11 | BCELOSS TRAIN 0.6850089430809021 | TRAINING ACC 67.125 | TESTING ACC 69.0
EPOCH 12 | BCELOSS TRAIN 0.6691190004348755 | TRAINING ACC 67.5 | TESTING ACC 68.0
EPOCH 13 | BCELOSS TRAIN 0.6817285418510437 | TRAINING ACC 73.125 | TESTING ACC 72.0
EPOCH 14 | BCELOSS TRAIN 0.6745217442512512 | TRAINING ACC 77.0 | TESTING ACC 73.5
EPOCH 15 | BCELOSS TRAIN 0.6450468301773071 | TRAINING ACC 62.5 | TESTING ACC 63.5
EPOCH 16 | BCELOSS TRAIN 0.6369532346725464 | TRAINING ACC 65.375 | TESTING ACC 66.5
EPOCH 17 | BCELOSS TRAIN 0.6455376148223877 | TRAINING ACC 69.25 | TESTING ACC 71.0
EPOCH 18 | BCELOSS TRAIN 0.6477933526039124 | TRAINING ACC 65.625 | TESTING ACC 67.0
EPOCH 19 | BCELOSS TRAIN 0.6479110717773438 | TRAINING ACC 66.75 | TESTING ACC 69.0
EPOCH 20 | BCELOSS TRAIN 0.6364610195159912 | TRAINING ACC 74.375 | TESTING ACC 74.5
EPOCH 21 | BCELOSS TRAIN 0.6489265561103821 | TRAINING ACC 69.125 | TESTING ACC 72.5
EPOCH 22 | BCELOSS TRAIN 0.655064046382904 | TRAINING ACC 69.875 | TESTING ACC 72.5
EPOCH 23 | BCELOSS TRAIN 0.6492965817451477 | TRAINING ACC 74.75 | TESTING ACC 76.0
EPOCH 24 | BCELOSS TRAIN 0.6289384961128235 | TRAINING ACC 73.875 | TESTING ACC 75.5
EPOCH 25 | BCELOSS TRAIN 0.6639038920402527 | TRAINING ACC 77.875 | TESTING ACC 77.5
EPOCH 26 | BCELOSS TRAIN 0.66072678565979 | TRAINING ACC 70.875 | TESTING ACC 74.0
EPOCH 27 | BCELOSS TRAIN 0.6398707628250122 | TRAINING ACC 81.375 | TESTING ACC 80.5
EPOCH 28 | BCELOSS TRAIN 0.6358336210250854 | TRAINING ACC 81.0 | TESTING ACC 80.5
EPOCH 29 | BCELOSS TRAIN 0.6169649362564087 | TRAINING ACC 77.0 | TESTING ACC 77.0
EPOCH 30 | BCELOSS TRAIN 0.6862431764602661 | TRAINING ACC 82.375 | TESTING ACC 81.0
EPOCH 31 | BCELOSS TRAIN 0.6299324035644531 | TRAINING ACC 81.125 | TESTING ACC 80.5
EPOCH 32 | BCELOSS TRAIN 0.6076124906539917 | TRAINING ACC 82.25 | TESTING ACC 81.5
EPOCH 33 | BCELOSS TRAIN 0.6340460777282715 | TRAINING ACC 81.0 | TESTING ACC 82.5
EPOCH 34 | BCELOSS TRAIN 0.5948718190193176 | TRAINING ACC 87.5 | TESTING ACC 87.5
EPOCH 35 | BCELOSS TRAIN 0.6291066408157349 | TRAINING ACC 85.625 | TESTING ACC 85.5
EPOCH 36 | BCELOSS TRAIN 0.6187013983726501 | TRAINING ACC 88.625 | TESTING ACC 87.5
EPOCH 37 | BCELOSS TRAIN 0.6471001505851746 | TRAINING ACC 87.75 | TESTING ACC 89.0
EPOCH 38 | BCELOSS TRAIN 0.6382052302360535 | TRAINING ACC 89.75 | TESTING ACC 89.5
EPOCH 39 | BCELOSS TRAIN 0.61151522397995 | TRAINING ACC 86.25 | TESTING ACC 86.0
EPOCH 40 | BCELOSS TRAIN 0.6044560074806213 | TRAINING ACC 91.5 | TESTING ACC 91.5
EPOCH 41 | BCELOSS TRAIN 0.5917207598686218 | TRAINING ACC 90.25 | TESTING ACC 89.0
EPOCH 42 | BCELOSS TRAIN 0.6241494417190552 | TRAINING ACC 89.875 | TESTING ACC 90.5
EPOCH 43 | BCELOSS TRAIN 0.6151015162467957 | TRAINING ACC 92.125 | TESTING ACC 92.0
EPOCH 44 | BCELOSS TRAIN 0.6239864230155945 | TRAINING ACC 93.375 | TESTING ACC 93.0
EPOCH 45 | BCELOSS TRAIN 0.595759391784668 | TRAINING ACC 94.625 | TESTING ACC 93.5
EPOCH 46 | BCELOSS TRAIN 0.573060929775238 | TRAINING ACC 95.375 | TESTING ACC 94.0
EPOCH 47 | BCELOSS TRAIN 0.5988495945930481 | TRAINING ACC 94.75 | TESTING ACC 96.5
EPOCH 48 | BCELOSS TRAIN 0.6048717498779297 | TRAINING ACC 95.5 | TESTING ACC 96.5
EPOCH 49 | BCELOSS TRAIN 0.615748405456543 | TRAINING ACC 96.25 | TESTING ACC 94.0
(me before i got it working) This is shit
Well yeah, its trying to fit a linear decision boundary on a circle. In this case, the model is underfitting. As in it it trying to fit a lower dimensional decision boundary on a higher-dimensional problem. As a result, it can’t learn!
import requests
from pathlib import Path
# Download helper functions from Learn PyTorch repo (if not already downloaded)
if Path("helper_functions.py").is_file():
print("helper_functions.py already exists, skipping download")
else:
print("Downloading helper_functions.py")
request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
with open("helper_functions.py", "wb") as f:
f.write(request.content)
from helper_functions import plot_predictions, plot_decision_boundary
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model, X_test, y_test)Output:
helper_functions.py already exists, skipping download
Output:
<Figure size 1200x600 with 2 Axes>

How to improve on a model?
There are a number of ways:
- increase the number of layers in the model
- increase the size of each layer
- increase / decrease the learning rate
- vary the number of epochs
- change the activation function (non-linear activation lets us learn non-linear problems (most problems))
- change the loss function
- transfer learning
In our case, we need to deal the non-linearity of the problem. Generally, I ran into a number of issues when doing this:
- be aware of where is my true input output: BCELossWithLogits is odd in that it will optimize with a sigmoid function before BCE, but when I want to gauge the accuracy of my model, I actually had to run another sigmoid and round to get a value I wanted (we add a torch.sigmoid (between 0 and 1, and symmetrical) and round (ceiling of 0 or 1) to turn the continuous representation of the classification space into discrete)
- generally do not gauge accuracy metrics WITHIN a batch (it could vary too much to get anything meaningful, ie. batch of 10 would have the chance of all 10 being correct hence acc is 100%)
- the inclusion of a non-linear activation function enables a NN to learn any function (proven, but we don’t know what the optimium width is, the time to convergence, the loss function to get to there, the optimizer, etc. So it exists, but we’ll never know immediately what it is)
Multi-class Classification
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
NUM_CLASSES = 4
NUM_FEATURES = 2
RANDOM_SEED = 88
NUM_SAMPLES = 1000
X, y = make_blobs(n_samples=NUM_SAMPLES,
n_features=NUM_FEATURES,
centers=NUM_CLASSES,
cluster_std=1.5,
random_state=RANDOM_SEED
)
plt.figure(figsize=(10,7))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdYlBu)Output:
<matplotlib.collections.PathCollection at 0x7d60b6fe4f20>
Output:
<Figure size 1000x700 with 1 Axes>

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
SPLIT = 0.8
BATCH_SIZE = 10
LEARNING_RATE = 0.01
EPOCHS = 100
X = torch.as_tensor(X, dtype=torch.float, device=device)
y = torch.as_tensor(y, dtype=torch.long, device=device)
print(X.shape)
print(y.shape)
split_i = int(SPLIT*len(X))
X_train, X_test = X[:split_i], X[split_i:]
y_train, y_test = y[:split_i], y[split_i:]
# Dataloader
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE)
# Model
class MultiClassifier(nn.Module):
def __init__(self, device):
super().__init__()
self.lin0 = nn.Linear(2, 20, device=device)
self.lin1 = nn.Linear(20, NUM_CLASSES, device=device)
# self.relu = nn.ReLU() actually this example didnt actually need a non-linear activation
# most likely because the decision boundary can be linear
self.softmax = nn.Softmax()
def forward(self, x):
x = self.lin0(x)
x = self.lin1(x)
return self.softmax(x)
model = MultiClassifier(device)
# Loss and Optimizer
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.CrossEntropyLoss()
# Define Accuracy Metric
def calc_accuracy(y_pred, y):
correct = sum(torch.eq(y_pred.argmax(dim=1), y).squeeze())
acc = correct / len(y)
return acc
# Train Loop
for e in range(EPOCHS):
model.train()
opt.zero_grad()
for x_b, y_b in train_dl:
pred = model(x_b)
loss = loss_fn(pred, y_b)
loss.backward()
opt.step()
train_acc = sum(calc_accuracy(model(x_b), y_b) for x_b, y_b in train_dl) / len(train_dl)
model.eval()
with torch.inference_mode():
test_acc = sum(calc_accuracy(model(x_b), y_b) for x_b, y_b in test_dl) / len(test_dl)
print(f"EPOCH {e} | TRAIN BCE LOSS {loss} | TRAIN ACC {train_acc} | TEST ACC {test_acc}")Output:
torch.Size([1000, 2])
torch.Size([1000])
EPOCH 0 | TRAIN BCE LOSS 1.3057008981704712 | TRAIN ACC 0.5099999308586121 | TEST ACC 0.46500006318092346
EPOCH 1 | TRAIN BCE LOSS 1.2484840154647827 | TRAIN ACC 0.6487500667572021 | TEST ACC 0.60999995470047
EPOCH 2 | TRAIN BCE LOSS 1.0800795555114746 | TRAIN ACC 0.7087499499320984 | TEST ACC 0.6749998927116394
EPOCH 3 | TRAIN BCE LOSS 0.996385395526886 | TRAIN ACC 0.6912499666213989 | TEST ACC 0.6549999713897705
Output:
/home/eddy/code/mine/pytorch-deep-learning/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1775: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
return self._call_impl(*args, **kwargs)
Output:
EPOCH 4 | TRAIN BCE LOSS 0.9900083541870117 | TRAIN ACC 0.7137498259544373 | TEST ACC 0.6699999570846558
EPOCH 5 | TRAIN BCE LOSS 0.969978928565979 | TRAIN ACC 0.7137498259544373 | TEST ACC 0.6699999570846558
EPOCH 6 | TRAIN BCE LOSS 0.9669672250747681 | TRAIN ACC 0.7137498259544373 | TEST ACC 0.6699999570846558
EPOCH 7 | TRAIN BCE LOSS 0.9650943875312805 | TRAIN ACC 0.7149998545646667 | TEST ACC 0.6699999570846558
EPOCH 8 | TRAIN BCE LOSS 0.9631978869438171 | TRAIN ACC 0.7149998545646667 | TEST ACC 0.6699999570846558
EPOCH 9 | TRAIN BCE LOSS 0.9616975784301758 | TRAIN ACC 0.717499852180481 | TEST ACC 0.6699999570846558
EPOCH 10 | TRAIN BCE LOSS 0.9603195190429688 | TRAIN ACC 0.717499852180481 | TEST ACC 0.6699999570846558
EPOCH 11 | TRAIN BCE LOSS 0.9591120481491089 | TRAIN ACC 0.7187498807907104 | TEST ACC 0.6699999570846558
EPOCH 12 | TRAIN BCE LOSS 0.9580297470092773 | TRAIN ACC 0.7187498807907104 | TEST ACC 0.6699999570846558
EPOCH 13 | TRAIN BCE LOSS 0.9570589065551758 | TRAIN ACC 0.7199999094009399 | TEST ACC 0.6799999475479126
EPOCH 14 | TRAIN BCE LOSS 0.9561826586723328 | TRAIN ACC 0.7237498760223389 | TEST ACC 0.6799999475479126
EPOCH 15 | TRAIN BCE LOSS 0.9553885459899902 | TRAIN ACC 0.7237498760223389 | TEST ACC 0.6799999475479126
EPOCH 16 | TRAIN BCE LOSS 0.9546659588813782 | TRAIN ACC 0.7237498760223389 | TEST ACC 0.6799999475479126
EPOCH 17 | TRAIN BCE LOSS 0.9540055990219116 | TRAIN ACC 0.7237498760223389 | TEST ACC 0.6799999475479126
EPOCH 18 | TRAIN BCE LOSS 0.9534003138542175 | TRAIN ACC 0.7249999046325684 | TEST ACC 0.6799999475479126
EPOCH 19 | TRAIN BCE LOSS 0.9528433084487915 | TRAIN ACC 0.7262498736381531 | TEST ACC 0.6799999475479126
EPOCH 20 | TRAIN BCE LOSS 0.9523292779922485 | TRAIN ACC 0.7262498736381531 | TEST ACC 0.6799999475479126
EPOCH 21 | TRAIN BCE LOSS 0.9518534541130066 | TRAIN ACC 0.7262498736381531 | TEST ACC 0.6799999475479126
EPOCH 22 | TRAIN BCE LOSS 0.9514118432998657 | TRAIN ACC 0.7274999022483826 | TEST ACC 0.6799999475479126
EPOCH 23 | TRAIN BCE LOSS 0.9510008096694946 | TRAIN ACC 0.7274999022483826 | TEST ACC 0.6799999475479126
EPOCH 24 | TRAIN BCE LOSS 0.9506174325942993 | TRAIN ACC 0.7287499308586121 | TEST ACC 0.6799999475479126
EPOCH 25 | TRAIN BCE LOSS 0.9502588510513306 | TRAIN ACC 0.7299999594688416 | TEST ACC 0.684999942779541
EPOCH 26 | TRAIN BCE LOSS 0.9499231576919556 | TRAIN ACC 0.731249988079071 | TEST ACC 0.684999942779541
EPOCH 27 | TRAIN BCE LOSS 0.9496078491210938 | TRAIN ACC 0.7325000166893005 | TEST ACC 0.684999942779541
EPOCH 28 | TRAIN BCE LOSS 0.9493112564086914 | TRAIN ACC 0.7325000166893005 | TEST ACC 0.684999942779541
EPOCH 29 | TRAIN BCE LOSS 0.9490319490432739 | TRAIN ACC 0.73375004529953 | TEST ACC 0.684999942779541
EPOCH 30 | TRAIN BCE LOSS 0.9487684965133667 | TRAIN ACC 0.7350000143051147 | TEST ACC 0.684999942779541
EPOCH 31 | TRAIN BCE LOSS 0.9485197067260742 | TRAIN ACC 0.7350000143051147 | TEST ACC 0.684999942779541
EPOCH 32 | TRAIN BCE LOSS 0.9482845067977905 | TRAIN ACC 0.7350000143051147 | TEST ACC 0.684999942779541
EPOCH 33 | TRAIN BCE LOSS 0.9480619430541992 | TRAIN ACC 0.7362499833106995 | TEST ACC 0.6899999976158142
EPOCH 34 | TRAIN BCE LOSS 0.9478511810302734 | TRAIN ACC 0.7362499833106995 | TEST ACC 0.6899999976158142
EPOCH 35 | TRAIN BCE LOSS 0.9476515054702759 | TRAIN ACC 0.7362499833106995 | TEST ACC 0.6949999332427979
EPOCH 36 | TRAIN BCE LOSS 0.9474622011184692 | TRAIN ACC 0.7362499833106995 | TEST ACC 0.6949999332427979
EPOCH 37 | TRAIN BCE LOSS 0.947282612323761 | TRAIN ACC 0.7362499833106995 | TEST ACC 0.6949999332427979
EPOCH 38 | TRAIN BCE LOSS 0.9471122622489929 | TRAIN ACC 0.737500011920929 | TEST ACC 0.6949999332427979
EPOCH 39 | TRAIN BCE LOSS 0.9469509124755859 | TRAIN ACC 0.737500011920929 | TEST ACC 0.6949999332427979
EPOCH 40 | TRAIN BCE LOSS 0.9467977285385132 | TRAIN ACC 0.737500011920929 | TEST ACC 0.6949999332427979
EPOCH 41 | TRAIN BCE LOSS 0.9466525912284851 | TRAIN ACC 0.737500011920929 | TEST ACC 0.6949999332427979
EPOCH 42 | TRAIN BCE LOSS 0.9465147852897644 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 43 | TRAIN BCE LOSS 0.946384072303772 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 44 | TRAIN BCE LOSS 0.9462602734565735 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 45 | TRAIN BCE LOSS 0.9461425542831421 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 46 | TRAIN BCE LOSS 0.9460309743881226 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 47 | TRAIN BCE LOSS 0.9459251165390015 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 48 | TRAIN BCE LOSS 0.9458244442939758 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 49 | TRAIN BCE LOSS 0.9457284808158875 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 50 | TRAIN BCE LOSS 0.9456372261047363 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 51 | TRAIN BCE LOSS 0.9455501437187195 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 52 | TRAIN BCE LOSS 0.9454669952392578 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 53 | TRAIN BCE LOSS 0.9453871846199036 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 54 | TRAIN BCE LOSS 0.9453107714653015 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 55 | TRAIN BCE LOSS 0.9452372789382935 | TRAIN ACC 0.7387499809265137 | TEST ACC 0.6949999332427979
EPOCH 56 | TRAIN BCE LOSS 0.9451664090156555 | TRAIN ACC 0.7399999499320984 | TEST ACC 0.6949999332427979
EPOCH 57 | TRAIN BCE LOSS 0.945097804069519 | TRAIN ACC 0.7399999499320984 | TEST ACC 0.6949999332427979
EPOCH 58 | TRAIN BCE LOSS 0.9450313448905945 | TRAIN ACC 0.7399999499320984 | TEST ACC 0.6949999332427979
EPOCH 59 | TRAIN BCE LOSS 0.9449666738510132 | TRAIN ACC 0.7399999499320984 | TEST ACC 0.6949999332427979
EPOCH 60 | TRAIN BCE LOSS 0.9449036717414856 | TRAIN ACC 0.7399999499320984 | TEST ACC 0.6949999332427979
EPOCH 61 | TRAIN BCE LOSS 0.9448421597480774 | TRAIN ACC 0.7412499785423279 | TEST ACC 0.6949999332427979
EPOCH 62 | TRAIN BCE LOSS 0.9447816610336304 | TRAIN ACC 0.7412499785423279 | TEST ACC 0.6949999332427979
EPOCH 63 | TRAIN BCE LOSS 0.9447223544120789 | TRAIN ACC 0.7412499785423279 | TEST ACC 0.6949999332427979
EPOCH 64 | TRAIN BCE LOSS 0.9446637034416199 | TRAIN ACC 0.7412499785423279 | TEST ACC 0.6949999332427979
EPOCH 65 | TRAIN BCE LOSS 0.944605827331543 | TRAIN ACC 0.7424999475479126 | TEST ACC 0.6949999332427979
EPOCH 66 | TRAIN BCE LOSS 0.9445484280586243 | TRAIN ACC 0.7437499761581421 | TEST ACC 0.6949999332427979
EPOCH 67 | TRAIN BCE LOSS 0.9444913864135742 | TRAIN ACC 0.7449999451637268 | TEST ACC 0.6949999332427979
EPOCH 68 | TRAIN BCE LOSS 0.944434642791748 | TRAIN ACC 0.7449999451637268 | TEST ACC 0.6949999332427979
EPOCH 69 | TRAIN BCE LOSS 0.9443780183792114 | TRAIN ACC 0.7449999451637268 | TEST ACC 0.6949999332427979
EPOCH 70 | TRAIN BCE LOSS 0.9443214535713196 | TRAIN ACC 0.7449999451637268 | TEST ACC 0.6949999332427979
EPOCH 71 | TRAIN BCE LOSS 0.9442648887634277 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999332427979
EPOCH 72 | TRAIN BCE LOSS 0.9442082643508911 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999332427979
EPOCH 73 | TRAIN BCE LOSS 0.9441514015197754 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999332427979
EPOCH 74 | TRAIN BCE LOSS 0.9440944790840149 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999332427979
EPOCH 75 | TRAIN BCE LOSS 0.9440372586250305 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999332427979
EPOCH 76 | TRAIN BCE LOSS 0.9439799189567566 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999332427979
EPOCH 77 | TRAIN BCE LOSS 0.943922221660614 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999332427979
EPOCH 78 | TRAIN BCE LOSS 0.9438641667366028 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 79 | TRAIN BCE LOSS 0.943805992603302 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 80 | TRAIN BCE LOSS 0.9437476396560669 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 81 | TRAIN BCE LOSS 0.9436887502670288 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 82 | TRAIN BCE LOSS 0.9436299204826355 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 83 | TRAIN BCE LOSS 0.9435707926750183 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 84 | TRAIN BCE LOSS 0.9435116052627563 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 85 | TRAIN BCE LOSS 0.9434520602226257 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 86 | TRAIN BCE LOSS 0.9433925747871399 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 87 | TRAIN BCE LOSS 0.9433329701423645 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 88 | TRAIN BCE LOSS 0.9432734251022339 | TRAIN ACC 0.747499942779541 | TEST ACC 0.6899999976158142
EPOCH 89 | TRAIN BCE LOSS 0.9432138204574585 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 90 | TRAIN BCE LOSS 0.9431542158126831 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 91 | TRAIN BCE LOSS 0.9430948495864868 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 92 | TRAIN BCE LOSS 0.9430354833602905 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 93 | TRAIN BCE LOSS 0.9429763555526733 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 94 | TRAIN BCE LOSS 0.9429174661636353 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 95 | TRAIN BCE LOSS 0.9428586959838867 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 96 | TRAIN BCE LOSS 0.9428003430366516 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 97 | TRAIN BCE LOSS 0.9427421689033508 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 98 | TRAIN BCE LOSS 0.9426843523979187 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6899999976158142
EPOCH 99 | TRAIN BCE LOSS 0.9426270723342896 | TRAIN ACC 0.7462499737739563 | TEST ACC 0.6949999928474426
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model, X_test, y_test)Output:
<Figure size 1200x600 with 2 Axes>

im a fking genius
Evaluation Metrics
Theres a couple of them and are all under torchmetrics
torchmetrics.Accuracyout of 100 predictions, how many were correct?torchmetrics.Precisiontrue positives over sum to true positives and false positives- higher the precision the more true positives there are
torchmetrics.Recalltrue positives over sum to true positives and false negatives- lower the recall the more false negatives there are
If we predict true but its false then is a false positive If we predict false but its true then its a false negative If we predict true and its true then its a true positive If we predict false and its fal then its a true negative
torchmetrics.F1Score()computes a combined metric for precision and recall- 2precisionrecall / (precision + recall)
torchmetrics.ConfusionMatrix()computes the whole TP, TN, FP, FN in a single tabletorchmetrics.Accuracygot some more, but I feel like remembering them is not the best use of time?
Practice
from sklearn.datasets import make_moons
import pandas as pd
import matplotlib.pyplot as plt
X, y = make_moons(n_samples=1000)
moons = pd.DataFrame({
"X1": X[:, 0],
"X2": X[:, 1],
"label": y
})
plt.scatter(x=X[:, 0], y=X[:,1], c=y, cmap=plt.cm.RdYlBu)Output:
<matplotlib.collections.PathCollection at 0x7d6032aa9490>
Output:
<Figure size 640x480 with 1 Axes>

import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
SPLIT = 0.8
BATCH_SIZE = 10
LEARNING_RATE = 0.01
EPOCHS = 150
X = torch.as_tensor(X, dtype=torch.float, device=device)
y = torch.as_tensor(y, dtype=torch.float, device=device)
split_i = int(SPLIT*len(X))
X_train, X_test = X[:split_i], X[split_i:]
y_train, y_test = y[:split_i], y[split_i:]
# Dataloaders
train_dl = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE)
test_dl = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE)
# Model
class MoonBinary(nn.Module):
def __init__(self, device=None):
super().__init__()
self.lin0 = nn.Linear(2, 50, device=device)
self.lin1 = nn.Linear(50, 1, device=device)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.lin0(x))
x = self.lin1(x)
return x.squeeze()
# Model, Loss, Opt
model = MoonBinary(device)
loss_fn = torch.nn.BCEWithLogitsLoss()
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
# Accuracy
def calc_accuracy(y_p, y):
correct = sum(torch.eq(torch.round(torch.sigmoid(y_p)), y))
acc = correct / len(y)
return acc
# Train, Test Loop
for e in range(EPOCHS):
model.train()
for x_b, y_b in train_dl:
pred = model(x_b)
loss = loss_fn(pred, y_b)
opt.zero_grad()
loss.backward()
opt.step() # batched gradient descent BY THE WAY
train_acc = sum(calc_accuracy(model(x_b), y_b) for x_b, y_b in train_dl) / len(train_dl)
with torch.inference_mode():
test_acc = sum(calc_accuracy(model(x_b), y_b) for x_b, y_b in test_dl) / len(test_dl)
print(f"EPOCH {e} | TRAINBCELOSS {loss} | TRAINACC {train_acc} | TESTACC {test_acc}")Output:
EPOCH 0 | TRAINBCELOSS 0.5968034267425537 | TRAINACC 0.8037501573562622 | TESTACC 0.7750000357627869
EPOCH 1 | TRAINBCELOSS 0.5352381467819214 | TRAINACC 0.8262501955032349 | TESTACC 0.8050000071525574
EPOCH 2 | TRAINBCELOSS 0.5016025304794312 | TRAINACC 0.8375000953674316 | TESTACC 0.8149999976158142
EPOCH 3 | TRAINBCELOSS 0.4814481735229492 | TRAINACC 0.8437501788139343 | TESTACC 0.824999988079071
EPOCH 4 | TRAINBCELOSS 0.4681767523288727 | TRAINACC 0.851250171661377 | TESTACC 0.8299999237060547
EPOCH 5 | TRAINBCELOSS 0.4586009085178375 | TRAINACC 0.8562502264976501 | TESTACC 0.8350000381469727
EPOCH 6 | TRAINBCELOSS 0.4511278569698334 | TRAINACC 0.8612502217292786 | TESTACC 0.8399999737739563
EPOCH 7 | TRAINBCELOSS 0.4448629319667816 | TRAINACC 0.8650001883506775 | TESTACC 0.8449999690055847
EPOCH 8 | TRAINBCELOSS 0.43938514590263367 | TRAINACC 0.8675002455711365 | TESTACC 0.8549999594688416
EPOCH 9 | TRAINBCELOSS 0.43448057770729065 | TRAINACC 0.8712501525878906 | TESTACC 0.8549999594688416
EPOCH 10 | TRAINBCELOSS 0.4299914538860321 | TRAINACC 0.8750002980232239 | TESTACC 0.85999995470047
EPOCH 11 | TRAINBCELOSS 0.4257783591747284 | TRAINACC 0.8762502670288086 | TESTACC 0.8649999499320984
EPOCH 12 | TRAINBCELOSS 0.42179355025291443 | TRAINACC 0.8762502670288086 | TESTACC 0.8749999403953552
EPOCH 13 | TRAINBCELOSS 0.41799721121788025 | TRAINACC 0.8800002932548523 | TESTACC 0.8749999403953552
EPOCH 14 | TRAINBCELOSS 0.4144265353679657 | TRAINACC 0.8825003504753113 | TESTACC 0.8749999403953552
EPOCH 15 | TRAINBCELOSS 0.4110715091228485 | TRAINACC 0.8850004076957703 | TESTACC 0.8749999403953552
EPOCH 16 | TRAINBCELOSS 0.4079027771949768 | TRAINACC 0.8862503170967102 | TESTACC 0.8799999356269836
EPOCH 17 | TRAINBCELOSS 0.40492573380470276 | TRAINACC 0.8887503743171692 | TESTACC 0.8799999356269836
EPOCH 18 | TRAINBCELOSS 0.40215036273002625 | TRAINACC 0.8900003433227539 | TESTACC 0.8799999356269836
EPOCH 19 | TRAINBCELOSS 0.399512380361557 | TRAINACC 0.8912503123283386 | TESTACC 0.8849999308586121
EPOCH 20 | TRAINBCELOSS 0.39698395133018494 | TRAINACC 0.8925004005432129 | TESTACC 0.8849999308586121
EPOCH 21 | TRAINBCELOSS 0.39457017183303833 | TRAINACC 0.8950003981590271 | TESTACC 0.8849999308586121
EPOCH 22 | TRAINBCELOSS 0.3922736346721649 | TRAINACC 0.8962503671646118 | TESTACC 0.8849999308586121
EPOCH 23 | TRAINBCELOSS 0.3900493383407593 | TRAINACC 0.8975003361701965 | TESTACC 0.8849999308586121
EPOCH 24 | TRAINBCELOSS 0.3878992199897766 | TRAINACC 0.8987503051757812 | TESTACC 0.8849999308586121
EPOCH 25 | TRAINBCELOSS 0.3858410120010376 | TRAINACC 0.900000274181366 | TESTACC 0.8849999308586121
EPOCH 26 | TRAINBCELOSS 0.38388893008232117 | TRAINACC 0.9025002717971802 | TESTACC 0.8849999308586121
EPOCH 27 | TRAINBCELOSS 0.3819725513458252 | TRAINACC 0.9025002717971802 | TESTACC 0.8849999308586121
EPOCH 28 | TRAINBCELOSS 0.3801158368587494 | TRAINACC 0.9037503600120544 | TESTACC 0.8849999308586121
EPOCH 29 | TRAINBCELOSS 0.3782866597175598 | TRAINACC 0.9037503600120544 | TESTACC 0.8849999308586121
EPOCH 30 | TRAINBCELOSS 0.3764672577381134 | TRAINACC 0.9050002098083496 | TESTACC 0.8899999856948853
EPOCH 31 | TRAINBCELOSS 0.3746752440929413 | TRAINACC 0.9050002098083496 | TESTACC 0.8899999856948853
EPOCH 32 | TRAINBCELOSS 0.37287262082099915 | TRAINACC 0.9062501788139343 | TESTACC 0.8899999856948853
EPOCH 33 | TRAINBCELOSS 0.3710886836051941 | TRAINACC 0.9075002074241638 | TESTACC 0.8899999856948853
EPOCH 34 | TRAINBCELOSS 0.3693262040615082 | TRAINACC 0.9087502360343933 | TESTACC 0.8899999856948853
EPOCH 35 | TRAINBCELOSS 0.3675501346588135 | TRAINACC 0.9087502360343933 | TESTACC 0.8899999856948853
EPOCH 36 | TRAINBCELOSS 0.3658146262168884 | TRAINACC 0.9100003242492676 | TESTACC 0.8899999856948853
EPOCH 37 | TRAINBCELOSS 0.36406633257865906 | TRAINACC 0.9112502336502075 | TESTACC 0.8899999856948853
EPOCH 38 | TRAINBCELOSS 0.3623354434967041 | TRAINACC 0.9125003218650818 | TESTACC 0.8899999856948853
EPOCH 39 | TRAINBCELOSS 0.36061760783195496 | TRAINACC 0.9125003218650818 | TESTACC 0.8899999856948853
EPOCH 40 | TRAINBCELOSS 0.35888874530792236 | TRAINACC 0.9137502908706665 | TESTACC 0.8899999856948853
EPOCH 41 | TRAINBCELOSS 0.3571382462978363 | TRAINACC 0.9150002598762512 | TESTACC 0.8899999856948853
EPOCH 42 | TRAINBCELOSS 0.3553925156593323 | TRAINACC 0.9150002598762512 | TESTACC 0.8899999856948853
EPOCH 43 | TRAINBCELOSS 0.3536626398563385 | TRAINACC 0.9150002598762512 | TESTACC 0.8899999856948853
EPOCH 44 | TRAINBCELOSS 0.3520314693450928 | TRAINACC 0.9175001978874207 | TESTACC 0.8899999856948853
EPOCH 45 | TRAINBCELOSS 0.3503742218017578 | TRAINACC 0.9175001978874207 | TESTACC 0.8899999856948853
EPOCH 46 | TRAINBCELOSS 0.3486935794353485 | TRAINACC 0.9175001978874207 | TESTACC 0.8899999856948853
EPOCH 47 | TRAINBCELOSS 0.34702324867248535 | TRAINACC 0.9175001978874207 | TESTACC 0.8899999856948853
EPOCH 48 | TRAINBCELOSS 0.34534573554992676 | TRAINACC 0.9187502861022949 | TESTACC 0.8949999213218689
EPOCH 49 | TRAINBCELOSS 0.3436354696750641 | TRAINACC 0.9187502861022949 | TESTACC 0.8949999213218689
EPOCH 50 | TRAINBCELOSS 0.34190040826797485 | TRAINACC 0.9187502861022949 | TESTACC 0.8949999213218689
EPOCH 51 | TRAINBCELOSS 0.34012389183044434 | TRAINACC 0.9200003743171692 | TESTACC 0.8949999213218689
EPOCH 52 | TRAINBCELOSS 0.33829447627067566 | TRAINACC 0.9212503433227539 | TESTACC 0.8949999213218689
EPOCH 53 | TRAINBCELOSS 0.33644866943359375 | TRAINACC 0.9212503433227539 | TESTACC 0.8949999213218689
EPOCH 54 | TRAINBCELOSS 0.33460575342178345 | TRAINACC 0.9225003123283386 | TESTACC 0.8949999213218689
EPOCH 55 | TRAINBCELOSS 0.3327866196632385 | TRAINACC 0.9225003123283386 | TESTACC 0.8949999213218689
EPOCH 56 | TRAINBCELOSS 0.3309022784233093 | TRAINACC 0.9237503409385681 | TESTACC 0.8949999213218689
EPOCH 57 | TRAINBCELOSS 0.3290051519870758 | TRAINACC 0.9237503409385681 | TESTACC 0.8949999213218689
EPOCH 58 | TRAINBCELOSS 0.32712945342063904 | TRAINACC 0.9250003099441528 | TESTACC 0.8949999213218689
EPOCH 59 | TRAINBCELOSS 0.32524561882019043 | TRAINACC 0.9250003099441528 | TESTACC 0.8949999213218689
EPOCH 60 | TRAINBCELOSS 0.32334497570991516 | TRAINACC 0.9250003099441528 | TESTACC 0.8949999213218689
EPOCH 61 | TRAINBCELOSS 0.3214181363582611 | TRAINACC 0.9275003671646118 | TESTACC 0.8949999213218689
EPOCH 62 | TRAINBCELOSS 0.3194522261619568 | TRAINACC 0.9275003671646118 | TESTACC 0.8949999213218689
EPOCH 63 | TRAINBCELOSS 0.3175146281719208 | TRAINACC 0.9275003671646118 | TESTACC 0.8949999213218689
EPOCH 64 | TRAINBCELOSS 0.31556645035743713 | TRAINACC 0.9287503361701965 | TESTACC 0.8949999213218689
EPOCH 65 | TRAINBCELOSS 0.3135727047920227 | TRAINACC 0.9300003051757812 | TESTACC 0.8949999213218689
EPOCH 66 | TRAINBCELOSS 0.3115679621696472 | TRAINACC 0.9300003051757812 | TESTACC 0.8949999213218689
EPOCH 67 | TRAINBCELOSS 0.30948781967163086 | TRAINACC 0.931250274181366 | TESTACC 0.8949999213218689
EPOCH 68 | TRAINBCELOSS 0.30735182762145996 | TRAINACC 0.931250274181366 | TESTACC 0.8949999213218689
EPOCH 69 | TRAINBCELOSS 0.30524855852127075 | TRAINACC 0.931250274181366 | TESTACC 0.8999999165534973
EPOCH 70 | TRAINBCELOSS 0.30309659242630005 | TRAINACC 0.9325003027915955 | TESTACC 0.8999999165534973
EPOCH 71 | TRAINBCELOSS 0.3010371923446655 | TRAINACC 0.9325003027915955 | TESTACC 0.8999999165534973
EPOCH 72 | TRAINBCELOSS 0.29891228675842285 | TRAINACC 0.9325003027915955 | TESTACC 0.8999999165534973
EPOCH 73 | TRAINBCELOSS 0.29679107666015625 | TRAINACC 0.9350002408027649 | TESTACC 0.8999999165534973
EPOCH 74 | TRAINBCELOSS 0.29464173316955566 | TRAINACC 0.9350002408027649 | TESTACC 0.8999999165534973
EPOCH 75 | TRAINBCELOSS 0.2924636900424957 | TRAINACC 0.9350002408027649 | TESTACC 0.8999999165534973
EPOCH 76 | TRAINBCELOSS 0.29028794169425964 | TRAINACC 0.9362502098083496 | TESTACC 0.8999999165534973
EPOCH 77 | TRAINBCELOSS 0.2880886495113373 | TRAINACC 0.9362502098083496 | TESTACC 0.9049999117851257
EPOCH 78 | TRAINBCELOSS 0.2858891189098358 | TRAINACC 0.9362502098083496 | TESTACC 0.9049999117851257
EPOCH 79 | TRAINBCELOSS 0.2837630808353424 | TRAINACC 0.9362502098083496 | TESTACC 0.9049999117851257
EPOCH 80 | TRAINBCELOSS 0.28163039684295654 | TRAINACC 0.9362502098083496 | TESTACC 0.9149999618530273
EPOCH 81 | TRAINBCELOSS 0.27946972846984863 | TRAINACC 0.9362502098083496 | TESTACC 0.9149999618530273
EPOCH 82 | TRAINBCELOSS 0.27723976969718933 | TRAINACC 0.9362502098083496 | TESTACC 0.9149999618530273
EPOCH 83 | TRAINBCELOSS 0.27503716945648193 | TRAINACC 0.9387502670288086 | TESTACC 0.9149999618530273
EPOCH 84 | TRAINBCELOSS 0.2727715075016022 | TRAINACC 0.9387502670288086 | TESTACC 0.9149999618530273
EPOCH 85 | TRAINBCELOSS 0.27054059505462646 | TRAINACC 0.9387502670288086 | TESTACC 0.9149999618530273
EPOCH 86 | TRAINBCELOSS 0.2683239281177521 | TRAINACC 0.9400002360343933 | TESTACC 0.9200000166893005
EPOCH 87 | TRAINBCELOSS 0.2661179006099701 | TRAINACC 0.9400002360343933 | TESTACC 0.9200000166893005
EPOCH 88 | TRAINBCELOSS 0.26390373706817627 | TRAINACC 0.9400002360343933 | TESTACC 0.9200000166893005
EPOCH 89 | TRAINBCELOSS 0.2616461515426636 | TRAINACC 0.9425002336502075 | TESTACC 0.9200000166893005
EPOCH 90 | TRAINBCELOSS 0.25943443179130554 | TRAINACC 0.9425002336502075 | TESTACC 0.9200000166893005
EPOCH 91 | TRAINBCELOSS 0.2572488486766815 | TRAINACC 0.9425002336502075 | TESTACC 0.9200000166893005
EPOCH 92 | TRAINBCELOSS 0.2550351619720459 | TRAINACC 0.945000171661377 | TESTACC 0.9200000166893005
EPOCH 93 | TRAINBCELOSS 0.25290781259536743 | TRAINACC 0.945000171661377 | TESTACC 0.9200000166893005
EPOCH 94 | TRAINBCELOSS 0.25070762634277344 | TRAINACC 0.945000171661377 | TESTACC 0.9200000166893005
EPOCH 95 | TRAINBCELOSS 0.24849490821361542 | TRAINACC 0.9462501406669617 | TESTACC 0.9249998927116394
EPOCH 96 | TRAINBCELOSS 0.24625921249389648 | TRAINACC 0.9462501406669617 | TESTACC 0.9249998927116394
EPOCH 97 | TRAINBCELOSS 0.24404366314411163 | TRAINACC 0.9462501406669617 | TESTACC 0.9249998927116394
EPOCH 98 | TRAINBCELOSS 0.24185048043727875 | TRAINACC 0.9475001692771912 | TESTACC 0.9299999475479126
EPOCH 99 | TRAINBCELOSS 0.23956310749053955 | TRAINACC 0.9475001692771912 | TESTACC 0.9299999475479126
EPOCH 100 | TRAINBCELOSS 0.23733022809028625 | TRAINACC 0.9500002264976501 | TESTACC 0.9299999475479126
EPOCH 101 | TRAINBCELOSS 0.23511801660060883 | TRAINACC 0.9500002264976501 | TESTACC 0.9299999475479126
EPOCH 102 | TRAINBCELOSS 0.23281097412109375 | TRAINACC 0.9500002264976501 | TESTACC 0.9299999475479126
EPOCH 103 | TRAINBCELOSS 0.23061347007751465 | TRAINACC 0.9525001645088196 | TESTACC 0.9299999475479126
EPOCH 104 | TRAINBCELOSS 0.2284357100725174 | TRAINACC 0.9525001645088196 | TESTACC 0.9299999475479126
EPOCH 105 | TRAINBCELOSS 0.22621750831604004 | TRAINACC 0.9525001645088196 | TESTACC 0.9299999475479126
EPOCH 106 | TRAINBCELOSS 0.22403974831104279 | TRAINACC 0.955000102519989 | TESTACC 0.9299999475479126
EPOCH 107 | TRAINBCELOSS 0.22188666462898254 | TRAINACC 0.955000102519989 | TESTACC 0.9299999475479126
EPOCH 108 | TRAINBCELOSS 0.2197176069021225 | TRAINACC 0.955000102519989 | TESTACC 0.934999942779541
EPOCH 109 | TRAINBCELOSS 0.2175581306219101 | TRAINACC 0.9562501907348633 | TESTACC 0.934999942779541
EPOCH 110 | TRAINBCELOSS 0.21543851494789124 | TRAINACC 0.9562501907348633 | TESTACC 0.934999942779541
EPOCH 111 | TRAINBCELOSS 0.21338076889514923 | TRAINACC 0.9575001001358032 | TESTACC 0.934999942779541
EPOCH 112 | TRAINBCELOSS 0.21126461029052734 | TRAINACC 0.9587500691413879 | TESTACC 0.934999942779541
EPOCH 113 | TRAINBCELOSS 0.20910589396953583 | TRAINACC 0.9600001573562622 | TESTACC 0.934999942779541
EPOCH 114 | TRAINBCELOSS 0.20704428851604462 | TRAINACC 0.9600001573562622 | TESTACC 0.934999942779541
EPOCH 115 | TRAINBCELOSS 0.20496883988380432 | TRAINACC 0.9600001573562622 | TESTACC 0.9399999976158142
EPOCH 116 | TRAINBCELOSS 0.20281553268432617 | TRAINACC 0.9612501263618469 | TESTACC 0.9399999976158142
EPOCH 117 | TRAINBCELOSS 0.2008403092622757 | TRAINACC 0.9625000953674316 | TESTACC 0.9399999976158142
EPOCH 118 | TRAINBCELOSS 0.19874177873134613 | TRAINACC 0.9625000953674316 | TESTACC 0.9449998736381531
EPOCH 119 | TRAINBCELOSS 0.19678282737731934 | TRAINACC 0.9625000953674316 | TESTACC 0.9449998736381531
EPOCH 120 | TRAINBCELOSS 0.1947566121816635 | TRAINACC 0.9637500643730164 | TESTACC 0.9499999284744263
EPOCH 121 | TRAINBCELOSS 0.1927836388349533 | TRAINACC 0.9637500643730164 | TESTACC 0.9499999284744263
EPOCH 122 | TRAINBCELOSS 0.19083981215953827 | TRAINACC 0.9637500643730164 | TESTACC 0.9499999284744263
EPOCH 123 | TRAINBCELOSS 0.18880391120910645 | TRAINACC 0.9662501215934753 | TESTACC 0.9499999284744263
EPOCH 124 | TRAINBCELOSS 0.18693207204341888 | TRAINACC 0.9662501215934753 | TESTACC 0.9499999284744263
EPOCH 125 | TRAINBCELOSS 0.1849227249622345 | TRAINACC 0.9675000309944153 | TESTACC 0.9499999284744263
EPOCH 126 | TRAINBCELOSS 0.18305471539497375 | TRAINACC 0.9687501192092896 | TESTACC 0.9499999284744263
EPOCH 127 | TRAINBCELOSS 0.18110330402851105 | TRAINACC 0.9687501192092896 | TESTACC 0.9499999284744263
EPOCH 128 | TRAINBCELOSS 0.17923371493816376 | TRAINACC 0.9700002074241638 | TESTACC 0.9499999284744263
EPOCH 129 | TRAINBCELOSS 0.17735783755779266 | TRAINACC 0.971250057220459 | TESTACC 0.9499999284744263
EPOCH 130 | TRAINBCELOSS 0.1755225509405136 | TRAINACC 0.971250057220459 | TESTACC 0.9499999284744263
EPOCH 131 | TRAINBCELOSS 0.17361319065093994 | TRAINACC 0.9725001454353333 | TESTACC 0.9499999284744263
EPOCH 132 | TRAINBCELOSS 0.17186909914016724 | TRAINACC 0.9725001454353333 | TESTACC 0.9549999237060547
EPOCH 133 | TRAINBCELOSS 0.16997766494750977 | TRAINACC 0.9725001454353333 | TESTACC 0.9599999785423279
EPOCH 134 | TRAINBCELOSS 0.16823987662792206 | TRAINACC 0.9725001454353333 | TESTACC 0.9599999785423279
EPOCH 135 | TRAINBCELOSS 0.16641424596309662 | TRAINACC 0.9725001454353333 | TESTACC 0.9599999785423279
EPOCH 136 | TRAINBCELOSS 0.16464504599571228 | TRAINACC 0.9750000834465027 | TESTACC 0.9599999785423279
EPOCH 137 | TRAINBCELOSS 0.16291101276874542 | TRAINACC 0.9750000834465027 | TESTACC 0.9599999785423279
EPOCH 138 | TRAINBCELOSS 0.16117341816425323 | TRAINACC 0.9750000834465027 | TESTACC 0.9599999785423279
EPOCH 139 | TRAINBCELOSS 0.15954437851905823 | TRAINACC 0.976250171661377 | TESTACC 0.9649999737739563
EPOCH 140 | TRAINBCELOSS 0.15789483487606049 | TRAINACC 0.976250171661377 | TESTACC 0.9649999737739563
EPOCH 141 | TRAINBCELOSS 0.15619178116321564 | TRAINACC 0.976250171661377 | TESTACC 0.9649999737739563
EPOCH 142 | TRAINBCELOSS 0.15460099279880524 | TRAINACC 0.9775000810623169 | TESTACC 0.9649999737739563
EPOCH 143 | TRAINBCELOSS 0.15297706425189972 | TRAINACC 0.9787501692771912 | TESTACC 0.9649999737739563
EPOCH 144 | TRAINBCELOSS 0.15131308138370514 | TRAINACC 0.9800001382827759 | TESTACC 0.9649999737739563
EPOCH 145 | TRAINBCELOSS 0.14975881576538086 | TRAINACC 0.9800001382827759 | TESTACC 0.9649999737739563
EPOCH 146 | TRAINBCELOSS 0.14820724725723267 | TRAINACC 0.9812501072883606 | TESTACC 0.9649999737739563
EPOCH 147 | TRAINBCELOSS 0.14667795598506927 | TRAINACC 0.9825000762939453 | TESTACC 0.9649999737739563
EPOCH 148 | TRAINBCELOSS 0.14517246186733246 | TRAINACC 0.9825000762939453 | TESTACC 0.9649999737739563
EPOCH 149 | TRAINBCELOSS 0.14367911219596863 | TRAINACC 0.9825000762939453 | TESTACC 0.9649999737739563
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model, X_test, y_test)Output:
<Figure size 1200x600 with 2 Axes>

More Practise
# Code for creating a spiral dataset from CS231n
import numpy as np
import matplotlib.pyplot as plt
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K,D)) # data matrix (each row = single example)
y = np.zeros(N*K, dtype='uint8') # class labels
for j in range(K):
ix = range(N*j,N*(j+1))
r = np.linspace(0.0,1,N) # radius
t = np.linspace(j*4,(j+1)*4,N) + np.random.randn(N)*0.2 # theta
X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
y[ix] = j
# lets visualize the data
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
plt.show()Output:
<Figure size 640x480 with 1 Axes>

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'
SPLIT = 0.8
LEARNING_RATE = 0.05
EPOCHS = 200
EVAL_INTERVAL = 2
BATCH_SIZE = 10
X = torch.as_tensor(X, dtype=torch.float32, device=device)
y = torch.as_tensor(y, dtype=torch.long, device=device)
random_shuffle = torch.randperm(len(X)) # TORCH.RANDPERM TO SHUFFLE!!!!!!!
X = X[random_shuffle]
y = y[random_shuffle]
split_i = int(SPLIT*len(X))
X_train, X_test = X[:split_i], X[split_i:]
y_train, y_test = y[:split_i], y[split_i:]
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE)
# Model, Loss, Opt
class SpiralClassification(nn.Module):
def __init__(self, device=None):
super().__init__()
self.lin0 = nn.Linear(2, 40, device=device)
self.lin1 = nn.Linear(40, K, device=device)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.lin0(x))
x = self.lin1(x)
return x
model = SpiralClassification(device)
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
# Train, Test Loop
def calc_acc(y_pred, y):
y_pred_argmax = torch.softmax(y_pred, dim=1).argmax(dim=1)
acc = sum(torch.eq(y_pred_argmax.squeeze(), y)) / len(y)
return acc
def train(model, loss_fn, opt, train_dl):
model.train()
train_loss = []
train_acc = []
for x_b, y_b in train_dl:
pred = model(x_b)
loss = loss_fn(pred, y_b)
train_loss.append(loss)
train_acc.append(calc_acc(pred, y_b))
opt.zero_grad()
loss.backward()
opt.step()
train_loss_avg = sum(train_loss) / len(train_loss)
train_acc_avg = sum(train_acc) / len(train_acc)
print(f"TRAIN LOSS {train_loss_avg} | TRAIN ACC {train_acc_avg}")
def test(model, test_dl):
model.eval()
with torch.inference_mode():
test_acc = []
for x_b, y_b in test_dl:
pred = model(x_b)
test_acc.append(calc_acc(pred, y_b))
test_acc_avg = sum(test_acc) / len(test_acc)
print(f"TEST ACC {test_acc_avg}")
for e in range(EPOCHS):
train(model, loss_fn, opt, train_dl)
if e % EVAL_INTERVAL == 0:
test(model, test_dl)Output:
TRAIN LOSS 1.0657037496566772 | TRAIN ACC 0.4208333492279053
TEST ACC 0.5833333730697632
TRAIN LOSS 0.9841545820236206 | TRAIN ACC 0.5625001192092896
TRAIN LOSS 0.9259986877441406 | TRAIN ACC 0.5583333969116211
TEST ACC 0.5333333611488342
TRAIN LOSS 0.8823034167289734 | TRAIN ACC 0.5375000238418579
TRAIN LOSS 0.8489739298820496 | TRAIN ACC 0.5291666984558105
TEST ACC 0.5666667222976685
TRAIN LOSS 0.8232254981994629 | TRAIN ACC 0.5250000357627869
TRAIN LOSS 0.8030401468276978 | TRAIN ACC 0.5291666984558105
TEST ACC 0.5666667222976685
TRAIN LOSS 0.7871977686882019 | TRAIN ACC 0.5250000357627869
TRAIN LOSS 0.7745394706726074 | TRAIN ACC 0.5333333611488342
TEST ACC 0.5833333730697632
TRAIN LOSS 0.764284610748291 | TRAIN ACC 0.5333333611488342
TRAIN LOSS 0.7558233737945557 | TRAIN ACC 0.5333333611488342
TEST ACC 0.6000000238418579
TRAIN LOSS 0.7486969828605652 | TRAIN ACC 0.5333333611488342
TRAIN LOSS 0.7425075769424438 | TRAIN ACC 0.5375000238418579
TEST ACC 0.5833333730697632
TRAIN LOSS 0.7370821833610535 | TRAIN ACC 0.5375000238418579
TRAIN LOSS 0.7322777509689331 | TRAIN ACC 0.5416666865348816
TEST ACC 0.5833333730697632
TRAIN LOSS 0.7279372215270996 | TRAIN ACC 0.5416666865348816
TRAIN LOSS 0.7239362597465515 | TRAIN ACC 0.5416666865348816
TEST ACC 0.5833333730697632
TRAIN LOSS 0.7202266454696655 | TRAIN ACC 0.5458333492279053
TRAIN LOSS 0.7167061567306519 | TRAIN ACC 0.5416667461395264
TEST ACC 0.6166666746139526
TRAIN LOSS 0.7133333683013916 | TRAIN ACC 0.5416667461395264
TRAIN LOSS 0.7100488543510437 | TRAIN ACC 0.5458333492279053
TEST ACC 0.6166666746139526
TRAIN LOSS 0.7067779302597046 | TRAIN ACC 0.5458333492279053
TRAIN LOSS 0.7035346031188965 | TRAIN ACC 0.5458333492279053
TEST ACC 0.6333333253860474
TRAIN LOSS 0.7003307342529297 | TRAIN ACC 0.5458333492279053
TRAIN LOSS 0.6971020102500916 | TRAIN ACC 0.5458333492279053
TEST ACC 0.6333333253860474
TRAIN LOSS 0.6938474178314209 | TRAIN ACC 0.5458333492279053
TRAIN LOSS 0.690616250038147 | TRAIN ACC 0.5500000715255737
TEST ACC 0.6333333253860474
TRAIN LOSS 0.6873027086257935 | TRAIN ACC 0.5625
TRAIN LOSS 0.6839281320571899 | TRAIN ACC 0.5750000476837158
TEST ACC 0.6500000357627869
TRAIN LOSS 0.6805365085601807 | TRAIN ACC 0.5750000476837158
TRAIN LOSS 0.6771494746208191 | TRAIN ACC 0.5750000476837158
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6737172603607178 | TRAIN ACC 0.5750000476837158
TRAIN LOSS 0.6703263521194458 | TRAIN ACC 0.5791667103767395
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6668359637260437 | TRAIN ACC 0.5791667103767395
TRAIN LOSS 0.6634132862091064 | TRAIN ACC 0.5750000476837158
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6599887609481812 | TRAIN ACC 0.5750000476837158
TRAIN LOSS 0.6565784811973572 | TRAIN ACC 0.5791667103767395
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6531292796134949 | TRAIN ACC 0.5791667103767395
TRAIN LOSS 0.6497141718864441 | TRAIN ACC 0.5791667103767395
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6462591886520386 | TRAIN ACC 0.5916666984558105
TRAIN LOSS 0.6427377462387085 | TRAIN ACC 0.595833420753479
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6392015218734741 | TRAIN ACC 0.595833420753479
TRAIN LOSS 0.635694682598114 | TRAIN ACC 0.6000000834465027
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6322484016418457 | TRAIN ACC 0.6000000834465027
TRAIN LOSS 0.6287260055541992 | TRAIN ACC 0.6125000715255737
TEST ACC 0.6666666865348816
TRAIN LOSS 0.6251420974731445 | TRAIN ACC 0.6125000715255737
TRAIN LOSS 0.6216278076171875 | TRAIN ACC 0.6208333969116211
TEST ACC 0.6833333373069763
TRAIN LOSS 0.6180787086486816 | TRAIN ACC 0.6208333969116211
TRAIN LOSS 0.6145181655883789 | TRAIN ACC 0.625
TEST ACC 0.7000000476837158
TRAIN LOSS 0.610988199710846 | TRAIN ACC 0.6291666626930237
TRAIN LOSS 0.6073516607284546 | TRAIN ACC 0.6291666626930237
TEST ACC 0.7000000476837158
TRAIN LOSS 0.603826642036438 | TRAIN ACC 0.6291666626930237
TRAIN LOSS 0.600206196308136 | TRAIN ACC 0.6291666626930237
TEST ACC 0.7000000476837158
TRAIN LOSS 0.5965393781661987 | TRAIN ACC 0.6333333253860474
TRAIN LOSS 0.5928676128387451 | TRAIN ACC 0.6458333134651184
TEST ACC 0.7000000476837158
TRAIN LOSS 0.5892030000686646 | TRAIN ACC 0.6583333015441895
TRAIN LOSS 0.5854144096374512 | TRAIN ACC 0.6583333015441895
TEST ACC 0.7166666984558105
TRAIN LOSS 0.5816792845726013 | TRAIN ACC 0.6625000238418579
TRAIN LOSS 0.5779438018798828 | TRAIN ACC 0.6666666269302368
TEST ACC 0.7166666984558105
TRAIN LOSS 0.5741564035415649 | TRAIN ACC 0.6666666269302368
TRAIN LOSS 0.5702852010726929 | TRAIN ACC 0.6708333492279053
TEST ACC 0.7333333492279053
TRAIN LOSS 0.5664829611778259 | TRAIN ACC 0.6874999403953552
TRAIN LOSS 0.5625776052474976 | TRAIN ACC 0.6874999403953552
TEST ACC 0.75
TRAIN LOSS 0.5587122440338135 | TRAIN ACC 0.699999988079071
TRAIN LOSS 0.5548640489578247 | TRAIN ACC 0.699999988079071
TEST ACC 0.75
TRAIN LOSS 0.5509438514709473 | TRAIN ACC 0.70416659116745
TRAIN LOSS 0.5470256805419922 | TRAIN ACC 0.70416659116745
TEST ACC 0.75
TRAIN LOSS 0.5430793762207031 | TRAIN ACC 0.70416659116745
TRAIN LOSS 0.539167046546936 | TRAIN ACC 0.70416659116745
TEST ACC 0.75
TRAIN LOSS 0.5352030396461487 | TRAIN ACC 0.70416659116745
TRAIN LOSS 0.5311792492866516 | TRAIN ACC 0.70416659116745
TEST ACC 0.7666666507720947
TRAIN LOSS 0.5272237062454224 | TRAIN ACC 0.70416659116745
TRAIN LOSS 0.5231789350509644 | TRAIN ACC 0.7124999761581421
TEST ACC 0.783333420753479
TRAIN LOSS 0.5191644430160522 | TRAIN ACC 0.7166666388511658
TRAIN LOSS 0.5151196718215942 | TRAIN ACC 0.7250000238418579
TEST ACC 0.783333420753479
TRAIN LOSS 0.5111501812934875 | TRAIN ACC 0.7291666865348816
TRAIN LOSS 0.507134199142456 | TRAIN ACC 0.7291666865348816
TEST ACC 0.8000000715255737
TRAIN LOSS 0.5030743479728699 | TRAIN ACC 0.7250000238418579
TRAIN LOSS 0.4991127848625183 | TRAIN ACC 0.7291666865348816
TEST ACC 0.8166667819023132
TRAIN LOSS 0.4950891137123108 | TRAIN ACC 0.7291666865348816
TRAIN LOSS 0.49107861518859863 | TRAIN ACC 0.7291666865348816
TEST ACC 0.8166667819023132
TRAIN LOSS 0.48709774017333984 | TRAIN ACC 0.7333332896232605
TRAIN LOSS 0.4832792282104492 | TRAIN ACC 0.7374999523162842
TEST ACC 0.8166667819023132
TRAIN LOSS 0.4792233109474182 | TRAIN ACC 0.7458333373069763
TRAIN LOSS 0.4753933846950531 | TRAIN ACC 0.7458333373069763
TEST ACC 0.8166667819023132
TRAIN LOSS 0.4714268147945404 | TRAIN ACC 0.7458333373069763
TRAIN LOSS 0.4676346480846405 | TRAIN ACC 0.7499999403953552
TEST ACC 0.8166667819023132
TRAIN LOSS 0.4637102782726288 | TRAIN ACC 0.7541666030883789
TRAIN LOSS 0.4599130153656006 | TRAIN ACC 0.7541666030883789
TEST ACC 0.8166667819023132
TRAIN LOSS 0.45611584186553955 | TRAIN ACC 0.7583333253860474
TRAIN LOSS 0.452392578125 | TRAIN ACC 0.7583333253860474
TEST ACC 0.8166667819023132
TRAIN LOSS 0.44859811663627625 | TRAIN ACC 0.7583333253860474
TRAIN LOSS 0.4449712634086609 | TRAIN ACC 0.7583333253860474
TEST ACC 0.8166667819023132
TRAIN LOSS 0.44137078523635864 | TRAIN ACC 0.762499988079071
TRAIN LOSS 0.43772754073143005 | TRAIN ACC 0.7666666507720947
TEST ACC 0.8166667819023132
TRAIN LOSS 0.43409860134124756 | TRAIN ACC 0.7666666507720947
TRAIN LOSS 0.4306368827819824 | TRAIN ACC 0.7666666507720947
TEST ACC 0.8166667819023132
TRAIN LOSS 0.4270704984664917 | TRAIN ACC 0.7666666507720947
TRAIN LOSS 0.4236374795436859 | TRAIN ACC 0.7708332538604736
TEST ACC 0.8166667819023132
TRAIN LOSS 0.4202161729335785 | TRAIN ACC 0.7708332538604736
TRAIN LOSS 0.41684743762016296 | TRAIN ACC 0.7749999761581421
TEST ACC 0.8166667819023132
TRAIN LOSS 0.41342979669570923 | TRAIN ACC 0.7791666388511658
TRAIN LOSS 0.4101482331752777 | TRAIN ACC 0.7833332419395447
TEST ACC 0.8166667819023132
TRAIN LOSS 0.40689608454704285 | TRAIN ACC 0.7874999046325684
TRAIN LOSS 0.4037189483642578 | TRAIN ACC 0.7916666269302368
TEST ACC 0.8166667819023132
TRAIN LOSS 0.4005892276763916 | TRAIN ACC 0.7916666269302368
TRAIN LOSS 0.3974658250808716 | TRAIN ACC 0.7916666269302368
TEST ACC 0.833333432674408
TRAIN LOSS 0.394326776266098 | TRAIN ACC 0.7916666269302368
TRAIN LOSS 0.39123600721359253 | TRAIN ACC 0.7999999523162842
TEST ACC 0.8500000834465027
TRAIN LOSS 0.3881646990776062 | TRAIN ACC 0.7958332896232605
TRAIN LOSS 0.3850294351577759 | TRAIN ACC 0.8041666746139526
TEST ACC 0.8500000834465027
TRAIN LOSS 0.38199231028556824 | TRAIN ACC 0.8124998807907104
TRAIN LOSS 0.3790113627910614 | TRAIN ACC 0.8166665434837341
TEST ACC 0.8500000834465027
TRAIN LOSS 0.37602147459983826 | TRAIN ACC 0.8166665434837341
TRAIN LOSS 0.37313589453697205 | TRAIN ACC 0.8166665434837341
TEST ACC 0.8666667342185974
TRAIN LOSS 0.37029194831848145 | TRAIN ACC 0.8166665434837341
TRAIN LOSS 0.36742323637008667 | TRAIN ACC 0.8208332061767578
TEST ACC 0.8666667342185974
TRAIN LOSS 0.3646599054336548 | TRAIN ACC 0.8249999284744263
TRAIN LOSS 0.3619324862957001 | TRAIN ACC 0.8249999284744263
TEST ACC 0.8666667342185974
TRAIN LOSS 0.35922738909721375 | TRAIN ACC 0.8333332538604736
TRAIN LOSS 0.356632798910141 | TRAIN ACC 0.8333332538604736
TEST ACC 0.8666667342185974
TRAIN LOSS 0.35400497913360596 | TRAIN ACC 0.8333332538604736
TRAIN LOSS 0.3514353930950165 | TRAIN ACC 0.8333332538604736
TEST ACC 0.8666667342185974
TRAIN LOSS 0.34901154041290283 | TRAIN ACC 0.8333332538604736
TRAIN LOSS 0.34646713733673096 | TRAIN ACC 0.8333332538604736
TEST ACC 0.8833333849906921
TRAIN LOSS 0.3440144956111908 | TRAIN ACC 0.8333332538604736
TRAIN LOSS 0.34156033396720886 | TRAIN ACC 0.8333332538604736
TEST ACC 0.9000000357627869
TRAIN LOSS 0.33921581506729126 | TRAIN ACC 0.8333332538604736
TRAIN LOSS 0.336823046207428 | TRAIN ACC 0.8333332538604736
TEST ACC 0.9166666865348816
TRAIN LOSS 0.3345527648925781 | TRAIN ACC 0.8333332538604736
TRAIN LOSS 0.3322502076625824 | TRAIN ACC 0.8333332538604736
TEST ACC 0.9166666865348816
TRAIN LOSS 0.330020934343338 | TRAIN ACC 0.8374998569488525
TRAIN LOSS 0.3278135061264038 | TRAIN ACC 0.841666579246521
TEST ACC 0.9166666865348816
TRAIN LOSS 0.3256322741508484 | TRAIN ACC 0.841666579246521
TRAIN LOSS 0.32348328828811646 | TRAIN ACC 0.8499999046325684
TEST ACC 0.9333333969116211
TRAIN LOSS 0.32139039039611816 | TRAIN ACC 0.8499999046325684
TRAIN LOSS 0.31929320096969604 | TRAIN ACC 0.8499999046325684
TEST ACC 0.9333333969116211
TRAIN LOSS 0.31722623109817505 | TRAIN ACC 0.8499999046325684
TRAIN LOSS 0.31518983840942383 | TRAIN ACC 0.8499999046325684
TEST ACC 0.9333333969116211
TRAIN LOSS 0.3132067620754242 | TRAIN ACC 0.8541666269302368
TRAIN LOSS 0.3112490475177765 | TRAIN ACC 0.8541666269302368
TEST ACC 0.9333333969116211
TRAIN LOSS 0.30930984020233154 | TRAIN ACC 0.8541666269302368
TRAIN LOSS 0.3073555827140808 | TRAIN ACC 0.8624998927116394
TEST ACC 0.9333333969116211
TRAIN LOSS 0.30552464723587036 | TRAIN ACC 0.8624998927116394
TRAIN LOSS 0.30361059308052063 | TRAIN ACC 0.8624998927116394
TEST ACC 0.9333333969116211
TRAIN LOSS 0.30183541774749756 | TRAIN ACC 0.8624998927116394
TRAIN LOSS 0.3000202775001526 | TRAIN ACC 0.8624998927116394
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2982179820537567 | TRAIN ACC 0.8666665554046631
TRAIN LOSS 0.29648905992507935 | TRAIN ACC 0.8666665554046631
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2947505712509155 | TRAIN ACC 0.8666665554046631
TRAIN LOSS 0.293094277381897 | TRAIN ACC 0.8708332777023315
TEST ACC 0.9333333969116211
TRAIN LOSS 0.29142141342163086 | TRAIN ACC 0.8708332777023315
TRAIN LOSS 0.2897477447986603 | TRAIN ACC 0.8708332777023315
TEST ACC 0.9333333969116211
TRAIN LOSS 0.28814563155174255 | TRAIN ACC 0.8708332777023315
TRAIN LOSS 0.2865486741065979 | TRAIN ACC 0.8708332777023315
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2849707007408142 | TRAIN ACC 0.8708332777023315
TRAIN LOSS 0.28341713547706604 | TRAIN ACC 0.8708332777023315
TEST ACC 0.9333333969116211
TRAIN LOSS 0.28188595175743103 | TRAIN ACC 0.8708332777023315
TRAIN LOSS 0.28036460280418396 | TRAIN ACC 0.8749999403953552
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2788779139518738 | TRAIN ACC 0.8749999403953552
TRAIN LOSS 0.2774117887020111 | TRAIN ACC 0.8749999403953552
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2759571373462677 | TRAIN ACC 0.8749999403953552
TRAIN LOSS 0.27450570464134216 | TRAIN ACC 0.8833332061767578
TEST ACC 0.9333333969116211
TRAIN LOSS 0.27307623624801636 | TRAIN ACC 0.8874998092651367
TRAIN LOSS 0.27169013023376465 | TRAIN ACC 0.8916665315628052
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2702730596065521 | TRAIN ACC 0.8916665315628052
TRAIN LOSS 0.26887303590774536 | TRAIN ACC 0.8916665315628052
TEST ACC 0.9333333969116211
TRAIN LOSS 0.26749542355537415 | TRAIN ACC 0.8916665315628052
TRAIN LOSS 0.26612573862075806 | TRAIN ACC 0.8916665315628052
TEST ACC 0.9333333969116211
TRAIN LOSS 0.26477378606796265 | TRAIN ACC 0.8916665315628052
TRAIN LOSS 0.2634212374687195 | TRAIN ACC 0.8916665315628052
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2621036171913147 | TRAIN ACC 0.8916665315628052
TRAIN LOSS 0.2608124613761902 | TRAIN ACC 0.8916665315628052
TEST ACC 0.9333333969116211
TRAIN LOSS 0.25951993465423584 | TRAIN ACC 0.8916665315628052
TRAIN LOSS 0.25825217366218567 | TRAIN ACC 0.8958331942558289
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2569856345653534 | TRAIN ACC 0.8999998569488525
TRAIN LOSS 0.2557488679885864 | TRAIN ACC 0.8999998569488525
TEST ACC 0.9333333969116211
TRAIN LOSS 0.254518061876297 | TRAIN ACC 0.8999998569488525
TRAIN LOSS 0.25332367420196533 | TRAIN ACC 0.904166579246521
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2521016001701355 | TRAIN ACC 0.904166579246521
TRAIN LOSS 0.2508886754512787 | TRAIN ACC 0.904166579246521
TEST ACC 0.9333333969116211
TRAIN LOSS 0.24968427419662476 | TRAIN ACC 0.904166579246521
TRAIN LOSS 0.2485085427761078 | TRAIN ACC 0.904166579246521
TEST ACC 0.9333333969116211
TRAIN LOSS 0.24733632802963257 | TRAIN ACC 0.904166579246521
TRAIN LOSS 0.24619154632091522 | TRAIN ACC 0.904166579246521
TEST ACC 0.9333333969116211
TRAIN LOSS 0.24508240818977356 | TRAIN ACC 0.904166579246521
TRAIN LOSS 0.2439526915550232 | TRAIN ACC 0.904166579246521
TEST ACC 0.9333333969116211
TRAIN LOSS 0.24283580482006073 | TRAIN ACC 0.904166579246521
TRAIN LOSS 0.2417338341474533 | TRAIN ACC 0.904166579246521
TEST ACC 0.9333333969116211
TRAIN LOSS 0.24064365029335022 | TRAIN ACC 0.9124998450279236
TRAIN LOSS 0.23956865072250366 | TRAIN ACC 0.9124998450279236
TEST ACC 0.9333333969116211
TRAIN LOSS 0.23850569128990173 | TRAIN ACC 0.9124998450279236
TRAIN LOSS 0.23745763301849365 | TRAIN ACC 0.9124998450279236
TEST ACC 0.9333333969116211
TRAIN LOSS 0.2364090383052826 | TRAIN ACC 0.9124998450279236
TRAIN LOSS 0.23538100719451904 | TRAIN ACC 0.9124998450279236
TEST ACC 0.9333333969116211
TRAIN LOSS 0.23436205089092255 | TRAIN ACC 0.9124998450279236
TRAIN LOSS 0.23334352672100067 | TRAIN ACC 0.9166665077209473
TEST ACC 0.9333333969116211
TRAIN LOSS 0.23234805464744568 | TRAIN ACC 0.9208332300186157
TRAIN LOSS 0.23137104511260986 | TRAIN ACC 0.9208332300186157
TEST ACC 0.9333333969116211
TRAIN LOSS 0.23038282990455627 | TRAIN ACC 0.9208332300186157
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model, X_test, y_test)Output:
<Figure size 1200x600 with 2 Axes>

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
SPLIT = 0.8
BATCH_SIZE = 32
LEARNING_RATE = 0.05
EPOCHS = 1000
INTERVAL = 2
X = torch.as_tensor(X, device='cpu')
y = torch.as_tensor(y, device='cpu')
shuffle = torch.randperm(len(X))
X = X[shuffle]
y = y[shuffle]
split_i = int(SPLIT*len(X))
X_train, X_test = X[:split_i], X[split_i:]
y_train, y_test = y[:split_i], y[split_i:]
train_dl = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE)
test_dl = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE)
class Spiral(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(2, 30)
self.lin1 = nn.Linear(30, K)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.lin0(x))
x = self.lin1(x)
return x
def calc_acc(y_pred, y):
y_pred_argmax = torch.softmax(y_pred, dim=1).argmax(dim=1)
acc = sum(torch.eq(y_pred_argmax, y))/len(y)
return acc
def train(model, loss_fn, opt, train_dl):
model.train()
losses = []
accs = []
for x_b, y_b in train_dl:
pred = model(x_b)
loss = loss_fn(pred, y_b)
losses.append(loss)
accs.append(calc_acc(pred, y_b))
opt.zero_grad()
loss.backward()
opt.step()
acc_avg = sum(accs) / len(accs)
loss_avg = sum(losses) / len(losses)
print(f"train acc {acc_avg} | loss_avg {loss_avg}")
def test(model, test_dl):
model.eval()
with torch.inference_mode():
accs = []
for x_b, y_b in test_dl:
pred = model(x_b)
accs.append(calc_acc(pred, y_b))
accs_avg = sum(accs)/ len(accs)
print(f"TEST ACC {accs_avg}")
model = Spiral()
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(EPOCHS):
train(model, loss_fn, opt, train_dl)
if e % INTERVAL == 0:
test(model, test_dl)Output:
train acc 0.37890625 | loss_avg 1.0739420652389526
TEST ACC 0.4464285671710968
train acc 0.453125 | loss_avg 1.0422812700271606
train acc 0.5546875 | loss_avg 1.019551157951355
TEST ACC 0.4330357313156128
train acc 0.54296875 | loss_avg 1.0007585287094116
train acc 0.51953125 | loss_avg 0.9839129447937012
TEST ACC 0.4977678656578064
train acc 0.54296875 | loss_avg 0.968197762966156
train acc 0.54296875 | loss_avg 0.9533126950263977
TEST ACC 0.5
train acc 0.54296875 | loss_avg 0.9391645789146423
train acc 0.54296875 | loss_avg 0.9257025718688965
TEST ACC 0.4821428656578064
train acc 0.5390625 | loss_avg 0.9128650426864624
train acc 0.53125 | loss_avg 0.9006432890892029
TEST ACC 0.4821428656578064
train acc 0.52734375 | loss_avg 0.8889985084533691
train acc 0.5234375 | loss_avg 0.8779304027557373
TEST ACC 0.484375
train acc 0.51953125 | loss_avg 0.8674430251121521
train acc 0.5234375 | loss_avg 0.857520341873169
TEST ACC 0.484375
train acc 0.5234375 | loss_avg 0.8481522798538208
train acc 0.5234375 | loss_avg 0.8393054604530334
TEST ACC 0.484375
train acc 0.5234375 | loss_avg 0.8309529423713684
train acc 0.51953125 | loss_avg 0.8230631947517395
TEST ACC 0.484375
train acc 0.5234375 | loss_avg 0.8156328201293945
train acc 0.53125 | loss_avg 0.8085906505584717
TEST ACC 0.484375
train acc 0.52734375 | loss_avg 0.8019410371780396
train acc 0.52734375 | loss_avg 0.7956579923629761
TEST ACC 0.484375
train acc 0.52734375 | loss_avg 0.7897065281867981
train acc 0.5390625 | loss_avg 0.7840877771377563
TEST ACC 0.484375
train acc 0.53515625 | loss_avg 0.7787821888923645
train acc 0.53515625 | loss_avg 0.773770809173584
TEST ACC 0.484375
train acc 0.5390625 | loss_avg 0.7690286636352539
train acc 0.5390625 | loss_avg 0.7645596861839294
TEST ACC 0.5
train acc 0.5390625 | loss_avg 0.7603194117546082
train acc 0.5390625 | loss_avg 0.7563295364379883
TEST ACC 0.484375
train acc 0.54296875 | loss_avg 0.752537190914154
train acc 0.54296875 | loss_avg 0.7489384412765503
TEST ACC 0.484375
train acc 0.54296875 | loss_avg 0.7455693483352661
train acc 0.546875 | loss_avg 0.7423689961433411
TEST ACC 0.484375
train acc 0.546875 | loss_avg 0.7393431663513184
train acc 0.546875 | loss_avg 0.736453652381897
TEST ACC 0.484375
train acc 0.546875 | loss_avg 0.7337056398391724
train acc 0.546875 | loss_avg 0.731086015701294
TEST ACC 0.484375
train acc 0.546875 | loss_avg 0.7285877466201782
train acc 0.546875 | loss_avg 0.7262130975723267
TEST ACC 0.484375
train acc 0.546875 | loss_avg 0.7239452004432678
train acc 0.546875 | loss_avg 0.721779465675354
TEST ACC 0.484375
train acc 0.546875 | loss_avg 0.7197105288505554
train acc 0.546875 | loss_avg 0.7177078723907471
TEST ACC 0.484375
train acc 0.546875 | loss_avg 0.7157747149467468
train acc 0.546875 | loss_avg 0.7139238119125366
TEST ACC 0.5
train acc 0.546875 | loss_avg 0.7121438980102539
train acc 0.5546875 | loss_avg 0.7104165554046631
TEST ACC 0.5
train acc 0.5546875 | loss_avg 0.7087557315826416
train acc 0.5546875 | loss_avg 0.7071480751037598
TEST ACC 0.5
train acc 0.5546875 | loss_avg 0.7055858969688416
train acc 0.5546875 | loss_avg 0.7040607929229736
TEST ACC 0.5
train acc 0.5546875 | loss_avg 0.7025690078735352
train acc 0.5546875 | loss_avg 0.7010958790779114
TEST ACC 0.5
train acc 0.5546875 | loss_avg 0.6996676325798035
train acc 0.5546875 | loss_avg 0.6982657313346863
TEST ACC 0.5
train acc 0.5546875 | loss_avg 0.6968927979469299
train acc 0.5546875 | loss_avg 0.6955443024635315
TEST ACC 0.5
train acc 0.5546875 | loss_avg 0.6942127346992493
train acc 0.5625 | loss_avg 0.6929039359092712
TEST ACC 0.5
train acc 0.5625 | loss_avg 0.6916161775588989
train acc 0.5625 | loss_avg 0.6903461217880249
TEST ACC 0.5
train acc 0.56640625 | loss_avg 0.6890968680381775
train acc 0.57421875 | loss_avg 0.6878643035888672
TEST ACC 0.5
train acc 0.57421875 | loss_avg 0.6866599321365356
train acc 0.57421875 | loss_avg 0.6854611039161682
TEST ACC 0.5
train acc 0.578125 | loss_avg 0.6842828989028931
train acc 0.578125 | loss_avg 0.6831086874008179
TEST ACC 0.5
train acc 0.58203125 | loss_avg 0.6819444894790649
train acc 0.58203125 | loss_avg 0.6807853579521179
TEST ACC 0.5
train acc 0.58203125 | loss_avg 0.6796298623085022
train acc 0.58203125 | loss_avg 0.6784722805023193
TEST ACC 0.5
train acc 0.58203125 | loss_avg 0.6773213744163513
train acc 0.58203125 | loss_avg 0.6761704087257385
TEST ACC 0.5
train acc 0.58203125 | loss_avg 0.6750238537788391
train acc 0.58203125 | loss_avg 0.6738806366920471
TEST ACC 0.5
train acc 0.58203125 | loss_avg 0.6727311611175537
train acc 0.58203125 | loss_avg 0.671582818031311
TEST ACC 0.5
train acc 0.58203125 | loss_avg 0.670437216758728
train acc 0.578125 | loss_avg 0.6692879796028137
TEST ACC 0.5
train acc 0.578125 | loss_avg 0.6681408286094666
train acc 0.58203125 | loss_avg 0.6670057773590088
TEST ACC 0.5
train acc 0.58203125 | loss_avg 0.6658620834350586
train acc 0.5859375 | loss_avg 0.6647154688835144
TEST ACC 0.515625
train acc 0.5859375 | loss_avg 0.663550853729248
train acc 0.5859375 | loss_avg 0.6623866558074951
TEST ACC 0.515625
train acc 0.59375 | loss_avg 0.6612167358398438
train acc 0.59375 | loss_avg 0.6600497364997864
TEST ACC 0.515625
train acc 0.59375 | loss_avg 0.658875584602356
train acc 0.59375 | loss_avg 0.6577147841453552
TEST ACC 0.515625
train acc 0.59765625 | loss_avg 0.6565470695495605
train acc 0.59765625 | loss_avg 0.6553821563720703
TEST ACC 0.515625
train acc 0.6015625 | loss_avg 0.6542139053344727
train acc 0.60546875 | loss_avg 0.6530461311340332
TEST ACC 0.515625
train acc 0.60546875 | loss_avg 0.6518716812133789
train acc 0.60546875 | loss_avg 0.6506864428520203
TEST ACC 0.515625
train acc 0.60546875 | loss_avg 0.6494995951652527
train acc 0.60546875 | loss_avg 0.6483058929443359
TEST ACC 0.515625
train acc 0.60546875 | loss_avg 0.6471095681190491
train acc 0.60546875 | loss_avg 0.6459074020385742
TEST ACC 0.515625
train acc 0.609375 | loss_avg 0.6446970105171204
train acc 0.609375 | loss_avg 0.6434816122055054
TEST ACC 0.515625
train acc 0.609375 | loss_avg 0.6422584056854248
train acc 0.609375 | loss_avg 0.6410391926765442
TEST ACC 0.515625
train acc 0.609375 | loss_avg 0.6398136019706726
train acc 0.609375 | loss_avg 0.6385864615440369
TEST ACC 0.515625
train acc 0.61328125 | loss_avg 0.6373634338378906
train acc 0.61328125 | loss_avg 0.6361363530158997
TEST ACC 0.515625
train acc 0.61328125 | loss_avg 0.6349045634269714
train acc 0.61328125 | loss_avg 0.6336672306060791
TEST ACC 0.515625
train acc 0.61328125 | loss_avg 0.6324185729026794
train acc 0.6171875 | loss_avg 0.6311583518981934
TEST ACC 0.515625
train acc 0.6171875 | loss_avg 0.6298933029174805
train acc 0.6171875 | loss_avg 0.6286513805389404
TEST ACC 0.515625
train acc 0.6171875 | loss_avg 0.6273716688156128
train acc 0.62109375 | loss_avg 0.6260882019996643
TEST ACC 0.515625
train acc 0.62109375 | loss_avg 0.624798595905304
train acc 0.62109375 | loss_avg 0.6235056519508362
TEST ACC 0.515625
train acc 0.62109375 | loss_avg 0.6221989989280701
train acc 0.62109375 | loss_avg 0.6208907961845398
TEST ACC 0.515625
train acc 0.62109375 | loss_avg 0.6195746064186096
train acc 0.625 | loss_avg 0.6182553172111511
TEST ACC 0.515625
train acc 0.625 | loss_avg 0.6169331073760986
train acc 0.625 | loss_avg 0.6155962347984314
TEST ACC 0.515625
train acc 0.625 | loss_avg 0.6142570972442627
train acc 0.625 | loss_avg 0.6129077076911926
TEST ACC 0.515625
train acc 0.625 | loss_avg 0.6115556359291077
train acc 0.625 | loss_avg 0.6101982593536377
TEST ACC 0.515625
train acc 0.625 | loss_avg 0.6088441610336304
train acc 0.625 | loss_avg 0.6074883937835693
TEST ACC 0.515625
train acc 0.625 | loss_avg 0.6061235070228577
train acc 0.625 | loss_avg 0.6047525405883789
TEST ACC 0.515625
train acc 0.62890625 | loss_avg 0.6033714413642883
train acc 0.62890625 | loss_avg 0.6019877195358276
TEST ACC 0.53125
train acc 0.62890625 | loss_avg 0.6005977392196655
train acc 0.62890625 | loss_avg 0.5992115139961243
TEST ACC 0.53125
train acc 0.62890625 | loss_avg 0.597815990447998
train acc 0.62890625 | loss_avg 0.596425473690033
TEST ACC 0.53125
train acc 0.6328125 | loss_avg 0.5950320363044739
train acc 0.63671875 | loss_avg 0.5936371088027954
TEST ACC 0.546875
train acc 0.63671875 | loss_avg 0.5922213792800903
train acc 0.640625 | loss_avg 0.5908105969429016
TEST ACC 0.546875
train acc 0.640625 | loss_avg 0.5894197225570679
train acc 0.640625 | loss_avg 0.5879999399185181
TEST ACC 0.546875
train acc 0.640625 | loss_avg 0.5865835547447205
train acc 0.640625 | loss_avg 0.585166335105896
TEST ACC 0.546875
train acc 0.640625 | loss_avg 0.5837392210960388
train acc 0.640625 | loss_avg 0.5823339223861694
TEST ACC 0.5625
train acc 0.640625 | loss_avg 0.5809054970741272
train acc 0.640625 | loss_avg 0.5794726610183716
TEST ACC 0.5625
train acc 0.640625 | loss_avg 0.5780341625213623
train acc 0.640625 | loss_avg 0.576623260974884
TEST ACC 0.5625
train acc 0.640625 | loss_avg 0.575194239616394
train acc 0.64453125 | loss_avg 0.5737690925598145
TEST ACC 0.5625
train acc 0.64453125 | loss_avg 0.5723375678062439
train acc 0.64453125 | loss_avg 0.5709031224250793
TEST ACC 0.5625
train acc 0.64453125 | loss_avg 0.5694713592529297
train acc 0.64453125 | loss_avg 0.5680195093154907
TEST ACC 0.5625
train acc 0.6484375 | loss_avg 0.566581130027771
train acc 0.6484375 | loss_avg 0.5651392936706543
TEST ACC 0.5625
train acc 0.6484375 | loss_avg 0.5636974573135376
train acc 0.6484375 | loss_avg 0.5622517466545105
TEST ACC 0.5625
train acc 0.6484375 | loss_avg 0.5607870221138
train acc 0.6484375 | loss_avg 0.5593372583389282
TEST ACC 0.5625
train acc 0.6484375 | loss_avg 0.5578868985176086
train acc 0.6484375 | loss_avg 0.5564336180686951
TEST ACC 0.5625
train acc 0.6484375 | loss_avg 0.5549836754798889
train acc 0.6484375 | loss_avg 0.5535200834274292
TEST ACC 0.5625
train acc 0.6484375 | loss_avg 0.5520595908164978
train acc 0.6484375 | loss_avg 0.5506043434143066
TEST ACC 0.5803571343421936
train acc 0.6484375 | loss_avg 0.5491631031036377
train acc 0.6484375 | loss_avg 0.5477069020271301
TEST ACC 0.5803571343421936
train acc 0.65625 | loss_avg 0.5462605953216553
train acc 0.65625 | loss_avg 0.5448019504547119
TEST ACC 0.5803571343421936
train acc 0.65625 | loss_avg 0.5433551073074341
train acc 0.6640625 | loss_avg 0.5418948531150818
TEST ACC 0.5803571343421936
train acc 0.66796875 | loss_avg 0.5404279828071594
train acc 0.66796875 | loss_avg 0.5389740467071533
TEST ACC 0.5803571343421936
train acc 0.66796875 | loss_avg 0.5375096797943115
train acc 0.66796875 | loss_avg 0.5360447764396667
TEST ACC 0.5803571343421936
train acc 0.66796875 | loss_avg 0.5345861911773682
train acc 0.671875 | loss_avg 0.533119797706604
TEST ACC 0.5803571343421936
train acc 0.671875 | loss_avg 0.5316686630249023
train acc 0.67578125 | loss_avg 0.5302025675773621
TEST ACC 0.5803571343421936
train acc 0.6796875 | loss_avg 0.5287414193153381
train acc 0.68359375 | loss_avg 0.5272737145423889
TEST ACC 0.5803571343421936
train acc 0.6875 | loss_avg 0.5258105993270874
train acc 0.6875 | loss_avg 0.5243470072746277
TEST ACC 0.5803571343421936
train acc 0.6953125 | loss_avg 0.522878885269165
train acc 0.69921875 | loss_avg 0.5214173197746277
TEST ACC 0.5803571343421936
train acc 0.703125 | loss_avg 0.5199467539787292
train acc 0.703125 | loss_avg 0.5184826850891113
TEST ACC 0.5982142686843872
train acc 0.703125 | loss_avg 0.5170143246650696
train acc 0.70703125 | loss_avg 0.5155519843101501
TEST ACC 0.5982142686843872
train acc 0.71484375 | loss_avg 0.5140751600265503
train acc 0.71875 | loss_avg 0.5126067399978638
TEST ACC 0.6160714626312256
train acc 0.71875 | loss_avg 0.5111532807350159
train acc 0.71875 | loss_avg 0.5096826553344727
TEST ACC 0.6160714626312256
train acc 0.71875 | loss_avg 0.5082189440727234
train acc 0.71875 | loss_avg 0.5067524909973145
TEST ACC 0.6160714626312256
train acc 0.71875 | loss_avg 0.5052952766418457
train acc 0.73046875 | loss_avg 0.5038262605667114
TEST ACC 0.6316964626312256
train acc 0.73046875 | loss_avg 0.5023663640022278
train acc 0.73046875 | loss_avg 0.5009061098098755
TEST ACC 0.6316964626312256
train acc 0.73046875 | loss_avg 0.4994451701641083
train acc 0.73046875 | loss_avg 0.4979945123195648
TEST ACC 0.6316964626312256
train acc 0.73046875 | loss_avg 0.4965370297431946
train acc 0.734375 | loss_avg 0.4950794577598572
TEST ACC 0.6316964626312256
train acc 0.734375 | loss_avg 0.4936247169971466
train acc 0.734375 | loss_avg 0.4921756982803345
TEST ACC 0.6316964626312256
train acc 0.734375 | loss_avg 0.4907274842262268
train acc 0.734375 | loss_avg 0.4892759621143341
TEST ACC 0.6316964626312256
train acc 0.734375 | loss_avg 0.4878336489200592
train acc 0.734375 | loss_avg 0.4863903224468231
TEST ACC 0.6316964626312256
train acc 0.73828125 | loss_avg 0.4849461317062378
train acc 0.73828125 | loss_avg 0.48351696133613586
TEST ACC 0.6495535373687744
train acc 0.7421875 | loss_avg 0.48206815123558044
train acc 0.7421875 | loss_avg 0.48063403367996216
TEST ACC 0.6495535373687744
train acc 0.7421875 | loss_avg 0.47919759154319763
train acc 0.7421875 | loss_avg 0.47777116298675537
TEST ACC 0.6495535373687744
train acc 0.7421875 | loss_avg 0.476337194442749
train acc 0.7421875 | loss_avg 0.47490769624710083
TEST ACC 0.6339285373687744
train acc 0.7421875 | loss_avg 0.473482608795166
train acc 0.75 | loss_avg 0.47206345200538635
TEST ACC 0.6339285373687744
train acc 0.75 | loss_avg 0.4706467390060425
train acc 0.7578125 | loss_avg 0.4692261815071106
TEST ACC 0.6339285373687744
train acc 0.7578125 | loss_avg 0.4678073227405548
train acc 0.7578125 | loss_avg 0.4663974642753601
TEST ACC 0.6495535373687744
train acc 0.7578125 | loss_avg 0.464981347322464
train acc 0.7578125 | loss_avg 0.4635755717754364
TEST ACC 0.6495535373687744
train acc 0.7578125 | loss_avg 0.4621688723564148
train acc 0.7578125 | loss_avg 0.4607711136341095
TEST ACC 0.6495535373687744
train acc 0.76171875 | loss_avg 0.4593673348426819
train acc 0.76171875 | loss_avg 0.4579731523990631
TEST ACC 0.6495535373687744
train acc 0.76171875 | loss_avg 0.45657962560653687
train acc 0.76171875 | loss_avg 0.4551919400691986
TEST ACC 0.6674107313156128
train acc 0.76953125 | loss_avg 0.4538031816482544
train acc 0.7734375 | loss_avg 0.4524244964122772
TEST ACC 0.6674107313156128
train acc 0.7734375 | loss_avg 0.4510452151298523
train acc 0.78125 | loss_avg 0.4496718645095825
TEST ACC 0.6852678656578064
train acc 0.78125 | loss_avg 0.44829249382019043
train acc 0.78125 | loss_avg 0.4469214081764221
TEST ACC 0.6852678656578064
train acc 0.78125 | loss_avg 0.44555824995040894
train acc 0.78125 | loss_avg 0.444199800491333
TEST ACC 0.6852678656578064
train acc 0.78125 | loss_avg 0.44283849000930786
train acc 0.78125 | loss_avg 0.4414844512939453
TEST ACC 0.6852678656578064
train acc 0.78515625 | loss_avg 0.4401353597640991
train acc 0.78515625 | loss_avg 0.43879246711730957
TEST ACC 0.6852678656578064
train acc 0.7890625 | loss_avg 0.43744874000549316
train acc 0.7890625 | loss_avg 0.43610966205596924
TEST ACC 0.6852678656578064
train acc 0.7890625 | loss_avg 0.43477925658226013
train acc 0.7890625 | loss_avg 0.4334568977355957
TEST ACC 0.6852678656578064
train acc 0.7890625 | loss_avg 0.432137131690979
train acc 0.7890625 | loss_avg 0.4308132827281952
TEST ACC 0.7008928656578064
train acc 0.7890625 | loss_avg 0.4294957220554352
train acc 0.79296875 | loss_avg 0.428187757730484
TEST ACC 0.7008928656578064
train acc 0.79296875 | loss_avg 0.4268844723701477
train acc 0.79296875 | loss_avg 0.4255804121494293
TEST ACC 0.7165178656578064
train acc 0.79296875 | loss_avg 0.42429301142692566
train acc 0.79296875 | loss_avg 0.42299985885620117
TEST ACC 0.7165178656578064
train acc 0.79296875 | loss_avg 0.4217296838760376
train acc 0.796875 | loss_avg 0.42044463753700256
TEST ACC 0.7165178656578064
train acc 0.796875 | loss_avg 0.4191663861274719
train acc 0.796875 | loss_avg 0.4179040789604187
TEST ACC 0.734375
train acc 0.796875 | loss_avg 0.41663578152656555
train acc 0.796875 | loss_avg 0.41537269949913025
TEST ACC 0.734375
train acc 0.80078125 | loss_avg 0.414135217666626
train acc 0.80078125 | loss_avg 0.41287651658058167
TEST ACC 0.734375
train acc 0.80078125 | loss_avg 0.4116441011428833
train acc 0.80078125 | loss_avg 0.4104124903678894
TEST ACC 0.734375
train acc 0.80078125 | loss_avg 0.4091813564300537
train acc 0.80078125 | loss_avg 0.40795350074768066
TEST ACC 0.734375
train acc 0.80078125 | loss_avg 0.4067384898662567
train acc 0.80078125 | loss_avg 0.4055228531360626
TEST ACC 0.734375
train acc 0.8046875 | loss_avg 0.404314249753952
train acc 0.80078125 | loss_avg 0.40311285853385925
TEST ACC 0.734375
train acc 0.80859375 | loss_avg 0.40191197395324707
train acc 0.80859375 | loss_avg 0.4007188081741333
TEST ACC 0.734375
train acc 0.81640625 | loss_avg 0.3995383083820343
train acc 0.81640625 | loss_avg 0.39835238456726074
TEST ACC 0.734375
train acc 0.81640625 | loss_avg 0.39717671275138855
train acc 0.81640625 | loss_avg 0.3960029184818268
TEST ACC 0.734375
train acc 0.81640625 | loss_avg 0.39484483003616333
train acc 0.81640625 | loss_avg 0.393679678440094
TEST ACC 0.734375
train acc 0.81640625 | loss_avg 0.39252448081970215
train acc 0.81640625 | loss_avg 0.39137449860572815
TEST ACC 0.734375
train acc 0.81640625 | loss_avg 0.39023616909980774
train acc 0.8203125 | loss_avg 0.38909244537353516
TEST ACC 0.734375
train acc 0.8203125 | loss_avg 0.3879695534706116
train acc 0.8203125 | loss_avg 0.38683807849884033
TEST ACC 0.734375
train acc 0.8203125 | loss_avg 0.38571739196777344
train acc 0.8203125 | loss_avg 0.38461020588874817
TEST ACC 0.734375
train acc 0.8203125 | loss_avg 0.3834913372993469
train acc 0.8203125 | loss_avg 0.38239502906799316
TEST ACC 0.734375
train acc 0.82421875 | loss_avg 0.3812827169895172
train acc 0.82421875 | loss_avg 0.38019564747810364
TEST ACC 0.734375
train acc 0.82421875 | loss_avg 0.3791239261627197
train acc 0.82421875 | loss_avg 0.378036230802536
TEST ACC 0.734375
train acc 0.82421875 | loss_avg 0.37694913148880005
train acc 0.82421875 | loss_avg 0.3758971691131592
TEST ACC 0.734375
train acc 0.82421875 | loss_avg 0.3748163878917694
train acc 0.828125 | loss_avg 0.37376296520233154
TEST ACC 0.734375
train acc 0.828125 | loss_avg 0.37271925806999207
train acc 0.828125 | loss_avg 0.3716661036014557
TEST ACC 0.734375
train acc 0.828125 | loss_avg 0.37062013149261475
train acc 0.83203125 | loss_avg 0.36959367990493774
TEST ACC 0.734375
train acc 0.83203125 | loss_avg 0.36856609582901
train acc 0.83203125 | loss_avg 0.36753737926483154
TEST ACC 0.75
train acc 0.83203125 | loss_avg 0.36652672290802
train acc 0.83203125 | loss_avg 0.36550891399383545
TEST ACC 0.75
train acc 0.83203125 | loss_avg 0.36450135707855225
train acc 0.83203125 | loss_avg 0.3634992241859436
TEST ACC 0.75
train acc 0.83203125 | loss_avg 0.3624943494796753
train acc 0.8359375 | loss_avg 0.3614926040172577
TEST ACC 0.75
train acc 0.8359375 | loss_avg 0.36049994826316833
train acc 0.8359375 | loss_avg 0.35952433943748474
TEST ACC 0.75
train acc 0.8359375 | loss_avg 0.3585370182991028
train acc 0.8359375 | loss_avg 0.3575628399848938
TEST ACC 0.75
train acc 0.83984375 | loss_avg 0.3565979599952698
train acc 0.83984375 | loss_avg 0.35563263297080994
TEST ACC 0.7678571343421936
train acc 0.83984375 | loss_avg 0.35467374324798584
train acc 0.83984375 | loss_avg 0.3537242114543915
TEST ACC 0.7678571343421936
train acc 0.83984375 | loss_avg 0.35277312994003296
train acc 0.83984375 | loss_avg 0.35183364152908325
TEST ACC 0.7678571343421936
train acc 0.83984375 | loss_avg 0.35089805722236633
train acc 0.83984375 | loss_avg 0.34997987747192383
TEST ACC 0.7678571343421936
train acc 0.83984375 | loss_avg 0.3490544855594635
train acc 0.8359375 | loss_avg 0.3481441140174866
TEST ACC 0.7678571343421936
train acc 0.83984375 | loss_avg 0.3472207486629486
train acc 0.8359375 | loss_avg 0.34631797671318054
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.3454127311706543
train acc 0.8359375 | loss_avg 0.34452393651008606
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.34362637996673584
train acc 0.8359375 | loss_avg 0.3427388072013855
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.3418363630771637
train acc 0.8359375 | loss_avg 0.34095361828804016
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.3400740325450897
train acc 0.8359375 | loss_avg 0.3391885757446289
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.33832237124443054
train acc 0.8359375 | loss_avg 0.3374485373497009
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.33659207820892334
train acc 0.8359375 | loss_avg 0.33572766184806824
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.3348836600780487
train acc 0.8359375 | loss_avg 0.3340195417404175
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.33315902948379517
train acc 0.8359375 | loss_avg 0.33229342103004456
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.33147159218788147
train acc 0.8359375 | loss_avg 0.3306419253349304
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.32982179522514343
train acc 0.8359375 | loss_avg 0.3289947807788849
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.32817500829696655
train acc 0.8359375 | loss_avg 0.32735738158226013
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.32654649019241333
train acc 0.8359375 | loss_avg 0.32575154304504395
TEST ACC 0.7678571343421936
train acc 0.8359375 | loss_avg 0.3249330222606659
train acc 0.8359375 | loss_avg 0.3241176903247833
TEST ACC 0.7834821343421936
train acc 0.8359375 | loss_avg 0.32332783937454224
train acc 0.8359375 | loss_avg 0.322533518075943
TEST ACC 0.7834821343421936
train acc 0.8359375 | loss_avg 0.3217444121837616
train acc 0.8359375 | loss_avg 0.32096540927886963
TEST ACC 0.7834821343421936
train acc 0.8359375 | loss_avg 0.3201807141304016
train acc 0.8359375 | loss_avg 0.3194139897823334
TEST ACC 0.7834821343421936
train acc 0.8359375 | loss_avg 0.31863269209861755
train acc 0.8359375 | loss_avg 0.31785497069358826
TEST ACC 0.7834821343421936
train acc 0.8359375 | loss_avg 0.3170812129974365
train acc 0.83984375 | loss_avg 0.3163032829761505
TEST ACC 0.7834821343421936
train acc 0.83984375 | loss_avg 0.3155325949192047
train acc 0.83984375 | loss_avg 0.3147764503955841
TEST ACC 0.7834821343421936
train acc 0.83984375 | loss_avg 0.3140205144882202
train acc 0.84375 | loss_avg 0.31325972080230713
TEST ACC 0.7834821343421936
train acc 0.84375 | loss_avg 0.312503457069397
train acc 0.84375 | loss_avg 0.3117527961730957
TEST ACC 0.7834821343421936
train acc 0.84375 | loss_avg 0.3109942674636841
train acc 0.84375 | loss_avg 0.3102568984031677
TEST ACC 0.7834821343421936
train acc 0.84375 | loss_avg 0.30951347947120667
train acc 0.84375 | loss_avg 0.308775395154953
TEST ACC 0.7834821343421936
train acc 0.84375 | loss_avg 0.30803725123405457
train acc 0.84765625 | loss_avg 0.30731916427612305
TEST ACC 0.7834821343421936
train acc 0.84765625 | loss_avg 0.3065977692604065
train acc 0.84765625 | loss_avg 0.30588215589523315
TEST ACC 0.7834821343421936
train acc 0.8515625 | loss_avg 0.30518051981925964
train acc 0.85546875 | loss_avg 0.3044714331626892
TEST ACC 0.7834821343421936
train acc 0.85546875 | loss_avg 0.3037680387496948
train acc 0.85546875 | loss_avg 0.30306583642959595
TEST ACC 0.7834821343421936
train acc 0.85546875 | loss_avg 0.30236801505088806
train acc 0.85546875 | loss_avg 0.3016725480556488
TEST ACC 0.7834821343421936
train acc 0.85546875 | loss_avg 0.30098867416381836
train acc 0.859375 | loss_avg 0.3003019690513611
TEST ACC 0.7834821343421936
train acc 0.859375 | loss_avg 0.2996370494365692
train acc 0.859375 | loss_avg 0.29894721508026123
TEST ACC 0.7834821343421936
train acc 0.859375 | loss_avg 0.29827427864074707
train acc 0.859375 | loss_avg 0.2976064085960388
TEST ACC 0.7834821343421936
train acc 0.859375 | loss_avg 0.2969557046890259
train acc 0.859375 | loss_avg 0.2962760329246521
TEST ACC 0.7834821343421936
train acc 0.859375 | loss_avg 0.2956194579601288
train acc 0.859375 | loss_avg 0.29497605562210083
TEST ACC 0.7834821343421936
train acc 0.859375 | loss_avg 0.29431232810020447
train acc 0.859375 | loss_avg 0.2936622202396393
TEST ACC 0.7834821343421936
train acc 0.859375 | loss_avg 0.29301658272743225
train acc 0.86328125 | loss_avg 0.2923775911331177
TEST ACC 0.7834821343421936
train acc 0.86328125 | loss_avg 0.2917136251926422
train acc 0.86328125 | loss_avg 0.291076123714447
TEST ACC 0.7834821343421936
train acc 0.86328125 | loss_avg 0.2904456555843353
train acc 0.86328125 | loss_avg 0.2898081839084625
TEST ACC 0.7834821343421936
train acc 0.86328125 | loss_avg 0.289174884557724
train acc 0.86328125 | loss_avg 0.28853243589401245
TEST ACC 0.7834821343421936
train acc 0.86328125 | loss_avg 0.28789281845092773
train acc 0.86328125 | loss_avg 0.28725647926330566
TEST ACC 0.7834821343421936
train acc 0.8671875 | loss_avg 0.28664523363113403
train acc 0.8671875 | loss_avg 0.28603288531303406
TEST ACC 0.7834821343421936
train acc 0.8671875 | loss_avg 0.28542083501815796
train acc 0.8671875 | loss_avg 0.28481292724609375
TEST ACC 0.7834821343421936
train acc 0.8671875 | loss_avg 0.2842090129852295
train acc 0.87109375 | loss_avg 0.2836052179336548
TEST ACC 0.7834821343421936
train acc 0.87109375 | loss_avg 0.28300023078918457
train acc 0.875 | loss_avg 0.2823965549468994
TEST ACC 0.7834821343421936
train acc 0.875 | loss_avg 0.2817961573600769
train acc 0.875 | loss_avg 0.2812008261680603
TEST ACC 0.7834821343421936
train acc 0.875 | loss_avg 0.28060686588287354
train acc 0.875 | loss_avg 0.2800155282020569
TEST ACC 0.7834821343421936
train acc 0.87890625 | loss_avg 0.2794229984283447
train acc 0.87890625 | loss_avg 0.2788335382938385
TEST ACC 0.7834821343421936
train acc 0.8828125 | loss_avg 0.27824667096138
train acc 0.8828125 | loss_avg 0.2776634395122528
TEST ACC 0.7834821343421936
train acc 0.8828125 | loss_avg 0.27708548307418823
train acc 0.88671875 | loss_avg 0.27650946378707886
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.27593886852264404
train acc 0.88671875 | loss_avg 0.2753664553165436
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.2747988700866699
train acc 0.88671875 | loss_avg 0.27423399686813354
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.2736741900444031
train acc 0.88671875 | loss_avg 0.2731190621852875
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.2725636959075928
train acc 0.88671875 | loss_avg 0.27201053500175476
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.271460622549057
train acc 0.88671875 | loss_avg 0.27091625332832336
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.2703719735145569
train acc 0.88671875 | loss_avg 0.2698334753513336
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.26929378509521484
train acc 0.88671875 | loss_avg 0.26875975728034973
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.26822447776794434
train acc 0.88671875 | loss_avg 0.26769503951072693
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.2671678066253662
train acc 0.88671875 | loss_avg 0.2666432559490204
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.26612237095832825
train acc 0.88671875 | loss_avg 0.2656024992465973
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.2650984823703766
train acc 0.88671875 | loss_avg 0.2645760774612427
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.26407694816589355
train acc 0.88671875 | loss_avg 0.263563334941864
TEST ACC 0.7834821343421936
train acc 0.88671875 | loss_avg 0.2630620300769806
train acc 0.88671875 | loss_avg 0.2625545263290405
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.2620564103126526
train acc 0.890625 | loss_avg 0.2615662217140198
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.26105639338493347
train acc 0.890625 | loss_avg 0.2605726420879364
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.26008155941963196
train acc 0.890625 | loss_avg 0.259598970413208
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.259124755859375
train acc 0.890625 | loss_avg 0.2586349546909332
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.25815820693969727
train acc 0.890625 | loss_avg 0.2576853334903717
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.2572188377380371
train acc 0.890625 | loss_avg 0.2567470967769623
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.25627267360687256
train acc 0.890625 | loss_avg 0.25580641627311707
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.25534793734550476
train acc 0.890625 | loss_avg 0.25487613677978516
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.2544137239456177
train acc 0.890625 | loss_avg 0.25395479798316956
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.25349336862564087
train acc 0.890625 | loss_avg 0.2530381679534912
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.25259289145469666
train acc 0.890625 | loss_avg 0.252132385969162
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.2516811788082123
train acc 0.890625 | loss_avg 0.25123482942581177
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.25078773498535156
train acc 0.890625 | loss_avg 0.2503582537174225
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.249909445643425
train acc 0.890625 | loss_avg 0.24947378039360046
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.2490389198064804
train acc 0.890625 | loss_avg 0.24859407544136047
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.2481614053249359
train acc 0.890625 | loss_avg 0.2477230429649353
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.24727711081504822
train acc 0.890625 | loss_avg 0.2468460351228714
TEST ACC 0.7834821343421936
train acc 0.890625 | loss_avg 0.24641093611717224
train acc 0.890625 | loss_avg 0.24599018692970276
TEST ACC 0.7991071343421936
train acc 0.890625 | loss_avg 0.24557030200958252
train acc 0.890625 | loss_avg 0.24513843655586243
TEST ACC 0.7991071343421936
train acc 0.890625 | loss_avg 0.24472324550151825
train acc 0.890625 | loss_avg 0.24430914223194122
TEST ACC 0.8169642686843872
train acc 0.890625 | loss_avg 0.2438828945159912
train acc 0.890625 | loss_avg 0.2434733808040619
TEST ACC 0.8169642686843872
train acc 0.890625 | loss_avg 0.24306447803974152
train acc 0.890625 | loss_avg 0.24264468252658844
TEST ACC 0.8169642686843872
train acc 0.890625 | loss_avg 0.24224141240119934
train acc 0.89453125 | loss_avg 0.24183575809001923
TEST ACC 0.8169642686843872
train acc 0.89453125 | loss_avg 0.2414323389530182
train acc 0.89453125 | loss_avg 0.24101980030536652
TEST ACC 0.8169642686843872
train acc 0.89453125 | loss_avg 0.24062272906303406
train acc 0.89453125 | loss_avg 0.24021555483341217
TEST ACC 0.8169642686843872
train acc 0.89453125 | loss_avg 0.23982371389865875
train acc 0.89453125 | loss_avg 0.23942068219184875
TEST ACC 0.8169642686843872
train acc 0.89453125 | loss_avg 0.23903149366378784
train acc 0.89453125 | loss_avg 0.23863276839256287
TEST ACC 0.8169642686843872
train acc 0.89453125 | loss_avg 0.23824571073055267
train acc 0.89453125 | loss_avg 0.23786328732967377
TEST ACC 0.8169642686843872
train acc 0.89453125 | loss_avg 0.2374589443206787
train acc 0.89453125 | loss_avg 0.237077534198761
TEST ACC 0.8348214626312256
train acc 0.89453125 | loss_avg 0.2366812527179718
train acc 0.89453125 | loss_avg 0.23630493879318237
TEST ACC 0.8348214626312256
train acc 0.89453125 | loss_avg 0.23593364655971527
train acc 0.89453125 | loss_avg 0.23553532361984253
TEST ACC 0.8348214626312256
train acc 0.89453125 | loss_avg 0.23516003787517548
train acc 0.8984375 | loss_avg 0.23478397727012634
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.23439937829971313
train acc 0.8984375 | loss_avg 0.23403006792068481
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.23364537954330444
train acc 0.8984375 | loss_avg 0.23328179121017456
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.23291364312171936
train acc 0.8984375 | loss_avg 0.23253455758094788
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.23217357695102692
train acc 0.8984375 | loss_avg 0.23180195689201355
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.2314438372850418
train acc 0.8984375 | loss_avg 0.23108458518981934
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.23071983456611633
train acc 0.8984375 | loss_avg 0.23035532236099243
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.23000474274158478
train acc 0.8984375 | loss_avg 0.22964251041412354
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.22929374873638153
train acc 0.8984375 | loss_avg 0.22893472015857697
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.22856895625591278
train acc 0.8984375 | loss_avg 0.22823429107666016
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.22787556052207947
train acc 0.8984375 | loss_avg 0.22753536701202393
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.2271866649389267
train acc 0.8984375 | loss_avg 0.22685188055038452
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.2265048772096634
train acc 0.8984375 | loss_avg 0.22617243230342865
TEST ACC 0.8348214626312256
train acc 0.8984375 | loss_avg 0.22582779824733734
train acc 0.8984375 | loss_avg 0.22548907995224
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.22516228258609772
train acc 0.90234375 | loss_avg 0.22482219338417053
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.22449633479118347
train acc 0.90234375 | loss_avg 0.22415845096111298
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.2238285094499588
train acc 0.90234375 | loss_avg 0.2235070914030075
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.2231747955083847
train acc 0.90234375 | loss_avg 0.22284455597400665
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.222518652677536
train acc 0.90234375 | loss_avg 0.2222069948911667
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.22187666594982147
train acc 0.90234375 | loss_avg 0.22155523300170898
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.2212342619895935
train acc 0.90234375 | loss_avg 0.22092509269714355
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.22059988975524902
train acc 0.90234375 | loss_avg 0.2202819436788559
TEST ACC 0.8348214626312256
train acc 0.90234375 | loss_avg 0.2199621945619583
train acc 0.90234375 | loss_avg 0.21964730322360992
TEST ACC 0.8348214626312256
train acc 0.90625 | loss_avg 0.219345822930336
train acc 0.90625 | loss_avg 0.219025120139122
TEST ACC 0.8348214626312256
train acc 0.90625 | loss_avg 0.21871499717235565
train acc 0.90625 | loss_avg 0.2184019386768341
TEST ACC 0.8348214626312256
train acc 0.90625 | loss_avg 0.21809537708759308
train acc 0.90625 | loss_avg 0.21778574585914612
TEST ACC 0.8348214626312256
train acc 0.90625 | loss_avg 0.2174815982580185
train acc 0.90625 | loss_avg 0.21717222034931183
TEST ACC 0.8348214626312256
train acc 0.91015625 | loss_avg 0.21687087416648865
train acc 0.91015625 | loss_avg 0.2165772020816803
TEST ACC 0.8348214626312256
train acc 0.91015625 | loss_avg 0.2162669599056244
train acc 0.9140625 | loss_avg 0.2159656584262848
TEST ACC 0.8348214626312256
train acc 0.9140625 | loss_avg 0.21566356718540192
train acc 0.9140625 | loss_avg 0.21536663174629211
TEST ACC 0.8348214626312256
train acc 0.9140625 | loss_avg 0.2150668501853943
train acc 0.9140625 | loss_avg 0.2147722989320755
TEST ACC 0.8348214626312256
train acc 0.9140625 | loss_avg 0.21447572112083435
train acc 0.9140625 | loss_avg 0.2141784131526947
TEST ACC 0.8348214626312256
train acc 0.9140625 | loss_avg 0.21388548612594604
train acc 0.9140625 | loss_avg 0.2135918140411377
TEST ACC 0.8348214626312256
train acc 0.9140625 | loss_avg 0.21330209076404572
train acc 0.9140625 | loss_avg 0.2130085676908493
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.21272194385528564
train acc 0.9140625 | loss_avg 0.21243257820606232
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2121465653181076
train acc 0.9140625 | loss_avg 0.21186155080795288
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2115756720304489
train acc 0.9140625 | loss_avg 0.2112889438867569
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.21100828051567078
train acc 0.9140625 | loss_avg 0.21072587370872498
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.21044398844242096
train acc 0.9140625 | loss_avg 0.21016626060009003
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20988711714744568
train acc 0.9140625 | loss_avg 0.20960839092731476
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2093351036310196
train acc 0.9140625 | loss_avg 0.20905981957912445
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2087840884923935
train acc 0.9140625 | loss_avg 0.2085127830505371
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20824626088142395
train acc 0.9140625 | loss_avg 0.2079644501209259
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20769020915031433
train acc 0.9140625 | loss_avg 0.2074151337146759
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20713835954666138
train acc 0.9140625 | loss_avg 0.20686814188957214
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2066023051738739
train acc 0.9140625 | loss_avg 0.2063184231519699
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2060387283563614
train acc 0.9140625 | loss_avg 0.20576301217079163
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2054794877767563
train acc 0.9140625 | loss_avg 0.205215722322464
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20495258271694183
train acc 0.9140625 | loss_avg 0.2046884149312973
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20442216098308563
train acc 0.9140625 | loss_avg 0.20416143536567688
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20390038192272186
train acc 0.9140625 | loss_avg 0.20363777875900269
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20337577164173126
train acc 0.9140625 | loss_avg 0.2031191736459732
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20285920798778534
train acc 0.9140625 | loss_avg 0.20259921252727509
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20234447717666626
train acc 0.9140625 | loss_avg 0.20209041237831116
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.20183363556861877
train acc 0.9140625 | loss_avg 0.20157743990421295
TEST ACC 0.8504464626312256
train acc 0.9140625 | loss_avg 0.2013261616230011
train acc 0.91796875 | loss_avg 0.20107465982437134
TEST ACC 0.8504464626312256
train acc 0.91796875 | loss_avg 0.20082105696201324
train acc 0.91796875 | loss_avg 0.2005685716867447
TEST ACC 0.8504464626312256
train acc 0.91796875 | loss_avg 0.20032009482383728
train acc 0.91796875 | loss_avg 0.20007002353668213
TEST ACC 0.8504464626312256
train acc 0.91796875 | loss_avg 0.1998198926448822
train acc 0.91796875 | loss_avg 0.19957461953163147
TEST ACC 0.8504464626312256
train acc 0.91796875 | loss_avg 0.19932858645915985
train acc 0.91796875 | loss_avg 0.1990835815668106
TEST ACC 0.8504464626312256
train acc 0.91796875 | loss_avg 0.19883674383163452
train acc 0.91796875 | loss_avg 0.19859573245048523
TEST ACC 0.8504464626312256
train acc 0.91796875 | loss_avg 0.19835515320301056
train acc 0.91796875 | loss_avg 0.1981128752231598
TEST ACC 0.8504464626312256
train acc 0.91796875 | loss_avg 0.19787220656871796
train acc 0.91796875 | loss_avg 0.19763357937335968
TEST ACC 0.8660714626312256
train acc 0.91796875 | loss_avg 0.1973954737186432
train acc 0.91796875 | loss_avg 0.19715608656406403
TEST ACC 0.8660714626312256
train acc 0.91796875 | loss_avg 0.1969173699617386
train acc 0.91796875 | loss_avg 0.19668132066726685
TEST ACC 0.8660714626312256
train acc 0.91796875 | loss_avg 0.19644512236118317
train acc 0.91796875 | loss_avg 0.19621065258979797
TEST ACC 0.8660714626312256
train acc 0.91796875 | loss_avg 0.19597238302230835
train acc 0.91796875 | loss_avg 0.19574208557605743
TEST ACC 0.8660714626312256
train acc 0.91796875 | loss_avg 0.19550754129886627
train acc 0.91796875 | loss_avg 0.19527588784694672
TEST ACC 0.8660714626312256
train acc 0.91796875 | loss_avg 0.19504126906394958
train acc 0.91796875 | loss_avg 0.1948143094778061
TEST ACC 0.8839285373687744
train acc 0.91796875 | loss_avg 0.19458416104316711
train acc 0.91796875 | loss_avg 0.19435641169548035
TEST ACC 0.8839285373687744
train acc 0.91796875 | loss_avg 0.19412411749362946
train acc 0.91796875 | loss_avg 0.19390013813972473
TEST ACC 0.8839285373687744
train acc 0.91796875 | loss_avg 0.1936713606119156
train acc 0.91796875 | loss_avg 0.19344711303710938
TEST ACC 0.8839285373687744
train acc 0.91796875 | loss_avg 0.19321878254413605
train acc 0.91796875 | loss_avg 0.19299417734146118
TEST ACC 0.8839285373687744
train acc 0.91796875 | loss_avg 0.1927710324525833
train acc 0.91796875 | loss_avg 0.19254635274410248
TEST ACC 0.8839285373687744
train acc 0.91796875 | loss_avg 0.1923239827156067
train acc 0.91796875 | loss_avg 0.192097008228302
TEST ACC 0.8839285373687744
train acc 0.91796875 | loss_avg 0.19188573956489563
train acc 0.921875 | loss_avg 0.19165657460689545
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.1914326697587967
train acc 0.921875 | loss_avg 0.19120477139949799
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.19098085165023804
train acc 0.921875 | loss_avg 0.19075897336006165
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.19054663181304932
train acc 0.921875 | loss_avg 0.19032591581344604
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.1900903582572937
train acc 0.921875 | loss_avg 0.1898835152387619
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18966613709926605
train acc 0.921875 | loss_avg 0.18943262100219727
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.1892300844192505
train acc 0.921875 | loss_avg 0.1890110969543457
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.1887938529253006
train acc 0.921875 | loss_avg 0.18857870995998383
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.1883513331413269
train acc 0.921875 | loss_avg 0.18814869225025177
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18793630599975586
train acc 0.921875 | loss_avg 0.18772242963314056
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18750856816768646
train acc 0.921875 | loss_avg 0.18728645145893097
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18708547949790955
train acc 0.921875 | loss_avg 0.1868760734796524
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18666693568229675
train acc 0.921875 | loss_avg 0.18645811080932617
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18625383079051971
train acc 0.921875 | loss_avg 0.18604597449302673
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.1858261674642563
train acc 0.921875 | loss_avg 0.18562527000904083
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18540999293327332
train acc 0.921875 | loss_avg 0.18519558012485504
TEST ACC 0.9017857313156128
train acc 0.921875 | loss_avg 0.18498381972312927
train acc 0.921875 | loss_avg 0.18477016687393188
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.18455302715301514
train acc 0.92578125 | loss_avg 0.1843416690826416
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.1841166913509369
train acc 0.92578125 | loss_avg 0.18389801681041718
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.18367791175842285
train acc 0.92578125 | loss_avg 0.1834627091884613
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.183244526386261
train acc 0.92578125 | loss_avg 0.18303003907203674
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.18281452357769012
train acc 0.92578125 | loss_avg 0.18259745836257935
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.18238775432109833
train acc 0.92578125 | loss_avg 0.18217946588993073
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.1819729208946228
train acc 0.92578125 | loss_avg 0.18176399171352386
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.18155866861343384
train acc 0.92578125 | loss_avg 0.18135206401348114
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.18114611506462097
train acc 0.92578125 | loss_avg 0.1809377670288086
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.1807398498058319
train acc 0.92578125 | loss_avg 0.1805291622877121
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.1803172379732132
train acc 0.92578125 | loss_avg 0.1801091581583023
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.17989987134933472
train acc 0.92578125 | loss_avg 0.17969082295894623
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.17948438227176666
train acc 0.92578125 | loss_avg 0.17928028106689453
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.17907536029815674
train acc 0.92578125 | loss_avg 0.1788693070411682
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.17866235971450806
train acc 0.92578125 | loss_avg 0.1784592866897583
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.17825816571712494
train acc 0.92578125 | loss_avg 0.17805549502372742
TEST ACC 0.9017857313156128
train acc 0.92578125 | loss_avg 0.17785364389419556
train acc 0.9296875 | loss_avg 0.17764341831207275
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17745381593704224
train acc 0.9296875 | loss_avg 0.17724363505840302
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17704790830612183
train acc 0.9296875 | loss_avg 0.17685166001319885
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17666392028331757
train acc 0.9296875 | loss_avg 0.17645663022994995
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.1762649416923523
train acc 0.9296875 | loss_avg 0.17607131600379944
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17587782442569733
train acc 0.9296875 | loss_avg 0.17569394409656525
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17549076676368713
train acc 0.9296875 | loss_avg 0.17529718577861786
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17510803043842316
train acc 0.9296875 | loss_avg 0.1749168485403061
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.1747259646654129
train acc 0.9296875 | loss_avg 0.17453521490097046
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17435400187969208
train acc 0.9296875 | loss_avg 0.17415285110473633
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17396719753742218
train acc 0.9296875 | loss_avg 0.173780620098114
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17359362542629242
train acc 0.9296875 | loss_avg 0.17340697348117828
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17322079837322235
train acc 0.9296875 | loss_avg 0.17303401231765747
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17285075783729553
train acc 0.9296875 | loss_avg 0.1726665496826172
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17248260974884033
train acc 0.9296875 | loss_avg 0.1722991168498993
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.1721145212650299
train acc 0.9296875 | loss_avg 0.17193427681922913
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17175260186195374
train acc 0.9296875 | loss_avg 0.17157138884067535
TEST ACC 0.9017857313156128
train acc 0.9296875 | loss_avg 0.17139039933681488
train acc 0.93359375 | loss_avg 0.1712101697921753
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.17102950811386108
train acc 0.93359375 | loss_avg 0.1708516627550125
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.170673668384552
train acc 0.93359375 | loss_avg 0.17049522697925568
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.17031779885292053
train acc 0.93359375 | loss_avg 0.17013856768608093
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.1699645072221756
train acc 0.93359375 | loss_avg 0.1697879433631897
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16961266100406647
train acc 0.93359375 | loss_avg 0.16943711042404175
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16926269233226776
train acc 0.93359375 | loss_avg 0.1690865159034729
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16891556978225708
train acc 0.93359375 | loss_avg 0.16873541474342346
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16855643689632416
train acc 0.93359375 | loss_avg 0.16837604343891144
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16819779574871063
train acc 0.93359375 | loss_avg 0.16801796853542328
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16784250736236572
train acc 0.93359375 | loss_avg 0.1676664501428604
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16749024391174316
train acc 0.93359375 | loss_avg 0.16731494665145874
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16713795065879822
train acc 0.93359375 | loss_avg 0.16696299612522125
TEST ACC 0.9017857313156128
train acc 0.93359375 | loss_avg 0.16678383946418762
train acc 0.9375 | loss_avg 0.16660569608211517
TEST ACC 0.9017857313156128
train acc 0.9375 | loss_avg 0.16642755270004272
train acc 0.94140625 | loss_avg 0.16625045239925385
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16607247292995453
train acc 0.94140625 | loss_avg 0.1658993512392044
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16572429239749908
train acc 0.94140625 | loss_avg 0.16555041074752808
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16537612676620483
train acc 0.94140625 | loss_avg 0.16520141065120697
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16503110527992249
train acc 0.94140625 | loss_avg 0.16485926508903503
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.1646786779165268
train acc 0.9453125 | loss_avg 0.1645006537437439
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.16431300342082977
train acc 0.9453125 | loss_avg 0.1641131341457367
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.16391368210315704
train acc 0.94140625 | loss_avg 0.16371531784534454
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16351714730262756
train acc 0.94140625 | loss_avg 0.16332806646823883
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16313792765140533
train acc 0.94140625 | loss_avg 0.16294871270656586
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.1627608984708786
train acc 0.94140625 | loss_avg 0.16257257759571075
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16238941252231598
train acc 0.94140625 | loss_avg 0.1622052639722824
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16202214360237122
train acc 0.94140625 | loss_avg 0.16183963418006897
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16166211664676666
train acc 0.94140625 | loss_avg 0.1614835262298584
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16130872070789337
train acc 0.94140625 | loss_avg 0.1611328274011612
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16095834970474243
train acc 0.94140625 | loss_avg 0.16078411042690277
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.1606101244688034
train acc 0.94140625 | loss_avg 0.16043740510940552
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.16026337444782257
train acc 0.94140625 | loss_avg 0.1600934863090515
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.159923255443573
train acc 0.94140625 | loss_avg 0.15975205600261688
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.15957826375961304
train acc 0.94140625 | loss_avg 0.15940241515636444
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.15922850370407104
train acc 0.94140625 | loss_avg 0.1590532660484314
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.1588822901248932
train acc 0.94140625 | loss_avg 0.15871039032936096
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.15854306519031525
train acc 0.94140625 | loss_avg 0.1583675891160965
TEST ACC 0.9017857313156128
train acc 0.94140625 | loss_avg 0.1582021564245224
train acc 0.94140625 | loss_avg 0.15802836418151855
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15786221623420715
train acc 0.9453125 | loss_avg 0.15767735242843628
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15750525891780853
train acc 0.9453125 | loss_avg 0.15732425451278687
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15715235471725464
train acc 0.9453125 | loss_avg 0.15697284042835236
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15680259466171265
train acc 0.9453125 | loss_avg 0.1566300094127655
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15645372867584229
train acc 0.9453125 | loss_avg 0.15628640353679657
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15611153841018677
train acc 0.9453125 | loss_avg 0.15594585239887238
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15578053891658783
train acc 0.9453125 | loss_avg 0.15562006831169128
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15546418726444244
train acc 0.9453125 | loss_avg 0.15529903769493103
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.1551463007926941
train acc 0.9453125 | loss_avg 0.15498407185077667
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15482690930366516
train acc 0.9453125 | loss_avg 0.15467384457588196
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.1545131653547287
train acc 0.9453125 | loss_avg 0.15436090528964996
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.1542012095451355
train acc 0.9453125 | loss_avg 0.1540462076663971
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15389567613601685
train acc 0.9453125 | loss_avg 0.15374130010604858
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.1535954773426056
train acc 0.9453125 | loss_avg 0.15344201028347015
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.1532936543226242
train acc 0.9453125 | loss_avg 0.15314894914627075
TEST ACC 0.9017857313156128
train acc 0.9453125 | loss_avg 0.15299740433692932
train acc 0.94921875 | loss_avg 0.1528538167476654
TEST ACC 0.9017857313156128
train acc 0.94921875 | loss_avg 0.15270265936851501
train acc 0.94921875 | loss_avg 0.15255944430828094
TEST ACC 0.9017857313156128
train acc 0.94921875 | loss_avg 0.15240955352783203
train acc 0.94921875 | loss_avg 0.15226717293262482
TEST ACC 0.9017857313156128
train acc 0.94921875 | loss_avg 0.15211808681488037
train acc 0.94921875 | loss_avg 0.15197843313217163
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1518310308456421
train acc 0.94921875 | loss_avg 0.1516880840063095
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.15155063569545746
train acc 0.94921875 | loss_avg 0.15140490233898163
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.15126575529575348
train acc 0.94921875 | loss_avg 0.15113000571727753
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.15098704397678375
train acc 0.94921875 | loss_avg 0.1508481204509735
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.15071412920951843
train acc 0.94921875 | loss_avg 0.15056562423706055
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1504373699426651
train acc 0.94921875 | loss_avg 0.1502998173236847
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.15015427768230438
train acc 0.94921875 | loss_avg 0.150024875998497
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14988094568252563
train acc 0.94921875 | loss_avg 0.14975112676620483
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14961586892604828
train acc 0.94921875 | loss_avg 0.1494731456041336
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14934542775154114
train acc 0.94921875 | loss_avg 0.14920321106910706
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14907599985599518
train acc 0.94921875 | loss_avg 0.14894233644008636
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14879931509494781
train acc 0.94921875 | loss_avg 0.14866892993450165
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14852219820022583
train acc 0.94921875 | loss_avg 0.14839087426662445
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14824171364307404
train acc 0.94921875 | loss_avg 0.1481153964996338
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.147975355386734
train acc 0.94921875 | loss_avg 0.14783336222171783
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14770109951496124
train acc 0.94921875 | loss_avg 0.14756101369857788
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1474231481552124
train acc 0.94921875 | loss_avg 0.1472817361354828
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14713449776172638
train acc 0.94921875 | loss_avg 0.14698852598667145
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1468524932861328
train acc 0.94921875 | loss_avg 0.14671657979488373
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1465775966644287
train acc 0.94921875 | loss_avg 0.1464356929063797
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1462976485490799
train acc 0.94921875 | loss_avg 0.14614802598953247
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14601951837539673
train acc 0.94921875 | loss_avg 0.1458773910999298
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1457415372133255
train acc 0.94921875 | loss_avg 0.14560052752494812
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14546649158000946
train acc 0.94921875 | loss_avg 0.1453273743391037
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14519169926643372
train acc 0.94921875 | loss_avg 0.14505407214164734
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14491894841194153
train acc 0.94921875 | loss_avg 0.14478179812431335
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14464686810970306
train acc 0.94921875 | loss_avg 0.14450645446777344
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14437445998191833
train acc 0.94921875 | loss_avg 0.14423400163650513
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1441049873828888
train acc 0.94921875 | loss_avg 0.14396800100803375
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1438392698764801
train acc 0.94921875 | loss_avg 0.14370322227478027
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.1435755491256714
train acc 0.94921875 | loss_avg 0.14343003928661346
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14330874383449554
train acc 0.94921875 | loss_avg 0.14317724108695984
TEST ACC 0.9174107313156128
train acc 0.94921875 | loss_avg 0.14304454624652863
train acc 0.95703125 | loss_avg 0.14291951060295105
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.1427880972623825
train acc 0.95703125 | loss_avg 0.14266254007816315
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.1425303816795349
train acc 0.95703125 | loss_avg 0.1423972249031067
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.14228180050849915
train acc 0.95703125 | loss_avg 0.14214980602264404
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.1420309841632843
train acc 0.95703125 | loss_avg 0.14191022515296936
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.1417785882949829
train acc 0.95703125 | loss_avg 0.14166706800460815
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.14154872298240662
train acc 0.95703125 | loss_avg 0.14141768217086792
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.14130577445030212
train acc 0.95703125 | loss_avg 0.14118583500385284
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.14105860888957977
train acc 0.95703125 | loss_avg 0.14094746112823486
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.14082780480384827
train acc 0.95703125 | loss_avg 0.14071092009544373
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.140586718916893
train acc 0.95703125 | loss_avg 0.14047321677207947
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.14035829901695251
train acc 0.95703125 | loss_avg 0.14022979140281677
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.1401243507862091
train acc 0.95703125 | loss_avg 0.14000463485717773
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.139882430434227
train acc 0.95703125 | loss_avg 0.13977576792240143
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13965557515621185
train acc 0.95703125 | loss_avg 0.13954254984855652
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13942044973373413
train acc 0.95703125 | loss_avg 0.13931168615818024
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13919851183891296
train acc 0.95703125 | loss_avg 0.13907405734062195
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13897131383419037
train acc 0.95703125 | loss_avg 0.13885466754436493
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13872690498828888
train acc 0.95703125 | loss_avg 0.13861845433712006
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13849741220474243
train acc 0.95703125 | loss_avg 0.13837607204914093
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13826580345630646
train acc 0.95703125 | loss_avg 0.13814090192317963
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13803434371948242
train acc 0.95703125 | loss_avg 0.13791689276695251
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.1377953290939331
train acc 0.95703125 | loss_avg 0.13768677413463593
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13757291436195374
train acc 0.95703125 | loss_avg 0.1374494433403015
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13734285533428192
train acc 0.95703125 | loss_avg 0.13722985982894897
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13711082935333252
train acc 0.95703125 | loss_avg 0.13700652122497559
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.1368950605392456
train acc 0.95703125 | loss_avg 0.1367751955986023
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13667112588882446
train acc 0.95703125 | loss_avg 0.13655227422714233
TEST ACC 0.9174107313156128
train acc 0.95703125 | loss_avg 0.13644972443580627
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model, X_test, y_test)Output:
<Figure size 1200x600 with 2 Axes>

second pass took 12min to write from scratch. no copying… still kinda slow
a = torch.rand((8))
b = torch.rand((8))
b = b.unsqueeze(0)
a = a.unsqueeze(1)
a.shape, b.shape
a.shape, b.shape, b@a # DUDE HOW DID YOU NOT KNOW HOW TO DO A DOT PRODUCT WITH MATMUL????????? 1X8 8X1 MAKES 1X1!!!!!!Output:
(torch.Size([8, 1]), torch.Size([1, 8]), tensor([[1.9803]]))
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=1000, random_state=88)
X.shape, y.shape
plt.scatter(x=X[:,0], y=X[:,1], c=y, cmap=plt.cm.RdYlBu)Output:
<matplotlib.collections.PathCollection at 0x7d60382d40e0>
Output:
<Figure size 640x480 with 1 Axes>

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
import tqdm as tqdm
BATCH_SIZE = 32
LEARNING_RATE = 0.01
EPOCHS = 100
EVAL_INT = 10
X = torch.as_tensor(X, dtype=torch.float)
y = torch.as_tensor(y)
shuffle = torch.randperm(len(X))
X = X[shuffle]
y = y[shuffle]
split = int(0.8*len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
train_dl = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE, shuffle=True)
class ClusterModel(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(2, 30)
self.lin1 = nn.Linear(30, 3)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.lin0(x))
x = self.lin1(x) # you gotta make sure that you are passing all the layers in right
return x
model = ClusterModel()
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
def train(model, train_dl, loss_fn, opt):
for x_b, y_b in train_dl:
pred = model(x_b)
loss = loss_fn(pred, y_b)
opt.zero_grad()
loss.backward()
opt.step()
def test(model, test_dl):
for x_b, y_b in test_dl:
pred = model(x_b)
for e in tqdm.tqdm(range(EPOCHS)):
train(model, train_dl, loss_fn, opt)
if e % EVAL_INT == 0:
with torch.inference_mode():
test(model, test_dl)
print("Done")
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model, X_test, y_test)Output:
0%| | 0/100 [00:00<?, ?it/s]
Output:
100%|██████████| 100/100 [00:00<00:00, 139.38it/s]
Output:
Done
Output:
<Figure size 1200x600 with 2 Axes>

