Skip to content

Commit

Permalink
Refactoring hash-dependent functions (#63)
Browse files Browse the repository at this point in the history
* Fixing PermissionError on NamedTempFile

* Revert "Fixing PermissionError on NamedTempFile"

This reverts commit d6eb146.

* Improving how to deal with hashes
  • Loading branch information
Alexandre de Siqueira authored Mar 15, 2021
1 parent fc56abd commit db2b552
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 98 deletions.
128 changes: 36 additions & 92 deletions butterfly/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from pooch import retrieve
from urllib import request

import hashlib
import socket


Expand All @@ -12,15 +14,9 @@
}

URL_HASH = {
'id_gender' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_gender/.SHA256SUM_ONLINE-id_gender',
'id_position' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_position/.SHA256SUM_ONLINE-id_position',
'segmentation' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/segmentation/.SHA256SUM_ONLINE-segmentation'
}

LOCAL_HASH = {
'id_gender' : Path('./models/SHA256SUM-id_gender'),
'id_position' : Path('./models/SHA256SUM-id_position'),
'segmentation' : Path('./models/SHA256SUM-segmentation')
'id_gender' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_gender/SHA256SUM-id_gender',
'id_position' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/id_position/SHA256SUM-id_position',
'segmentation' : 'https://gitlab.com/alexdesiqueira/mothra-models/-/raw/main/models/segmentation/SHA256SUM-segmentation'
}


Expand All @@ -39,39 +35,8 @@ def _get_model_info(weights):
URL of the file for the latest model.
url_hash : str
URL of the hash file for the latest model.
local_hash : pathlib.Path
Path of the local hash file.
"""
return (URL_MODEL.get(weights.stem), URL_HASH.get(weights.stem),
LOCAL_HASH.get(weights.stem))


def _check_hashes(weights):
"""Helping function. Downloads hashes for `weights` if they are not
present.
Parameters
----------
weights : str or pathlib.Path
Path of the file containing weights.
Returns
-------
None
"""
_, url_hash, local_hash = _get_model_info(weights)

if not local_hash.is_file():
download_hash_from_url(url_hash=url_hash, filename=local_hash)

# creating filename to save url_hash.
filename = local_hash.parent/Path(url_hash).name

if not filename.is_file():
download_hash_from_url(url_hash=url_hash, filename=filename)


return None
return (URL_MODEL.get(weights.stem), URL_HASH.get(weights.stem))


def download_weights(weights):
Expand All @@ -86,9 +51,7 @@ def download_weights(weights):
-------
None
"""
# check if hashes are in disk, then get info from the model.
_check_hashes(weights)
_, url_hash, local_hash = _get_model_info(weights)
_, url_hash = _get_model_info(weights)

# check if weights is in its folder. If not, download the file.
if not weights.is_file():
Expand All @@ -97,34 +60,15 @@ def download_weights(weights):
# file exists: check if we have the last version; download if not.
else:
if has_internet():
local_hash_val = read_hash_local(filename=local_hash)
url_hash_val = read_hash_from_url(path=local_hash.parent,
url_hash=url_hash)
local_hash_val = read_hash_local(weights)
url_hash_val = read_hash_from_url(url_hash)
if local_hash_val != url_hash_val:
print('New training data available. Downloading...')
fetch_data(weights)

return None


def download_hash_from_url(url_hash, filename):
"""Downloads hash from `url_hash`.
Parameters
----------
url_hash : str
URL of the SHA256 hash.
filename : str
Filename to save the SHA256 hash locally.
Returns
-------
None
"""
retrieve(url=url_hash, known_hash=None, fname=filename, path='.')
return None


def fetch_data(weights):
"""Downloads and checks the hash of `weights`, according to its filename.
Expand All @@ -137,15 +81,11 @@ def fetch_data(weights):
-------
None
"""
url_model, url_hash, local_hash = _get_model_info(weights)
url_model, url_hash = _get_model_info(weights)

# creating filename to save url_hash.
filename = local_hash.parent/Path(url_hash).name

download_hash_from_url(url_hash=url_hash, filename=filename)
local_hash_val = read_hash_local(local_hash)
url_hash_val = read_hash_from_url(url_hash)
retrieve(url=url_model,
known_hash=f'sha256:{local_hash_val}',
known_hash=f'sha256:{url_hash_val}',
fname=weights,
path='.')

Expand All @@ -166,40 +106,44 @@ def has_internet():
return socket.gethostbyname(socket.gethostname()) != '127.0.0.1'


def read_hash_local(filename):
"""Reads local SHA256 hash file.
def read_hash_local(weights):
"""Reads local SHA256 hash from weights.
Parameters
----------
filename : pathlib.Path
Path of the hash file.
weights : str or pathlib.Path
Path of the file containing weights.
Returns
-------
local_hash : str
SHA256 hash.
local_hash : str or None
SHA256 hash of weights file.
Notes
-----
Returns None if file is not found.
"""
BUFFER_SIZE = 65536
sha256 = hashlib.sha256()

try:
with open(filename, 'r') as file_hash:
hashes = [line for line in file_hash]
# expecting only one hash, and not interested in the filename:
local_hash, _ = hashes[0].split()
with open(weights, 'rb') as file_weights:
while True:
data = file_weights.read(BUFFER_SIZE)
if not data:
break
sha256.update(data)
local_hash = sha256.hexdigest()
except FileNotFoundError:
local_hash = None
return local_hash


def read_hash_from_url(path, url_hash):
"""Downloads and returns the SHA256 hash online for the file in `url_hash`.
def read_hash_from_url(url_hash):
"""Returns the SHA256 hash online for the file in `url_hash`.
Parameters
----------
path : str
Where to look for the hash file.
url_hash : str
URL of the hash file for the latest model.
Expand All @@ -208,14 +152,14 @@ def read_hash_from_url(path, url_hash):
online_hash : str
SHA256 hash for the file in `url_hash`.
"""
filename = Path(url_hash).name
latest_hash = Path(f'{path}/{filename}')
user_agent = 'Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.9.0.7) Gecko/2009021910 Firefox/3.0.7'
headers = {'User-Agent':user_agent,}

download_hash_from_url(url_hash=url_hash, filename=filename)
with open(latest_hash, 'r') as file_hash:
hashes = [line for line in file_hash]
aux_req = request.Request(url_hash, None, headers)
response = request.urlopen(aux_req)
hashes = response.read()

# expecting only one hash, and not interested in the filename:
online_hash, _ = hashes[0].split()
online_hash, _ = hashes.decode('ascii').split()

return online_hash
1 change: 0 additions & 1 deletion models/.SHA256SUM_ONLINE-id_gender

This file was deleted.

1 change: 0 additions & 1 deletion models/.SHA256SUM_ONLINE-id_position

This file was deleted.

1 change: 0 additions & 1 deletion models/.SHA256SUM_ONLINE-segmentation

This file was deleted.

1 change: 0 additions & 1 deletion models/SHA256SUM-id_gender

This file was deleted.

1 change: 0 additions & 1 deletion models/SHA256SUM-id_position

This file was deleted.

1 change: 0 additions & 1 deletion models/SHA256SUM-segmentation

This file was deleted.

0 comments on commit db2b552

Please sign in to comment.