Skip to content

amazon-science/eideticnet-training

EideticNet

EideticNet is a PyTorch-based framework for training neural networks that can learn multiple tasks sequentially without (catastrophic) forgetting. It accomplishes this iterative pruning, selective deletion of synaptic connections on a task-specific basis, and parameter freezing. The available pruning methods are Taylor expansion-based pruning, Lp-norm weight magnitude pruning, and random pruning.

Features

  • For each task on which a network is trained, accuracy is preserved perfectly.
  • Functions for propagating sparsity induced by pruning to subsequent layers.
  • Forward transfer is configurable. When forward forward transfer is enabled, the features learned during training of previous tasks are available to (the subsequent layers of) subsequent ones. When forward transfer is disabled, each task occupies its own disjoint subnetwork.
  • Batch normalization parameters are preserved for each task.

Installation

  1. Install the package:
pip install -e .
  1. Install required dependencies:
pip install torch torchmetrics torchvision

Note that the command for installing PyTorch may vary depending on your environment.

Supported networks

  • MLP (Multi-Layer Perceptron)

    from eideticnet_training.networks import MLP
    
    mlp = MLP(
        in_features=784,  # Input dimension
        num_classes=[10, 10],  # List of output dimensions for each task
        num_layers=2,  # Number of hidden layers
        width=4096,  # Width of hidden layers
        dropout=0.0,  # Dropout probability
        bn=True  # Use batch normalization
    )
  • ConvNet

    from eideticnet_training.networks import ConvNet
    
    cnn = ConvNet(
        in_channels=3,  # Input channels
        num_classes=[10, 10],  # List of output dimensions for each task
        num_layers=2,  # Number of conv layers per block
        width=32,  # Base width of conv layers
        dropout=0.0,  # Dropout probability
        bn=True  # Use batch normalization
    )
  • ResNet (18/34/50/101)

    from eideticnet_training.networks import ResNet
    
    # ResNet-18
    resnet18 = ResNet(
        in_channels=3,
        num_classes=[10, 10],
        n_blocks=[2, 2, 2, 2],
        expansion=1
    )
    
    # ResNet-34
    resnet34 = ResNet(
        in_channels=3,
        num_classes=[10, 10],
        n_blocks=[3, 4, 6, 3],
        expansion=1
    )
    
    # ResNet-50
    resnet50 = ResNet(
        in_channels=3,
        num_classes=[10, 10],
        n_blocks=[3, 4, 6, 3],
        expansion=4
    )
    
    # ResNet-101
    resnet101 = ResNet(
        in_channels=3,
        num_classes=[10, 10],
        n_blocks=[3, 4, 23, 3],
        expansion=4
    )

Basic Usage

  1. Define your network by inheriting from EideticNetwork:
from eideticnet_training.networks import EideticNetwork

class MyNetwork(EideticNetwork):
    def __init__(self):
        super().__init__()
        # Define network layers

    def forward(self, x):
        # Define forward pass

    def _bridge_prune(self, pct, pruning_type, score_threshold=None):
        # Define pruning connections between layers
  1. Train sequential tasks:
model = MyNetwork()
optimizer = torch.optim.Adam(model.parameters())

for task_id in range(num_tasks):
    # Prepare for new task
    model.prepare_for_task(task_id)

    # Train the task
    model.train_task(
        dataloader=train_loader,
        metric=accuracy_metric,
        optimizer=optimizer,
        test_batch_size=256,
        pruning_step=0.1,  # Prune 10% of parameters per iteration
        pruning_type="l2",  # Use L2 norm-based pruning
        validation_tasks=validation_datasets,
        validation_metrics=validation_accuracy_metrics,
        early_stopping_patience=5
    )

For working examples, see:

  • Simple example: tests/test_eidetic_network.py
  • Full implementation: experiments/sequential_classification.py

Testing

Run the test suite:

pip install -e .'[test]'
pytest

Contributing

See CONTRIBUTING.md.

About

No description, website, or topics provided.

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages