Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow nested batch_arg_name in BatchSizeFinder/Tuner.scale_batch_size() #20560

Open
ibro45 opened this issue Jan 23, 2025 · 1 comment
Open
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@ibro45
Copy link

ibro45 commented Jan 23, 2025

Description & Motivation

Hi,

Would it be possible to allow dot-notation in, for example, tuner.scale_batch_size(model, batch_arg_name="dataloaders.train.batch_size")?

Perhaps some other features would benefit from this. It should be simple to achieve that through lightning_hasattr and lightning_getattr.

Pitch

No response

Alternatives

No response

Additional context

No response

cc @lantiga @Borda

@ibro45 ibro45 added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Jan 23, 2025
@ibro45
Copy link
Author

ibro45 commented Jan 24, 2025

I added this to my class Model(LightningModule) that allows for that behavior, but I don't think this is ideal.


    def __getattr__(self, name: str) -> Any:
        """
        Override __getattr__ to handle dot-delimited attribute access.

        Args:
            name (str): The attribute name, potentially dot-delimited.

        Returns:
            Any: The value of the nested attribute.

        Raises:
            AttributeError: If the attribute path does not exist.
        """
        if '.' in name:
            try:
                attrs = name.split('.')
                obj = self
                for attr in attrs:
                    obj = getattr(obj, attr)
                return obj
            except AttributeError as e:
                raise AttributeError(f"Attribute path '{name}' not found.") from e
        # If no dot, defer to the default behavior
        return super().__getattr__(name)
    
    def __setattr__(self, name: str, value: Any) -> None:
        """
        Override __setattr__ to handle dot-delimited attribute setting.

        Args:
            name (str): The attribute name, potentially dot-delimited.
            value (Any): The value to set.

        Raises:
            AttributeError: If the attribute path does not exist.
        """
        if '.' in name:
            attrs = name.split('.')
            obj = self
            for attr in attrs[:-1]:
                try:
                    obj = getattr(obj, attr)
                except AttributeError as e:
                    raise AttributeError(f"Attribute path '{'.'.join(attrs[:-1])}' not found.") from e
            setattr(obj, attrs[-1], value)
        else:
            # Use the default behavior for single attributes
            super().__setattr__(name, value)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant