diff --git a/src/discord_auth.py b/src/discord_auth.py index accdeb6..c2c3995 100644 --- a/src/discord_auth.py +++ b/src/discord_auth.py @@ -51,7 +51,7 @@ def discord_oauth_login(): return redirect(discord_oauth.gen_auth_url()) -@discord_blueprint.route("/discord/oauth_callback", methods=["GET","POST"]) +@discord_blueprint.route("/discord/oauth_callback", methods=["GET", "POST"]) @ratelimit(method="POST", limit=10, interval=5) def discord_oauth_callback(): """ @@ -64,10 +64,52 @@ def discord_oauth_callback(): log.debug("OAuth Response Code: [{}]".format(request.args.get("code"))) global discord_oauth token = discord_oauth.get_access_token(request.args.get("code")) - log.debug("token= {}".format(token)) + log.debug("token=[{}]".format(token)) user_json = discord_oauth.get_user_info(token) + log.debug("User data: [{}]".format(str(user_json))) # process user info/login/etc - return "userdata: " + str(user_json) + if user_json: + # lookup by email + user = Users.query.filter_by(email=user_json["email"]).first() + if user is None: + # Check if user changed email + discord_user = DiscordUser.query.filter_by(id=user_json["id"]).first() + if discord_user: + user = Users.query.filter_by(email=discord_user.email) + if user: + user.email = user_json["email"] + discord_user.email = user_json["email"] + db.session.commit() + else: + log.error("Login failed: user[{user}], discord_user[{d_user}], \ + oauth[{user_json}]".format(user=user, d_user=discord_user, + user_json=user_json)) + return "Error logging in via Discord Oauth2" + else: + # Create new user + user = Users( + name=user_json["username"], + email=user_json["email"], + oauth_id=user_json["id"], + verified=user_json["verified"] + ) + discord_user = DiscordUser( + id=user_json["id"], + username=user_json["username"], + discriminator=user_json["discriminator"], + avatar_hash=user_json["avatar"], + mfa_enabled=user_json["mfa_enabled"], + verified=user_json["verified"], + email=user_json["email"] + ) + db.session.add(user) + db.session.add(discord_user) + db.session.commit() + # Login + login_user(user) + else: + return "Error logging in via Discord OAuth2" + return redirect('/challenges') def check_debug_mode(debug: bool): @@ -119,13 +161,13 @@ def setup_oauth(config): global discord_oauth global plugin_name discord_oauth = Discord_Oauth( - client_id=config["client_id"], - client_secret=config["client_secret"], - scope=config["scope"], - redirect_uri="https://{}/discord/oauth_callback".format(config["domain"]), - discord_api_url=config["base_discord_api_url"], - plugin_name=plugin_name - ) + client_id=config["client_id"], + client_secret=config["client_secret"], + scope=config["scope"], + redirect_uri="https://{}/discord/oauth_callback".format(config["domain"]), + discord_api_url=config["base_discord_api_url"], + plugin_name=plugin_name + ) # Load plugin into CTFd