Skip to content

Commit

Permalink
Merge pull request #67 from GabrielBG0/restructuring
Browse files Browse the repository at this point in the history
Moving models to specific folders by use
  • Loading branch information
GabrielBG0 authored Jul 4, 2024
2 parents 4c6e199 + 586648f commit 77ee77e
Show file tree
Hide file tree
Showing 20 changed files with 30 additions and 35 deletions.
8 changes: 4 additions & 4 deletions minerva/models/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .base import SimpleSupervisedModel
from .deeplabv3 import DeepLabV3
from .setr import SETR_PUP
from .unet import UNet
from .wisenet import WiseNet
from .image.deeplabv3 import DeepLabV3
from .image.setr import SETR_PUP
from .image.unet import UNet
from .image.wisenet import WiseNet

__all__ = [
"SimpleSupervisedModel",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Optional, Sequence

from torch import Tensor, load, nn, optim
from torch import Tensor, nn, optim
from torchmetrics import Metric
from torchvision.models.resnet import resnet50
from torchvision.models.segmentation.deeplabv3 import ASPP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import lightning as L
import torch
from torch import nn
from torchmetrics import JaccardIndex, Metric
from torchmetrics import Metric

from minerva.models.nets.vit import _VisionTransformerBackbone
from minerva.utils.upsample import Upsample, resize
from minerva.models.nets.image.vit import _VisionTransformerBackbone
from minerva.utils.upsample import Upsample


class _SETRUPHead(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
""" Full assembly of the parts to form the complete network """

import time
from typing import Dict, Optional
from typing import Optional

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CyclicLR, StepLR

from minerva.models.nets.base import SimpleSupervisedModel

Expand Down Expand Up @@ -227,5 +223,5 @@ def __init__(
loss_fn=loss_fn or torch.nn.MSELoss(),
learning_rate=learning_rate,
flatten=False,
**kwargs
**kwargs,
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Tuple

import lightning as L
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import time
from typing import Tuple

import lightning as L
import numpy as np
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torchmetrics import Accuracy

from minerva.models.nets.base import SimpleSupervisedModel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import time
from functools import partial
from typing import Literal, Tuple
from typing import Tuple

import lightning as L
import numpy as np
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torchmetrics import Accuracy

from minerva.models.nets.base import SimpleSupervisedModel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from minerva.models.nets.deeplabv3 import DeepLabV3, DeepLabV3Backbone
from minerva.models.nets.image.deeplabv3 import DeepLabV3, DeepLabV3Backbone


def test_deeplabv3_model():
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from minerva.models.nets.vit import (
from minerva.models.nets.image.vit import (
mae_vit_base_patch16,
mae_vit_base_patch16D4d256,
mae_vit_huge_patch14,
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from minerva.models.nets.cnns import CNN_HaEtAl_1D, CNN_HaEtAl_2D
from minerva.models.nets.time_series.cnns import CNN_HaEtAl_1D, CNN_HaEtAl_2D


def test_cnn_ha_etal_1d_forward():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from minerva.models.nets.cnns import CNN_PF_2D, CNN_PFF_2D
from minerva.models.nets.time_series.cnns import CNN_PF_2D, CNN_PFF_2D


def test_cnn_pf_forward():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import torch
from minerva.models.nets.imu_transformer import IMUTransformerEncoder, IMUCNN

from minerva.models.nets.time_series.imu_transformer import (
IMUCNN,
IMUTransformerEncoder,
)


def test_imu_transformer_forward():
input_shape = (6, 60)
Expand All @@ -18,4 +23,4 @@ def test_imu_cnn_forward():

x = torch.rand(1, *input_shape)
y = model(x)
assert y is not None
assert y is not None
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from minerva.models.nets.inception_time import InceptionTime

from minerva.models.nets.time_series.inception_time import InceptionTime


def test_inception_time_forward():
input_shape = (6, 60)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch

from minerva.models.nets.resnet import ResNet1D_8, ResNetSE1D_5, ResNetSE1D_8
from minerva.models.nets.time_series.resnet import (
ResNet1D_8,
ResNetSE1D_5,
ResNetSE1D_8,
)


def test_resnet_1d_8_forward():
Expand Down

0 comments on commit 77ee77e

Please sign in to comment.