Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mikrotik worker v1.6.3 #85

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,7 @@ venv.bak/
*.db

# Logs
*.log
*.log

# Database files
*.sqlite
1 change: 1 addition & 0 deletions env.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export APP_DB_URL="postgresql://whohacks:S3cret@localhost:5432/whohacks"
export OAUTH_OPENID="http://sso.hsp.sh/auth/realms/hsp/.well-known/openid-configuration"
export OAUTH_CLIENT_ID="fake-development-client-id"
export LOGLEVEL="DEBUG" # DEBUG | INFO | WARNING | ERROR
env | grep APP_ > .env
env | grep OUATH_ > .env
52 changes: 52 additions & 0 deletions helpers/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
import os

loggers = {}


def get_loglevel(loglevel: str):
match loglevel:
case "DEBUG":
loglevel = logging.DEBUG
case "INFO":
loglevel = logging.INFO
case "WARNING":
loglevel = logging.WARNING
case "ERROR":
loglevel = logging.ERROR
case "CRITICAL":
loglevel = logging.CRITICAL
case _:
loglevel = logging.INFO

return loglevel


def init_logger(name: str) -> logging.Logger:
if loggers.get(name):
return loggers[name]

logger = logging.getLogger(name)
loglevel = get_loglevel(os.environ.get("LOGLEVEL"))

formatter = logging.Formatter(
fmt=f"({name}) %(asctime)s %(module)s %(levelname)s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)

for handler in logger.handlers:
logger.removeHandler(handler)

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
file_handler = logging.FileHandler(f"{name}.log")
file_handler.setFormatter(formatter)

logger.addHandler(stream_handler)
logger.addHandler(file_handler)

logger.setLevel(loglevel)

loggers[name] = logger

return logger
11 changes: 3 additions & 8 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from unittest import TestCase

from helpers.logger import init_logger
from whois.app import WhohacksApp
from whois.data.db.database import Database
from whois.settings.testing import app_settings, mikrotik_settings
Expand All @@ -9,19 +10,13 @@
class ApiTestCase(TestCase):

def setUp(self):
self.logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s %(module)s %(levelname)s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
level=logging.DEBUG,
force=True,
)
self.logger = init_logger(__name__)
self.logger.addHandler(logging.FileHandler(f"{__name__}.log"))

self.db = Database("sqlite:///whohacks.test.sqlite")
self.db.drop()
self.db.create_db()
self.whois = WhohacksApp(app_settings, mikrotik_settings, self.db, self.logger)
self.whois = WhohacksApp(app_settings, mikrotik_settings, self.db)
self.app = self.whois.app.test_client()
self.app.testing = True

Expand Down
110 changes: 59 additions & 51 deletions whois/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,14 @@
from logging import Logger

from authlib.integrations.flask_client import OAuth
from flask import (
Flask,
abort,
flash,
jsonify,
redirect,
render_template,
request,
url_for,
)
from flask import (Flask, abort, flash, jsonify, redirect, render_template,
request, url_for)
from flask_cors import CORS
from flask_login import (
LoginManager,
current_user,
login_required,
login_user,
logout_user,
)
from flask_login import (LoginManager, current_user, login_required,
login_user, logout_user)
from sqlalchemy.orm.exc import NoResultFound

from helpers.logger import init_logger
from whois.data.db.database import Database
from whois.data.repository.device_repository import DeviceRepository
from whois.data.repository.user_repository import UserRepository
Expand All @@ -39,10 +27,9 @@ def __init__(
app_settings: AppSettings,
mikrotik_settings: MikrotikSettings,
database: Database,
logger: Logger,
):
self.logger = logger
self.logger.debug("Initializing WhohacksApp...")
self.logger = init_logger("WhohacksApp")
self.logger.info("Initializing WhohacksApp...")

self.app = Flask(__name__)
self.cors = CORS(self.app, resources={r"/api/*": {"origins": "*"}})
Expand All @@ -66,7 +53,7 @@ def __init__(
self.add_template_filters()
self.register_routes()

self.logger.debug("Initialized WhohacksApp")
self.logger.info("Initialized WhohacksApp")

self.common_vars_tpl = {"app": self.app.config.get_namespace("APP_")}

Expand All @@ -78,6 +65,12 @@ def __init__(
client_kwargs={"scope": "openid profile email"},
)

def _log_endpoint(self, endpoint: str, details: str = None) -> None:
if details:
self.logger.debug(f'Called "{endpoint}", ' + details)
else:
self.logger.debug(f'Called "{endpoint}"')

def add_rules(self) -> None:
self.login_manager.user_loader(self.load_user)
self.app.before_request = self.before_request
Expand Down Expand Up @@ -116,14 +109,15 @@ def register_routes(self) -> None:

# Rules for Flask App
def load_user(self, user_id):
self.logger.debug(f'Load user with ID: "{user_id}"')
try:
return self.user_repository.get_by_id(user_id)
except NoResultFound as exc:
self.app.logger.error("{}".format(exc))
return None

def before_request(self):
self.app.logger.info("connecting to db")
self.logger.debug("Preparing for request")
self.database.connect()

if request.headers.getlist("X-Forwarded-For"):
Expand All @@ -141,19 +135,14 @@ def before_request(self):
flash("Outside local network, some functions forbidden!", "outside-warning")

def after_request(self, error):
if self.database.is_connected:
self.app.logger.info("Closing the database connection")
self.database.disconnect()
else:
self.app.logger.info("Database connection was already closed")

self.database.disconnect()
if error:
self.app.logger.error(error)
self.logger.error(error)

# Routes for Flask App
def index(self):
"""Serve list of people in hs, show panel for logged users"""
self.logger.debug("Called '/'")
self._log_endpoint("/")
recent = self.device_repository.get_recent(
timedelta(**self.app_settings.RECENT_TIME)
)
Expand All @@ -172,7 +161,7 @@ def index(self):

@login_required
def devices(self):
self.logger.debug("Called '/devices'")
self._log_endpoint("/devices")
recent = self.device_repository.get_recent(
timedelta(**self.app_settings.RECENT_TIME)
)
Expand Down Expand Up @@ -200,7 +189,8 @@ def now_at_space(self):
used by other services in HS,
requests should be from hsp.sh domain or from HSWAN
"""
self.logger.debug("Called '/api/now'")
self._log_endpoint("/api/now")
self.logger.debug(f'Recieved arguments: "{request.args}"')
period = {**self.app_settings.RECENT_TIME}

for key in ["days", "hours", "minutes"]:
Expand All @@ -210,27 +200,31 @@ def now_at_space(self):
recent = self.device_repository.get_recent(
timedelta(**self.app_settings.RECENT_TIME)
)
users = self.helpers.filter_hidden(self.helpers.owners_from_devices(recent))
owners = self.helpers.owners_from_devices(recent)
users = self.helpers.filter_hidden(
[self.user_repository.get_by_id(owner) for owner in owners]
)

data = {
"users": sorted(map(str, self.helpers.filter_anon_names(users))),
"headcount": len(users),
"unknown_devices": len(self.helpers.unclaimed_devices(recent)),
}

self.logger.info("sending request for /api/now {}".format(data))
self.logger.info("Sending request for /api/now {}".format(data))

return jsonify(data)

def set_device_flags(self, device, new_flags):
self.logger.debug(f"Update device flags: {device=}, {new_flags=}")
if device.owner is not None and device.owner.get_id() != current_user.get_id():
self.logger.error("no permission for {}".format(current_user.username))
self.logger.error("No permission for {}".format(current_user.username))
flash("No permission!".format(device.mac_address), "error")
return
device.is_hidden = "hidden" in new_flags
device.is_esp = "esp" in new_flags
device.is_infrastructure = "infrastructure" in new_flags
print(device.flags)
self.logger.debug(f"New device_flags={device.flags}")
device.save()
self.logger.info(
"{} changed {} flags to {}".format(
Expand All @@ -241,7 +235,7 @@ def set_device_flags(self, device, new_flags):

def device_view(self, mac_address):
"""Get info about device, claim device, release device"""
self.logger.debug("Called '/device'")
self._log_endpoint("/device")
try:
device = self.device_repository.get_by_mac_address(mac_address)
except NoResultFound as exc:
Expand All @@ -262,8 +256,9 @@ def device_view(self, mac_address):
return render_template("device.html", device=device, **self.common_vars_tpl)

def claim_device(self, device):
self.logger.debug(f"Claim device: {device.__repr__()}")
if device.owner is not None:
self.logger.error("no permission for {}".format(current_user.username))
self.logger.error("No permission for {}".format(current_user.username))
flash("No permission!".format(device.mac_address), "error")
return
device.owner = current_user.get_id()
Expand All @@ -274,8 +269,9 @@ def claim_device(self, device):
flash("Claimed {}!".format(device.mac_address), "success")

def unclaim_device(self, device):
self.logger.debug(f"Unclaim device: {device.__repr__()}")
if device.owner is not None and device.owner.get_id() != current_user.get_id():
self.logger.error("no permission for {}".format(current_user.username))
self.logger.error("No permission for {}".format(current_user.username))
flash("No permission!".format(device.mac_address), "error")
return
device.owner = None
Expand Down Expand Up @@ -316,10 +312,10 @@ def register_form(self):

def login(self):
"""Login using query to DB or SSO"""
self.logger.debug("Called '/login'")
self._log_endpoint("/login")

if current_user.is_authenticated:
self.logger.error("Shouldn't login when auth")
self.logger.info(f"User {current_user} is already authenticated")
flash("You are already logged in", "error")
return redirect(url_for("devices"))

Expand All @@ -331,14 +327,21 @@ def login(self):
user = None

if user:
self.logger.debug(f"User found: {user}")
if user.is_sso and self.app_settings.OIDC_ENABLED:
# User created via sso -> redirect to sso login
self.logger.info("Redirect to SSO user: {}".format(user.username))
self.logger.info(
f"User {user} is an SSO user. Redirecting to SSO login"
)
return redirect(url_for("login_oauth"))
elif user.auth(request.form["password"]):
# User password hash match -> login user successfully
login_user(user)
self.logger.info("logged in: {}".format(user.username))
self.logger.info(f"User {user} is a regular user. Attempt login")
login_success = login_user(user)
if login_success:
self.logger.info(f"User {user} was successfully authenticated")
else:
self.logger.info(f"User {user} was NOT authenticated")

else:
pass

Expand All @@ -351,7 +354,9 @@ def login(self):
)
return redirect(url_for("devices"))
else:
self.logger.info("failed log in: {}".format(request.form["username"]))
self.logger.info(
f'Failed to log in: username={request.form["username"]}'
)
flash("Invalid credentials", "error")

return render_template(
Expand All @@ -361,12 +366,12 @@ def login(self):
)

def login_oauth(self):
self.logger.debug("Called '/login/oauth'")
self._log_endpoint("/login/oauth")
redirect_uri = url_for("callback", _external=True)
return self.oauth.sso.authorize_redirect(redirect_uri)

def callback(self):
self.logger.debug("Called '/login/callback'")
self._log_endpoint("/login/callback")
token = self.oauth.sso.authorize_access_token()
user_info = self.oauth.sso.parse_id_token(token)
if user_info:
Expand Down Expand Up @@ -399,16 +404,15 @@ def callback(self):
return redirect(url_for("login"))

def logout(self):
self.logger.debug("Called '/logout'")
username = current_user.username
self._log_endpoint("/logout", f"{username=}")
logout_user()
self.app.logger.info("logged out: {}".format(username))
flash("Logged out.", "info")
return redirect(url_for("index"))

def profile_edit(self):
# TODO: logging
self.logger.debug("Called '/profile'")
self._log_endpoint("/profile")
if request.method == "POST":
if current_user.auth(request.values.get("password", None)) is True:
try:
Expand Down Expand Up @@ -456,8 +460,10 @@ def register(self, username, password, display_name=None):
:param display_name: displayed username
:return: user instance
"""
self.logger.debug(f"Registering user: {username}")
user = User(username=username, display_name=display_name)
user.password = password
self.logger.debug(f"Final user: {user.__repr__()}")
self.user_repository.insert(user)
return user

Expand All @@ -468,6 +474,8 @@ def register_from_sso(self, username, display_name=None):
:param display_name: displayed username
:return: user instance
"""
self.logger.debug(f"Registering user (SSO): {username}")
user = User(username=username, display_name=display_name)
self.logger.debug(f"Final user: {user.__repr__()}")
self.user_repository.insert(user)
return user
Loading
Loading