EideticNet is a PyTorch-based framework for training neural networks that can learn multiple tasks sequentially without (catastrophic) forgetting. It accomplishes this iterative pruning, selective deletion of synaptic connections on a task-specific basis, and parameter freezing. The available pruning methods are Taylor expansion-based pruning, Lp-norm weight magnitude pruning, and random pruning.
- For each task on which a network is trained, accuracy is preserved perfectly.
- Functions for propagating sparsity induced by pruning to subsequent layers.
- Forward transfer is configurable. When forward forward transfer is enabled, the features learned during training of previous tasks are available to (the subsequent layers of) subsequent ones. When forward transfer is disabled, each task occupies its own disjoint subnetwork.
- Batch normalization parameters are preserved for each task.
- Install the package:
pip install -e .
- Install required dependencies:
pip install torch torchmetrics torchvision
Note that the command for installing PyTorch may vary depending on your environment.
-
MLP (Multi-Layer Perceptron)
from eideticnet_training.networks import MLP mlp = MLP( in_features=784, # Input dimension num_classes=[10, 10], # List of output dimensions for each task num_layers=2, # Number of hidden layers width=4096, # Width of hidden layers dropout=0.0, # Dropout probability bn=True # Use batch normalization )
-
ConvNet
from eideticnet_training.networks import ConvNet cnn = ConvNet( in_channels=3, # Input channels num_classes=[10, 10], # List of output dimensions for each task num_layers=2, # Number of conv layers per block width=32, # Base width of conv layers dropout=0.0, # Dropout probability bn=True # Use batch normalization )
-
ResNet (18/34/50/101)
from eideticnet_training.networks import ResNet # ResNet-18 resnet18 = ResNet( in_channels=3, num_classes=[10, 10], n_blocks=[2, 2, 2, 2], expansion=1 ) # ResNet-34 resnet34 = ResNet( in_channels=3, num_classes=[10, 10], n_blocks=[3, 4, 6, 3], expansion=1 ) # ResNet-50 resnet50 = ResNet( in_channels=3, num_classes=[10, 10], n_blocks=[3, 4, 6, 3], expansion=4 ) # ResNet-101 resnet101 = ResNet( in_channels=3, num_classes=[10, 10], n_blocks=[3, 4, 23, 3], expansion=4 )
- Define your network by inheriting from
EideticNetwork
:
from eideticnet_training.networks import EideticNetwork
class MyNetwork(EideticNetwork):
def __init__(self):
super().__init__()
# Define network layers
def forward(self, x):
# Define forward pass
def _bridge_prune(self, pct, pruning_type, score_threshold=None):
# Define pruning connections between layers
- Train sequential tasks:
model = MyNetwork()
optimizer = torch.optim.Adam(model.parameters())
for task_id in range(num_tasks):
# Prepare for new task
model.prepare_for_task(task_id)
# Train the task
model.train_task(
dataloader=train_loader,
metric=accuracy_metric,
optimizer=optimizer,
test_batch_size=256,
pruning_step=0.1, # Prune 10% of parameters per iteration
pruning_type="l2", # Use L2 norm-based pruning
validation_tasks=validation_datasets,
validation_metrics=validation_accuracy_metrics,
early_stopping_patience=5
)
For working examples, see:
- Simple example:
tests/test_eidetic_network.py
- Full implementation:
experiments/sequential_classification.py
Run the test suite:
pip install -e .'[test]'
pytest
See CONTRIBUTING.md
.