Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PaliGemma LoRA #464

Merged
merged 12 commits into from
Jun 26, 2024
Merged

Add PaliGemma LoRA #464

merged 12 commits into from
Jun 26, 2024

Conversation

probicheaux
Copy link
Collaborator

Description

Add in class that can perform inference using LoRAs

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

How has this change been tested, please provide a testcase or example of how you tested the change?

Locally

Any specific deployment considerations

n/a

Docs

  • Docs updated? What were the changes:

@probicheaux
Copy link
Collaborator Author

  1. PaliGemma needs transformers>=4.41.1, but requirements.cogvlm.txt and reqiurements.groundingdino.txt pin transformers low. Can we avoid that now?
  2. This change doesn't work with get_model because we don't know a priori if the PaliGemma model is a LoRA or not. How should I handle that? Put something in the model bucket and check for that? Right now, there's a file adapter_config.json that exists if and only if the model is a LoRA. Should I use that file to check which class to load in get_model?

@@ -1,4 +1,4 @@
transformers>=4.36.0,<4.38.0
transformers>=4.36.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember pinning the version due to: #355
Not sure if never versions of transformers solve the problem, but if yes, probably lower-bound should be bumped


self.processor = AutoProcessor.from_pretrained(self.cache_dir)


if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that not part of the change - but is that needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, is what needed? the main? no, it was just for testing, we can remove

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, main

@@ -150,6 +161,36 @@ def download_model_artefacts_from_s3(self) -> None:
raise NotImplementedError()


class LoRAPaliGemma(PaliGemma):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you post a docs example + description on LoRA model?
I guess this class is needed to be able to load LoRA-fine-tuned models from hf hub as people are posting those, but what about our platform? PaliGemma is RoboflowInferenceModel, but it seems that we don't load weights from our hosting - which may be the indication that this is kind of "core" model?
Also - is that always required to have HF token - or maybe we could rely on their auth?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can write up something more detailed but the short answer is this

  1. LoRA is a technique to train a small "diff" from some base model A
  2. This PR assumes that users will deploy the LoRA (the diff) to Roboflow, but in order to use it, they will need to download the base model A from huggingface
  3. This doesn't reduce the amount of data transferred on the first LoRA loaded, but will significantly reduce data transfer on subsequent LoRA loads -- from 6GB for a fully finetuned model, to only 28MB for a new LoRA
  4. This will also reduce our storage needs because we don't need to host 6GB for each fine tune, just 28MB for each LoRA

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to be clear, we are loading weights from our hosting

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also - is that always required to have HF token - or maybe we could rely on their auth?

I'm not sure I understand this question

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from previous answers I see that HF tokens will be required, at least sometimes

@capjamesg
Copy link
Contributor

I have tested this implementation and successfully trained a model with LoRA.

@probicheaux probicheaux self-assigned this Jun 13, 2024
@PawelPeczek-Roboflow
Copy link
Collaborator

Fine, as long as we resolve this #464 (comment) we are free to merge, I believe that would only take testing CogVLM and probably setting transformers>=4.41.1

@PawelPeczek-Roboflow
Copy link
Collaborator

Regarding question 2 form here
get_model(...) calls internally get_model_type(...). Would be best if we could have the information responded from API at that level.
If not feasible, relying on adapter_config.json is ok, but that probably would take having a single class for LoRA and non-LoRA versions?

@PawelPeczek-Roboflow
Copy link
Collaborator

@probicheaux - how we plan to move on with this?

@probicheaux
Copy link
Collaborator Author

@PawelPeczek-Roboflow sorry, I've been super busy. Just fixed the get_model thing by pushing a new model_conversion param that adds peft to lora models. I also tested cogvlm in the new docker container (verifying transformers==4.41.2 and it works fine.

@PawelPeczek-Roboflow PawelPeczek-Roboflow merged commit 4a5e258 into main Jun 26, 2024
50 checks passed
@PawelPeczek-Roboflow PawelPeczek-Roboflow deleted the paligemma-lora branch June 26, 2024 06:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants