From a0929eaca6d4a7402048c83176202df5e4dc2c69 Mon Sep 17 00:00:00 2001 From: yes Date: Tue, 22 Oct 2024 02:13:02 -0700 Subject: [PATCH] fix for torch shape error --- .../Federated_Pytorch_MNIST_Tutorial.ipynb | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/openfl-tutorials/Federated_Pytorch_MNIST_Tutorial.ipynb b/openfl-tutorials/Federated_Pytorch_MNIST_Tutorial.ipynb index 9522073280..193f880ca5 100644 --- a/openfl-tutorials/Federated_Pytorch_MNIST_Tutorial.ipynb +++ b/openfl-tutorials/Federated_Pytorch_MNIST_Tutorial.ipynb @@ -65,9 +65,6 @@ "metadata": {}, "outputs": [], "source": [ - "def one_hot(labels, classes):\n", - " return np.eye(classes)[labels]\n", - "\n", "transform = transforms.Compose(\n", " [transforms.ToTensor(),\n", " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", @@ -75,15 +72,13 @@ "trainset = torchvision.datasets.MNIST(root='./data', train=True,\n", " download=True, transform=transform)\n", "\n", - "train_images,train_labels = trainset.train_data, np.array(trainset.train_labels)\n", + "train_images,train_labels = trainset.data, np.array(trainset.targets)\n", "train_images = torch.from_numpy(np.expand_dims(train_images, axis=1)).float()\n", - "\n", "validset = torchvision.datasets.MNIST(root='./data', train=False,\n", " download=True, transform=transform)\n", "\n", - "valid_images,valid_labels = validset.test_data, np.array(validset.test_labels)\n", - "valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()\n", - "valid_labels = one_hot(valid_labels,10)" + "valid_images,valid_labels = validset.data, np.array(validset.targets)\n", + "valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()" ] }, {