Skip to content

Commit

Permalink
Updates and fixes for newer PyTorch versions. Using PyTorch's tensorb…
Browse files Browse the repository at this point in the history
…oard instead of TensorboardX.
  • Loading branch information
martin-danelljan committed Jan 16, 2020
1 parent 8370fe9 commit cf49b54
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 26 deletions.
4 changes: 2 additions & 2 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
- It is possible to use any PyTorch supported version of CUDA (not necessarily v10).
- For more details about PyTorch installation, see https://pytorch.org/get-started/previous-versions/.

#### Install matplotlib, pandas, opencv, visdom and tensorboadX
#### Install matplotlib, pandas, opencv, visdom and tensorboad
```bash
conda install matplotlib pandas
pip install opencv-python tensorboardX visdom
pip install opencv-python visdom tb-nightly
```


Expand Down
4 changes: 2 additions & 2 deletions INSTALL_win.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
- It is possible to use any PyTorch supported version of CUDA (not necessarily v10), but better be the same version with your preinstalled CUDA (if you have one)
- For more details about PyTorch installation, see https://pytorch.org/get-started/previous-versions/.

#### Install matplotlib, pandas, opencv, visdom and tensorboadX
#### Install matplotlib, pandas, opencv, visdom and tensorboad
```bash
conda install matplotlib pandas
pip install opencv-python tensorboardX visdom
pip install opencv-python visdom tb-nightly
```


Expand Down
10 changes: 0 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
# PyTracking
A general python framework for training and running visual object trackers, based on **PyTorch**.

### **News:** Code released for **DiMP**!!!
Code now released for our new tracker **DiMP**, accepted as an Oral at ICCV 2019.
This release also includes many **new features**, including:
* Visualization with Visdom
* VOT integration
* Many new network modules
* Multi GPU training
* PyTorch v1.2 support


## Highlights

Expand Down
4 changes: 2 additions & 2 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ pip install opencv-python

echo ""
echo ""
echo "****************** Installing tensorboardX ******************"
pip install tensorboardX
echo "****************** Installing tensorboard ******************"
pip install tb-nightly

echo ""
echo ""
Expand Down
23 changes: 17 additions & 6 deletions ltr/admin/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
import sys
from pathlib import Path
import importlib
import inspect


def load_trained_network(workspace_dir, network_path, checkpoint=None):
checkpoint_dir = os.path.join(workspace_dir, 'checkpoints')
directory = '{}/{}'.format(checkpoint_dir, network_path)

net, _ = load_network(directory, checkpoint)
return net

def load_network(network_dir=None, checkpoint=None, constructor_fun_name=None, constructor_module=None, **kwargs):
"""Loads a network checkpoint file.
Expand All @@ -18,6 +26,7 @@ def load_network(network_dir=None, checkpoint=None, constructor_fun_name=None, c
The extra keyword arguments are supplied to the network constructor to replace saved ones.
"""


if network_dir is not None:
net_path = Path(network_dir)
else:
Expand Down Expand Up @@ -58,14 +67,16 @@ def load_network(network_dir=None, checkpoint=None, constructor_fun_name=None, c
net_constr.fun_name = constructor_fun_name
if constructor_module is not None:
net_constr.fun_module = constructor_module
for arg, val in kwargs.items():
if arg in net_constr.kwds.keys():
net_constr.kwds[arg] = val
else:
print('WARNING: Keyword argument "{}" not found when loading network.'.format(arg))
# Legacy networks before refactoring
if net_constr.fun_module.startswith('dlframework.'):
net_constr.fun_module = net_constr.fun_module[len('dlframework.'):]
net_fun = getattr(importlib.import_module(net_constr.fun_module), net_constr.fun_name)
net_fun_args = list(inspect.signature(net_fun).parameters.keys())
for arg, val in kwargs.items():
if arg in net_fun_args:
net_constr.kwds[arg] = val
else:
print('WARNING: Keyword argument "{}" not found when loading network. It was ignored.'.format(arg))
net = net_constr.get()
else:
raise RuntimeError('No constructor for the given network.')
Expand Down Expand Up @@ -118,4 +129,4 @@ def _cleanup_legacy_env():
if m.startswith('dlframework'):
del_modules.append(m)
for m in del_modules:
del sys.modules[m]
del sys.modules[m]
2 changes: 1 addition & 1 deletion ltr/admin/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from collections import OrderedDict
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter


class TensorboardWriter:
Expand Down
2 changes: 1 addition & 1 deletion ltr/models/target_classifier/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def forward(self, weights, feat, bb, sample_weight=None, num_iter=None, compute_
scores_grad = sample_weight * (score_mask * scores_grad)

# Compute optimal step length
alpha_num = (weights_grad * weights_grad).view(num_sequences, -1).sum(dim=1)
alpha_num = (weights_grad * weights_grad).sum(dim=(1,2,3))
alpha_den = ((scores_grad * scores_grad).view(num_images, num_sequences, -1).sum(dim=(0,2)) + reg_weight * alpha_num).clamp(1e-8)
alpha = alpha_num / alpha_den

Expand Down
3 changes: 1 addition & 2 deletions pytracking/libs/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,4 @@ def conv1x1(input: torch.Tensor, weight: torch.Tensor):
if weight is None:
return input

return torch.matmul(weight.view(weight.shape[0], weight.shape[1]),
input.view(input.shape[0], input.shape[1], -1)).view(input.shape[0], weight.shape[0], input.shape[2], input.shape[3])
return torch.conv2d(input, weight)

0 comments on commit cf49b54

Please sign in to comment.