diff --git a/src/retk/config.py b/src/retk/config.py index 569ccff..d8190a0 100644 --- a/src/retk/config.py +++ b/src/retk/config.py @@ -83,8 +83,8 @@ def __init__(self): ) self.REFRESH_TOKEN_EXPIRE_DELTA = datetime.timedelta(days=self.JWT_REFRESH_EXPIRED_DAYS) - self.ACCESS_TOKEN_EXPIRE_DELTA = datetime.timedelta(minutes=self.JWT_ACCESS_EXPIRED_MINS) - + # self.ACCESS_TOKEN_EXPIRE_DELTA = datetime.timedelta(minutes=self.JWT_ACCESS_EXPIRED_MINS) + self.ACCESS_TOKEN_EXPIRE_DELTA = datetime.timedelta(seconds=5) @lru_cache() def get_settings() -> Settings: diff --git a/src/retk/const/response_codes.py b/src/retk/const/response_codes.py index 34689ce..15b354a 100644 --- a/src/retk/const/response_codes.py +++ b/src/retk/const/response_codes.py @@ -40,7 +40,7 @@ class CodeEnum(IntEnum): ACCOUNT_EXIST_TRY_FORGET_PASSWORD = 30 USER_DISABLED = 31 NOT_PERMITTED = 32 - EXPIRED_ACCESS_TOKEN = 33 + EXPIRED_OR_NO_ACCESS_TOKEN = 33 USER_NOT_EXIST = 34 INVALID_PARAMS = 35 @@ -89,9 +89,13 @@ class CodeMessage: zh="账户已存在,请尝试通过忘记密码找回", en="Account exists, try forget password to recover", ), - CodeEnum.USER_DISABLED: CodeMessage(zh="用户已被禁用", en="User has been disabled"), + CodeEnum.USER_DISABLED: CodeMessage( + zh="因违反平台规则,此账户已被禁用", + en="This account has been disabled due to violation of platform rules" + ), CodeEnum.NOT_PERMITTED: CodeMessage(zh="无权限", en="Not permitted"), - CodeEnum.EXPIRED_ACCESS_TOKEN: CodeMessage(zh="访问令牌已过期", en="Access token has expired"), + CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN: CodeMessage(zh="访问令牌已过期或失效", + en="Access token has expired or invalid"), CodeEnum.USER_NOT_EXIST: CodeMessage(zh="用户不存在", en="User does not exist"), CodeEnum.INVALID_PARAMS: CodeMessage(zh="无效参数", en="Invalid parameter"), } @@ -130,7 +134,7 @@ class CodeMessage: CodeEnum.ACCOUNT_EXIST_TRY_FORGET_PASSWORD: 422, CodeEnum.USER_DISABLED: 403, CodeEnum.NOT_PERMITTED: 403, - CodeEnum.EXPIRED_ACCESS_TOKEN: 200, + CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN: 200, CodeEnum.USER_NOT_EXIST: 404, CodeEnum.INVALID_PARAMS: 400, } diff --git a/src/retk/controllers/account.py b/src/retk/controllers/account.py index 9583a4e..a17eb93 100644 --- a/src/retk/controllers/account.py +++ b/src/retk/controllers/account.py @@ -96,13 +96,19 @@ async def login( ) -> JSONResponse: # TODO: 后台应记录成功登录用户名和 IP、时间. # 当尝试登录 IP 不在历史常登录 IP 地理位置时,应进行多因素二次验证用户身份,防止用户因密码泄漏被窃取账户 - u, code = await user.get_by_email(req.email, disabled=False, exclude_manager=False) + u, code = await user.get_by_email(req.email, disabled=None, exclude_manager=False) if code != const.CodeEnum.OK: raise json_exception( request_id=req_id, code=code, language=req.language, ) + if u["disabled"]: + raise json_exception( + request_id=req_id, + code=const.CodeEnum.USER_DISABLED, + language=req.language, + ) if not await account.manager.is_right_password( email=u["email"], @@ -146,7 +152,7 @@ async def auto_login( payload = jwt_decode(token=token) except Exception: # pylint: disable=broad-except return r - u, code = await user.get(uid=payload["uid"]) + u, code = await user.get(uid=payload["uid"], disabled=False) if code != const.CodeEnum.OK: return r return schemas.user.get_user_info_response_from_u_dict(u, request_id=req_id) diff --git a/src/retk/controllers/manager.py b/src/retk/controllers/manager.py index c08b712..3fba28f 100644 --- a/src/retk/controllers/manager.py +++ b/src/retk/controllers/manager.py @@ -20,9 +20,9 @@ async def get_user_info( req: schemas.manager.GetUserRequest, ) -> schemas.user.UserInfoResponse: if __check_use_uid(au=au, req=req): - u, code = await user.get(uid=req.uid, exclude_manager=True) + u, code = await user.get(uid=req.uid, disabled=None, exclude_manager=True) else: - u, code = await user.get_by_email(email=req.email, exclude_manager=True) + u, code = await user.get_by_email(email=req.email, disabled=None, exclude_manager=True) maybe_raise_json_exception(au=au, code=code) return schemas.user.get_user_info_response_from_u_dict(u=u, request_id=au.request_id) diff --git a/src/retk/core/files/importing/async_tasks/obsidian/task.py b/src/retk/core/files/importing/async_tasks/obsidian/task.py index 4f18c46..0bdae82 100644 --- a/src/retk/core/files/importing/async_tasks/obsidian/task.py +++ b/src/retk/core/files/importing/async_tasks/obsidian/task.py @@ -75,7 +75,7 @@ async def upload_obsidian_task( # noqa: C901 logger.debug(f"obsidian upload, uid={uid}, filter time: {t2 - t1:.2f}") # add new md files with only title - u, code = await core.user.get(uid=uid) + u, code = await core.user.get(uid=uid, disabled=False) au = AuthedUser( u=convert_user_dict_to_authed_user(u), request_id=request_id, diff --git a/src/retk/core/files/importing/async_tasks/text/task.py b/src/retk/core/files/importing/async_tasks/text/task.py index e5cfced..3aac86d 100644 --- a/src/retk/core/files/importing/async_tasks/text/task.py +++ b/src/retk/core/files/importing/async_tasks/text/task.py @@ -37,7 +37,7 @@ async def update_text_task( # noqa: C901 ) return - u, code = await core.user.get(uid=uid) + u, code = await core.user.get(uid=uid, disabled=False) au = AuthedUser( u=convert_user_dict_to_authed_user(u), request_id=request_id, diff --git a/src/retk/core/user.py b/src/retk/core/user.py index b2578e6..2c749f5 100644 --- a/src/retk/core/user.py +++ b/src/retk/core/user.py @@ -131,7 +131,7 @@ async def patch( # noqa: C901 if res.modified_count != 1: return None, const.CodeEnum.OPERATION_FAILED - return await get(uid=au.u.id) + return await get(uid=au.u.id, disabled=None) def __get_user_condition(condition: dict, exclude_manager: bool) -> dict: @@ -145,7 +145,7 @@ def __get_user_condition(condition: dict, exclude_manager: bool) -> dict: async def get_by_email( email: str, - disabled: bool = False, + disabled: Optional[bool] = False, exclude_manager: bool = False, ) -> Tuple[Optional[tps.UserMeta], const.CodeEnum]: if config.get_settings().ONE_USER: @@ -158,10 +158,12 @@ async def get_by_email( async def get_account( account: str, source: int, - disabled: bool = False, + disabled: Optional[bool] = False, exclude_manager: bool = False, ) -> Tuple[Optional[tps.UserMeta], const.CodeEnum]: - c = {"source": source, "account": account, "disabled": disabled} + c = {"source": source, "account": account} + if disabled is not None: + c["disabled"] = disabled c = __get_user_condition(condition=c, exclude_manager=exclude_manager) u = await client.coll.users.find_one(c) if u is None: @@ -169,12 +171,15 @@ async def get_account( return u, const.CodeEnum.OK -async def get(uid: str, exclude_manager: bool = False) -> Tuple[Optional[tps.UserMeta], const.CodeEnum]: - c = {"id": uid, "disabled": False} +async def get(uid: str, disabled: Optional[bool] = False, exclude_manager: bool = False) -> Tuple[ + Optional[tps.UserMeta], const.CodeEnum]: + c = {"id": uid} + if disabled is not None: + c["disabled"] = disabled c = __get_user_condition(condition=c, exclude_manager=exclude_manager) u = await client.coll.users.find_one(c) if u is None: - return None, const.CodeEnum.USER_DISABLED + return None, const.CodeEnum.USER_NOT_EXIST if u["usedSpace"] < 0: # reset usedSpace to 0 await client.coll.users.update_one( diff --git a/src/retk/routes/utils.py b/src/retk/routes/utils.py index f783699..bc556bf 100644 --- a/src/retk/routes/utils.py +++ b/src/retk/routes/utils.py @@ -61,6 +61,7 @@ def verify_referer(referer: Optional[str] = Header(None)): async def __process_auth_headers( # noqa: C901 + is_refresh_token: bool, refresh_token_id: str, token: str = Header(alias="Authorization", default=""), request_id: str = Header( @@ -70,7 +71,7 @@ async def __process_auth_headers( # noqa: C901 if token is None or token == "": raise json_exception( request_id=request_id, - code=const.CodeEnum.INVALID_AUTH, + code=const.CodeEnum.INVALID_AUTH if is_refresh_token else const.CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN, log_msg="empty token", ) au = AuthedUser( @@ -86,19 +87,19 @@ async def __process_auth_headers( # noqa: C901 if is_access is None: code = const.CodeEnum.INVALID_AUTH err = "invalid token" - elif (is_access and refresh_token_id != "") or (not is_access and refresh_token_id == ""): + elif is_access == is_refresh_token: code = const.CodeEnum.INVALID_AUTH err = "invalid token" - elif refresh_token_id != "" and payload["uid"] != refresh_token_id: + elif is_refresh_token and payload["uid"] != refresh_token_id: code = const.CodeEnum.INVALID_AUTH err = "invalid token" else: - u, code = await core.user.get(uid=payload["uid"]) + u, code = await core.user.get(uid=payload["uid"], disabled=False) if code != const.CodeEnum.OK: err = f"get user failed, code={code}" except jwt.exceptions.ExpiredSignatureError: - code = const.CodeEnum.EXPIRED_AUTH if refresh_token_id != "" else const.CodeEnum.EXPIRED_ACCESS_TOKEN + code = const.CodeEnum.EXPIRED_AUTH if is_refresh_token else const.CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN err = "auth expired" except jwt.exceptions.DecodeError: code = const.CodeEnum.INVALID_AUTH @@ -126,7 +127,12 @@ async def process_normal_headers( default="", alias="RequestId", max_length=const.settings.MD_MAX_LENGTH ) ) -> AuthedUser: - return await __process_auth_headers(refresh_token_id="", token=token, request_id=request_id) + return await __process_auth_headers( + is_refresh_token=False, + refresh_token_id="", + token=token, + request_id=request_id + ) async def process_refresh_token_headers( @@ -136,7 +142,12 @@ async def process_refresh_token_headers( default="", alias="RequestId", max_length=const.settings.MD_MAX_LENGTH ) ) -> AuthedUser: - return await __process_auth_headers(refresh_token_id=id_, token=token, request_id=request_id) + return await __process_auth_headers( + is_refresh_token=True, + refresh_token_id=id_, + token=token, + request_id=request_id + ) async def process_no_auth_headers( diff --git a/src/retk/safety.py b/src/retk/safety.py index 3caf47b..b2ea135 100644 --- a/src/retk/safety.py +++ b/src/retk/safety.py @@ -20,10 +20,7 @@ cookie_domain = None cookie_secure = False -if vue_app_mode == "dev": - cookie_samesite = "strict" -else: - cookie_samesite = "strict" +cookie_samesite = "strict" class CSPMiddleware(BaseHTTPMiddleware): diff --git a/tests/test_api.py b/tests/test_api.py index 45a310a..29499b0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -46,6 +46,15 @@ def test_home(self): self.assertEqual(Response, type(resp)) def test_register(self): + resp = self.client.put( + "/api/account/auto-login", + headers={"RequestId": "xxx"}, + ) + self.assertEqual(200, resp.status_code) + rj = resp.json() + self.assertEqual("xxx", rj["requestId"]) + self.assertIsNone(rj["user"]) + token, _ = account.app_captcha.generate() data = jwt_decode(token) resp = self.client.post( @@ -86,7 +95,7 @@ def test_email_verification(self, mock_send): ) self.assertEqual(200, resp.status_code) rj = resp.json() - self.assertNotEqual("", rj["accessToken"]) + self.assertNotEqual("", rj["token"]) self.assertEqual("xxx", rj["requestId"]) @@ -104,24 +113,28 @@ async def asyncSetUp(self) -> None: "email": const.DEFAULT_USER["email"], "password": "", }) - self.assertEqual(200, resp.status_code) - rj = resp.json() - self.assertEqual(840, len(rj["accessToken"])) - self.assertTrue(rj["accessToken"].startswith("Bearer ")) - self.refresh_token = rj["refreshToken"] + self.assertEqual(200, resp.status_code, msg=resp.json()) + self.access_token = resp.cookies.get(const.settings.COOKIE_ACCESS_TOKEN) + self.refresh_token_id = resp.cookies.get(const.settings.COOKIE_REFRESH_TOKEN_ID) + self.refresh_token = resp.cookies.get(const.settings.COOKIE_REFRESH_TOKEN) + + self.assertEqual(842, len(self.access_token)) + self.assertTrue(self.access_token.startswith("\"Bearer "), msg=self.access_token) + self.assertTrue(self.refresh_token.startswith("\"Bearer "), msg=self.access_token) self.default_headers = { - "Authorization": rj["accessToken"], "RequestId": "xxx", } + resp = self.client.get( "/api/users", - headers=self.default_headers + headers=self.default_headers, ) - self.assertEqual(200, resp.status_code, msg=resp.json()) rj = resp.json() - self.uid = rj["uid"] + self.assertEqual(200, resp.status_code, msg=rj) + self.assertNotIn("detail", rj) async def asyncTearDown(self) -> None: + self.client.cookies.clear() await client.drop() shutil.rmtree(Path(__file__).parent / "tmp" / ".data" / "files", ignore_errors=True) shutil.rmtree(Path(__file__).parent / "tmp" / ".data" / "md", ignore_errors=True) @@ -145,34 +158,51 @@ def check_ok_response(self, resp: Response, status_code: int = 200, rid="xxx") - self.assertEqual(rid, rj["requestId"]) return rj - async def test_access_refresh_token(self): + def set_access_token(self, token: str): + try: + self.client.cookies.delete(const.settings.COOKIE_ACCESS_TOKEN) + except KeyError: + pass + self.client.cookies[const.settings.COOKIE_ACCESS_TOKEN] = token + + async def test_auto_login(self): resp = self.client.put( + "/api/account/auto-login", + headers={"RequestId": "xxx"}, + ) + self.assertEqual(200, resp.status_code) + rj = resp.json() + self.assertEqual("xxx", rj["requestId"]) + self.assertIsNotNone(rj["user"]) + + self.client.cookies.delete(const.settings.COOKIE_ACCESS_TOKEN) + resp = self.client.put( + "/api/account/auto-login", + headers={"RequestId": "xxx"}, + ) + self.assertEqual(200, resp.status_code) + self.assertIsNone(resp.json()["user"]) + + async def test_access_refresh_token(self): + self.client.put( "/api/account/login", json={ "email": const.DEFAULT_USER["email"], "password": "", }) - rj = resp.json() - access_token = rj["accessToken"] - refresh_token = rj["refreshToken"] resp = self.client.get( "/api/users", - headers={ - "Authorization": access_token, - "RequestId": "xxx" - } + headers=self.default_headers, ) self.assertEqual(200, resp.status_code) + self.set_access_token(resp.cookies.get(const.settings.COOKIE_REFRESH_TOKEN)) resp = self.client.get( "/api/users", - headers={ - "Authorization": refresh_token, - "RequestId": "xxx" - } + headers=self.default_headers, ) - self.error_check(resp, 401, const.CodeEnum.INVALID_AUTH) + self.error_check(resp, 200, const.CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN) async def test_access_token_expire(self): aed = config.get_settings().ACCESS_TOKEN_EXPIRE_DELTA @@ -187,81 +217,64 @@ async def test_access_token_expire(self): "password": "", "language": "zh", }) - rj = resp.json() - access_token = rj["accessToken"] - refresh_token = rj["refreshToken"] + self.assertEqual(200, resp.status_code) time.sleep(0.001) resp = self.client.get( "/api/users", headers={ - "Authorization": access_token, "RequestId": "xxx" - } + }, ) self.assertEqual(200, resp.status_code) rj = resp.json() detail = rj["detail"] - self.assertEqual(const.CodeEnum.EXPIRED_ACCESS_TOKEN.value, detail["code"], msg=detail) + self.assertEqual(const.CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN.value, detail["code"], msg=detail) self.assertEqual("xxx", detail["requestId"], msg=detail) resp = self.client.get( "/api/account/access-token", - headers={ - "Authorization": refresh_token, - "RequestId": "xxx", - "ID": self.uid, - } + headers=self.default_headers, ) self.error_check(resp, 401, const.CodeEnum.EXPIRED_AUTH) config.get_settings().REFRESH_TOKEN_EXPIRE_DELTA = red - resp = self.client.put( + self.client.put( "/api/account/login", json={ "email": const.DEFAULT_USER["email"], "password": "", }) - rj = resp.json() - access_token = rj["accessToken"] - refresh_token = rj["refreshToken"] + time.sleep(0.001) resp = self.client.get( "/api/users", - headers={ - "Authorization": access_token, - "RequestId": "xxx" - } + headers=self.default_headers, ) self.assertEqual(200, resp.status_code) rj = resp.json() - self.assertEqual(const.CodeEnum.EXPIRED_ACCESS_TOKEN.value, rj["detail"]["code"], msg=rj) + self.assertEqual(const.CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN.value, rj["detail"]["code"], msg=rj) config.get_settings().ACCESS_TOKEN_EXPIRE_DELTA = aed + old_access_token = resp.cookies.get(const.settings.COOKIE_ACCESS_TOKEN) resp = self.client.get( "/api/account/access-token", - headers={ - "Authorization": refresh_token, - "RequestId": "xxx", - "ID": self.uid, - } + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) - self.assertNotEqual(access_token, rj["accessToken"]) - self.assertTrue(rj["accessToken"].startswith("Bearer ")) - self.assertEqual("", rj["refreshToken"]) + self.check_ok_response(resp, 200) + at = resp.cookies.get(const.settings.COOKIE_ACCESS_TOKEN) + self.assertNotEqual(old_access_token, at) + self.assertTrue(at.startswith("\"Bearer ")) resp = self.client.get( "/api/users", - headers={ - "Authorization": rj["accessToken"], - "RequestId": "xxx" - } + headers=self.default_headers, ) self.check_ok_response(resp, 200) - self.assertEqual(self.uid, resp.json()["uid"]) + rj = resp.json() + self.assertEqual("rethink", rj["user"]["nickname"]) async def test_add_user_update_password(self): config.get_settings().ONE_USER = False @@ -271,14 +284,12 @@ async def test_add_user_update_password(self): code = data["code"].replace(config.get_settings().CAPTCHA_SALT, "") email = "a@b.c" + del self.client.cookies[const.settings.COOKIE_ACCESS_TOKEN] resp = self.client.get( "/api/users", - headers={ - "Authorization": "xxxx", - "RequestId": "xxx" - }) - self.assertEqual(401, resp.status_code) - self.error_check(resp, 401, const.CodeEnum.INVALID_AUTH) + headers=self.default_headers, + ) + self.error_check(resp, 200, const.CodeEnum.EXPIRED_OR_NO_ACCESS_TOKEN) lang = "zh" resp = self.client.post( @@ -290,18 +301,16 @@ async def test_add_user_update_password(self): "verification": code, "language": lang, }, - headers={"RequestId": "xxx"} + headers=self.default_headers ) - rj = self.check_ok_response(resp, 201) - u_token = rj["accessToken"] + _ = self.check_ok_response(resp, 201) + u_token = resp.cookies.get(const.settings.COOKIE_ACCESS_TOKEN) self.assertNotEqual("", u_token) resp = self.client.get( "/api/users", - headers={ - "Authorization": rj["accessToken"], - "RequestId": "xxx" - }) + headers=self.default_headers, + ) rj = self.check_ok_response(resp, 200) self.assertEqual("a**@b.c", rj["user"]["email"]) self.assertEqual("zh", rj["user"]["settings"]["language"]) @@ -312,10 +321,7 @@ async def test_add_user_update_password(self): "oldPassword": "xxx111", "newPassword": "abc222", }, - headers={ - "Authorization": u_token, - "RequestId": "xxx" - } + headers=self.default_headers, ) self.error_check(resp, 400, const.CodeEnum.OLD_PASSWORD_ERROR, language=lang) @@ -325,10 +331,7 @@ async def test_add_user_update_password(self): "oldPassword": "abc111", "newPassword": "abc222", }, - headers={ - "Authorization": u_token, - "RequestId": "xxx", - } + headers=self.default_headers, ) _ = self.check_ok_response(resp, 200) u = await client.coll.users.find_one({"email": email}) @@ -367,13 +370,13 @@ def test_update_user(self): "nodeDisplayMethod": const.NodeDisplayMethodEnum.LIST.value, } }, - headers=self.default_headers + headers=self.default_headers, ) _ = self.check_ok_response(resp, 200) resp = self.client.get( "/api/users", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual("new nickname", rj["user"]["nickname"]) @@ -391,7 +394,7 @@ def test_update_user(self): "editorCodeTheme": "dracula", } }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual("zh", rj["user"]["settings"]["language"]) @@ -408,13 +411,13 @@ def test_recent_search(self): "p": 0, "limit": 5 }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( "/api/recent/searched", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(["aaa"], rj["queries"]) @@ -429,7 +432,7 @@ def test_node(self): "p": 0, "limit": 5 }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertGreater(len(rj["data"]["nodes"]), 0) @@ -443,9 +446,9 @@ def test_node(self): "p": 0, "limit": 5 }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.post( "/api/nodes", @@ -453,14 +456,14 @@ def test_node(self): "md": "node1\ntext", "type": const.NodeTypeEnum.MARKDOWN.value, }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 201) node = rj["node"] resp = self.client.get( f'/api/nodes/{node["id"]}', - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) n = rj["node"] @@ -473,13 +476,13 @@ def test_node(self): json={ "md": "node2\ntext", }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( f'/api/nodes/{node["id"]}', - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) n = rj["node"] @@ -490,23 +493,23 @@ def test_node(self): resp = self.client.put( f'/api/trash/{node["id"]}', - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( "/api/trash", params={"p": 0, "limit": 10}, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(1, len(rj["data"]["nodes"])) resp = self.client.put( f'/api/trash/{node["id"]}/restore', - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( f'/api/nodes/{node["id"]}/at', @@ -537,23 +540,22 @@ def test_node(self): f'/api/trash/{node["id"]}', headers=self.default_headers ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.delete( f"/api/trash/{node['id']}", - headers=self.default_headers + headers=self.default_headers, ) self.assertEqual(200, resp.status_code) - rj = resp.json() resp = self.client.delete( "/api/trash/ssa", - headers=self.default_headers + headers=self.default_headers, ) self.error_check(resp, 404, const.CodeEnum.NODE_NOT_EXIST) resp = self.client.get( f'/api/nodes/{node["id"]}', - headers=self.default_headers + headers=self.default_headers, ) self.error_check(resp, 404, const.CodeEnum.NODE_NOT_EXIST) @@ -586,7 +588,7 @@ def test_batch(self): "md": f"node{i}\ntext", "type": const.NodeTypeEnum.MARKDOWN.value, }, - headers=self.default_headers + headers=self.default_headers, ) self.check_ok_response(resp, 201) @@ -599,7 +601,7 @@ def test_batch(self): "p": 0, "limit": 5 }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(5, len(rj["data"]["nodes"])) @@ -610,14 +612,14 @@ def test_batch(self): json={ "nids": [n["id"] for n in rj["data"]["nodes"][:3]], }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( "/api/trash", params={"p": 0, "limit": 10}, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(3, len(rj["data"]["nodes"])) @@ -628,14 +630,14 @@ def test_batch(self): json={ "nids": [n["id"] for n in rj["data"]["nodes"][:2]], }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( "/api/trash", params={"p": 0, "limit": 10}, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(1, len(rj["data"]["nodes"])) @@ -646,13 +648,13 @@ def test_batch(self): json={ "nids": [n["id"] for n in rj["data"]["nodes"]], }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( "/api/trash", params={"p": 0, "limit": 10}, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(0, len(rj["data"]["nodes"])) @@ -690,7 +692,7 @@ def test_update_obsidian(self): time.sleep(0.1) resp = self.client.get( "/api/files/upload-process", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual("obsidian", rj["type"], msg=rj) @@ -708,7 +710,7 @@ def test_update_obsidian(self): "p": 0, "limit": 5, }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(5, len(rj["data"]["nodes"])) @@ -738,7 +740,7 @@ def test_upload_text(self): time.sleep(0.1) resp = self.client.get( "/api/files/upload-process", - headers=self.default_headers + headers=self.default_headers, ) self.assertEqual(200, resp.status_code) rj = resp.json() @@ -757,7 +759,7 @@ def test_upload_text(self): "p": 0, "limit": 5, }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(4, len(rj["data"]["nodes"])) @@ -774,7 +776,7 @@ def test_upload_image(self): resp = self.client.post( "/api/files/vditor", files={"file[]": f1}, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual({ @@ -792,7 +794,7 @@ def test_upload_file(self): resp = self.client.post( "/api/files/vditor", files={"file[]": f1}, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual({ @@ -810,7 +812,7 @@ def test_upload_invalid_file(self): resp = self.client.post( "/api/files/vditor", files={"file[]": f1}, - headers=self.default_headers + headers=self.default_headers, ) self.assertEqual(200, resp.status_code, msg=resp.json()) rj = resp.json() @@ -825,7 +827,7 @@ def test_fetch_image(self): resp = self.client.post( "/api/files/vditor/images", json={"url": img}, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual(img, rj["data"]["url"]) @@ -834,7 +836,7 @@ def test_fetch_image(self): resp = self.client.post( "/api/files/vditor/images", json={"url": img}, - headers=self.default_headers + headers=self.default_headers, ) self.error_check(resp, 400, const.CodeEnum.FILE_OPEN_ERROR) @@ -849,7 +851,7 @@ def test_put_quick_node(self, mocker): "md": "node1\ntext", "type": const.NodeTypeEnum.MARKDOWN.value, }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 201) node = rj["node"] @@ -863,7 +865,7 @@ def test_put_quick_node(self, mocker): "md": "https://baidu.com", "type": const.NodeTypeEnum.MARKDOWN.value, }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 201) node = rj["node"] @@ -894,7 +896,7 @@ def test_md_history( "md": "title\ntext", "type": const.NodeTypeEnum.MARKDOWN.value, }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 201) n1 = rj["node"] @@ -906,9 +908,9 @@ def test_md_history( json={ "md": "title1\ntext", }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) time.sleep(0.001) @@ -917,13 +919,13 @@ def test_md_history( json={ "md": "title2\ntext", }, - headers=self.default_headers + headers=self.default_headers, ) - rj = self.check_ok_response(resp, 200) + self.check_ok_response(resp, 200) resp = self.client.get( f"/api/nodes/{n1['id']}/history", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) hist = rj["versions"] @@ -931,7 +933,7 @@ def test_md_history( resp = self.client.get( f"/api/nodes/{n1['id']}/history/{hist[1]}/md", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertEqual("title1\ntext", rj["md"]) @@ -954,14 +956,14 @@ def check_one_plugin(ps): register_official_plugins() resp = self.client.get( "/api/plugins", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) check_one_plugin(rj["plugins"]) resp = self.client.get( "/api/plugins/editor-side", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) check_one_plugin(rj["plugins"]) @@ -969,7 +971,7 @@ def check_one_plugin(ps): resp = self.client.get( f"/api/plugins/{pid}", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertNotEqual("", rj["html"]) @@ -983,14 +985,14 @@ def check_one_plugin(ps): "p": 0, "limit": 5, }, - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) nid = rj["data"]["nodes"][0]["id"] resp = self.client.get( f"/api/plugins/{pid}/editor-side/{nid}", - headers=self.default_headers + headers=self.default_headers, ) rj = self.check_ok_response(resp, 200) self.assertNotEqual("", rj["html"]) @@ -1003,7 +1005,7 @@ def check_one_plugin(ps): "method": "test", "data": "test", }, - headers=self.default_headers + headers=self.default_headers, ) rj = resp.json() self.assertEqual(False, rj["success"]) @@ -1020,7 +1022,7 @@ def check_one_plugin(ps): "method": "method", "data": "test", }, - headers=self.default_headers + headers=self.default_headers, ) rj = resp.json() self.assertEqual("method", rj["method"]) @@ -1030,12 +1032,13 @@ def check_one_plugin(ps): config.get_settings().PLUGINS = False async def test_manager(self): + manager_token = self.client.cookies.get(const.settings.COOKIE_ACCESS_TOKEN) resp = self.client.put( "/api/managers/users/disable", json={ "uid": "xxx", }, - headers=self.default_headers + headers=self.default_headers, ) self.error_check(resp, 403, const.CodeEnum.NOT_PERMITTED) @@ -1052,7 +1055,7 @@ async def test_manager(self): json={ "uid": "xxx", }, - headers=self.default_headers + headers=self.default_headers, ) self.error_check(resp, 404, const.CodeEnum.USER_NOT_EXIST) @@ -1074,16 +1077,17 @@ async def test_manager(self): }, headers={"RequestId": "xxx"} ) - rj = self.check_ok_response(resp, 201) - u_token = rj["accessToken"] + self.check_ok_response(resp, 201) + u_token = resp.cookies.get(const.settings.COOKIE_ACCESS_TOKEN) uid = (await client.coll.users.find_one({"email": email}))["id"] + self.set_access_token(manager_token) resp = self.client.put( "/api/managers/users", json={ "uid": uid, }, - headers=self.default_headers + headers=self.default_headers, ) self.check_ok_response(resp, 200) @@ -1092,14 +1096,14 @@ async def test_manager(self): json={ "email": email, }, - headers=self.default_headers + headers=self.default_headers, ) self.check_ok_response(resp, 200) resp = self.client.put( "/api/managers/users", json={}, - headers=self.default_headers + headers=self.default_headers, ) self.error_check(resp, 400, const.CodeEnum.INVALID_PARAMS) @@ -1108,54 +1112,60 @@ async def test_manager(self): json={ "uid": uid, }, - headers=self.default_headers + headers=self.default_headers, ) self.check_ok_response(resp, 200) + resp = self.client.put( + "/api/account/login", + json={ + "email": email, + "password": "abc111", + }, + headers=self.default_headers, + ) + self.error_check(resp, 403, const.CodeEnum.USER_DISABLED) + + self.set_access_token(u_token) resp = self.client.get( "/api/users", - headers={ - "Authorization": u_token, - "RequestId": "xxx" - } + headers=self.default_headers, ) - self.error_check(resp, 403, const.CodeEnum.USER_DISABLED) + self.error_check(resp, 404, const.CodeEnum.USER_NOT_EXIST) + self.set_access_token(manager_token) resp = self.client.put( "/api/managers/users/enable", json={ "email": email, }, - headers=self.default_headers + headers=self.default_headers, ) self.check_ok_response(resp, 200) + self.set_access_token(u_token) resp = self.client.get( "/api/users", - headers={ - "Authorization": u_token, - "RequestId": "xxx" - } + headers=self.default_headers, ) self.check_ok_response(resp, 200) + self.set_access_token(manager_token) resp = self.client.put( "/api/managers/users/delete", json={ "uid": uid, }, - headers=self.default_headers + headers=self.default_headers, ) self.check_ok_response(resp, 200) + self.set_access_token(u_token) resp = self.client.get( "/api/users", - headers={ - "Authorization": u_token, - "RequestId": "xxx" - } + headers=self.default_headers, ) - self.error_check(resp, 403, const.CodeEnum.USER_DISABLED) + self.error_check(resp, 404, const.CodeEnum.USER_NOT_EXIST) doc = await client.coll.users.update_one( {"id": admin_uid}, @@ -1166,7 +1176,7 @@ async def test_manager(self): async def test_statistic_user_behavior(self): # login - resp = self.client.put( + self.client.put( "/api/account/login", json={ "email": const.DEFAULT_USER["email"], @@ -1179,8 +1189,6 @@ async def test_statistic_user_behavior(self): ).to_list(None) self.assertEqual(const.UserBehaviorTypeEnum.LOGIN.value, docs[-1]["type"]) - token = resp.json()["accessToken"] - # create node resp = self.client.post( "/api/nodes", @@ -1188,10 +1196,7 @@ async def test_statistic_user_behavior(self): "md": "node1\ntext", "type": const.NodeTypeEnum.MARKDOWN.value, }, - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) docs = await client.coll.user_behavior.find( {"uid": uid} @@ -1206,10 +1211,7 @@ async def test_statistic_user_behavior(self): "md": "node1\ntext", "type": const.NodeTypeEnum.MARKDOWN.value, }, - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) docs = await client.coll.user_behavior.find( {"uid": uid} @@ -1220,10 +1222,7 @@ async def test_statistic_user_behavior(self): # trash node self.client.put( f"/api/trash/{resp.json()['node']['id']}", - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) docs = await client.coll.user_behavior.find( {"uid": uid} @@ -1233,10 +1232,7 @@ async def test_statistic_user_behavior(self): # restore node self.client.put( f"/api/trash/{resp.json()['node']['id']}/restore", - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) docs = await client.coll.user_behavior.find( {"uid": uid} @@ -1246,17 +1242,11 @@ async def test_statistic_user_behavior(self): # delete node self.client.put( f"/api/trash/{resp.json()['node']['id']}", - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) self.client.delete( f"/api/trash/{resp.json()['node']['id']}", - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) docs = await client.coll.user_behavior.find( {"uid": uid} @@ -1273,10 +1263,7 @@ async def test_statistic_user_behavior(self): "p": 0, "limit": 5 }, - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) docs = await client.coll.user_behavior.find( {"uid": uid} @@ -1292,10 +1279,7 @@ async def test_statistic_user_behavior(self): "p": 0, "limit": 10, }, - headers={ - "Authorization": token, - "RequestId": "xxx" - } + headers=self.default_headers, ) docs = await client.coll.user_behavior.find( {"uid": uid} @@ -1306,14 +1290,11 @@ async def test_statistic_user_behavior(self): # logout resp = self.client.post( "/api/statistic/user-behavior", - headers={ - "Authorization": token, - "RequestId": "xxx" - }, json={ "type": const.UserBehaviorTypeEnum.LOGOUT.value, "remark": "logout", - } + }, + headers=self.default_headers, ) self.assertEqual(201, resp.status_code) docs = await client.coll.user_behavior.find( diff --git a/tests/test_core_local.py b/tests/test_core_local.py index ab22f41..b44eb73 100644 --- a/tests/test_core_local.py +++ b/tests/test_core_local.py @@ -96,9 +96,13 @@ async def test_user(self): self.assertEqual(const.CodeEnum.OK, code) u, code = await core.user.get(uid=_uid) - self.assertEqual(const.CodeEnum.USER_DISABLED, code) + self.assertEqual(const.CodeEnum.USER_NOT_EXIST, code) self.assertIsNone(u) + u, code = await core.user.get(uid=_uid, disabled=None) + self.assertEqual(const.CodeEnum.OK, code) + self.assertTrue(u["disabled"]) + code = await core.account.manager.enable_by_uid(uid=_uid) self.assertEqual(const.CodeEnum.OK, code) diff --git a/tests/test_core_remote.py b/tests/test_core_remote.py index 108dbba..c369012 100644 --- a/tests/test_core_remote.py +++ b/tests/test_core_remote.py @@ -163,8 +163,12 @@ async def test_user(self): code = await core.account.manager.disable_by_uid(uid=_id) self.assertEqual(const.CodeEnum.OK, code) - u, code = await core.user.get(uid=_id) - self.assertEqual(const.CodeEnum.USER_DISABLED, code) + u, code = await core.user.get(uid=_id, disabled=False) + self.assertEqual(const.CodeEnum.USER_NOT_EXIST, code) + + u, code = await core.user.get(uid=_id, disabled=None) + self.assertEqual(const.CodeEnum.OK, code) + self.assertTrue(u["disabled"]) code = await core.account.manager.enable_by_uid(uid=_id) self.assertEqual(const.CodeEnum.OK, code)