diff --git a/common/Httpx.py b/common/Httpx.py index a02d0b0..4496dee 100644 --- a/common/Httpx.py +++ b/common/Httpx.py @@ -1,9 +1,9 @@ # ---------------------------------------- -# - mode: python - -# - author: helloplhm-qwq - -# - name: Httpx.py - -# - project: lx-music-api-server - -# - license: MIT - +# - mode: python - +# - author: helloplhm-qwq - +# - name: Httpx.py - +# - project: lx-music-api-server - +# - license: MIT - # ---------------------------------------- # This file is part of the "lx-music-api-server" project. @@ -21,15 +21,16 @@ from . import utils from . import variable + def is_valid_utf8(text) -> bool: try: if isinstance(text, bytes): - text = text.decode('utf-8') + text = text.decode("utf-8") # 判断是否为有效的utf-8字符串 if "\ufffe" in text: return False try: - text.encode('utf-8').decode('utf-8') + text.encode("utf-8").decode("utf-8") return True except UnicodeDecodeError: return False @@ -37,42 +38,48 @@ def is_valid_utf8(text) -> bool: logger.error(traceback.format_exc()) return False + def is_plain_text(text) -> bool: # 判断是否为纯文本 - pattern = re.compile(r'[^\x00-\x7F]') + pattern = re.compile(r"[^\x00-\x7F]") return not bool(pattern.search(text)) + def convert_dict_to_form_string(dic: dict) -> str: # 将字典转换为表单字符串 - return '&'.join([f'{k}={v}' for k, v in dic.items()]) + return "&".join([f"{k}={v}" for k, v in dic.items()]) + def log_plaintext(text: str) -> str: - if (text.startswith('{') and text.endswith('}')): + if text.startswith("{") and text.endswith("}"): try: text = json.loads(text) except: pass - elif (text.startswith('')): # xml data + elif text.startswith(""): # xml data try: - text = f'xml: {utils.load_xml(text)}' + text = f"xml: {utils.load_xml(text)}" except: pass return text + # 内置的UA列表 -ua_list = [ 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36 Edg/112.0.1722.39', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1788.0', - 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1788.0 uacq', - 'Mozilla/5.0 (Windows NT 10.0; WOW64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5666.197 Safari/537.36', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 uacq', - 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36' - ] +ua_list = [ + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36 Edg/112.0.1722.39", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1788.0", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1788.0 uacq", + "Mozilla/5.0 (Windows NT 10.0; WOW64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.5666.197 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 uacq", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36", +] # 日志记录器 -logger = log.log('http_utils') +logger = log.log("http_utils") + -def request(url: str, options = {}) -> requests.Response: - ''' +def request(url: str, options={}) -> requests.Response: + """ Http请求主函数, 用于发送网络请求 - url: 需要请求的URL地址(必填) - options: 请求的配置参数(可选, 留空时为GET请求, 总体与nodejs的请求的options填写差不多) @@ -84,15 +91,15 @@ def request(url: str, options = {}) -> requests.Response: - no-cache: 不缓存 - : 缓存可用秒数 - cache-ignore: 缓存忽略关键字 - + @ return: requests.Response类型的响应数据 - ''' + """ # 缓存读取 - cache_key = f'{url}{options}' - if (isinstance(options.get('cache-ignore'), list)): - for i in options.get('cache-ignore'): - cache_key = cache_key.replace(str(i), '') - options.pop('cache-ignore') + cache_key = f"{url}{options}" + if isinstance(options.get("cache-ignore"), list): + for i in options.get("cache-ignore"): + cache_key = cache_key.replace(str(i), "") + options.pop("cache-ignore") cache_key = utils.createMD5(cache_key) if options.get("cache") and options["cache"] != "no-cache": cache = config.getCache("httpx", cache_key) @@ -104,77 +111,83 @@ def request(url: str, options = {}) -> requests.Response: options.pop("cache") else: cache_info = None - # 获取请求方法,没有则默认为GET请求 try: - method = options['method'].upper() - options.pop('method') + method = options["method"].upper() + options.pop("method") except Exception as e: - method = 'GET' + method = "GET" # 获取User-Agent,没有则从ua_list中随机选择一个 try: - d_lower = {k.lower(): v for k, v in options['headers'].items()} - useragent = d_lower['user-agent'] + d_lower = {k.lower(): v for k, v in options["headers"].items()} + useragent = d_lower["user-agent"] except: try: - options['headers']['User-Agent'] = random.choice(ua_list) + options["headers"]["User-Agent"] = random.choice(ua_list) except: - options['headers'] = {} - options['headers']['User-Agent'] = random.choice(ua_list) + options["headers"] = {} + options["headers"]["User-Agent"] = random.choice(ua_list) # 检查是否在国内 - if ((not variable.iscn) and (not options["headers"].get("X-Forwarded-For"))): + if (not variable.iscn) and (not options["headers"].get("X-Forwarded-For")): options["headers"]["X-Forwarded-For"] = variable.fakeip # 获取请求主函数 try: reqattr = getattr(requests, method.lower()) except AttributeError: - raise AttributeError('Unsupported method: '+method) + raise AttributeError("Unsupported method: " + method) # 请求前记录 - logger.debug(f'HTTP Request: {url}\noptions: {options}') + logger.debug(f"HTTP Request: {url}\noptions: {options}") # 转换body/form参数为原生的data参数,并为form请求追加Content-Type头 - if (method == 'POST') or (method == 'PUT'): - if options.get('body'): - options['data'] = options['body'] - options.pop('body') - if options.get('form'): - options['data'] = convert_dict_to_form_string(options['form']) - options.pop('form') - options['headers']['Content-Type'] = 'application/x-www-form-urlencoded' - if (isinstance(options['data'], dict)): - options['data'] = json.dumps(options['data']) + if (method == "POST") or (method == "PUT"): + if options.get("body"): + options["data"] = options["body"] + options.pop("body") + if options.get("form"): + options["data"] = convert_dict_to_form_string(options["form"]) + options.pop("form") + options["headers"]["Content-Type"] = "application/x-www-form-urlencoded" + if isinstance(options["data"], dict): + options["data"] = json.dumps(options["data"]) # 进行请求 try: logger.info("-----start----- " + url) req = reqattr(url, **options) except Exception as e: - logger.error(f'HTTP Request runs into an Error: {log.highlight_error(traceback.format_exc())}') + logger.error(f"HTTP Request runs into an Error: {log.highlight_error(traceback.format_exc())}") raise e # 请求后记录 - logger.debug(f'Request to {url} succeed with code {req.status_code}') - if (req.content.startswith(b'\x78\x9c') or req.content.startswith(b'\x78\x01')): # zlib headers + logger.debug(f"Request to {url} succeed with code {req.status_code}") + if req.content.startswith(b"\x78\x9c") or req.content.startswith(b"\x78\x01"): # zlib headers try: decompressed = zlib.decompress(req.content) - if (is_valid_utf8(decompressed)): + if is_valid_utf8(decompressed): logger.debug(log_plaintext(decompressed.decode("utf-8"))) else: - logger.debug('response is not text binary, ignore logging it') + logger.debug("response is not text binary, ignore logging it") except: - logger.debug('response is not text binary, ignore logging it') + logger.debug("response is not text binary, ignore logging it") else: - if (is_valid_utf8(req.content)): + if is_valid_utf8(req.content): logger.debug(log_plaintext(req.content.decode("utf-8"))) else: - logger.debug('response is not text binary, ignore logging it') + logger.debug("response is not text binary, ignore logging it") # 缓存写入 - if (cache_info and cache_info != "no-cache"): + if cache_info and cache_info != "no-cache": cache_data = pickle.dumps(req) expire_time = (cache_info if isinstance(cache_info, int) else 3600) + int(time.time()) - config.updateCache("httpx", cache_key, {"expire": True, "time": expire_time, "data": utils.createBase64Encode(cache_data)}) + config.updateCache( + "httpx", + cache_key, + {"expire": True, "time": expire_time, "data": utils.createBase64Encode(cache_data)}, + expire_time, + ) logger.debug("缓存已更新: " + url) + def _json(): return json.loads(req.content) - setattr(req, 'json', _json) + + setattr(req, "json", _json) # 返回请求 return req @@ -184,22 +197,25 @@ def checkcn(): req = request("https://mips.kugou.com/check/iscn?&format=json") body = utils.CreateObject(req.json()) variable.iscn = bool(body.flag) - if (not variable.iscn): - variable.fakeip = config.read_config('common.fakeip') + if not variable.iscn: + variable.fakeip = config.read_config("common.fakeip") logger.info(f"您在非中国大陆服务器({body.country})上启动了项目,已自动开启ip伪装") - logger.warning("此方式无法解决咪咕音乐的链接获取问题,您可以配置代理,服务器地址可在下方链接中找到\nhttps://hidemy.io/cn/proxy-list/?country=CN#list") + logger.warning( + "此方式无法解决咪咕音乐的链接获取问题,您可以配置代理,服务器地址可在下方链接中找到\nhttps://hidemy.io/cn/proxy-list/?country=CN#list" + ) except Exception as e: - logger.warning('检查服务器位置失败,已忽略') + logger.warning("检查服务器位置失败,已忽略") logger.warning(traceback.format_exc()) + class ClientResponse: # 这个类为了方便aiohttp响应与requests响应的跨类使用,也为了解决pickle无法缓存的问题 def __init__(self, status, content, headers): self.status = status self.content = content self.headers = headers - self.text = content.decode("utf-8", errors='ignore') - + self.text = content.decode("utf-8", errors="ignore") + def json(self): return json.loads(self.content) @@ -208,11 +224,12 @@ async def convert_to_requests_response(aiohttp_response) -> ClientResponse: content = await aiohttp_response.content.read() # 从aiohttp响应中读取字节数据 status_code = aiohttp_response.status # 获取状态码 headers = dict(aiohttp_response.headers.items()) # 获取标头信息并转换为字典 - + return ClientResponse(status_code, content, headers) -async def AsyncRequest(url, options = {}) -> ClientResponse: - ''' + +async def AsyncRequest(url, options={}) -> ClientResponse: + """ Http异步请求主函数, 用于发送网络请求 - url: 需要请求的URL地址(必填) - options: 请求的配置参数(可选, 留空时为GET请求, 总体与nodejs的请求的options填写差不多) @@ -224,17 +241,17 @@ async def AsyncRequest(url, options = {}) -> ClientResponse: - no-cache: 不缓存 - : 缓存可用秒数 - cache-ignore: 缓存忽略关键字 - + @ return: common.Httpx.ClientResponse类型的响应数据 - ''' - if (not variable.aioSession): + """ + if not variable.aioSession: variable.aioSession = aiohttp.ClientSession(trust_env=True) # 缓存读取 - cache_key = f'{url}{options}' - if (isinstance(options.get('cache-ignore'), list)): - for i in options.get('cache-ignore'): - cache_key = cache_key.replace(str(i), '') - options.pop('cache-ignore') + cache_key = f"{url}{options}" + if isinstance(options.get("cache-ignore"), list): + for i in options.get("cache-ignore"): + cache_key = cache_key.replace(str(i), "") + options.pop("cache-ignore") cache_key = utils.createMD5(cache_key) if options.get("cache") and options["cache"] != "no-cache": cache = config.getCache("httpx_async", cache_key) @@ -247,76 +264,80 @@ async def AsyncRequest(url, options = {}) -> ClientResponse: options.pop("cache") else: cache_info = None - # 获取请求方法,没有则默认为GET请求 try: - method = options['method'] - options.pop('method') + method = options["method"] + options.pop("method") except Exception as e: - method = 'GET' + method = "GET" # 获取User-Agent,没有则从ua_list中随机选择一个 try: - d_lower = {k.lower(): v for k, v in options['headers'].items()} - useragent = d_lower['user-agent'] + d_lower = {k.lower(): v for k, v in options["headers"].items()} + useragent = d_lower["user-agent"] except: try: - options['headers']['User-Agent'] = random.choice(ua_list) + options["headers"]["User-Agent"] = random.choice(ua_list) except: - options['headers'] = {} - options['headers']['User-Agent'] = random.choice(ua_list) + options["headers"] = {} + options["headers"]["User-Agent"] = random.choice(ua_list) # 检查是否在国内 - if ((not variable.iscn) and (not options["headers"].get("X-Forwarded-For"))): + if (not variable.iscn) and (not options["headers"].get("X-Forwarded-For")): options["headers"]["X-Forwarded-For"] = variable.fakeip # 获取请求主函数 try: reqattr = getattr(variable.aioSession, method.lower()) except AttributeError: - raise AttributeError('Unsupported method: '+method) + raise AttributeError("Unsupported method: " + method) # 请求前记录 - logger.debug(f'HTTP Request: {url}\noptions: {options}') + logger.debug(f"HTTP Request: {url}\noptions: {options}") # 转换body/form参数为原生的data参数,并为form请求追加Content-Type头 - if (method == 'POST') or (method == 'PUT'): - if (options.get('body') is not None): - options['data'] = options['body'] - options.pop('body') - if (options.get('form') is not None): - options['data'] = convert_dict_to_form_string(options['form']) - options.pop('form') - options['headers']['Content-Type'] = 'application/x-www-form-urlencoded' - if (isinstance(options.get('data'), dict)): - options['data'] = json.dumps(options['data']) + if (method == "POST") or (method == "PUT"): + if options.get("body") is not None: + options["data"] = options["body"] + options.pop("body") + if options.get("form") is not None: + options["data"] = convert_dict_to_form_string(options["form"]) + options.pop("form") + options["headers"]["Content-Type"] = "application/x-www-form-urlencoded" + if isinstance(options.get("data"), dict): + options["data"] = json.dumps(options["data"]) # 进行请求 try: logger.info("-----start----- " + url) req_ = await reqattr(url, **options) except Exception as e: - logger.error(f'HTTP Request runs into an Error: {log.highlight_error(traceback.format_exc())}') + logger.error(f"HTTP Request runs into an Error: {log.highlight_error(traceback.format_exc())}") raise e # 请求后记录 - logger.debug(f'Request to {url} succeed with code {req_.status}') + logger.debug(f"Request to {url} succeed with code {req_.status}") # 为懒人提供的不用改代码移植的方法 # 才不是梓澄呢 req = await convert_to_requests_response(req_) - if (req.content.startswith(b'\x78\x9c') or req.content.startswith(b'\x78\x01')): # zlib headers + if req.content.startswith(b"\x78\x9c") or req.content.startswith(b"\x78\x01"): # zlib headers try: decompressed = zlib.decompress(req.content) - if (is_valid_utf8(decompressed)): + if is_valid_utf8(decompressed): logger.debug(log_plaintext(decompressed.decode("utf-8"))) else: - logger.debug('response is not text binary, ignore logging it') + logger.debug("response is not text binary, ignore logging it") except: - logger.debug('response is not text binary, ignore logging it') + logger.debug("response is not text binary, ignore logging it") else: - if (is_valid_utf8(req.content)): + if is_valid_utf8(req.content): logger.debug(log_plaintext(req.content.decode("utf-8"))) else: - logger.debug('response is not text binary, ignore logging it') + logger.debug("response is not text binary, ignore logging it") # 缓存写入 - if (cache_info and cache_info != "no-cache"): + if cache_info and cache_info != "no-cache": cache_data = pickle.dumps(req) expire_time = (cache_info if isinstance(cache_info, int) else 3600) + int(time.time()) - config.updateCache("httpx_async", cache_key, {"expire": True, "time": expire_time, "data": utils.createBase64Encode(cache_data)}) + config.updateCache( + "httpx_async", + cache_key, + {"expire": True, "time": expire_time, "data": utils.createBase64Encode(cache_data)}, + expire_time, + ) logger.debug("缓存已更新: " + url) # 返回请求 - return req \ No newline at end of file + return req diff --git a/common/config.py b/common/config.py index 565b0f3..238f59c 100644 --- a/common/config.py +++ b/common/config.py @@ -19,33 +19,52 @@ from .log import log from . import default_config import threading +import redis -logger = log('config_manager') +logger = log("config_manager") # 创建线程本地存储对象 local_data = threading.local() +local_cache = threading.local() +local_redis = threading.local() + def get_data_connection(): - # 检查线程本地存储对象是否存在连接对象,如果不存在则创建一个新的连接对象 - if (not hasattr(local_data, 'connection')): - local_data.connection = sqlite3.connect('./config/data.db') return local_data.connection -# 创建线程本地存储对象 -local_cache = threading.local() - - def get_cache_connection(): - # 检查线程本地存储对象是否存在连接对象,如果不存在则创建一个新的连接对象 - if not hasattr(local_cache, 'connection'): - local_cache.connection = sqlite3.connect('./cache.db') return local_cache.connection +def get_redis_connection(): + return local_redis.connection + + +def handle_connect_db(): + try: + local_data.connection = sqlite3.connect("./config/data.db") + if read_config("common.cache.adapter") == "redis": + host = read_config("common.cache.redis.host") + port = read_config("common.cache.redis.port") + user = read_config("common.cache.redis.user") + password = read_config("common.cache.redis.password") + db = read_config("common.cache.redis.db") + client = redis.Redis(host=host, port=port, username=user, password=password, db=db) + if not client.ping(): + raise + local_redis.connection = client + else: + local_cache.connection = sqlite3.connect("./cache.db") + except: + logger.error("连接数据库失败") + sys.exit(1) + + class ConfigReadException(Exception): pass + yaml = yaml_.YAML() default_str = default_config.default default = yaml.load(default_str) @@ -54,8 +73,10 @@ class ConfigReadException(Exception): def handle_default_config(): with open("./config/config.yml", "w", encoding="utf-8") as f: f.write(default_str) - if (not os.getenv('build')): - logger.info(f'首次启动或配置文件被删除,已创建默认配置文件\n建议您到{variable.workdir + os.path.sep}config.yml修改配置后重新启动服务器') + if not os.getenv("build"): + logger.info( + f"首次启动或配置文件被删除,已创建默认配置文件\n建议您到{variable.workdir + os.path.sep}config.yml修改配置后重新启动服务器" + ) return default @@ -96,8 +117,7 @@ def save_data(config_data): # Insert the new configuration data into the 'data' table for key, value in config_data.items(): - cursor.execute( - "INSERT INTO data (key, value) VALUES (?, ?)", (key, json.dumps(value))) + cursor.execute("INSERT INTO data (key, value) VALUES (?, ?)", (key, json.dumps(value))) conn.commit() @@ -106,51 +126,69 @@ def save_data(config_data): logger.error(traceback.format_exc()) -def getCache(module, key): - try: - # 连接到数据库(如果数据库不存在,则会自动创建) - conn = get_cache_connection() - - # 创建一个游标对象 - cursor = conn.cursor() +def handleBuildRedisKey(module, key): + prefix = read_config("common.cache.redis.key_prefix") + return f"{prefix}:{module}:{key}" - cursor.execute("SELECT data FROM cache WHERE module=? AND key=?", - (module, key)) - result = cursor.fetchone() - if result: - cache_data = json.loads(result[0]) - cache_data["time"] = int(cache_data["time"]) - if (not cache_data['expire']): - return cache_data - if (int(time.time()) < int(cache_data['time'])): +def getCache(module, key): + try: + if read_config("common.cache.adapter") == "redis": + redis = get_redis_connection() + key = handleBuildRedisKey(module, key) + result = redis.get(key) + if result: + cache_data = json.loads(result) return cache_data + else: + # 连接到数据库(如果数据库不存在,则会自动创建) + conn = get_cache_connection() + + # 创建一个游标对象 + cursor = conn.cursor() + + cursor.execute("SELECT data FROM cache WHERE module=? AND key=?", (module, key)) + + result = cursor.fetchone() + if result: + cache_data = json.loads(result[0]) + cache_data["time"] = int(cache_data["time"]) + if not cache_data["expire"]: + return cache_data + if int(time.time()) < int(cache_data["time"]): + return cache_data except: pass # traceback.print_exc() - return False + return None -def updateCache(module, key, data): +def updateCache(module, key, data, expire=None): try: - # 连接到数据库(如果数据库不存在,则会自动创建) - conn = get_cache_connection() - - # 创建一个游标对象 - cursor = conn.cursor() - - cursor.execute( - "SELECT data FROM cache WHERE module=? AND key=?", (module, key)) - result = cursor.fetchone() - if result: - cursor.execute( - "UPDATE cache SET data = ? WHERE module = ? AND key = ?", (json.dumps(data), module, key)) + if read_config("common.cache.adapter") == "redis": + redis = get_redis_connection() + key = handleBuildRedisKey(module, key) + redis.set(key, json.dumps(data), ex=expire if expire and expire > 0 else None) else: - cursor.execute( - "INSERT INTO cache (module, key, data) VALUES (?, ?, ?)", (module, key, json.dumps(data))) - conn.commit() + # 连接到数据库(如果数据库不存在,则会自动创建) + conn = get_cache_connection() + + # 创建一个游标对象 + cursor = conn.cursor() + + cursor.execute("SELECT data FROM cache WHERE module=? AND key=?", (module, key)) + result = cursor.fetchone() + if result: + cursor.execute( + "UPDATE cache SET data = ? WHERE module = ? AND key = ?", (json.dumps(data), module, key) + ) + else: + cursor.execute( + "INSERT INTO cache (module, key, data) VALUES (?, ?, ?)", (module, key, json.dumps(data)) + ) + conn.commit() except: - logger.error('缓存写入遇到错误…') + logger.error("缓存写入遇到错误…") logger.error(traceback.format_exc()) @@ -158,13 +196,13 @@ def resetRequestTime(ip): config_data = load_data() try: try: - config_data['requestTime'][ip] = 0 + config_data["requestTime"][ip] = 0 except KeyError: - config_data['requestTime'] = {} - config_data['requestTime'][ip] = 0 + config_data["requestTime"] = {} + config_data["requestTime"][ip] = 0 save_data(config_data) except: - logger.error('配置写入遇到错误…') + logger.error("配置写入遇到错误…") logger.error(traceback.format_exc()) @@ -172,20 +210,20 @@ def updateRequestTime(ip): try: config_data = load_data() try: - config_data['requestTime'][ip] = time.time() + config_data["requestTime"][ip] = time.time() except KeyError: - config_data['requestTime'] = {} - config_data['requestTime'][ip] = time.time() + config_data["requestTime"] = {} + config_data["requestTime"][ip] = time.time() save_data(config_data) except: - logger.error('配置写入遇到错误...') + logger.error("配置写入遇到错误...") logger.error(traceback.format_exc()) def getRequestTime(ip): config_data = load_data() try: - value = config_data['requestTime'][ip] + value = config_data["requestTime"][ip] except: value = 0 return value @@ -193,7 +231,7 @@ def getRequestTime(ip): def read_data(key): config = load_data() - keys = key.split('.') + keys = key.split(".") value = config for k in keys: if k not in value and keys.index(k) != len(keys) - 1: @@ -208,7 +246,7 @@ def read_data(key): def write_data(key, value): config = load_data() - keys = key.split('.') + keys = key.split(".") current = config for k in keys[:-1]: if k not in current: @@ -223,7 +261,7 @@ def write_data(key, value): def push_to_list(key, obj): config = load_data() - keys = key.split('.') + keys = key.split(".") current = config for k in keys[:-1]: if k not in current: @@ -240,10 +278,10 @@ def push_to_list(key, obj): def write_config(key, value): config = None - with open('./config/config.yml', 'r', encoding='utf-8') as f: + with open("./config/config.yml", "r", encoding="utf-8") as f: config = yaml_.YAML().load(f) - keys = key.split('.') + keys = key.split(".") current = config for k in keys[:-1]: if k not in current: @@ -258,14 +296,14 @@ def write_config(key, value): y.preserve_blank_lines = True # 写入配置并保留注释和空行 - with open('./config/config.yml', 'w', encoding='utf-8') as f: + with open("./config/config.yml", "w", encoding="utf-8") as f: y.dump(config, f) def read_default_config(key): try: config = default - keys = key.split('.') + keys = key.split(".") value = config for k in keys: if isinstance(value, dict): @@ -286,7 +324,7 @@ def read_default_config(key): def _read_config(key): try: config = variable.config - keys = key.split('.') + keys = key.split(".") value = config for k in keys: if isinstance(value, dict): @@ -307,7 +345,7 @@ def _read_config(key): def read_config(key): try: config = variable.config - keys = key.split('.') + keys = key.split(".") value = config for k in keys: if isinstance(value, dict): @@ -323,23 +361,23 @@ def read_config(key): return value except: default_value = read_default_config(key) - if (isinstance(default_value, type(None))): - logger.warning(f'配置文件{key}不存在') + if isinstance(default_value, type(None)): + logger.warning(f"配置文件{key}不存在") else: for i in range(len(keys)): - tk = '.'.join(keys[:(i + 1)]) + tk = ".".join(keys[: (i + 1)]) tkvalue = _read_config(tk) - logger.debug(f'configfix: 读取配置文件{tk}的值:{tkvalue}') - if ((tkvalue is None) or (tkvalue == {})): + logger.debug(f"configfix: 读取配置文件{tk}的值:{tkvalue}") + if (tkvalue is None) or (tkvalue == {}): write_config(tk, read_default_config(tk)) - logger.info(f'配置文件{tk}不存在,已创建') + logger.info(f"配置文件{tk}不存在,已创建") return default_value def write_data(key, value): config = load_data() - keys = key.split('.') + keys = key.split(".") current = config for k in keys[:-1]: if k not in current: @@ -351,26 +389,26 @@ def write_data(key, value): save_data(config) -def initConfig(): - if (not os.path.exists('./config')): - os.mkdir('config') - if (os.path.exists('./config.json')): - shutil.move('config.json','./config') - if (os.path.exists('./data.db')): - shutil.move('./data.db','./config') - if (os.path.exists('./config/config.json')): - os.rename('./config/config.json', './config/config.json.bak') +def init_config(): + if not os.path.exists("./config"): + os.mkdir("config") + if os.path.exists("./config.json"): + shutil.move("config.json", "./config") + if os.path.exists("./data.db"): + shutil.move("./data.db", "./config") + if os.path.exists("./config/config.json"): + os.rename("./config/config.json", "./config/config.json.bak") handle_default_config() - logger.warning('json配置文件已不再使用,已将其重命名为config.json.bak') - logger.warning('配置文件不会自动更新(因为变化太大),请手动修改配置文件重启服务器') + logger.warning("json配置文件已不再使用,已将其重命名为config.json.bak") + logger.warning("配置文件不会自动更新(因为变化太大),请手动修改配置文件重启服务器") sys.exit(0) try: with open("./config/config.yml", "r", encoding="utf-8") as f: try: variable.config = yaml.load(f.read()) - if (not isinstance(variable.config, dict)): - logger.warning('配置文件并不是一个有效的字典,使用默认值') + if not isinstance(variable.config, dict): + logger.warning("配置文件并不是一个有效的字典,使用默认值") variable.config = default with open("./config/config.yml", "w", encoding="utf-8") as f: yaml.dump(variable.config, f) @@ -384,125 +422,133 @@ def initConfig(): except FileNotFoundError: variable.config = handle_default_config() # print(variable.config) - variable.log_length_limit = read_config('common.log_length_limit') - variable.debug_mode = read_config('common.debug_mode') + variable.log_length_limit = read_config("common.log_length_limit") + variable.debug_mode = read_config("common.debug_mode") logger.debug("配置文件加载成功") - conn = sqlite3.connect('./cache.db') + + # 尝试连接数据库 + handle_connect_db() + + conn = sqlite3.connect("./cache.db") # 创建一个游标对象 cursor = conn.cursor() # 创建一个表来存储缓存数据 - cursor.execute('''CREATE TABLE IF NOT EXISTS cache + cursor.execute( + """CREATE TABLE IF NOT EXISTS cache (id INTEGER PRIMARY KEY AUTOINCREMENT, module TEXT NOT NULL, key TEXT NOT NULL, -data TEXT NOT NULL)''') +data TEXT NOT NULL)""" + ) conn.close() - conn2 = sqlite3.connect('./config/data.db') + conn2 = sqlite3.connect("./config/data.db") # 创建一个游标对象 cursor2 = conn2.cursor() - cursor2.execute('''CREATE TABLE IF NOT EXISTS data + cursor2.execute( + """CREATE TABLE IF NOT EXISTS data (key TEXT PRIMARY KEY, -value TEXT)''') +value TEXT)""" + ) conn2.close() - logger.debug('数据库初始化成功') + logger.debug("数据库初始化成功") # handle data - all_data_keys = {'banList': [], 'requestTime': {}, 'banListRaw': []} + all_data_keys = {"banList": [], "requestTime": {}, "banListRaw": []} data = load_data() - if (data == {}): - write_data('banList', []) - write_data('requestTime', {}) - logger.info('数据库内容为空,已写入默认值') + if data == {}: + write_data("banList", []) + write_data("requestTime", {}) + logger.info("数据库内容为空,已写入默认值") for k, v in all_data_keys.items(): - if (k not in data): + if k not in data: write_data(k, v) - logger.info(f'数据库中不存在{k},已创建') + logger.info(f"数据库中不存在{k},已创建") # 处理代理配置 - if (read_config('common.proxy.enable')): - if (read_config('common.proxy.http_value')): - os.environ['http_proxy'] = read_config('common.proxy.http_value') - logger.info('HTTP协议代理地址: ' + - read_config('common.proxy.http_value')) - if (read_config('common.proxy.https_value')): - os.environ['https_proxy'] = read_config('common.proxy.https_value') - logger.info('HTTPS协议代理地址: ' + - read_config('common.proxy.https_value')) - logger.info('代理功能已开启,请确保代理地址正确,否则无法连接网络') + if read_config("common.proxy.enable"): + if read_config("common.proxy.http_value"): + os.environ["http_proxy"] = read_config("common.proxy.http_value") + logger.info("HTTP协议代理地址: " + read_config("common.proxy.http_value")) + if read_config("common.proxy.https_value"): + os.environ["https_proxy"] = read_config("common.proxy.https_value") + logger.info("HTTPS协议代理地址: " + read_config("common.proxy.https_value")) + logger.info("代理功能已开启,请确保代理地址正确,否则无法连接网络") # cookie池 - if (read_config('common.cookiepool')): - logger.info('已启用cookie池功能,请确定配置的cookie都能正确获取链接') - logger.info('传统的源 - 单用户cookie配置将被忽略') - logger.info('所以即使某个源你只有一个cookie,也请填写到cookiepool对应的源中,否则将无法使用该cookie') + if read_config("common.cookiepool"): + logger.info("已启用cookie池功能,请确定配置的cookie都能正确获取链接") + logger.info("传统的源 - 单用户cookie配置将被忽略") + logger.info("所以即使某个源你只有一个cookie,也请填写到cookiepool对应的源中,否则将无法使用该cookie") variable.use_cookie_pool = True # 移除已经过期的封禁数据 - banlist = read_data('banList') - banlistRaw = read_data('banListRaw') + banlist = read_data("banList") + banlistRaw = read_data("banListRaw") count = 0 for b in banlist: - if (b['expire'] and (time.time() > b['expire_time'])): + if b["expire"] and (time.time() > b["expire_time"]): count += 1 banlist.remove(b) - if (b['ip'] in banlistRaw): - banlistRaw.remove(b['ip']) - write_data('banList', banlist) - write_data('banListRaw', banlistRaw) - if (count != 0): - logger.info(f'已移除{count}条过期封禁数据') + if b["ip"] in banlistRaw: + banlistRaw.remove(b["ip"]) + write_data("banList", banlist) + write_data("banListRaw", banlistRaw) + if count != 0: + logger.info(f"已移除{count}条过期封禁数据") # 处理旧版数据库的banListRaw - banlist = read_data('banList') - banlistRaw = read_data('banListRaw') - if (banlist != [] and banlistRaw == []): + banlist = read_data("banList") + banlistRaw = read_data("banListRaw") + if banlist != [] and banlistRaw == []: for b in banlist: - banlistRaw.append(b['ip']) + banlistRaw.append(b["ip"]) return def ban_ip(ip_addr, ban_time=-1): - if read_config('security.banlist.enable'): - banList = read_data('banList') - banList.append({ - 'ip': ip_addr, - 'expire': read_config('security.banlist.expire.enable'), - 'expire_time': read_config('security.banlist.expire.length') if (ban_time == -1) else ban_time, - }) - write_data('banList', banList) - banListRaw = read_data('banListRaw') - if (ip_addr not in banListRaw): + if read_config("security.banlist.enable"): + banList = read_data("banList") + banList.append( + { + "ip": ip_addr, + "expire": read_config("security.banlist.expire.enable"), + "expire_time": read_config("security.banlist.expire.length") if (ban_time == -1) else ban_time, + } + ) + write_data("banList", banList) + banListRaw = read_data("banListRaw") + if ip_addr not in banListRaw: banListRaw.append(ip_addr) - write_data('banListRaw', banListRaw) + write_data("banListRaw", banListRaw) else: - if (variable.banList_suggest < 10): + if variable.banList_suggest < 10: variable.banList_suggest += 1 - logger.warning('黑名单功能已被关闭,我们墙裂建议你开启这个功能以防止恶意请求') + logger.warning("黑名单功能已被关闭,我们墙裂建议你开启这个功能以防止恶意请求") def check_ip_banned(ip_addr): - if read_config('security.banlist.enable'): - banList = read_data('banList') - banlistRaw = read_data('banListRaw') - if (ip_addr in banlistRaw): + if read_config("security.banlist.enable"): + banList = read_data("banList") + banlistRaw = read_data("banListRaw") + if ip_addr in banlistRaw: for b in banList: - if (b['ip'] == ip_addr): - if (b['expire']): - if (b['expire_time'] > int(time.time())): + if b["ip"] == ip_addr: + if b["expire"]: + if b["expire_time"] > int(time.time()): return True else: banList.remove(b) - banlistRaw.remove(b['ip']) - write_data('banListRaw', banlistRaw) - write_data('banList', banList) + banlistRaw.remove(b["ip"]) + write_data("banListRaw", banlistRaw) + write_data("banList", banList) return False else: return True @@ -512,10 +558,10 @@ def check_ip_banned(ip_addr): else: return False else: - if (variable.banList_suggest <= 10): + if variable.banList_suggest <= 10: variable.banList_suggest += 1 - logger.warning('黑名单功能已被关闭,我们墙裂建议你开启这个功能以防止恶意请求') + logger.warning("黑名单功能已被关闭,我们墙裂建议你开启这个功能以防止恶意请求") return False -initConfig() +init_config() diff --git a/common/default_config.py b/common/default_config.py index 0107dda..caa410a 100644 --- a/common/default_config.py +++ b/common/default_config.py @@ -51,6 +51,18 @@ local_music: # 服务器侧本地音乐相关配置,如果需要使用此功能请确保你的带宽足够 audio_path: ./audio temp_path: ./temp + # 缓存配置 + cache: + # 适配器 [redis,sql] + adapter: sql + # redis 配置 + redis: + host: 127.0.0.1 + port: 6379 + db: 0 + user: "" + password: "" + key_prefix: "LXAPISERVER" security: rate_limit: # 请求速率限制 填入的值为至少间隔多久才能进行一次请求,单位:秒,不限制请填为0 diff --git a/modules/__init__.py b/modules/__init__.py index 0aff336..6ab19a2 100644 --- a/modules/__init__.py +++ b/modules/__init__.py @@ -11,6 +11,7 @@ from common.utils import require from common import log from common import config + # 从.引入的包并没有在代码中直接使用,但是是用require在请求时进行引入的,不要动 from . import kw from . import mg @@ -20,194 +21,194 @@ import traceback import time -logger = log.log('api_handler') +logger = log.log("api_handler") sourceExpirationTime = { - 'tx': { + "tx": { "expire": True, "time": 80400, # 不知道tx为什么要取一个这么不对劲的数字当过期时长 }, - 'kg': { + "kg": { "expire": True, "time": 24 * 60 * 60, # 24 hours }, - 'kw': { - "expire": True, - "time": 60 * 60 # 60 minutes - }, - 'wy': { + "kw": {"expire": True, "time": 60 * 60}, # 60 minutes + "wy": { "expire": True, "time": 20 * 60, # 20 minutes }, - 'mg': { + "mg": { "expire": False, "time": 0, - } - + }, } -async def url(source, songId, quality, query = {}): - if (not quality): +async def url(source, songId, quality, query={}): + if not quality: return { - 'code': 2, - 'msg': '需要参数"quality"', - 'data': None, + "code": 2, + "msg": '需要参数"quality"', + "data": None, } - - if (source == "kg"): + + if source == "kg": songId = songId.lower() - + try: - cache = config.getCache('urls', f'{source}_{songId}_{quality}') + cache = config.getCache("urls", f"{source}_{songId}_{quality}") if cache: logger.debug(f'使用缓存的{source}_{songId}_{quality}数据,URL:{cache["url"]}') return { - 'code': 0, - 'msg': 'success', - 'data': cache['url'], - 'extra': { - 'cache': True, - 'quality': { - 'target': quality, - 'result': quality, + "code": 0, + "msg": "success", + "data": cache["url"], + "extra": { + "cache": True, + "quality": { + "target": quality, + "result": quality, }, - 'expire': { + "expire": { # 在更新缓存的时候把有效期的75%作为链接可用时长,现在加回来 - 'time': int(cache['time'] + (sourceExpirationTime[source]['time'] * 0.25)) if cache['expire'] else None, - 'canExpire': cache['expire'], - } + "time": ( + int(cache["time"] + (sourceExpirationTime[source]["time"] * 0.25)) + if cache["expire"] + else None + ), + "canExpire": cache["expire"], + }, }, } except: logger.error(traceback.format_exc()) try: - func = require('modules.' + source + '.url') + func = require("modules." + source + ".url") except: return { - 'code': 1, - 'msg': '未知的源或不支持的方法', - 'data': None, + "code": 1, + "msg": "未知的源或不支持的方法", + "data": None, } try: result = await func(songId, quality) logger.info(f'获取{source}_{songId}_{quality}成功,URL:{result["url"]}') - canExpire = sourceExpirationTime[source]['expire'] - expireTime = sourceExpirationTime[source]['time'] + int(time.time()) - config.updateCache('urls', f'{source}_{songId}_{quality}', { - "expire": canExpire, - # 取有效期的75%作为链接可用时长 - "time": int(expireTime - sourceExpirationTime[source]['time'] * 0.25), - "url": result['url'], - }) + canExpire = sourceExpirationTime[source]["expire"] + expireTime = sourceExpirationTime[source]["time"] + int(time.time()) + canUseTime = int(expireTime - sourceExpirationTime[source]["time"] * 0.25) + config.updateCache( + "urls", + f"{source}_{songId}_{quality}", + { + "expire": canExpire, + # 取有效期的75%作为链接可用时长 + "time": canUseTime, + "url": result["url"], + }, + canUseTime if canExpire else None, + ) logger.debug(f'缓存已更新:{source}_{songId}_{quality}, URL:{result["url"]}, expire: {expireTime}') return { - 'code': 0, - 'msg': 'success', - 'data': result['url'], - 'extra': { - 'cache': False, - 'quality': { - 'target': quality, - 'result': result['quality'], + "code": 0, + "msg": "success", + "data": result["url"], + "extra": { + "cache": False, + "quality": { + "target": quality, + "result": result["quality"], }, - 'expire': { - 'time': expireTime if canExpire else None, - 'canExpire': canExpire, + "expire": { + "time": expireTime if canExpire else None, + "canExpire": canExpire, }, }, } except FailedException as e: - logger.info(f'获取{source}_{songId}_{quality}失败,原因:' + e.args[0]) + logger.info(f"获取{source}_{songId}_{quality}失败,原因:" + e.args[0]) return { - 'code': 2, - 'msg': e.args[0], - 'data': None, + "code": 2, + "msg": e.args[0], + "data": None, } + async def lyric(source, songId, _, query): - cache = config.getCache('lyric', f'{source}_{songId}') + cache = config.getCache("lyric", f"{source}_{songId}") if cache: - return { - 'code': 0, - 'msg': 'success', - 'data': cache['data'] - } + return {"code": 0, "msg": "success", "data": cache["data"]} try: - func = require('modules.' + source + '.lyric') + func = require("modules." + source + ".lyric") except: return { - 'code': 1, - 'msg': '未知的源或不支持的方法', - 'data': None, + "code": 1, + "msg": "未知的源或不支持的方法", + "data": None, } try: result = await func(songId) - config.updateCache('lyric', f'{source}_{songId}', { - "data": result, - "time": int(time.time() + (86400 * 3)), # 歌词缓存3天 - "expire": True, - }) - logger.debug(f'缓存已更新:{source}_{songId}, lyric: {result}') - return { - 'code': 0, - 'msg': 'success', - 'data': result - } + expireTime = int(time.time() + (86400 * 3)) + config.updateCache( + "lyric", + f"{source}_{songId}", + { + "data": result, + "time": expireTime, # 歌词缓存3天 + "expire": True, + }, + expireTime, + ) + logger.debug(f"缓存已更新:{source}_{songId}, lyric: {result}") + return {"code": 0, "msg": "success", "data": result} except FailedException as e: return { - 'code': 2, - 'msg': e.args[0], - 'data': None, + "code": 2, + "msg": e.args[0], + "data": None, } + async def search(source, songid, _, query): try: - func = require('modules.' + source + '.search') + func = require("modules." + source + ".search") except: return { - 'code': 1, - 'msg': '未知的源或不支持的方法', - 'data': None, + "code": 1, + "msg": "未知的源或不支持的方法", + "data": None, } try: result = await func(songid, query) - return { - 'code': 0, - 'msg': 'success', - 'data': result - } + return {"code": 0, "msg": "success", "data": result} except FailedException as e: return { - 'code': 2, - 'msg': e.args[0], - 'data': None, + "code": 2, + "msg": e.args[0], + "data": None, } + async def other(method, source, songid, _, query): try: - func = require('modules.' + source + '.' + method) + func = require("modules." + source + "." + method) except: return { - 'code': 1, - 'msg': '未知的源或不支持的方法', - 'data': None, + "code": 1, + "msg": "未知的源或不支持的方法", + "data": None, } try: result = await func(songid) - return { - 'code': 0, - 'msg': 'success', - 'data': result - } + return {"code": 0, "msg": "success", "data": result} except FailedException as e: return { - 'code': 2, - 'msg': e.args[0], - 'data': None, + "code": 2, + "msg": e.args[0], + "data": None, } + async def info_with_query(source, songid, _, query): - return await other('info', source, songid, None) \ No newline at end of file + return await other("info", source, songid, None) diff --git a/requirements.txt b/requirements.txt index d84195e..ad2e871 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ mutagen pillow colorama ruamel-yaml +redis +redis[hiredis]