From 3c0aa7613c59b9287b8a802f9e494ee1994c5bd0 Mon Sep 17 00:00:00 2001 From: yanghua Date: Thu, 26 Sep 2024 15:44:27 +0800 Subject: [PATCH] Tag bucket --- pyproject.toml | 1 + tosfs/consts.py | 4 +- tosfs/core.py | 20 +++++ tosfs/tag.py | 218 ++++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 205 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 299d2d4..36a9403 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ select = [ ignore = [ "S101", # Use of `assert` detected "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes + "S108", # Probable insecure usage of temporary file or directory "D203", # no-blank-line-before-class "D213", # multi-line-summary-second-line "PLR0913", # Too many arguments in function definition diff --git a/tosfs/consts.py b/tosfs/consts.py index 98d5e62..556aac0 100644 --- a/tosfs/consts.py +++ b/tosfs/consts.py @@ -31,6 +31,8 @@ LS_OPERATION_DEFAULT_MAX_ITEMS = 1000 +TOSFS_LOG_FORMAT = "%(asctime)s %(name)s [%(levelname)s] %(filename)s:%(lineno)d %(funcName)s : %(message)s" # noqa: E501 + # environment variable names ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL" -TOSFS_LOG_FORMAT = "%(asctime)s %(name)s [%(levelname)s] %(filename)s:%(lineno)d %(funcName)s : %(message)s" # noqa: E501 +ENV_NAME_TOS_BUCKET_TAG_ENABLE = "TOS_BUCKET_TAG_ENABLE" diff --git a/tosfs/core.py b/tosfs/core.py index 807c1bb..947626d 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -25,6 +25,7 @@ from fsspec import AbstractFileSystem from fsspec.spec import AbstractBufferedFile from fsspec.utils import setup_logging as setup_logger +from tos.auth import CredentialProviderAuth from tos.exceptions import TosClientError, TosServerError from tos.models import CommonPrefixInfo from tos.models2 import ( @@ -54,6 +55,7 @@ from tosfs.fsspec_utils import glob_translate from tosfs.mpu import MultipartUploader from tosfs.retry import retryable_func_executor +from tosfs.tag import BucketTagMgr from tosfs.utils import find_bucket_key, get_brange logger = logging.getLogger("tosfs") @@ -203,6 +205,10 @@ def __init__( if version_aware: raise ValueError("Currently, version_aware is not supported.") + self.tag_enabled = os.environ.get("TOS_TAG_ENABLED", True) + if self.tag_enabled: + self._init_tag_manager() + self.version_aware = version_aware self.default_block_size = ( default_block_size or FILE_OPERATION_READ_WRITE_BUFFER_SIZE @@ -2093,12 +2099,26 @@ def _split_path(self, path: str) -> Tuple[str, str, Optional[str]]: bucket, keypart = find_bucket_key(path) key, _, version_id = keypart.partition("?versionId=") + + if self.tag_enabled: + self.bucket_tag_mgr.add_bucket_tag(bucket) + return ( bucket, key, version_id if self.version_aware and version_id else None, ) + def _init_tag_manager(self) -> None: + auth = self.tos_client.auth + if isinstance(auth, CredentialProviderAuth): + credentials = auth.credentials_provider.get_credentials() + self.bucket_tag_mgr = BucketTagMgr( + credentials.get_ak(), credentials.get_sk(), auth.region + ) + else: + raise TosfsError("Currently only support CredentialProviderAuth type") + @staticmethod def _fill_dir_info( bucket: str, common_prefix: Optional[CommonPrefixInfo], key: str = "" diff --git a/tosfs/tag.py b/tosfs/tag.py index 020c73d..ea13f17 100644 --- a/tosfs/tag.py +++ b/tosfs/tag.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""The module contains all the business logic for tagging tos buckets .""" + +import fcntl +import functools import json import logging import os import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any from volcengine.ApiInfo import ApiInfo from volcengine.base.Service import Service @@ -32,89 +38,227 @@ service_info_map = { - "cn-beijing": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, }, - Credentials("", "", "emr", "cn-beijing"), 60 * 5, 60 * 5, "http"), - "cn-guangzhou": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, }, - Credentials("", "", "emr", "cn-guangzhou"), 60 * 5, 60 * 5, "http"), - "cn-shanghai": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, }, - Credentials("", "", "emr", "cn-shanghai"), 60 * 5, 60 * 5, "http"), - "ap-southeast-1": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, }, - Credentials("", "", "emr", "ap-southeast-1"), 60 * 5, 60 * 5, "http"), - "cn-beijing-qa": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, }, - Credentials("", "", "emr_qa", "cn-beijing"), 60 * 5, 60 * 5, "http"), + "cn-beijing": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "cn-beijing"), + 60 * 5, + 60 * 5, + "http", + ), + "cn-guangzhou": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "cn-guangzhou"), + 60 * 5, + 60 * 5, + "http", + ), + "cn-shanghai": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "cn-shanghai"), + 60 * 5, + 60 * 5, + "http", + ), + "ap-southeast-1": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", "ap-southeast-1"), + 60 * 5, + 60 * 5, + "http", + ), + "cn-beijing-qa": ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr_qa", "cn-beijing"), + 60 * 5, + 60 * 5, + "http", + ), } api_info = { - PUT_TAG_ACTION_NAME: ApiInfo("POST", "/", { - "Action": PUT_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, {}, {}), - GET_TAG_ACTION_NAME: ApiInfo("GET", "/", { - "Action": GET_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, {}, {}), - DEL_TAG_ACTION_NAME: ApiInfo("POST", "/", { - "Action": DEL_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, {}, {}), + PUT_TAG_ACTION_NAME: ApiInfo( + "POST", + "/", + {"Action": PUT_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, + {}, + {}, + ), + GET_TAG_ACTION_NAME: ApiInfo( + "GET", + "/", + {"Action": GET_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, + {}, + {}, + ), + DEL_TAG_ACTION_NAME: ApiInfo( + "POST", + "/", + {"Action": DEL_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, + {}, + {}, + ), } + class BucketTagAction(Service): + """BucketTagAction is a class to manage the tag of bucket.""" + _instance_lock = threading.Lock() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + """Singleton.""" if not hasattr(BucketTagAction, "_instance"): with BucketTagAction._instance_lock: if not hasattr(BucketTagAction, "_instance"): BucketTagAction._instance = object.__new__(cls) return BucketTagAction._instance - def __init__(self, access_key = None, secret_key = None, region = "cn-beijing"): - if region is None: - region = "cn-beijing" + def __init__(self, key: str, secret: str, region: str = "cn-beijing") -> None: + """Init BucketTagAction.""" super().__init__(self.get_service_info(region), self.get_api_info()) - if access_key is not None and secret_key is not None: - self.set_ak(access_key) - self.set_sk(secret_key) + self.set_ak(key) + self.set_sk(secret) @staticmethod - def get_api_info(): + def get_api_info() -> dict: + """Get api info.""" return api_info @staticmethod - def get_service_info(region): + def get_service_info(region: str) -> ServiceInfo: + """Get service info.""" service_info = service_info_map.get(region) if service_info: return service_info elif "VOLC_REGION" in os.environ: - return ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, }, - Credentials("", "", "emr", region), 60 * 5, 60 * 5, "http") + return ServiceInfo( + OPEN_API_HOST, + { + ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, + }, + Credentials("", "", "emr", region), + 60 * 5, + 60 * 5, + "http", + ) else: raise Exception("do not support region %s" % region) - def put_bucket_tag(self, bucket): - params = {"Bucket": bucket,} + def put_bucket_tag(self, bucket: str) -> tuple[str, bool]: + """Put tag for bucket.""" + params = { + "Bucket": bucket, + } try: res = self.json(PUT_TAG_ACTION_NAME, params, json.dumps("")) res_json = json.loads(res) - logging.debug("Put tag for bucket %s is success. The result of put_Bucket_tag is %s.", bucket, res_json) + logging.debug("Put tag for bucket %s successfully: %s .", bucket, res_json) return (bucket, True) except Exception as e: - logging.error("Put tag for bucket %s is failed: %s", bucket, e) + logging.debug("Put tag for bucket %s failed: %s .", bucket, e) return (bucket, False) - def get_bucket_tag(self, bucket): - params = {"Bucket": bucket,} + def get_bucket_tag(self, bucket: str) -> bool: + """Get tag for bucket.""" + params = { + "Bucket": bucket, + } try: res = self.get(GET_TAG_ACTION_NAME, params) res_json = json.loads(res) logging.debug("The result of get_Bucket_tag is %s", res_json) return True except Exception as e: - logging.error("Get tag for %s is failed: %s", bucket, e) + logging.debug("Get tag for %s is failed: %s", bucket, e) return False - def del_bucket_tag(self, bucket): - params = {"Bucket": bucket,} + def del_bucket_tag(self, bucket: str) -> None: + """Delete tag for bucket.""" + params = { + "Bucket": bucket, + } try: res = self.json(DEL_TAG_ACTION_NAME, params, json.dumps("")) res_json = json.loads(res) logging.debug("The result of del_Bucket_tag is %s", res_json) except Exception as e: - logging.error("Delete tag for %s is failed: %s", bucket, e) + logging.debug("Delete tag for %s is failed: %s", bucket, e) + + +THREAD_POOL_SIZE = 2 +TAGGED_BUCKETS_FILE = "/tmp/.emr_tagged_buckets" + + +def singleton(cls: Any) -> Any: + """Singleton decorator.""" + _instances = {} + + @functools.wraps(cls) + def get_instance(*args: Any, **kwargs: Any) -> Any: + if cls not in _instances: + _instances[cls] = cls(*args, **kwargs) + return _instances[cls] + + return get_instance + + +@singleton +class BucketTagMgr: + """BucketTagMgr is a class to manage the tag of bucket.""" + + def __init__(self, key: str, secret: str, region: str): + """Init BucketTagMgr.""" + self.executor = ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE) + self.cached_bucket_set: set = set() + self.key = key + self.secret = secret + self.region = region + + def add_bucket_tag(self, bucket: str) -> None: + """Add tag for bucket.""" + collect_bucket_set = set() + collect_bucket_set.add(bucket) + + if ( + len(collect_bucket_set) == 0 + or len(collect_bucket_set - self.cached_bucket_set) == 0 + ): + return + + tagged_bucket_from_file_set = set() + if os.path.exists(TAGGED_BUCKETS_FILE): + with open(TAGGED_BUCKETS_FILE, "r") as file: + tagged_bucket_from_file_set = set(file.read().split(" ")) + + self.cached_bucket_set = self.cached_bucket_set | tagged_bucket_from_file_set + need_tag_buckets = collect_bucket_set - self.cached_bucket_set + + bucket_tag_service = BucketTagAction(self.key, self.secret, self.region) + + for res in self.executor.map( + bucket_tag_service.put_bucket_tag, need_tag_buckets + ): + if res[1] is True: + self.cached_bucket_set.add(res[0]) + with open(TAGGED_BUCKETS_FILE, "w") as fw: + fcntl.flock(fw, fcntl.LOCK_EX) + fw.write(" ".join(self.cached_bucket_set)) + fcntl.flock(fw, fcntl.LOCK_UN) + fw.close()