Skip to content

Commit

Permalink
Tag bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Sep 26, 2024
1 parent 79c81d7 commit 3c0aa76
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 38 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ select = [
ignore = [
"S101", # Use of `assert` detected
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"S108", # Probable insecure usage of temporary file or directory
"D203", # no-blank-line-before-class
"D213", # multi-line-summary-second-line
"PLR0913", # Too many arguments in function definition
Expand Down
4 changes: 3 additions & 1 deletion tosfs/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

LS_OPERATION_DEFAULT_MAX_ITEMS = 1000

TOSFS_LOG_FORMAT = "%(asctime)s %(name)s [%(levelname)s] %(filename)s:%(lineno)d %(funcName)s : %(message)s" # noqa: E501

# environment variable names
ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL"
TOSFS_LOG_FORMAT = "%(asctime)s %(name)s [%(levelname)s] %(filename)s:%(lineno)d %(funcName)s : %(message)s" # noqa: E501
ENV_NAME_TOS_BUCKET_TAG_ENABLE = "TOS_BUCKET_TAG_ENABLE"
20 changes: 20 additions & 0 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from fsspec import AbstractFileSystem
from fsspec.spec import AbstractBufferedFile
from fsspec.utils import setup_logging as setup_logger
from tos.auth import CredentialProviderAuth
from tos.exceptions import TosClientError, TosServerError
from tos.models import CommonPrefixInfo
from tos.models2 import (
Expand Down Expand Up @@ -54,6 +55,7 @@
from tosfs.fsspec_utils import glob_translate
from tosfs.mpu import MultipartUploader
from tosfs.retry import retryable_func_executor
from tosfs.tag import BucketTagMgr
from tosfs.utils import find_bucket_key, get_brange

logger = logging.getLogger("tosfs")
Expand Down Expand Up @@ -203,6 +205,10 @@ def __init__(
if version_aware:
raise ValueError("Currently, version_aware is not supported.")

self.tag_enabled = os.environ.get("TOS_TAG_ENABLED", True)
if self.tag_enabled:
self._init_tag_manager()

self.version_aware = version_aware
self.default_block_size = (
default_block_size or FILE_OPERATION_READ_WRITE_BUFFER_SIZE
Expand Down Expand Up @@ -2093,12 +2099,26 @@ def _split_path(self, path: str) -> Tuple[str, str, Optional[str]]:

bucket, keypart = find_bucket_key(path)
key, _, version_id = keypart.partition("?versionId=")

if self.tag_enabled:
self.bucket_tag_mgr.add_bucket_tag(bucket)

return (
bucket,
key,
version_id if self.version_aware and version_id else None,
)

def _init_tag_manager(self) -> None:
auth = self.tos_client.auth
if isinstance(auth, CredentialProviderAuth):
credentials = auth.credentials_provider.get_credentials()
self.bucket_tag_mgr = BucketTagMgr(
credentials.get_ak(), credentials.get_sk(), auth.region
)
else:
raise TosfsError("Currently only support CredentialProviderAuth type")

@staticmethod
def _fill_dir_info(
bucket: str, common_prefix: Optional[CommonPrefixInfo], key: str = ""
Expand Down
218 changes: 181 additions & 37 deletions tosfs/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""The module contains all the business logic for tagging tos buckets ."""

import fcntl
import functools
import json
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any

from volcengine.ApiInfo import ApiInfo
from volcengine.base.Service import Service
Expand All @@ -32,89 +38,227 @@


service_info_map = {
"cn-beijing": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, },
Credentials("", "", "emr", "cn-beijing"), 60 * 5, 60 * 5, "http"),
"cn-guangzhou": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, },
Credentials("", "", "emr", "cn-guangzhou"), 60 * 5, 60 * 5, "http"),
"cn-shanghai": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, },
Credentials("", "", "emr", "cn-shanghai"), 60 * 5, 60 * 5, "http"),
"ap-southeast-1": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, },
Credentials("", "", "emr", "ap-southeast-1"), 60 * 5, 60 * 5, "http"),
"cn-beijing-qa": ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, },
Credentials("", "", "emr_qa", "cn-beijing"), 60 * 5, 60 * 5, "http"),
"cn-beijing": ServiceInfo(
OPEN_API_HOST,
{
ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE,
},
Credentials("", "", "emr", "cn-beijing"),
60 * 5,
60 * 5,
"http",
),
"cn-guangzhou": ServiceInfo(
OPEN_API_HOST,
{
ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE,
},
Credentials("", "", "emr", "cn-guangzhou"),
60 * 5,
60 * 5,
"http",
),
"cn-shanghai": ServiceInfo(
OPEN_API_HOST,
{
ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE,
},
Credentials("", "", "emr", "cn-shanghai"),
60 * 5,
60 * 5,
"http",
),
"ap-southeast-1": ServiceInfo(
OPEN_API_HOST,
{
ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE,
},
Credentials("", "", "emr", "ap-southeast-1"),
60 * 5,
60 * 5,
"http",
),
"cn-beijing-qa": ServiceInfo(
OPEN_API_HOST,
{
ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE,
},
Credentials("", "", "emr_qa", "cn-beijing"),
60 * 5,
60 * 5,
"http",
),
}

api_info = {
PUT_TAG_ACTION_NAME: ApiInfo("POST", "/", {
"Action": PUT_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, {}, {}),
GET_TAG_ACTION_NAME: ApiInfo("GET", "/", {
"Action": GET_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, {}, {}),
DEL_TAG_ACTION_NAME: ApiInfo("POST", "/", {
"Action": DEL_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION}, {}, {}),
PUT_TAG_ACTION_NAME: ApiInfo(
"POST",
"/",
{"Action": PUT_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION},
{},
{},
),
GET_TAG_ACTION_NAME: ApiInfo(
"GET",
"/",
{"Action": GET_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION},
{},
{},
),
DEL_TAG_ACTION_NAME: ApiInfo(
"POST",
"/",
{"Action": DEL_TAG_ACTION_NAME, "Version": EMR_OPEN_API_VERSION},
{},
{},
),
}


class BucketTagAction(Service):
"""BucketTagAction is a class to manage the tag of bucket."""

_instance_lock = threading.Lock()

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
"""Singleton."""
if not hasattr(BucketTagAction, "_instance"):
with BucketTagAction._instance_lock:
if not hasattr(BucketTagAction, "_instance"):
BucketTagAction._instance = object.__new__(cls)
return BucketTagAction._instance

def __init__(self, access_key = None, secret_key = None, region = "cn-beijing"):
if region is None:
region = "cn-beijing"
def __init__(self, key: str, secret: str, region: str = "cn-beijing") -> None:
"""Init BucketTagAction."""
super().__init__(self.get_service_info(region), self.get_api_info())
if access_key is not None and secret_key is not None:
self.set_ak(access_key)
self.set_sk(secret_key)
self.set_ak(key)
self.set_sk(secret)

@staticmethod
def get_api_info():
def get_api_info() -> dict:
"""Get api info."""
return api_info

@staticmethod
def get_service_info(region):
def get_service_info(region: str) -> ServiceInfo:
"""Get service info."""
service_info = service_info_map.get(region)
if service_info:
return service_info
elif "VOLC_REGION" in os.environ:
return ServiceInfo(OPEN_API_HOST, {ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE, },
Credentials("", "", "emr", region), 60 * 5, 60 * 5, "http")
return ServiceInfo(
OPEN_API_HOST,
{
ACCEPT_HEADER_KEY: ACCEPT_HEADER_JSON_VALUE,
},
Credentials("", "", "emr", region),
60 * 5,
60 * 5,
"http",
)
else:
raise Exception("do not support region %s" % region)

def put_bucket_tag(self, bucket):
params = {"Bucket": bucket,}
def put_bucket_tag(self, bucket: str) -> tuple[str, bool]:
"""Put tag for bucket."""
params = {
"Bucket": bucket,
}

try:
res = self.json(PUT_TAG_ACTION_NAME, params, json.dumps(""))
res_json = json.loads(res)
logging.debug("Put tag for bucket %s is success. The result of put_Bucket_tag is %s.", bucket, res_json)
logging.debug("Put tag for bucket %s successfully: %s .", bucket, res_json)
return (bucket, True)
except Exception as e:
logging.error("Put tag for bucket %s is failed: %s", bucket, e)
logging.debug("Put tag for bucket %s failed: %s .", bucket, e)
return (bucket, False)

def get_bucket_tag(self, bucket):
params = {"Bucket": bucket,}
def get_bucket_tag(self, bucket: str) -> bool:
"""Get tag for bucket."""
params = {
"Bucket": bucket,
}
try:
res = self.get(GET_TAG_ACTION_NAME, params)
res_json = json.loads(res)
logging.debug("The result of get_Bucket_tag is %s", res_json)
return True
except Exception as e:
logging.error("Get tag for %s is failed: %s", bucket, e)
logging.debug("Get tag for %s is failed: %s", bucket, e)
return False

def del_bucket_tag(self, bucket):
params = {"Bucket": bucket,}
def del_bucket_tag(self, bucket: str) -> None:
"""Delete tag for bucket."""
params = {
"Bucket": bucket,
}
try:
res = self.json(DEL_TAG_ACTION_NAME, params, json.dumps(""))
res_json = json.loads(res)
logging.debug("The result of del_Bucket_tag is %s", res_json)
except Exception as e:
logging.error("Delete tag for %s is failed: %s", bucket, e)
logging.debug("Delete tag for %s is failed: %s", bucket, e)


THREAD_POOL_SIZE = 2
TAGGED_BUCKETS_FILE = "/tmp/.emr_tagged_buckets"


def singleton(cls: Any) -> Any:
"""Singleton decorator."""
_instances = {}

@functools.wraps(cls)
def get_instance(*args: Any, **kwargs: Any) -> Any:
if cls not in _instances:
_instances[cls] = cls(*args, **kwargs)
return _instances[cls]

return get_instance


@singleton
class BucketTagMgr:
"""BucketTagMgr is a class to manage the tag of bucket."""

def __init__(self, key: str, secret: str, region: str):
"""Init BucketTagMgr."""
self.executor = ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE)
self.cached_bucket_set: set = set()
self.key = key
self.secret = secret
self.region = region

def add_bucket_tag(self, bucket: str) -> None:
"""Add tag for bucket."""
collect_bucket_set = set()
collect_bucket_set.add(bucket)

if (
len(collect_bucket_set) == 0
or len(collect_bucket_set - self.cached_bucket_set) == 0
):
return

tagged_bucket_from_file_set = set()
if os.path.exists(TAGGED_BUCKETS_FILE):
with open(TAGGED_BUCKETS_FILE, "r") as file:
tagged_bucket_from_file_set = set(file.read().split(" "))

self.cached_bucket_set = self.cached_bucket_set | tagged_bucket_from_file_set
need_tag_buckets = collect_bucket_set - self.cached_bucket_set

bucket_tag_service = BucketTagAction(self.key, self.secret, self.region)

for res in self.executor.map(
bucket_tag_service.put_bucket_tag, need_tag_buckets
):
if res[1] is True:
self.cached_bucket_set.add(res[0])

with open(TAGGED_BUCKETS_FILE, "w") as fw:
fcntl.flock(fw, fcntl.LOCK_EX)
fw.write(" ".join(self.cached_bucket_set))
fcntl.flock(fw, fcntl.LOCK_UN)
fw.close()

0 comments on commit 3c0aa76

Please sign in to comment.