Skip to content

Commit

Permalink
Enable dynamic load for generator classes
Browse files Browse the repository at this point in the history
  • Loading branch information
ademariag committed Aug 27, 2023
1 parent f2ff549 commit 712c3e2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 170 deletions.
17 changes: 4 additions & 13 deletions system/generators/kubernetes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,15 @@

from kapitan.inputs.kadet import inventory

from .argocd import *
from .base import *
from .certmanager import *
from .common import kgenlib
from .gke import *
from .helm import *
from .istio import *
from .networking import *
from .prometheus import *
from .rbac import *
from .storage import *
from .workloads import *

inv = inventory(lazy=True)
# Loads generators dynamically
kgenlib.load_generators(__name__, __file__)


def main(input_params):
generator = kgenlib.BaseGenerator(inventory=inv)
target_inventory = inventory(lazy=True)
generator = kgenlib.BaseGenerator(inventory=target_inventory)
store = generator.generate()
store.process_mutations(input_params.get("mutations", {}))

Expand Down
7 changes: 3 additions & 4 deletions system/generators/terraform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from kapitan.inputs.kadet import inventory

from .common import *
from .github import *
from .google import *
from .terraform import *
from .common import TerraformStore, kgenlib

kgenlib.load_generators(__name__, __file__)

logger = logging.getLogger(__name__)

Expand Down
191 changes: 38 additions & 153 deletions system/lib/kgenlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextvars
import functools
import logging
from enum import Enum
from typing import List
Expand Down Expand Up @@ -27,6 +28,26 @@
target = current_target.get()


@functools.lru_cache
def load_generators(name, path):
from importlib import import_module
from inspect import isclass
from pathlib import Path
from pkgutil import iter_modules

# iterate through the modules in the current package
package_dir = Path(path).resolve().parent
for _, module_name, _ in iter_modules([package_dir]):
# import the module and iterate through its attributes
module = import_module(f"{name}.{module_name}")
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)

if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute


class DeleteContent(Exception):
# Raised when a content should be deleted
pass
Expand All @@ -41,15 +62,15 @@ def patch_config(config: Dict, inventory: Dict, inventory_path: str) -> None:

def register_function(func, params):
target = current_target.get()
logging.debug(
logger.debug(
f"Registering generator {func.__name__} with params {params} for target {target}"
)

my_dict = registered_generators.get()
generator_list = my_dict.get(target, [])
generator_list.append((func, params))

logging.debug(
logger.debug(
f"Currently registered {len(generator_list)} generators for target {target}"
)

Expand Down Expand Up @@ -143,47 +164,22 @@ class ContentType(Enum):


class BaseContent(BaseModel):
"""
Base class for content handling, providing methods for manipulation
and initialization from various sources.
Attributes:
content_type (ContentType): Represents the type of content. Default is ContentType.YAML.
filename (str): Represents the name of the output file. Default is "output".
"""

content_type: ContentType = ContentType.YAML
filename: str = "output"

def body(self):
"""Placeholder method intended to be overridden by subclasses."""
pass

@classmethod
def from_baseobj(cls, baseobj: BaseObj):
"""
Create a BaseContent instance from a BaseObj.
Args:
baseobj (BaseObj): Object to be converted to BaseContent.
Returns:
BaseContent: Initialized with baseobj.
"""
"""Return a BaseContent initialised with baseobj."""
return cls.from_dict(baseobj.root)

@classmethod
def from_yaml(cls, file_path) -> list:
"""
Create a list of BaseContent objects from a YAML file.
Args:
file_path (str): Path to the YAML file.
def from_yaml(cls, file_path) -> List:
"""Returns a list of BaseContent initialised with the content of file_path data."""

Returns:
List[BaseContent]: List of BaseContent objects.
"""
content_list = []
content_list = list()
with open(file_path) as fp:
yaml_objs = yaml.safe_load_all(fp)
for yaml_obj in yaml_objs:
Expand All @@ -194,18 +190,8 @@ def from_yaml(cls, file_path) -> list:

@classmethod
def from_dict(cls, dict_value):
"""
Create a BaseContent instance from a dictionary.
Args:
dict_value (dict): Dictionary for initialization.
"""Return a BaseContent initialised with dict_value."""

Returns:
BaseContent: Initialized with dict_value.
Raises:
CompileError: Error raised when importing the dictionary fails.
"""
if dict_value:
try:
obj = cls()
Expand All @@ -216,41 +202,19 @@ def from_dict(cls, dict_value):
f"error when importing item '{dict_value}' of type {type(dict_value)}: {e}"
)

def parse(self, content: dict):
"""
Parse the given content and set it as the root attribute.
Args:
content (dict): Content to be parsed.
"""
def parse(self, content: Dict):
self.root = content

@staticmethod
def findpath(obj, path):
"""
Fetch value of nested attribute within an object recursively using the path.
Args:
obj (object): Object containing the attribute.
path (str): Dot-separated path to the attribute.
Returns:
any: Value of the attribute.
"""
path_parts = path.split(".")
value = getattr(obj, path_parts[0])
if len(path_parts) == 1:
return value
else:
return BaseContent.findpath(value, ".".join(path_parts[1:]))

def mutate(self, mutations: list):
"""
Apply mutations to the content.
Args:
mutations (list): List of mutations to apply.
"""
def mutate(self, mutations: List):
for action, conditions in mutations.items():
if action == "patch":
for condition in conditions:
Expand All @@ -271,15 +235,6 @@ def mutate(self, mutations: list):
break

def match(self, match_conditions):
"""
Check if content matches the given conditions.
Args:
match_conditions (dict): Conditions to check against.
Returns:
bool: True if matches all conditions, else False.
"""
for key, values in match_conditions.items():
if "*" in values:
return True
Expand All @@ -291,36 +246,14 @@ def match(self, match_conditions):
return True

def patch(self, patch):
"""
Merge content with the given patch.
Args:
patch (dict): Values to patch onto the content.
"""
self.root.merge_update(Dict(patch), box_merge_lists="extend")


class BaseStore(BaseModel):
"""
A class that represents a store of BaseContent objects.
Attributes:
content_list (List[BaseContent]): List of BaseContent objects stored.
"""

content_list: List[BaseContent] = []

@classmethod
def from_yaml_file(cls, file_path):
"""
Create a BaseStore instance from a YAML file.
Args:
file_path (str): Path to the YAML file.
Returns:
BaseStore: A BaseStore populated with BaseContent objects from the YAML file.
"""
store = cls()
with open(file_path) as fp:
yaml_objs = yaml.safe_load_all(fp)
Expand All @@ -330,13 +263,7 @@ def from_yaml_file(cls, file_path):
return store

def add(self, object):
"""
Add an object or a list of objects to the store.
Args:
object (any): The object to add. Can be of type BaseContent, BaseStore, BaseObj, list.
"""
logging.debug(f"Adding {type(object)} to store")
logger.debug(f"Adding {type(object)} to store")
if isinstance(object, BaseContent):
self.content_list.append(object)
elif isinstance(object, BaseStore):
Expand All @@ -356,22 +283,10 @@ def add(self, object):
self.content_list.append(object)

def add_list(self, contents: List[BaseContent]):
"""
Add a list of BaseContent objects to the store.
Args:
contents (List[BaseContent]): List of BaseContent objects to add.
"""
for content in contents:
self.add(content)

def import_from_helm_chart(self, **kwargs):
"""
Import BaseContent objects from a Helm chart.
Args:
**kwargs: Keyword arguments for the HelmChart object.
"""
self.add_list(
[
BaseContent.from_baseobj(resource)
Expand All @@ -380,55 +295,25 @@ def import_from_helm_chart(self, **kwargs):
)

def apply_patch(self, patch: Dict):
"""
Apply a patch to each BaseContent in the store.
Args:
patch (Dict): A dictionary representing the patch to be applied.
"""
for content in self.get_content_list():
content.patch(patch)

def process_mutations(self, mutations: Dict):
"""
Process mutations on each BaseContent in the store.
Args:
mutations (Dict): A dictionary representing the mutations to be processed.
Raises:
CompileError: Error raised when processing mutations fails.
"""
for content in self.get_content_list():
try:
content.mutate(mutations)
except DeleteContent as e:
logging.debug(e)
logger.debug(e)
self.content_list.remove(content)
except:
raise CompileError(f"Error when processing mutations on {content}")

def get_content_list(self):
"""
Retrieve the content list of the store.
Returns:
List[BaseContent]: List of BaseContent objects stored.
"""
return getattr(self, "content_list", [])

def dump(self, output_filename=None, already_processed=False):
"""
Dump the BaseStore contents.
Args:
output_filename (str, optional): Name of the output file.
already_processed (bool, optional): Whether the content has been processed. Defaults to False.
Returns:
any: Object list or dictionary.
"""
logging.debug(f"Dumping {len(self.get_content_list())} items")
"""Return object dict/list."""
logger.debug(f"Dumping {len(self.get_content_list())} items")
if not already_processed:
for content in self.get_content_list():
if output_filename:
Expand Down Expand Up @@ -458,7 +343,7 @@ def __init__(
self.inventory = inventory
self.global_inventory = inventory_global()
self.generator_defaults = findpath(self.inventory, defaults_path)
logging.debug(f"Setting {self.generator_defaults} as generator defaults")
logger.debug(f"Setting {self.generator_defaults} as generator defaults")

if store == None:
self.store = BaseStore()
Expand All @@ -482,7 +367,7 @@ def expand_and_run(self, func, params, inventory=None):
)

if configs:
logging.debug(
logger.debug(
f"Found {len(configs)} configs to generate at {path} for target {target}"
)
for config_id, config in configs.items():
Expand Down Expand Up @@ -511,14 +396,14 @@ def expand_and_run(self, func, params, inventory=None):
"global_inventory": self.global_inventory,
"target": current_target.get(),
}
logging.debug(
logger.debug(
f"Running class {func.__name__} for {config_id} with params {local_params.keys()}"
)
self.store.add(func(**local_params))

def generate(self):
generators = registered_generators.get().get(target, [])
logging.debug(
logger.debug(
f"{len(generators)} classes registered as generators for target {target}"
)
for func, params in generators:
Expand Down

0 comments on commit 712c3e2

Please sign in to comment.