-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
87 lines (69 loc) · 2.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
from awave.transform1d import DWT1d
from awave.transform2d import DWT2d
from awave.filtermodel import FilterConv
from config import *
import time
from icecream import ic
import pywt
import os
from torchvision import datasets, transforms, models
def train1d(data, filter_model, device):
# Initializing
awt = DWT1d(filter_model = filter_model, device=device).to(device=device)
# Training
awt.fit(X=data,batch_size = BATCH_SIZE, num_epochs = NUM_EPOCHS, lr= LR)
name = f"models/{awt.__module__}__BATCH-{BATCH_SIZE}__EPOCH-{NUM_EPOCHS}__DATA-{DATA_NAME}__FILTER-{OUT_CHANNELS}__TIME-{time.time()}.pth"
torch.save(awt, name)
def train2d(data, filter_model, device):
# Initializing
awt = DWT2d(filter_model = filter_model, J=LEVEL, device=device, useExistingFilter=False, wave='db3').to(device=device)
# Test Data
data_test = torch.load(f'data/{DATA_NAME}_test.pth')
# Training
awt.fit(X=data, X_test=data_test, batch_size = BATCH_SIZE, num_epochs = NUM_EPOCHS, lr= LR)
return awt
if __name__ == "__main__":
"""Set the device , 'cpu' by default.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"""Provide the filter prediction model.
"""
# model = FilterConv(in_channels = IN_CHANNELS, out_channels = OUT_CHANNELS).to(device)
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(num_ftrs, OUT_CHANNELS)
)
model.to(device)
print(model)
"""Load the data.
"""
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize((0.5), (0.5))
# ])
# original = transform(pywt.data.camera()).squeeze()
# original = torch.stack([original, original, original])
# data = [original for i in range(100)]
# data = torch.stack(data)
# data = torch.load(DATA_PATH).to(device)
# # # ic(data.shape, x[0].shape)
# x = torch.split(data, min(BATCH_SIZE*500, data.size(0)), 0)
# data = torch.rand([1000, 3, 32, 32])
# ic(len(x1))
# ic(x1[0].shape)
# Dry run an example on model
# ic(model(x1[0]).shape)
""" Following line for CIFAR10 dataset"""
data = torch.load(f'data/{DATA_NAME}_train.pth').to(device)
ic(data.shape)
"""Train the model"""
awt = train2d(data, model, device)
if not os.path.exists(f'models/{awt.__module__}/'):
os.mkdir(f'models/{awt.__module__}/')
name = f"models/{awt.__module__}/filtersize_{OUT_CHANNELS}-batchsize_{BATCH_SIZE}-epochs_{NUM_EPOCHS}-LR_{LR}-J{LEVEL}.pth"
torch.save(awt, name)
# train1d(data, model, device)