diff --git a/src/main.py b/src/main.py index 243a31e..69793f1 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,4 @@ +import logging from PIL import Image import torch import torch.nn as nn @@ -6,6 +7,10 @@ from torch.utils.data import DataLoader import numpy as np +# Set up logging +logging.basicConfig(level=logging.ERROR) +logger = logging.getLogger('training') + # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ transforms.ToTensor(), @@ -35,14 +40,17 @@ def forward(self, x): optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() -# Training loop +# Training loop with error logging epochs = 3 for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + try: + for images, labels in trainloader: + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + except Exception as e: + logger.error("Exception occurred", exc_info=True) torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file