Skip to content

Commit

Permalink
fix python scripts for local mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Pop John committed Aug 6, 2024
1 parent 9ca0bd0 commit d771bcb
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import { ScriptFacadeService } from '../../../core/services';
import {
AlgorithmKey,
AlgorithmType,
AlgorithmTypeKeyValue,
TrainAlgorithmsEnum
} from '../../../model-compression/models/enums/algorithms.enum';
import { isScriptActive } from '../../../model-compression/models/enums/script-status.enum';
Expand All @@ -45,9 +44,20 @@ import { isNilOrEmptyString } from '../../../shared/shared.utils';
export class PanelAlgorithmTypeForTrainingComponent implements OnInit {
@Input({ required: true }) controlKey = '';

readonly algorithmTypesOptions = AlgorithmTypeKeyValue.filter(
(option) => option.key !== AlgorithmType.TRAIN && option.key !== AlgorithmType.AWQ
);
readonly algorithmTypesOptions = [
{
key: AlgorithmType.QUANTIZATION,
value: 'Quantization'
},
{
key: AlgorithmType.PRUNING,
value: 'Pruning'
},
{
key: AlgorithmType.MACHINE_UNLEARNING,
value: 'Machine Unlearning'
}
];
readonly ALGORITHM_TYPE_CONTROL_NAME = 'algorithmType';

get parentFormGroup() {
Expand Down
21 changes: 15 additions & 6 deletions machine_learning_core/examples_pruning/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
script_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_checkpoints_dir = os.path.join(script_dir, 'models_checkpoints')

def prepare_model(model_arch = 'resnet18', device='cpu', logger=None):
def prepare_model(model_arch='resnet18', device='cpu', logger=None):
logger.log(f'==> Building model {model_arch}...')

if model_arch in globals():
Expand All @@ -48,23 +48,32 @@ def prepare_model(model_arch = 'resnet18', device='cpu', logger=None):
net = model_constructor()
net = net.to(device)

checkpoint_path = os.path.join(script_dir, models_checkpoints_dir, f'{model_arch}.pt')
checkpoint_path = os.path.join(models_checkpoints_dir, f'{model_arch}.pt')
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
new_state_dict = OrderedDict()

for k, v in checkpoint.items():
name = k[7:]
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v

net.load_state_dict(new_state_dict)
logger.log(f"Model state dict keys: {net.state_dict().keys()}")
logger.log(f"Checkpoint state dict keys: {new_state_dict.keys()}")

missing_keys, unexpected_keys = net.load_state_dict(new_state_dict, strict=False)

if missing_keys:
logger.log(f"Missing keys when loading state dict: {missing_keys}")
if unexpected_keys:
logger.log(f"Unexpected keys when loading state dict: {unexpected_keys}")

logger.log(f"Loaded checkpoint for {model_arch} from {checkpoint_path}")
except FileNotFoundError:
error_msg = f"No checkpoint found for {model_arch} at {checkpoint_path}. Please train the model first."
logger.log(error_msg)
raise FileNotFoundError(error_msg)
except KeyError:
error_msg = f"Checkpoint for {model_arch} at {checkpoint_path} does not have the expected format. Please ensure the checkpoint is correct and try again."
except KeyError as e:
error_msg = f"Checkpoint for {model_arch} at {checkpoint_path} does not have the expected format: {e}. Please ensure the checkpoint is correct and try again."
logger.log(error_msg)
raise RuntimeError(error_msg)

Expand Down
19 changes: 14 additions & 5 deletions machine_learning_core/examples_quant/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,32 @@ def prepare_model(model_arch = 'resnet18', device='cpu', logger=None):
net = model_constructor()
net = net.to(device)

checkpoint_path = os.path.join(script_dir, models_checkpoints_dir, f'{model_arch}.pt')
checkpoint_path = os.path.join(models_checkpoints_dir, f'{model_arch}.pt')
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
new_state_dict = OrderedDict()

for k, v in checkpoint.items():
name = k[7:]
name = k[7:] if k.startswith('module.') else k # Strip 'module.' prefix if it exists
new_state_dict[name] = v

net.load_state_dict(new_state_dict)
logger.log(f"Model state dict keys: {net.state_dict().keys()}")
logger.log(f"Checkpoint state dict keys: {new_state_dict.keys()}")

missing_keys, unexpected_keys = net.load_state_dict(new_state_dict, strict=False)

if missing_keys:
logger.log(f"Missing keys when loading state dict: {missing_keys}")
if unexpected_keys:
logger.log(f"Unexpected keys when loading state dict: {unexpected_keys}")

logger.log(f"Loaded checkpoint for {model_arch} from {checkpoint_path}")
except FileNotFoundError:
error_msg = f"No checkpoint found for {model_arch} at {checkpoint_path}. Please train the model first."
logger.log(error_msg)
raise FileNotFoundError(error_msg)
except KeyError:
error_msg = f"Checkpoint for {model_arch} at {checkpoint_path} does not have the expected format. Please ensure the checkpoint is correct and try again."
except KeyError as e:
error_msg = f"Checkpoint for {model_arch} at {checkpoint_path} does not have the expected format: {e}. Please ensure the checkpoint is correct and try again."
logger.log(error_msg)
raise RuntimeError(error_msg)

Expand Down
21 changes: 15 additions & 6 deletions machine_learning_core/examples_unlearning/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
script_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_checkpoints_dir = os.path.join(script_dir, 'models_checkpoints')

def prepare_model(model_arch = 'resnet18', device='cpu', logger=None):
def prepare_model(model_arch='resnet18', device='cpu', logger=None):
logger.log(f'==> Building model {model_arch}...')

if model_arch in globals():
Expand All @@ -48,23 +48,32 @@ def prepare_model(model_arch = 'resnet18', device='cpu', logger=None):
net = model_constructor()
net = net.to(device)

checkpoint_path = os.path.join(script_dir, models_checkpoints_dir, f'{model_arch}.pt')
checkpoint_path = os.path.join(models_checkpoints_dir, f'{model_arch}.pt')
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
new_state_dict = OrderedDict()

for k, v in checkpoint.items():
name = k[7:]
name = k[7:] if k.startswith('module.') else k # Strip 'module.' prefix if it exists
new_state_dict[name] = v

net.load_state_dict(new_state_dict)
logger.log(f"Model state dict keys: {net.state_dict().keys()}")
logger.log(f"Checkpoint state dict keys: {new_state_dict.keys()}")

missing_keys, unexpected_keys = net.load_state_dict(new_state_dict, strict=False)

if missing_keys:
logger.log(f"Missing keys when loading state dict: {missing_keys}")
if unexpected_keys:
logger.log(f"Unexpected keys when loading state dict: {unexpected_keys}")

logger.log(f"Loaded checkpoint for {model_arch} from {checkpoint_path}")
except FileNotFoundError:
error_msg = f"No checkpoint found for {model_arch} at {checkpoint_path}. Please train the model first."
logger.log(error_msg)
raise FileNotFoundError(error_msg)
except KeyError:
error_msg = f"Checkpoint for {model_arch} at {checkpoint_path} does not have the expected format. Please ensure the checkpoint is correct and try again."
except KeyError as e:
error_msg = f"Checkpoint for {model_arch} at {checkpoint_path} does not have the expected format: {e}. Please ensure the checkpoint is correct and try again."
logger.log(error_msg)
raise RuntimeError(error_msg)

Expand Down

0 comments on commit d771bcb

Please sign in to comment.