Skip to content

Commit

Permalink
Make a class to hold transform metadata, use the ID for disabling
Browse files Browse the repository at this point in the history
  • Loading branch information
TeamSpen210 committed Aug 13, 2024
1 parent ee546d1 commit f148eb3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
62 changes: 40 additions & 22 deletions src/hammeraddons/bsp_transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Transformations that can be applied to the BSP file."""
from typing import Awaitable, Callable, Dict, FrozenSet, List, Mapping, Optional, Tuple, Union
from typing import (
Awaitable, Callable, Container, Dict, FrozenSet, List, Mapping, Optional, Protocol, Tuple,
TypeVar, Union,
)
from typing_extensions import TypeAlias
import warnings
from pathlib import Path
import inspect

import attrs

from srctools import FGD, VMF, EmptyMapping, Entity, FileSystem, Keyvalues, Output
from srctools.bsp import BSP
Expand Down Expand Up @@ -137,28 +141,42 @@ def add_code(self, ent: Entity, code: str) -> None:
self._ent_code[ent] = '{}\n{}'.format(existing, code)


TransFunc = Callable[[Context], Awaitable[None]]
TransFuncOrSync = Callable[[Context], Optional[Awaitable[None]]]
TRANSFORMS: Dict[str, TransFunc] = {}
TRANSFORM_PRIORITY: Dict[str, int] = {}
TRANSFORM_ID: Dict[str, str] = {}
TransFunc: TypeAlias = Callable[[Context], Awaitable[None]]
TransFuncOrSync: TypeAlias = Callable[[Context], Optional[Awaitable[None]]]
TransFuncT = TypeVar('TransFuncT', bound=Callable[[Context], Optional[Awaitable[None]]])


@attrs.frozen(eq=False)
class Transform:
"""A transform function."""
func: TransFunc
name: str
priority: int


TRANSFORMS: Dict[str, Transform] = {}


def trans(name: str, *, priority: int=0) -> Callable[[TransFuncOrSync], TransFunc]:
class TransProto(Protocol):
def __call__(self, func: TransFuncT) -> TransFuncT: ...


def trans(name: str, *, priority: int=0) -> TransProto:
"""Add a transformation procedure to the list."""
def deco(func: TransFuncOrSync) -> TransFunc:
name = name.strip()
if ',' in name:
raise ValueError('Commas are not allowed in names!')

def deco(func: TransFuncT) -> TransFuncT:
"""Stores the transformation."""
TRANSFORM_PRIORITY[name] = priority
TRANSFORM_ID[name] = func.__name__
if inspect.iscoroutinefunction(func):
TRANSFORMS[name] = func
return func
TRANSFORMS[name.casefold()] = Transform(func, name, priority)
else:
async def async_wrapper(ctx: Context) -> None:
"""Just freeze all other tasks to run this."""
func(ctx)
TRANSFORMS[name] = async_wrapper
return async_wrapper
TRANSFORMS[name.casefold()] = Transform(async_wrapper, name, priority)
return func
return deco


Expand All @@ -182,17 +200,17 @@ async def run_transformations(
modelcompile_dump=modelcompile_dump,
)

enabled_transforms = list(filter(lambda it: TRANSFORM_ID[it[0]] not in disabled, sorted(TRANSFORMS.items(), key=lambda tup: TRANSFORM_PRIORITY[tup[0]])))
LOGGER.info( 'Enabled transforms: {}', ', '.join([TRANSFORM_ID[it] for it, _ in enabled_transforms]))

for func_name, func in enabled_transforms:
LOGGER.info('Running "{}"...', func_name)
for transform in sorted(TRANSFORMS.values(), key=lambda trans: trans.priority):
if transform.name.casefold() in disabled:
LOGGER.info('Skipping "{}"', transform.name)
continue
LOGGER.info('Running "{}"...', transform.name)
try:
context.config = config[func_name.casefold()]
context.config = config[transform.name.casefold()]
except KeyError:
context.config = Keyvalues(func_name, [])
context.config = Keyvalues(transform.name, [])
LOGGER.debug('Config: {!r}', context.config)
await func(context)
await transform.func(context)

if context._ent_code:
LOGGER.info('Injecting VScript code...')
Expand Down
6 changes: 3 additions & 3 deletions tests/test_transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from hammeraddons.bsp_transform import Context, TransFunc, TRANSFORMS
from srctools.bsp import BSP
from srctools.filesys import VirtualFileSystem
from srctools.filesys import FileSystemChain
from srctools.game import Game
from srctools.packlist import PackList

Expand All @@ -17,7 +17,7 @@ def blank_ctx(shared_datadir: Path) -> Context:
"""Build a blank context."""
bsp = BSP(shared_datadir / 'blank.bsp')
game = Game(shared_datadir)
fsys = VirtualFileSystem({})
fsys = FileSystemChain()
return Context(
fsys,
bsp.ents,
Expand All @@ -35,6 +35,6 @@ def get_transform_func(module_name: str, transform: str) -> Callable[[Context],
sys.path.append(folder)
try:
importlib.import_module(module_name)
return TRANSFORMS[transform]
return TRANSFORMS[transform].func
finally:
sys.path.remove(folder)

0 comments on commit f148eb3

Please sign in to comment.