Skip to content

Class Indices Overwritten During Continual Learning with Avalanche 0.4.0a #1503

@Funkvay

Description

@Funkvay

I'm encountering an issue with class indices being overwritten when performing continual learning using Avalanche 0.4.0a.

Environment:

Avalanche Version: 0.4.0a
OS: Ubuntu 20.04
Python Environment: default (not using Conda)

I'm employing Avalanche for continual learning with the aim that when the model trains on new data, it retains knowledge of the old data and successfully integrates the new information. Despite the dataset size, this indexing issue persists. I've been using a small dataset of approximately 10-15 photos for simplicity in troubleshooting. Specifically, after adding 3 new classes to a model initially trained on 5 classes (indexed 0-4), I expected predictions to be indexed as 0,1,2,3,4,5,6,7. However, I received predictions as 0,1,2,3,4,0,1,2.

Here is the code that I have:

import argparse
import os
from torchvision.models import resnet101
import torch
from torch.nn import CrossEntropyLoss
from torchvision import transforms, datasets
from avalanche.benchmarks import nc_benchmark
from avalanche.training.supervised import Naive
from avalanche.training.plugins import ReplayPlugin
from avalanche.evaluation.metrics import (
    forgetting_metrics,
    accuracy_metrics,
    loss_metrics,
)
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
import torch.nn.init as init


def main(args):
    # Load the saved model's weights
    saved_weights = torch.load(args.model_path)
    num_old_classes = saved_weights['fc.weight'].shape[0]

    # Initialize the ResNet-101 model
    model = resnet101()

    # Adjust the final layer to match the number of classes in the saved model
    num_features = model.fc.in_features
    model.fc = torch.nn.Linear(num_features, num_old_classes)
    model.load_state_dict(saved_weights)

    # Determine the number of classes in the new dataset
    num_new_classes = len(os.listdir(os.path.join(args.data_path, 'train')))
    total_classes = num_old_classes + num_new_classes
    
    # Expand the final layer of the model to accommodate the new classes
    weights = model.fc.weight.data
    biases = model.fc.bias.data
    model.fc = torch.nn.Linear(in_features=num_features, out_features=total_classes)
    with torch.no_grad():
        model.fc.weight[:num_old_classes] = weights
        model.fc.bias[:num_old_classes] = biases

    # Initialize weights for the new classes
    init.kaiming_uniform_(model.fc.weight[num_old_classes:], mode='fan_in', nonlinearity='relu')
    init.zeros_(model.fc.bias[num_old_classes:])

    model.train()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define data transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load datasets
    train_dataset = datasets.ImageFolder(root=os.path.join(args.data_path, 'train'), transform=transform)
    test_dataset = datasets.ImageFolder(root=os.path.join(args.data_path, 'val'), transform=transform)

    # Adjust labels for the new dataset to be continuous with the old dataset
    train_dataset.targets = [label + num_old_classes for label in train_dataset.targets]
    test_dataset.targets = [label + num_old_classes for label in test_dataset.targets]

    # Create a continual learning benchmark
    benchmark = nc_benchmark(train_dataset, test_dataset, n_experiences=num_new_classes, task_labels=False, seed=1234)

    # Define logging and evaluation plugins
    interactive_logger = InteractiveLogger()
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        forgetting_metrics(experience=True),
        loggers=[interactive_logger],
    )

    # Define the continual learning strategy
    cl_strategy = Naive(
        model,
        torch.optim.SGD(model.parameters(), lr=0.01),
        CrossEntropyLoss(),
        train_mb_size=8,
        train_epochs=10,
        eval_mb_size=8,
        device=device,
        plugins=[ReplayPlugin(mem_size=1005)],
        evaluator=eval_plugin,
    )

    # Continual learning training loop
    for experience in benchmark.train_stream:
        cl_strategy.train(experience)

    # Save the updated model
    torch.save(model.state_dict(), args.model_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model.")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the new dataset directory.")
    args = parser.parse_args()
    main(args)

I'm trying to understand how to preserve accurate class indexing throughout retraining and evaluation, ideally without changing the core architecture and using Avalanche.

Please suggest how to solve this and make indexation right.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions