Skip to content

Commit

Permalink
fix for torch shape error
Browse files Browse the repository at this point in the history
  • Loading branch information
tanwarsh committed Nov 4, 2024
1 parent 7aca44e commit a0929ea
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions openfl-tutorials/Federated_Pytorch_MNIST_Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,20 @@
"metadata": {},
"outputs": [],
"source": [
"def one_hot(labels, classes):\n",
" return np.eye(classes)[labels]\n",
"\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"\n",
"trainset = torchvision.datasets.MNIST(root='./data', train=True,\n",
" download=True, transform=transform)\n",
"\n",
"train_images,train_labels = trainset.train_data, np.array(trainset.train_labels)\n",
"train_images,train_labels = trainset.data, np.array(trainset.targets)\n",
"train_images = torch.from_numpy(np.expand_dims(train_images, axis=1)).float()\n",
"\n",
"validset = torchvision.datasets.MNIST(root='./data', train=False,\n",
" download=True, transform=transform)\n",
"\n",
"valid_images,valid_labels = validset.test_data, np.array(validset.test_labels)\n",
"valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()\n",
"valid_labels = one_hot(valid_labels,10)"
"valid_images,valid_labels = validset.data, np.array(validset.targets)\n",
"valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()"
]
},
{
Expand Down

0 comments on commit a0929ea

Please sign in to comment.