Skip to content

Commit

Permalink
add encrypted field
Browse files Browse the repository at this point in the history
We can specify `encrypted=True` for any field and I will make this to be
saved encrypted using AES256.
  • Loading branch information
tomkukral committed Dec 14, 2017
1 parent dc92948 commit 514f2d9
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM python:3.6-slim
FROM python:3.6

# prepare directory
WORKDIR /code
Expand Down
81 changes: 72 additions & 9 deletions kqueen/storages/etcd.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from .exceptions import BackendError
from Crypto import Random
from Crypto.Cipher import AES
from datetime import datetime
from dateutil.parser import parse as du_parse
from flask import current_app
from kqueen.config import current_config

import base64
import bcrypt
import etcd
import hashlib
import importlib
import json
import logging
import uuid
import importlib
import six
from datetime import datetime
from dateutil.parser import parse as du_parse
from kqueen.config import current_config
from flask import current_app
from .exceptions import BackendError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,7 +50,12 @@ def __init__(self, *args, **kwargs):
else:
self.value = kwargs.get('value', None)

# set block size for crypto
self.bs = 16

# field parameters
self.required = kwargs.get('required', False)
self.encrypted = kwargs.get('encrypted', False)

def on_create(self, **kwargs):
"""Optional action that should be run only on newly created objects"""
Expand Down Expand Up @@ -97,6 +106,55 @@ def validate(self):
"""
return True

def _get_encryption_key(self):
"""
Read encryption key and format it.
Returns:
Encryption key.
"""

# check for key
config = current_config()
key = config.get('SECRET_KEY')

if key is None:
raise Exception('Missing SECRET_KEY')

# calculate hash passowrd
return hashlib.sha256(key.encode('utf-8')).digest()[:self.bs]

def _pad(self, s):
return s + (self.bs - len(s) % self.bs) * chr(self.bs - len(s) % self.bs)

def _unpad(self, s):
return s[:-ord(s[len(s) - 1:])]

def encrypt(self):
"""Encrypt stored value"""

key = self._get_encryption_key()
padded = self._pad(str(self.serialize()))

iv = Random.new().read(self.bs)
suite = AES.new(key, AES.MODE_CBC, iv)
encrypted = suite.encrypt(padded)
encoded = base64.b64encode(iv + encrypted)

return encoded

def decrypt(self, crypted, **kwargs):
key = self._get_encryption_key()
decoded = base64.b64decode(crypted)

iv = decoded[:self.bs]
suite = AES.new(key, AES.MODE_CBC, iv)
decrypted = suite.decrypt(decoded[self.bs:]).decode('utf-8')

serialized = self._unpad(decrypted)
self.deserialize(serialized, **kwargs)
print('Seralizing from: {}, value: {}'.format(serialized, self.value))

def __str__(self):
return str(self.value)

Expand All @@ -115,7 +173,7 @@ class StringField(Field):
class BoolField(Field):

def deserialize(self, serialized, **kwargs):
if isinstance(serialized, six.string_types):
if isinstance(serialized, str):
value = json.loads(serialized)
self.set_value(value, **kwargs)

Expand Down Expand Up @@ -156,10 +214,15 @@ class DatetimeField(Field):

def deserialize(self, serialized, **kwargs):
value = None

# convert to float if serialized is digit
if isinstance(serialized, str) and serialized.isdigit():
serialized = float(serialized)

if isinstance(serialized, (float, int)):
value = datetime.fromtimestamp(serialized)
self.set_value(value, **kwargs)
elif isinstance(serialized, six.string_types):
elif isinstance(serialized, str):
value = du_parse(serialized)
self.set_value(value, **kwargs)

Expand Down
117 changes: 89 additions & 28 deletions kqueen/storages/test_model_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,30 @@

import datetime
import pytest
import itertools


def create_model(required=False, global_ns=False):
def create_model(required=False, global_ns=False, encrypted=False):
class TestModel(Model, metaclass=ModelMeta):
if global_ns:
global_namespace = global_ns

id = IdField(required=required)
string = StringField(required=required)
json = JSONField(required=required)
password = PasswordField(required=required)
relation = RelationField(required=required)
datetime = DatetimeField(required=required)
boolean = BoolField(required=required)
id = IdField(required=required, encrypte=encrypted)
string = StringField(required=required, encrypte=encrypted)
json = JSONField(required=required, encrypte=encrypted)
password = PasswordField(required=required, encrypted=encrypted)
relation = RelationField(required=required, encrypte=encrypted)
datetime = DatetimeField(required=required, encrypte=encrypted)
boolean = BoolField(required=required, encrypte=encrypted)

_required = required
_global_ns = global_ns
_encrypted = encrypted

if _global_ns:
_namespace = None
else:
_namespace = namespace

return TestModel

Expand Down Expand Up @@ -58,8 +68,25 @@ def model_serialized(related=None):
)


@pytest.fixture
def create_object():
@pytest.fixture(params=itertools.product(*[
[True, False],
[True, False],
[True, False]
]))
def get_object(request):
return create_object(*request.param)


@pytest.fixture(params=itertools.product(*[
[True, False],
[True, False],
[True, False]
]))
def get_model(request):
return create_model(*request.param)


def create_object(required=False, global_ns=False, encrypted=False):
model = create_model()

obj1 = model(namespace, **model_kwargs)
Expand Down Expand Up @@ -111,14 +138,14 @@ def test_save_skip_validation(self):


class TestModelAddId:
def test_id_added(self, create_object):
obj = create_object
def test_id_added(self, get_object):
obj = get_object

assert obj.id is None
assert obj.verify_id()
assert obj.id is not None

create_object.save()
obj.save()


class TestRequiredFields:
Expand All @@ -131,23 +158,23 @@ def test_required(self, required):


class TestGetFieldNames:
def test_get_field_names(self, create_object):
field_names = create_object.__class__.get_field_names()
def test_get_field_names(self, get_object):
field_names = get_object.__class__.get_field_names()
req = model_fields

assert set(field_names) == set(req)

def test_get_dict(self, create_object):
dicted = create_object.get_dict()
def test_get_dict(self, get_object):
dicted = get_object.get_dict()

assert isinstance(dicted, dict)


class TestFieldSetGet:
"""Validate getters and setters for fields"""
@pytest.mark.parametrize('field_name', model_kwargs.keys())
def test_get_fields(self, field_name, create_object):
at = getattr(create_object, field_name)
def test_get_fields(self, field_name, get_object):
at = getattr(get_object, field_name)
req = model_kwargs[field_name]

assert at == req
Expand All @@ -166,22 +193,22 @@ def test_set_fields(self, field_name):
class TestSerialization:
"""Serialization and deserialization create same objects"""

def test_serizalization(self, create_object):
serialized = create_object.serialize()
def test_serizalization(self, get_object):
serialized = get_object.serialize()

assert serialized == model_serialized(related=create_object.relation)
assert serialized == model_serialized(related=get_object.relation)

def test_deserialization(self, create_object, monkeypatch):
def test_deserialization(self, get_object, monkeypatch):
def fake(self, class_name):
return create_object.__class__
return get_object.__class__

monkeypatch.setattr(RelationField, '_get_related_class', fake)

object_class = create_object.__class__
create_object.save()
new_object = object_class.deserialize(create_object.serialize(), namespace=namespace)
object_class = get_object.__class__
get_object.save()
new_object = object_class.deserialize(get_object.serialize(), namespace=namespace)

assert new_object.get_dict() == create_object.get_dict()
assert new_object.get_dict() == get_object.get_dict()


class TestGetDict:
Expand Down Expand Up @@ -401,3 +428,37 @@ def test_dict_value_returns_boolean(self):
self.field.set_value(self.boolean)

assert self.field.dict_value() == self.boolean


#
# Encryption
#
class TestFieldEncryption:
def test_get_encryption_key(self, get_model):

obj = get_model(get_model._namespace, **model_kwargs)

field = obj._string
KEY_LENGTH = obj._string.bs
key = field._get_encryption_key()

assert len(key) == KEY_LENGTH

@pytest.mark.parametrize('field_name, field_value', model_kwargs.items())
def test_encryption_and_decryption(self, field_name, field_value):

cls = create_model(False, False, True)
obj = cls(namespace, **model_kwargs)

field = getattr(obj, '_{}'.format(field_name))
field.set_value(field_value)

# encryption
encrypted = field.encrypt()
print(encrypted)

# decryption
field.set_value(None)
field.decrypt(encrypted)

assert field.value == field_value
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ bcrypt
Flask==0.12.2
Flask-JWT==0.3.2
flask-swagger-ui
gunicorn
kubernetes
prometheus_client
pycrypto
python-etcd
python-jenkins
prometheus_client
pyyaml
requests
gunicorn

# dev
coveralls
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
'flask-swagger-ui',
'gunicorn',
'kubernetes',
'pycrypto',
'prometheus_client',
'python-etcd',
'python-jenkins',
Expand Down

0 comments on commit 514f2d9

Please sign in to comment.