-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
41 lines (35 loc) · 1018 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from data import LatAccelDataModule
from model import MLControlsSim
import pytorch_lightning as pl
from lightning.pytorch.loggers import CSVLogger
def validate():
model = MLControlsSim.load_from_checkpoint("lightning_logs/version_9/checkpoints/epoch=0-step=1000.ckpt")
data = LatAccelDataModule(
path="../NNFF/data/CHEVROLET_VOLT_PREMIER_2017/000",
batch_size=2 ** 10,
)
trainer = pl.Trainer()
trainer.validate(model, datamodule=data)
def main():
data = LatAccelDataModule(
path="../NNFF/data/CHEVROLET_VOLT_PREMIER_2017/",
batch_size=2 ** 10,
)
model = MLControlsSim(
n_layers=4,
n_head=4,
n_embd=128,
lr=6e-4,
weight_decay=0.1,
)
logger = CSVLogger(".")
trainer = pl.Trainer(
max_steps=10_000,
precision=32,
logger=logger,
val_check_interval=200,
gradient_clip_val=1.0,
)
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
main()