Skip to content

Commit

Permalink
added invvit
Browse files Browse the repository at this point in the history
  • Loading branch information
guochengqian committed Oct 19, 2023
1 parent d35aa63 commit f2f7532
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 15 deletions.
48 changes: 48 additions & 0 deletions cfgs/s3dis_pix4point/pix4pointInv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
model:
NAME: BaseSeg
encoder_args:
NAME: InvPointViT
in_channels: 7
embed_dim: 384
depth: 12
num_heads: 6
mlp_ratio: 4.
drop_rate: 0.
attn_drop_rate: 0.0
drop_path_rate: 0.1
add_pos_each_block: True
qkv_bias: True
act_args:
act: 'gelu' # better than relu
norm_args:
norm: 'ln'
eps: 1.0e-6
embed_args:
NAME: P3Embed
feature_type: 'dp_df' # show an abaltion study of this.
reduction: 'max'
sample_ratio: 0.0625
normalize_dp: False
group_size: 32
subsample: 'fps' # random, FPS
group: 'knn' # change it to group args.
conv_args:
order: conv-norm-act
layers: 4
norm_args:
norm: 'ln2d'
decoder_args:
NAME: PointViTDecoder
channel_scaling: 1
global_feat: cls,max
progressive_input: True
cls_args:
NAME: SegHead
num_classes: 13
in_channels: null
mlps: [256]
norm_args:
norm: 'ln1d'

mode: finetune_encoder_inv
pretrained_path: pretrained/imagenet/mae_s.pth
6 changes: 2 additions & 4 deletions docs/projects/pix4point.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ Please refer to DeiT's repo for details.
- finetune Image Pretrained Transformer
```bash
CUDA_VISIBLE_DEVICES=0 python examples/segmentation/main.py --cfg cfgs/s3dis_sphere_pix4point/pix4point.yaml
CUDA_VISIBLE_DEVICES=0 python examples/segmentation/main.py --cfg cfgs/s3dis_pix4point/pix4point.yaml
```
- test
```bash
CUDA_VISIBLE_DEVICES=0 python examples/segmentation/main.py --cfg cfgs/s3dis_sphere_pix4point/pix4point.yaml mode=test pretrained_path=<pretrained_path>
CUDA_VISIBLE_DEVICES=0 python examples/segmentation/main.py --cfg cfgs/s3dis_pix4point/pix4point.yaml mode=test pretrained_path=<pretrained_path>
```
Expand Down Expand Up @@ -87,5 +87,3 @@ If you are using our code in your work, please kindly cite the following:
primaryClass={cs.CV}
}
```
6 changes: 5 additions & 1 deletion examples/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch, torch.nn as nn
from torch import distributed as dist
from torch.utils.tensorboard import SummaryWriter
from openpoints.utils import set_random_seed, save_checkpoint, load_checkpoint, resume_checkpoint, setup_logger_dist, \
from openpoints.utils import set_random_seed, save_checkpoint, load_checkpoint, load_checkpoint_inv, resume_checkpoint, setup_logger_dist, \
cal_model_parm_nums, Wandb
from openpoints.utils import AverageMeter, ConfusionMatrix, get_mious
from openpoints.dataset import build_dataloader_from_cfg
Expand Down Expand Up @@ -148,6 +148,10 @@ def main(gpu, cfg, profile=False):
# finetune the whole model
logging.info(f'Finetuning from {cfg.pretrained_path}')
load_checkpoint(model.encoder, cfg.pretrained_path)
elif cfg.mode == 'finetune_encoder_inv':
# finetune the whole model
logging.info(f'Finetuning from {cfg.pretrained_path}')
load_checkpoint_inv(model.encoder, cfg.pretrained_path)
else:
logging.info('Training from scratch')
train_loader = build_dataloader_from_cfg(cfg.batch_size,
Expand Down
27 changes: 18 additions & 9 deletions examples/segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.tensorboard import SummaryWriter
from torch_scatter import scatter
from openpoints.utils import set_random_seed, save_checkpoint, load_checkpoint, resume_checkpoint, setup_logger_dist, \
cal_model_parm_nums, Wandb, generate_exp_directory, resume_exp_directory, EasyConfig, dist_utils, find_free_port
cal_model_parm_nums, Wandb, generate_exp_directory, resume_exp_directory, EasyConfig, dist_utils, find_free_port, load_checkpoint_inv
from openpoints.utils import AverageMeter, ConfusionMatrix, get_mious
from openpoints.dataset import build_dataloader_from_cfg, get_features_by_keys, get_class_weights
from openpoints.dataset.data_util import voxelize
Expand Down Expand Up @@ -174,7 +174,7 @@ def main(gpu, cfg):
else:
if cfg.mode == 'val':
best_epoch, best_val = load_checkpoint(model, pretrained_path=cfg.pretrained_path)
val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg, num_votes=1)
val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg, num_votes=1, epoch=epoch)
with np.printoptions(precision=2, suppress=True):
logging.info(
f'Best ckpt @E{best_epoch}, val_oa , val_macc, val_miou: {val_oa:.2f} {val_macc:.2f} {val_miou:.2f}, '
Expand All @@ -196,8 +196,13 @@ def main(gpu, cfg):
return test_miou

elif 'encoder' in cfg.mode:
logging.info(f'Finetuning from {cfg.pretrained_path}')
load_checkpoint(model_module.encoder, cfg.pretrained_path, cfg.get('pretrained_module', None))
if 'inv' in cfg.mode:
logging.info(f'Finetuning from {cfg.pretrained_path}')
load_checkpoint_inv(model.encoder, cfg.pretrained_path)
else:
logging.info(f'Finetuning from {cfg.pretrained_path}')
load_checkpoint(model_module.encoder, cfg.pretrained_path, cfg.get('pretrained_module', None))

else:
logging.info(f'Finetuning from {cfg.pretrained_path}')
load_checkpoint(model, cfg.pretrained_path, cfg.get('pretrained_module', None))
Expand Down Expand Up @@ -243,7 +248,7 @@ def main(gpu, cfg):

is_best = False
if epoch % cfg.val_freq == 0:
val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg)
val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg, epoch=epoch)
if val_miou > best_val:
is_best = True
best_val = val_miou
Expand Down Expand Up @@ -295,7 +300,8 @@ def main(gpu, cfg):
load_checkpoint(model, pretrained_path=os.path.join(cfg.ckpt_dir, f'{cfg.run_name}_ckpt_best.pth'))
cfg.csv_path = os.path.join(cfg.run_dir, cfg.run_name + f'.csv')
if 'sphere' in cfg.dataset.common.NAME.lower():
test_miou, test_macc, test_oa, test_ious, test_accs = validate_sphere(model, val_loader, cfg)
# TODO:
test_miou, test_macc, test_oa, test_ious, test_accs = validate_sphere(model, val_loader, cfg, epoch=epoch)
else:
data_list = generate_data_list(cfg)
test_miou, test_macc, test_oa, test_ious, test_accs, _ = test(model, data_list, cfg)
Expand All @@ -313,7 +319,7 @@ def main(gpu, cfg):
load_checkpoint(model, pretrained_path=os.path.join(cfg.ckpt_dir, f'{cfg.run_name}_ckpt_best.pth'))
set_random_seed(cfg.seed)
val_miou, val_macc, val_oa, val_ious, val_accs = validate_fn(model, val_loader, cfg, num_votes=20,
data_transform=data_transform)
data_transform=data_transform, epoch=epoch)
if writer is not None:
writer.add_scalar('val_miou20', val_miou, cfg.epochs + 50)

Expand Down Expand Up @@ -349,6 +355,7 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler
vis_points(data['pos'].cpu().numpy()[0], data['x'][0, :3, :].transpose(1, 0))
end of debug """
data['x'] = get_features_by_keys(data, cfg.feature_keys)
data['epoch'] = epoch
with torch.cuda.amp.autocast(enabled=cfg.use_amp):
logits = model(data)
loss = criterion(logits, target) if 'mask' not in cfg.criterion_args.NAME.lower() \
Expand Down Expand Up @@ -386,7 +393,7 @@ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler


@torch.no_grad()
def validate(model, val_loader, cfg, num_votes=1, data_transform=None):
def validate(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1):
model.eval() # set model to eval mode
cm = ConfusionMatrix(num_classes=cfg.num_classes, ignore_index=cfg.ignore_index)
pbar = tqdm(enumerate(val_loader), total=val_loader.__len__(), desc='Val')
Expand All @@ -396,6 +403,7 @@ def validate(model, val_loader, cfg, num_votes=1, data_transform=None):
data[key] = data[key].cuda(non_blocking=True)
target = data['y'].squeeze(-1)
data['x'] = get_features_by_keys(data, cfg.feature_keys)
data['epoch'] = epoch
logits = model(data)
if 'mask' not in cfg.criterion_args.NAME or cfg.get('use_maks', False):
cm.update(logits.argmax(dim=1), target)
Expand Down Expand Up @@ -430,7 +438,7 @@ def validate(model, val_loader, cfg, num_votes=1, data_transform=None):


@torch.no_grad()
def validate_sphere(model, val_loader, cfg, num_votes=1, data_transform=None):
def validate_sphere(model, val_loader, cfg, num_votes=1, data_transform=None, epoch=-1):
"""
validation for sphere sampled input points with mask.
in this case, between different batches, there are overlapped points.
Expand All @@ -451,6 +459,7 @@ def validate_sphere(model, val_loader, cfg, num_votes=1, data_transform=None):
for key in data.keys():
data[key] = data[key].cuda(non_blocking=True)
data['x'] = get_features_by_keys(data, cfg.feature_keys)
data['epoch'] = epoch
logits = model(data)
all_logits.append(logits)
idx_points.append(data['input_inds'])
Expand Down

0 comments on commit f2f7532

Please sign in to comment.