From 8dd6405efa967889fb1fb9ccb73ad2684a7019a0 Mon Sep 17 00:00:00 2001 From: Agus Date: Thu, 28 Nov 2024 13:00:01 +0100 Subject: [PATCH 01/30] Fix `StepOutput` type (#1072) --- src/distilabel/steps/typing.py | 20 ++------------------ tests/unit/steps/argilla/test_preference.py | 2 ++ 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/src/distilabel/steps/typing.py b/src/distilabel/steps/typing.py index 4f6f53d5d9..9a3e5bb586 100644 --- a/src/distilabel/steps/typing.py +++ b/src/distilabel/steps/typing.py @@ -14,24 +14,8 @@ from typing import Any, Dict, Iterator, List, Tuple, Union -StepData = List[Dict[str, Any]] -StepStatistics = Dict[str, Any] -StepOutput = Iterator[Dict[str, Union[StepData, StepStatistics]]] -r"""`StepOutput` is an alias of the typing. -A step output is a dict of the form: -{ - "outputs": [ - {"col1": "val1", "col2": "val2"}, - {"col1": "val1", "col2": "val2"}, - {"col1": "val1", "col2": "val2"}, - ], - "statistics": { - "llm": {}, - "time": 12341234, - ... - } -} -""" +StepOutput = Iterator[List[Dict[str, Any]]] + GeneratorStepOutput = Iterator[Tuple[List[Dict[str, Any]], bool]] """`GeneratorStepOutput` is an alias of the typing `Iterator[Tuple[List[Dict[str, Any]], bool]]`""" diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py index ab63ee5419..ec97dc5f71 100644 --- a/tests/unit/steps/argilla/test_preference.py +++ b/tests/unit/steps/argilla/test_preference.py @@ -85,6 +85,8 @@ def test_process(self, mock_dataset) -> None: step.load() step._instruction = "instruction" step._generations = "generations" + step._ratings = "ratings" + step._rationales = "rationales" step._dataset = mock_dataset # type: ignore step._dataset.records.log = lambda x: x # type: ignore From fa13ae1d2b1a0fec87ffe2718893de786f36d639 Mon Sep 17 00:00:00 2001 From: Sara Han <127759186+sdiazlor@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:52:05 +0100 Subject: [PATCH 02/30] docs: update issue templates (#1074) --- .../1-add_documentation_report.yml | 26 +++++++ .github/ISSUE_TEMPLATE/2-bug_python.yml | 70 +++++++++++++++++++ .github/ISSUE_TEMPLATE/3-feature_request.yml | 44 ++++++++++++ ...-template.md => 4-blank-issue-template.md} | 0 .github/ISSUE_TEMPLATE/config.yml | 5 ++ .../\360\237\206\225-feature-request.md" | 20 ------ .../\360\237\220\233-bug-report.md" | 30 -------- .../\360\237\223\232-documentation-update.md" | 16 ----- 8 files changed, 145 insertions(+), 66 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/1-add_documentation_report.yml create mode 100644 .github/ISSUE_TEMPLATE/2-bug_python.yml create mode 100644 .github/ISSUE_TEMPLATE/3-feature_request.yml rename .github/ISSUE_TEMPLATE/{blank-issue-template.md => 4-blank-issue-template.md} (100%) create mode 100644 .github/ISSUE_TEMPLATE/config.yml delete mode 100644 ".github/ISSUE_TEMPLATE/\360\237\206\225-feature-request.md" delete mode 100644 ".github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" delete mode 100644 ".github/ISSUE_TEMPLATE/\360\237\223\232-documentation-update.md" diff --git a/.github/ISSUE_TEMPLATE/1-add_documentation_report.yml b/.github/ISSUE_TEMPLATE/1-add_documentation_report.yml new file mode 100644 index 0000000000..6ca07d7997 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/1-add_documentation_report.yml @@ -0,0 +1,26 @@ +name: "\U0001F4DA Add a documentation report" +description: "Have you spotted a typo or mistake in our docs?" +title: "[DOCS]" +labels: ["documentation"] +assignees: [] + +body: + - type: markdown + attributes: + value: "Thank you for reporting a documentation mistake! Before you get started, please [search to see](https://github.com/argilla-io/distilabel/issues) if an issue already exists for the bug you encountered." + + - type: textarea + id: doc_report + attributes: + label: "Which page or section is this issue related to?" + description: "Please include the URL and/or source." + validations: + required: false + + - type: textarea + id: doc_review + attributes: + label: "What are you documenting, or what change are you making in the documentation?" + description: "If a documentation needs to be created, please specify its coverage.\n If there's a typo or something needs revisiting, please indicate it and show code/text/screenshots." + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/2-bug_python.yml b/.github/ISSUE_TEMPLATE/2-bug_python.yml new file mode 100644 index 0000000000..b4733ebecd --- /dev/null +++ b/.github/ISSUE_TEMPLATE/2-bug_python.yml @@ -0,0 +1,70 @@ +name: "\U0001FAB2 Bug report" +description: "Report bugs and unexpected behavior." +title: "[BUG]" +labels: ["bug", "ml-internal"] +assignees: [] + +body: + - type: markdown + attributes: + value: "Thank you for reporting a bug! Before you get started, please [search to see](https://github.com/argilla-io/distilabel/issues) if an issue already exists for the bug you encountered." + + - type: textarea + id: bug_description + attributes: + label: "Describe the bug" + description: "A clear and concise description of the bug." + validations: + required: true + + - type: textarea + id: stacktrace + attributes: + label: "To reproduce" + description: "The code to reproduce the behavior." + placeholder: | + ```python + my_python_code + ``` + validations: + required: false + + - type: textarea + id: expected_behavior + attributes: + label: "Expected behavior" + description: "A clear and concise description of what you expected to happen." + validations: + required: false + + - type: textarea + id: screenshots + attributes: + label: "Screenshots" + description: "If applicable, add screenshots to help explain your problem." + validations: + required: false + + - type: textarea + id: environment + attributes: + label: "Environment" + description: "Since version 1.16.0 you can use `python -m argilla info` command to easily get the used versions." + value: | + - Distilabel Version [e.g. 1.0.0]: + - Python Version [e.g. 3.11]: + validations: + required: false + + - type: textarea + id: additional_context + attributes: + label: "Additional context" + description: "Add any other relevant information." + validations: + required: false + + - type: markdown + attributes: + value: | + 📌 Make sure you have provided all the required information in each section so we can support you properly. diff --git a/.github/ISSUE_TEMPLATE/3-feature_request.yml b/.github/ISSUE_TEMPLATE/3-feature_request.yml new file mode 100644 index 0000000000..6572e2a457 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/3-feature_request.yml @@ -0,0 +1,44 @@ +name: "\U0001F195 Feature request" +description: "Share cool new ideas for the project." +title: "[FEATURE]" +labels: ["enhancement", "ml-internal"] +assignees: [] + + +body: + - type: markdown + attributes: + value: "Thank you for sharing your feature request! Please fill out the sections below." + + - type: textarea + id: feature_request + attributes: + label: "Is your feature request related to a problem? Please describe." + description: "A clear and concise description of what the problem is." + placeholder: "I'm always frustrated when..." + validations: + required: false + + - type: textarea + id: feature_description + attributes: + label: "Describe the solution you'd like" + description: "A clear and concise description of what you want to happen." + validations: + required: false + + - type: textarea + id: feature_alternatives + attributes: + label: "Describe alternatives you've considered" + description: "A clear and concise description of any alternative solutions or features you've considered." + validations: + required: false + + - type: textarea + id: additional_context + attributes: + label: "Additional context" + description: "Add any other context or screenshots about the feature request here." + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/blank-issue-template.md b/.github/ISSUE_TEMPLATE/4-blank-issue-template.md similarity index 100% rename from .github/ISSUE_TEMPLATE/blank-issue-template.md rename to .github/ISSUE_TEMPLATE/4-blank-issue-template.md diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..a9944106c4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: 🗯 Community Discussions + url: http://hf.co/join/discord + about: Our Discord Community loves to discuss distilabel and NLP topics diff --git "a/.github/ISSUE_TEMPLATE/\360\237\206\225-feature-request.md" "b/.github/ISSUE_TEMPLATE/\360\237\206\225-feature-request.md" deleted file mode 100644 index 2bd974da3a..0000000000 --- "a/.github/ISSUE_TEMPLATE/\360\237\206\225-feature-request.md" +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: "\U0001F195 Feature request" -about: Suggest an idea for this project. -title: "[FEATURE]" -labels: enhancement, ml-internal -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git "a/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" "b/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" deleted file mode 100644 index ae2c21cf68..0000000000 --- "a/.github/ISSUE_TEMPLATE/\360\237\220\233-bug-report.md" +++ /dev/null @@ -1,30 +0,0 @@ ---- -name: "\U0001F41B Bug report" -about: Create a report to help us improve. -title: "[BUG]" -labels: bug, ml-internal -assignees: '' - ---- - -**Describe the bug** -A clear and concise description of what the bug is. - -**To Reproduce** -Code to reproduce -```python - -``` - -**Expected behaviour** -A clear and concise description of what you expected to happen. - -**Screenshots** -If applicable, add screenshots to help explain your problem. - -**Desktop (please complete the following information):** - - Package version: - - Python version: - -**Additional context** -Add any other context about the problem here. diff --git "a/.github/ISSUE_TEMPLATE/\360\237\223\232-documentation-update.md" "b/.github/ISSUE_TEMPLATE/\360\237\223\232-documentation-update.md" deleted file mode 100644 index 01e8234da5..0000000000 --- "a/.github/ISSUE_TEMPLATE/\360\237\223\232-documentation-update.md" +++ /dev/null @@ -1,16 +0,0 @@ ---- -name: "\U0001F4D9 Documentation update" -about: Create a piece of documentation or update one that needs some care. -title: "[DOCS]" -labels: documentation -assignees: "" ---- - -## Which page or section is this issue related to? - - - -## What are you documenting, or what change are you making in the documentation? - - - From f8e41cdd92c25b3d16134ab4fec264d4abdb7d02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Tue, 3 Dec 2024 09:18:36 +0100 Subject: [PATCH 03/30] Update `unload` method from `vLLM` to properly free resources (#1077) --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 +- src/distilabel/models/llms/vllm.py | 21 +++++++++++++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6415f9e1ba..ec4e222bbc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - --fuzzy-match-generates-todo - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.2 + rev: v0.8.1 hooks: - id: ruff args: [--fix] diff --git a/pyproject.toml b/pyproject.toml index bf9550c9f0..e5bcc8399a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ distilabel = "distilabel.cli.app:app" "distilabel/components-gallery" = "distilabel.utils.mkdocs.components_gallery:ComponentsGalleryPlugin" [project.optional-dependencies] -dev = ["ruff == 0.6.2", "pre-commit >= 3.5.0"] +dev = ["ruff == 0.8.1", "pre-commit >= 3.5.0"] docs = [ "mkdocs-material >=9.5.17", "mkdocstrings[python] >= 0.24.0", diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index dd83d8489e..eed9fa012b 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import gc import json from functools import cached_property from typing import ( @@ -214,11 +216,30 @@ def load(self) -> None: def unload(self) -> None: """Unloads the `vLLM` model.""" + self._cleanup_vllm_model() self._model = None # type: ignore self._tokenizer = None # type: ignore CudaDevicePlacementMixin.unload(self) super().unload() + def _cleanup_vllm_model(self) -> None: + import torch + from vllm.distributed.parallel_state import ( + destroy_distributed_environment, + destroy_model_parallel, + ) + + destroy_model_parallel() + destroy_distributed_environment() + del self._model.llm_engine.model_executor + del self._model + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" From 6bb61d1685aa73a2430d89e14c7bbefcd70cef0e Mon Sep 17 00:00:00 2001 From: Agus Date: Wed, 4 Dec 2024 11:52:36 +0100 Subject: [PATCH 04/30] Add tasks to replicate Math-shepherd (#1052) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez --- docs/assets/tutorials-assets/math-sheperd.png | Bin 0 -> 98081 bytes docs/sections/pipeline_samples/index.md | 8 + .../pipeline_samples/papers/math_shepherd.md | 299 +++++++++ examples/pipe_math_shepherd.py | 74 +++ mkdocs.yml | 1 + src/distilabel/models/llms/vllm.py | 2 +- src/distilabel/steps/columns/expand.py | 161 ++++- src/distilabel/steps/tasks/__init__.py | 6 + .../steps/tasks/math_shepherd/__init__.py | 14 + .../steps/tasks/math_shepherd/completer.py | 613 ++++++++++++++++++ .../steps/tasks/math_shepherd/generator.py | 372 +++++++++++ .../steps/tasks/math_shepherd/utils.py | 318 +++++++++ .../utils/mkdocs/components_gallery.py | 2 + src/distilabel/utils/serialization.py | 2 +- tests/unit/steps/argilla/test_preference.py | 1 + tests/unit/steps/columns/test_expand.py | 252 ++++++- .../steps/tasks/math_shepherd/__init__.py | 14 + .../tasks/math_shepherd/test_completer.py | 476 ++++++++++++++ .../tasks/math_shepherd/test_generator.py | 252 +++++++ .../steps/tasks/math_shepherd/test_utils.py | 231 +++++++ 20 files changed, 3085 insertions(+), 13 deletions(-) create mode 100644 docs/assets/tutorials-assets/math-sheperd.png create mode 100644 docs/sections/pipeline_samples/papers/math_shepherd.md create mode 100644 examples/pipe_math_shepherd.py create mode 100644 src/distilabel/steps/tasks/math_shepherd/__init__.py create mode 100644 src/distilabel/steps/tasks/math_shepherd/completer.py create mode 100644 src/distilabel/steps/tasks/math_shepherd/generator.py create mode 100644 src/distilabel/steps/tasks/math_shepherd/utils.py create mode 100644 tests/unit/steps/tasks/math_shepherd/__init__.py create mode 100644 tests/unit/steps/tasks/math_shepherd/test_completer.py create mode 100644 tests/unit/steps/tasks/math_shepherd/test_generator.py create mode 100644 tests/unit/steps/tasks/math_shepherd/test_utils.py diff --git a/docs/assets/tutorials-assets/math-sheperd.png b/docs/assets/tutorials-assets/math-sheperd.png new file mode 100644 index 0000000000000000000000000000000000000000..55575c3d305ff7869e0f2ac5df904814280c97c7 GIT binary patch literal 98081 zcmZU)1C(XYt~lJbZQHiZX=B>9J#E{bwr$U}ZQFes)3*Kf{O-N)z5D;N)?TM}s*+Tt zs&Z;qk_g3blJGD%FhD>+@X}IZ%0NIMnm|Cneozo!m=wy^>8}?JOHolpX;D!kMJIbR zOB+)lAgTn{_1`pP})xU4AsEX8L z6`@LiG}!Ub6;}Lc*lB-MzXeV_dQ37}a;{X888|C{u{<9Mha>#Fe^f0}bdftDbV;Jv zV;~xgSoBGQ79b-4G@)Hb7*zHj2F4)hzvQUVOL|MJ3vIJON%gTNKv6<0M-&T#ZI zQ=UNQ{+wc*anti(g$>LUYE1Q~9W$`to1vIGqHPrtHRf6NQ%=cfXk<@1aYzOQz07&! zfC#P$;CSp%pS)d5HvA-2tn~$1HZlsM{-cK{WF~|5FC```b`hMO0e)>sQ6t$<)-&*}~oh>)Y9z0(~TT>T9A`e>|J7*pbev*G$@O+j3 zq8Uku{%PW3%}=5!uSg_n?_^5E!NA18L?Qq~L`1~rWMamnEGF@9@~=035(^g>2OdU7 zcXxLNcUA^_Cv!$-Zf34-}?!ygdaXfM;X`%myoiO6OyM8L@V6guOJix$!W!sSc8|%vi66k*) z1dKic4fS6DlIj5eUkIo-gD(cBkf1d2s7U_>U~~YS|3V-=^+G{`EFnUu5>fuog7{m= zyCEeuAZ|`fO7bI<4oLO?7XT6G6%>G&njUx?6DLSHJl~jZbU5NmOGiB0QL3VSagHa+ zD=v=dT6#AyH3d(QQZ>*N6C{}_D=Z{ks`(D?^%`ctdH=|yY%mo90rC%g5`zsGG0Vw~ z87!1Br*m6lsH&+sf6t<#6j+rOyMOm5r<0z-m_)=+loYWL&nELiY7;1^w=`r$W#vE=c@{&|XP877 zSIm-<5)pZNb4gk8=x{QlqB%KHj&-JpwJ3Bo8lfDqESVipbDwJ4#&L&L}lmvw4)c8R^3A98e*B?M^Z>Ghy(-Uf;J*4$*r(!n44zj2`pZ77Mw>{9HmoLfp7h zD<+{AD75Z##D0we;_8yH9-+#<5PaFIcGUTTEz&3heYbnvC>V=e*t*+xD7be(f@o(I z&Re774INjtmsqPB?js!!yQH8MYf0G=uMEW#Vwv%Xb*$yU7lcg83CQ#6oyy-I$CuN( zCpe%T+?yEz*Zg3f*BxlgKD)8H&*i!>?_O6Wq>*LI?lH%PTd0 zF!|iX10=Nvpy(mlgw^Ev*;+1uDrLlPB}LF6+g&9c+4Fbi0`|w{qjeABde=)0RdwHW z!0Al;t+3jC5nep-28x`&voU9}jB(h*G~CKdqSNeusND{%uA>>lpIKrgHfc0dzKEX5F*W=vDrr_y0UhF=-q13` z=SoL_TkqQd|FUht!wYeT;qh7>WlASIH)CG<{6 zdyg##J2x%X=y~z$%)*C-evv((UTUH_ndPs<&PLrPL6dX-J5p#I`^Yw(dCDU#Grjd0 zJUUdN|E;8+tMQB|XNe$W@?8d0W^mTeo8KnO3wX+hvOtfbS7nSd4JY_@SiiqLnH!>s z9#Oyv!Lm(6kAH}}6bX&hKGs)0I)1j$CA=J%fy*mq!o|vTrbWYfVPO&k#PW$?^D5Nv zwc)D!4$YTL6VbLLzJVpp%WAB^y#HOYrOh zydW@y2Cv3Txb+wL@8$SV45by$X7vQ8n)>i+pp}Xc?N)unh&1rztF7{yT6mfk z%5x=S9hPgs4web?mb&n@a0X1eFr$n6I@6mHFd4gr73*n^Rx26COwcpR^)HPv0hWv4 z@OXU4jMLTxuU6Gs`RxTfgE;wd7 z2$iN=7A*AyT~$Bhh)YEb6uHqKTL^k94oa_8H6VrEG?;C7Ky03M&}0Y+4EHAZxu)6V zOq^VJe4$?vb9)00pE>2P9Og70XG6lS8@y3nSX_Ds++gb&g>b7wIYYi98@pu9z}uSi zp<5$;ARz9Z^zgkvw=nR6cqL6?Y0Tlvs&6LcWUK0Lc;H0w(2tv1sEf@yIH(J6zg(}h z-y+o1Y=o_gPtKGj<}|^-(=J!XWEf7hT~zi^jgs$d*59_hJ~ z^Bq`Hx@Gb{|FQ9g2iL9^#e8jrHeE=z0*pG8C9ZRdaFqWEL&{U#Gioml^}`K+p3=TA zkbT%;Zk5nooxZPVJfX-Ry1yS^o5}MSzO7{DITdVg;7*>aP@OQy+Iy%eR@RYA>>Do? z^za`3XMY`cd%Z;?eU4yv#Pz0`!DyzKRYln)orEyCC|h}gKde_4(sy-48JP?!j2%!v z71pw@j{J1VKBlJ=47e>!_{tA@0|#qVt(fkR+H$4me)3bS+dkfRn{hX%BmF9{`K9dH z*53qo)~)>$fO{p0k5eJEp!VA|fWo#6)|xyap>F^DL=d+QYQi*guP34aXS3LP@M)b5 z4}KD2XTTw>+XaWwR>^pqiFMR??Xf7;Y3EKZ$S9`vgq}jl*tZy^n_iA4EGKB+olXAd#PX(z~|UHSIDE7@6#tjXfVa! z1K6hlm~X!SATc`$kU+(K{qX~JTTY#X{KH^H0i2rWOl3ue;aolFmE(5gYB8w8)r$7O zacc0-77V$VX5?cJ)G&e3H~Hi|?ON^^_tO~P=17vERp7J+p4Uhp$EOkN-~^(Zw-sD4 z8KbNH7uZA|)qI%@#1xnq=5h_v1*#f}?-WNYP4F2h%Mh5T5Ze}>Kaxi4_bO7_+zd%u zniw@zO(ITh&&anC62@VmGA@n^VKgw545?v%|6o#pTXy-!N-k;{&e z8&u)K6Xwzdg3@AI=1ql5ul#ldzj#|k^2N04r=G=qMf???cL2S9$Xqs8!T`4`V!BYI zje^uJ-yqCc51BCpgYtCsUs`bmv zM`}KS0xLYbcBeLu0&8&th(|tbXJ8B2LPO_5t*R)FLf5*WNrD?rj~#XWp?Yxlw^EzK z8u%2<>jT+j#hZaYNsJ9syCu~(yt}*zbbW`lQn33{(0)7*Ids1vjyamtlO?!=yXaYG zL-XVY-WXfLsC~pVI%6?C4|6k2A_DY>OPNQ6P3ti@UUJj2Svc%xGPmwf^p-&;Zzjw2 z=*ijEd}6;CCpH4{WwFR{)|Fwb0tsmLy}ysz%z9F<|IgGpCv-#8>H5Z&+*BriwrmF_4bH13%EBZ}ES(z&3F zaV&dZMXBz-{_M(a(*v&|4<@_$+WNY_&iXxmAORB!i`}jxBxuYP5nG!X1>u3d;t-)! z3W0Gr&_YrOl8R)Ox{+o1P|_r$Ca?((vrq^mjUo9u#@`}#yURq#q{WR$__GhTs)B#v?M=o8(>#p#13svT$D*fuGZM{^iimV+=y zd)&q>fJyY#RiI=iHF+je2MNE?;~~%D7|76~qT?^v)dYrfk)$q?$ClLHf>bv!OjM$W zaWMdIB>-Svr%FyqgEl}$)NvKG=jim4h^!?P^hWyFx6$m0`BL+8MRj=LKYe=Mo z343;m%XJfzW|1~O>yce<^%*0-JmRXu1I0@(i8uzD0IiWiEzv)HWkN*Z2A1rNAViFD2=wj8UcOcO)|gq2On`zcDwylMpt3O354!xJSWeTT^;W11>f%%O@x zX4JY1Et;?IfHpakUDWvehKUg~GRS#qM_Q8wrSq)*VdhniUKwx7sFzxwVHFd;e*%9PDYBgSvkT!Io6_85}Jb;97~Tp$(CL6j|K z(eY7yBUZ_wW&lk&NKk2CAC*XA^UvaJl_mI~PpN+J@utvf2j-B5%U~9Huq)x73$r+K z9Ld~3OBF&#+bPUeT^3S+LyXmtNzG9sNnu>p159+Q)q&2dxuY%@ihAQsjgcA1UR^8K zXLQKp4iy}5uvlsURR^YCaei_F_Aw=SSd4P2&zKxUhV5@?M?s1CO)Aui~pXK8P` zOgT;~U#;RrCH6^yk08R^{ca$BD7C-eZXe`OZ|5NDRy&kBj!)5Co}+`lo*}pl`TNBr zvM9maUVWSi9e!QEd#5=>l7GT?OV@5)psr{VLM zQ#%E3T(UOHVrFw@e1h%+z`TfQ2{I>aCj zZpz~_8~%ZiOK*TXU$+f#=6)gJya=bK%^-4m1W~OPKvCJGovWhr?6uPH|3R%_gXM># zkd9hUr=)ME(20b*7t4=}XmGRdteDM%8D(8^T%<~yXDUzb?SY*dnho~Bi2fMkOF70Sk5P4`wn}%bUw&a`+KGFH+R$M zp~XUmV_kGx^u@hbVJ{vc+nMlOppJ5l4lI1)8FSI2lg(AK(BUcTnJdh6nQ-hJVNF+v zdxs#w<{Q}Kp~4jn7tY0_KA?Yi87-F^44B-v~CvnS1{a5DR~frXml6dHAiBLBI;V+0`q zxIKdNc>uH9*123#o#{(LClF>i1Pnc#ukvU4APafsM;GaIa<;r1h7PMx=AR$obIJO8 zm3J`$u*yoWGUCB__a;WISDu!sZMNF5K_RM!mdugE9ehC~yRq7~e#I>tmUC~wW-#G! z0H11;8)|}kGF2vdL3sUq7|16VjTi{npa<67b~t ziB*~dUbHQkaEW)l!ib?4JAL)=ZvBA+>*O2zib!C(OXzaZxQ)sVJJ=n~Z>Y@#mY(<4 zJDHjD+F~I4vuFmg2lp`U2g(oG>1T7s?LBQFGgJ{16S&;zF%eN%JT|In-|co!V!CK= z$4={&?W+_{VeBHD56)JfMpg+vOGeKQQ zO|-m@SSC`)p#UM07JM&T5)W=PT3yS|%}{-aj9J*>Iy%7cAd+dSSZ~SgEdBoSK|Js# z1%6(hGV%nSM*%J1oS$OnI)0ODCZkEm^Y_dc2@Iv+b{NKOQr+GkvIPxHVLaRjDJ3Pe zWs*J2PAQ{)R3$M5TH_%?%SXIJg<-sNX^*x%2?^i)pYP_#NbLl|^SvLM05-yN=+;Zw zqDj6cNH6Qyo(72R(ukKXU%zd2X@=#xKfj$f{;KifNS^~U0JbEhRt-YTet_gWtIvnJOzr(5DB#@xVF70BYW_wekN&C)!j zARNtRkGD9Y$JW^+BYxOc8~o>w->xx@|@qPA8sB3*?_=1rlY}dej<#$VCkrEfm)iHJUmH z^7BfriVqt=q-*j{WGR8TVm?hs9aYX3nc4yl{NR)wa60EJkB+o)ewD5qjZVEtq-}Aw zbeoL-IH0@kP8QZCSM6PpR)^UlT{`tA^FcE^ouGaM*fJueyWv;?efn(;S*Rw}jng=0 zv7ZJx%GUWPp4S1nC6#eyhO-h=E^>M(0bkB_P!&3Gc@_;02 z!fU#;3RQ&(CR(W~R?+V*7&04k$MSatydQ2A?#^gMo+2+kQM$E7$8!CsLXa8Z>zns_ zMLZ{i-09upN?%kNZkNoRZP3-Kss$11Q4ft&VD3u~*+hTALibWE_l%C zFF$A4iA)hb_3dP8PcZ!Lx7KZ5*C*x64gwjc9~PFzh_g@`fI@~gstk*gl{z(@PLL2E zA1+S&xV!0v*tL~WjB23Vl3eBalKK_RJ z7eX-A_e;GE#8dk9lGuAFD@|4~E-p=mk2~B5xorXaXv{LRnxE+$P>q=Odp9p73H|0no0(jM%mq#p_O$T?y|CSQA;Wk83~MTunp?he;hX5#2|Dw5V8sTr^Cu( zVsIZit+^t%Qj~<{ir%fw;}kn2@-j4Pb3nz?q$U!9eF-X@J$RFi6xv>qD&{8P-_A4Q ziS}i2XuP^2HR*;X+ZFP@0RmF1Hx=r!)f|FifBFyZ4lyTTulv0%VB6(-?;M|pdAb($ zg?ST}O0O>#`>3E)@E9G0_IrGs@auIsioSa;K3Sbu%}Gm=Lb?j;c#|Mxw@Lh&p7#FE z@{=#EdR4*O%NWtJL^FeBDFd@E-KFF2L_^0n9FfRXc3xI+7a)I2RN{6J!5h&{u@y~TkWOAcT=v!O@WU*}K z6W+4OI)H|PNmA2^dK1Ia2)k3>2bJJD{CkkfD$R0j@#I*cgGBoBl*I6zM*y3*liXQ6 zbvjFIKD4M$bF2+#c)ZTIhy(enlJZ#x3R5Se&pZ2EOc~p6j_KY8u4Ge}oAuNTWvfelRT)SY5fs-* zs{?|9hV~~_{M&6gns!u{{efRO9S)srQ|fN%Yuptb*XLCAuN&&(j7vYI2@SrwPpIdJ zLz0)V6!*I{Z?~Zs&KqKt8~LsobmNY#o(o-u`Q^`+)0P=gWT-cIC^$-ag%~SzWE_kf zg4BG~h+RdEG^@ahj1*Dk8&<2WB8GywCLOsWCaRja8vM*zvb1RPs?=17lck!#nJcP~ zEqKaCLfP_tR8z6F+QiU7z$vEGTBZRSvG6HKHRU)8p=`|d_>;<=rl}!mfq&$2 z{C669;;3Vnz5=a!GUc|>9Ls3wun7oz%^bQ7hrtrsw>Fm(R@7N$M=Ulo>NTAIASrDfMV#WEFhCr+!zzoBDdzDEet66{#33v*1NXAaZ%518W46t6->|N3bv zohKK}LKtrJ=4k&2qsq99xcO1^q@^8ed{AI&v7l1%t&~0g>MWH1;(1@CXDnuVLif;D z4CFCWRGZnMm)@IJ4}?dRceEd4Uh+L6!WpyBvhe%L4f8od!4*LbE=J+!ETfa-u@e!a zjDK#Xpi|Bow^XYU}=k`G6ZN5Wnt^%jRYwf84<>tPA4K_qty;2 zlRqj4duH|Zo1LF<_4{@FxlQ_OYdB@gMqE>Rzo#>2?6$qe(#}^|;@o-cmwfRuS_q@u zY;h8QOy#H>K}n)nMt{rd3_4^izRr(LDv=SW=hA*3wU7LX|5G}?7*>PuYMsj|=aSb2 zc;!sQ%SqMzm?UKsj4Ow8u12@}H2ivgIf;TE4^`qmR4ao!0h-HHHm%RU>wILigPbOn zF)F7igPB$4Ax+cU>z*mFgs&C9#Q!&uM3!4s5` zP;hIt+w4B_Xuc5sv(mt7Lc`wR8|#8JD@73;m~SnYW+G74c?S}4rMpp52+Xb#`OIA6 z79G+cU~7yuJcAprlYQi({T5M^&o)RwlJ_&Qe24iV>{>%dpYB`+`44OYnLz2RrzR1{ zR&J;Snbv^JLVj^-kbknMU_u4Aga*D_Raq&{d~D zBB&GC@8yEw`F!{5u#V#;pr+ccVE?^D1wrmvx3%gA;oc=C^1LB^7I_lvXd7JwH=~U8 zxLn&e0-A48TmYL*%&*8uzaiALu9D>tE|D#eYLTbVpO2|MQaQ|Zh!$w{XL7bSLaC)q zAwqv^H?kDtgw`jp&=VBSRttprV%p*2!@MYCDaaL4Ylq|fogwmiaB^G$kfVi?%Ac9f zfeyyc7uuUNO;YfKM#jp$HyEP4&Fh@Kiwa%bvN7XGwN{K{r(s^{%PkIjw(<-nPz1x( z{7g~+aWBNXN7tuTh1qsbG`&xXDCPS8^J%y$s4Wx1iUG&I}WR1 zQf_e9+fC`@vD}xNrp<1(mGFB#|itKn3Xtu(T{Rs zhYl%4$DrTg&roP- za|y2+HR`zX&&G2sf?mM>2;%$j)Z5iE&_B3fBX2|lI+Q_*L zMMu9=QlPcJ;sU-R1GbR@atkR&Z1c$IbTjG$F(7M#^0TA%UV^EyJU?*B#w09w+u*;h zVGHa8aicAzVc*7%sE+d&uS%F`jiO^wf9J>4-`!`QJ)%I`bQ6sLxNQEckf!R<>3_Qz zx@iqPT3P~nH2WrzTlm33^QMln8`^ohWPH;=y5hA7U1}GvFt&=!%Zd@d8WH?TI5<2o4kIso$$jwH{?}~CFLC_ zY)?q0KI{E30@}QEKw>$!#B2k{;2n+e%>G549Igfhh`6lq{y{Y5w;ZVQb$CBE>A?f| zQU}t2vm-J(9bGp(VE!RvxNa=gW-`1?97(_hEpvrhc<3IfAeKw(5EO+@Jz1={fd>uS zLBY!K_Esci{$!7lLnd8W|jb^ zUntu^x|yNomObdy{G1H&bdB-LzPx7LAFOXYuX|G!;^{8mPnJ;~?Ky3ho(l{=Q>i|V zVMU;M#HG>pf@DuMj5_cYyKQoe^xRN%(C{A4hR(oGnT(qGpMTfWpn&o5#6={v?3H=V zAp(+=7Fm=+4}CDX+Qj<2RsEX6eM8yTbETGfqdfwjGrRJab1kEySQ)z3j81itJ)`4c zVvVLb9bgKri_0NMhbf`yJ@6^0L2OOG#|GM-aAy5VgR0yzErvCtEd%)mT2f%JGIv(t zo{!FKSixOTgWq4$ZrRvncQ%~SQ=%!jK(xKY=-5zM^3JIR<`+?kKIany;Fe{Jg25IhL?hc*-~}pY2fQHxWwe#o=dn0hI%0 z+$KLpSlfBEKfeDo3mJ)A7+V^SF6VeMw$Lr}@akt=h5YLZ?9b)36NnHhUUD@=N?T2n z7_Gy3X{Cd|XMwM`n3vhm(nDD1m}lFxn^gQZjn#?iDoow3vhqL-T;rv%7of!ILR%Si z!!K!|N{wl=atO?PT{?kLNI{{dYH@w+;Et!`0OrolW3Y(TI3YBL=@;ulRfCfGXDlwN zNua@T6g3`g9W~_<^!`zH$?m46>q8t@g^>0+z9)dc1XIa041)@sBpp zkMtiH2^=c*C7q5EKeAwRJWdO1#6x{cqy~r5nTxRpYWYW=d9i!9QR49_llbaW4%>TP zM{`zOMbR5EIi`#)(bdvjPlA5p`yWp2x5!6Pz7Ep#)+M0Rja&9M%{x?O3F4?^9`7`2 zLn#k-t<+B;7l*kB(|4j+tQVU7R;R^{YoZZqmXM#xf1;cqbVsV)$H1LZ*H8+T!kiqA z86Ha6W6n1%X9EaBOUP>%id$0Dy%PyYTflZ*!{xZX8ZB43S!m9JN$^xqRn#i*hXZ`} zN$@J&`v+)P&e5QMq~TrF$f3ejNqE;;-}I?~Ej11R{R-*c!`A_rDr3hqrn~OhKIW-a z8txbN`Lfn6E@GBS!~)`Q@3tU^tGH$*7Y=v@m~y#LZ@req1|vgf)Rx4qQCvcjLR|&v zJ@tH45ml2cd!*WU*sXE#w4(b%-LAQ8F+)qvRjImSWkUsX0#DUvsh?^O4%mavyF{7p&ectd_0b%=7V6?O zMc;h*rpVcA_TEU@slfGYUVzE#(6*6c^Ayc}Mj2mytd9qj6(gR>xgAppUp%wJJ2Bs-w9d_?f zQzkTxBcrM%>A`GiY&E)}E#$&FOQ;bu(cA*)ISNuFj@Q2_D_k-o=c4TPA1!oK(t~j%!rH=C!L`O5;tp4_$L^oYS^e?Al zj4_^c+q;VQ?J#KZ#Gt5D2-=*u!CotE@@$A98$eFemCYaiRvwRS{s~vhc4}nr10OD^ zKZL#wlfF=&Ud5BB#|_YnrrE8%7|IjXtAke*oyTo~!ses*aRm`yw&ejfl>8?;d52Ss z=(2Sc(M)Tw1H+1}4e*E7*B1#T&^T{4&+$crHB*>&l&3%HT8I&%C03_`ETE+H4->Ea zBpSwnh@-<4HXkuC_3Su-HDc$q4_HOYYy;2LZcXJvhX_u>zx-)25c1=s*CD7m^ICrU zO@aUR;t{-@k;`|?qId}SE_9!S+yHfn;Ybspw@An zJ6dr(@o4Zus>}Ui62iy@9=^fTvv)K^PTB9rlH zPlz5bztA$8leK};`wYNBC<~Lzp;bq=r)-y*xfZZP?(tbpq&hAQKHFKaSO>^N*MI)B z&1k=fntw2Ax>Tc(1VjSRy$6A(&~XYhpglXnyA76W+{bP)(f!X z2<}5W>py#Pxqdi6o6SQCUnA@}JnSgbHOQg%vFNs>z8PZ4Pg&@VfNmps(JQIX1 z1oYOGq*F$C9Ij>;+vR22%RreNwL_ng@WQjQl1yjs=1wdMVyGy3ER^KQV*#;LbCAn5 zzCw9^EZ_7zA^W$Q4M|YF!^6qKJEiun+&okF&s+U*~rM`*kA8*GJ8z)?{GhSuQI z@2XHZc$Lz@R>xD=oMJ06dOYJG758b>RlFP{i6!*cV<@1f>s;QnxY&P2yF>DrLnY9d z&i)p-8ctt@I#mAwUij&~O0x5BUt~>96t5jJ)B+iR%LU5J4l9sTg|pyC+({h` zCOCRM#RU&|UxD}<8BE`w+Yj_3FYKd*UbqTjrJ_x23UQIi#O-_oTmj!qG3YUAv&5g&350=MG>hW=r%sX(w z!&Qr|;YR5SG|D$LNp=*NmzLV`dSo)oz7KTLFco-epZXPuvbC>yHHvNxOG4A{3^1NVbhoBY zO`VHhsF=3tUnWlGzAZLDMZe&6D+Ei2<1o$QLKQ}7Y(n+)8;RQ(lS`7j-c*dqu6ETL zO3Y7Z?LTGkW7R(b#*_QF@c-Gvki%p<`gHO@E}30pUvXqdK{4TDcFked*@#4mR$T=_bsGWA||Nhq5{m zG;ui?PKN0Fv@=};pD1gs?gRd>vNI`iO>0ehcp)n?3e1#g1z!fx!z*jaR(5*w$6=a& zD|h3=xU7qewi4QnP$VFE6~3pxA5`5|Yi~R(syjKyQ8}N0^vz$|WNU565G|aM}mp zr`fQ+2@WYT`Os&`BWC)4G}>g~{oe8A|dD=ng&-N z$!EJrpBYk_Cg|vxggd$tM0I8Yc4{_4XeA_m zDf`?_iMJUEXc15n_AOQcdnPXkOBYV`p%*GOsYo_mx5rG{kHBSVzu?!bZ#=yYn)af3 zB$Ct9{{pJE4KEk`(m^-r zoe)pG$XPt0)^jR8v;S_)!E2xjvq)yZ`AkDmmRkuD-5YJj99#$PGKs-#DNmDSNon5z zx_E1Ms$t6*tj=|Tra3pcQ)su_J~B$Xd~dm$t&-9dTeeD4#|^pOv;%VmUezS~&iE@m zjYa-47;hY0=ohqBvhMAHE0@S}Q8);}I$On=x|%#YBHauT2@9)r1ok zhbLb8E!)XHzPM@TBmfqZ(Rc?4H@S8m5}zJ-er$Viu4K&oSeol^BGO-YvI9#@QU+X2 zgkW&T_Ju@YI-}KSLKaHI{lSE|U^O9*#8?b&p_pVsPy8#X3)$edkCfC{s*s^C7M~ZE zPNzN=rn)!L@%4I7Qm4%Ucy@NSfzI$h2{uVJ}-B`upL>s$yr&^8nt@F z#PrG1zH3mff5%z`g+{NHEEj|U{?M5ScA>bj!NJ)U z{BaqU4z;C)umJmsM%n*i#hI#-2^n8N9 z4FP>2=>k44?!?L#z){c|>bC2Bo-HQ5 zL}_@j@e7sD5uRuVn={z`s^$@eoIX~b_$^Q#Imytb~_Ka|6b|Jj(cKsV09YQu= z=}_;-^QBK1S24>v=w%xJPKWvrQICDB&@UMD*~V=lo0rpC5N9na;M4ulq@$T;OjRm@ zsDF0eUw!-wK5U`x2st2H?Kf$z>R=9-&~Y$Aol>l+&yu+#`CUEyIT>^23JOs|T!ZEQ z-ss>VF;T5PhYvXqhdf!0bJfTS1V{LnCw@;GS=4o+?PfC{`wKW zBT?y=;Pv(QQ}YBC2?zn?0 z$m4R!o9AC0a`IQMAPO+KDS2dcp7IXj%$ITDhTl267KK81gW=Rz78#5miRg-FB)ZZU z^0y4|y1$62K20;eKA99W9f5TF0Jnccyt=sy&e@tq-t~pq`vfkat1)P`+f*FxAWkO? z|3`@*-=WaPAZbhU1%CVCq7KDj23KYQq1=Fj-30iDPXs%EJb@|@Y)wm^UsgK3u8CJ^ zcIhD1bRl--u2p<*G&@>rrmfNAdRBQw4*$EL-|)bEBmGknvp{rpH*@>zh17(01cQa@ zjRxaIC?0cDcK3D_f4dy#itFWROMyj4M^hac+7OpRP5xcuzThB&;yo|upvW)(+SX64 z#63;v{}`Poa}pvb&CiKj3J3bHy>w+SMD#+7lz3;46~P4YSKoABxsB==LcW9fWaJ3Q z03|7@cl1tJADF+Lq66cSCUSUVL!y~l0#Ca`I5Gu?TZ4y6uxS90{!Nja|8f~u>AOLo z0(@@}T`F`j$j4PhL7iEXXBH&~TWI~TB}In+nGLLfw!TvWe} z!ECR0F`dWaaR(}qN}`7jFj}6TpTnb~qPkpcV58Hj4}848&OaZ`6~ngJZAO}or}hpF zQSb}C#_}-PV#U9}JZ;NbUT4DhRB-m}*!K80tY*Fi02*BPYU`QMz(_HUUUzEfkB_qWIxuN5Cx?;~-sf*WHhEFNvVVsDJgp@yK4;h< z>h5e;|V_c|?}8`@Fb(5WDa605@4J(kS#s@GpGF`w5SGWSAP~ za<;(^ zg2Q2hI-JT53Ja5*5~WoDQI~|#c6OrQdB?(PL}bF25mC>~T7T#Go5q*I;C7 zvZo9V+5Z5hBYoeW^~u4Q4}G6H-h;&CnqR>?v;A~4tjoiy*5QIZR3P{k)@YXPNo8l? z{}SUQ@OYd8+ldOBD{1a{+=1_HR77~4Xn@K0%N{nY^dg*r?Y#rq!AL=e>kY^-i@-{g zJ*-7px~aY*gk%@0En5(Jdf`%yVSm$o35Vu8t8eRGULFw>AdLTLt;Kfy^X={J#-vd^ z3S*@urVHsdA^q8Cu1LDBv3W0~EdrIYZ)OGw9+zW6(t1#>t%yo&h@40X5G0B(PoW@m zdfRA-xf8sYt<__GYk7vJg@4zXNnh;!I|M(ohJ+4NQnOiWZ?@?dcBBdN&%s{Ix`P2o zsk>_CG}Y(qy-cC8Hp@Ao~=dG1sDLw~96>fh>JyLVmdf31~HZMu(B8p`k3s;78r#DuiIr~iQB zUZ0=G1^<2HjmJ>=$I$t_={ZTywT(1cc)G#-FiB|O!5d9tsRjZY49Q5HH-#2`dB9p- z$n`m~_$IxB*7E0|oG|3^Iw(Xs)9_ywF+&2c^~dMBW1VMOF|}q(9ya7lgA>46j>YSP zMJ$oTm(F1922|zwa-BwR_*>WbYNOfh>z&yyms%n&H#$u_=AMp2t&(5A68#RMcezV?4N+Ol!S11~L zuvD(dRjp7gF|Kyl`fxfMT3${b`?VH=g!ezS003c!6Qv(iGuW_BQLExz;gw{y-8llI z|Hj+IttMr&C9?aVMQ9oUhAioWUO*u4SZ{2m_8JL6h7#wC5^PjaNe>GK1tZ{~KtQHc zole^A{E?f=WkIp#beTlq=jqF(g4ATWf8un$0nz6EdTjCj{dYLUc@kb?1&qY6dbb5P zfyrwvBVtNvINm(Dm;fWg&ieHPqc#cA(1WJ4GdY(pxvI5Sjym;4^vv%H08>f)zq(HA<_PqM%D@i z+=-^Y5;o_r{o7vjjIOx6ZBUHjX-UOZUyM?OX4yK65y`Ic3wyD564Llby}1Ibet)dXmgLXQ9h7ty9~|Idn3r?MLcwtB~t@pDi8>~wQv2I z;FZr(J~a{i-?zs{N_boC-aKZw>1Y;f-6I-xzacVGUe#+R7D69ZqL+PtWkDq|+IK9UUAF z-k&Z;aTfz~AMWnJKD$mi?AdhiutH{vM`f{gr!$?qawQ(;*$jUh3ibI zKpk;^NDiLP^%$$^P9T=Fhf=kQ`#$qAe98K>?Z0=%`l#nx@A*tQZuG>zi0!dxUof6< zAR78?qW_D_lSRl&EF&lwNU2aH_WnhF5tnkh|3+{Hr&_JSEF&oy**hvOIJEonNIx9E zlA+mP1c{1>IEW56dnrn595dzQwV6t*>p(E4U$LQQ+L3(+qFP2EZvB2VR?d76YaYKfzZy3*n zownIOvCZUhjaI+KwUeb5wmY6W430|AJbF#PcTHctZl~JG^tJ)5_Bg&cVU*TxAD)3t zJz=v&RT|Fw8LsxPjWaGx=61Z_A7aw`NOo=lZ^GENUwd0E#KLBGkEplYWcg^#S23_Y#e!uu%98EMbehTcbdhEEB4f7RebssT z9lKu7XHrURo>W)2pv!|o)Q<=ko1!l-#TADK}?ObGHx^@cTz94E_obM80o2JPm`I+M;{#bXAEn!-u`^H zo3>#tGVU!Yb6(G1$bLHa*D+FPhcMdCh1rf?Z?8lBNa5b3c#pbE9oM3>05Uzh_MXpM z$5^$qM-=8|b{QMhZNfK(ZLvM9Y?zALGflJR*Ul*F{2}(w)@-g0GORNNK7utmF})>Q zeOEbInkX5#SSfbROiJwp|DM{$V+q}QcC~XRl4%PD%*$9$PP)P+a3!C87$+I-kELM5 zrKdl*$$dE+Ya4H5d_(<)(+X>)zl9nRXK`wdB9mtr%`@4&y}f64cXJJ4djQo$ZjJAO zQ9REWfTFq8)C+3$b8<>cNxkL7+kilR-;Q4TkN;!H4}ThtQD=YCzlA-r-R-%9^}g@i zJO7Kt7FRn8$^uAgvlzkbPu_o|{_~Bwu&o#j&l^c@=II|MHzR{_AKmWuiSBfvKHY^D zm8Ie5*Nj1Q>_Ihgh!z`Dz}GKPYuDN0r%jJb?sM1CnQ*Sw3aDgV3{O0!AR0715u%KD z)LM*etTUT@y4eh(EHoWxRyP_IK9vfi6@TiUlD1Y#S(Jmfx6n9l>jCG5Z!O>Y@GfjO zdD3+J`mJVfLM+WbDcqu^MAM`ocGqiX)Vjqh&X7Be@=g?c}HtaYJc>CQK ziKFONqC-9rwZarCrR0$gjaC2o#hE?x2ktQ!;&%HowS=Qbai-nkjQf%JizfwLKL4{` zQtvuc$8L+E>i2unj$p{=Km3ji9HGQtJ zHMKW>;tvt+ZQXlS79>kk0=KT|R+~K+_JV+YQTo9PDy`1cZeta&dBQ8#b3busY`Pzb zrR%{fUti$ z;o^lYuN{M?d*Vs;!lbNLz4UJH+Zpod32q{d7ah>#`QYi}{l)XFyM~}h^yz~Mwb}@K zPoj@jYOP#2r^EN}T-|C9>Ykmg;>tHek|FM&SLW=P>A3i<>d{(v|Fr1w3PIGF=W-}h zwE`kt)2|k9EJQdpb;(H=u4;+o%7nhWNhnwK(aD(9n)RvZbf#?078gek-iww?aokQR zJYy#GnNwR{iQ3VTY*lv~L&YF@n^W7J24=5aDZ7Md~FIb8GY;SKL1Fn5@Pmy|yaF8XVo?jHmbR(;6Bjr{nJio3XjH_TACV?8?lF zHhph!*WKG%H5&9#HWS|Eb%fQJgY9m=?PqiyQOu6aQH*{D1B11mn1C!72+ebGYorb- z=zuE47;SnaqXc>${vCRPR#E*X4tQdfntBO)6QolVFf$_)Lj*hVlYkYTOC?RS*j-yg zGchr-b8#8zbVbq?p&=(HzcT?-NAVV)ba7uM(Dhig>^m*DIZ2V z52A^cd$45&RYv+Y=(7sx7C{^a8!cQM&uM>fWG+>c;2k>j5Kv#eP*G8RG%Hcq^=`Jt zGfBZ=ai&h|d8MhPzt#KxbL*Kmuu_*%bDl6Be;dVIMcu^W^4>H4p;S)Du{3c=qBb0w zN@Oo33R9T{0Uj?}&fMT3>XqX>aeaG8?7czlL0kYWc3r7%g%+dyg}4|2@%J?)=Vw_~ zBh@(01-u{M|J=9%+dV5mF~NBKe6y%1%8}BOr$NBQCOIkDVMtwM@1Og5W3$B51|69^ zf@5k;|L5_=qIbkZTG~Be2UIa-Yi+9&m_T+|zFsZst3+2BU@%-X8^ChKInPGf*VeSu z`qbD8e!bH+)EXTfJ`x`2*_SnNZupBf3+wf>$`@4zJ9v6-{>+#$DCb9^gA_IqS6d^Y zn|Dn=HOE`NvPSq_!TuVT7q8YKy*H_#zPLn8I=7)`jfMAAwx=he!r~qO;q<=HGN2no zdP4Pxo<`)RDKGEW+IEk~OUNJHR-teOv*i|$e;P~&v0ukE(e{fc2BZCoYfT?^9NKT# z^6-vJ6-8k74s5_)0(Xyl%tcOFk^0N)jjTX6p(2I0l`Cmn&UjVS6#5&Gxues1LA<*SX0q73c!$^O4ed{+(CKm^NX&#qMD7Dn4?~pFxfF}i-}R@)v3AcR zV@k(yFs4d_AYI#wmt-;Zm+c^Vk~0I~PMy(}53l!qBU1=CI;c5j%KykeLUDPa0>!70 zbHFwVg`B0%3w6H%Z#hkPn#hVko#UW3!D?52!!Fl$jkKgx&o}GXW4Ry;uF9cKkcGfh z&g6G*VlY`e#{U?_H=iP08~D^x%VoR6P;N+GZMH8`1jJm_~eF@b9^in$Dpy zcmT7ka6+6wJ-8B)%+co$&Ck<+1qM$K8RCja;TiDAQ`LK@BTjx?)b`5t`gpf>blm%B zTB*=UHe0TW1pbD;Jzpf1%Hcwv&EXCV3h4ld@qDprIm|Rua#YzC@zZ2BZ`FnGg3j z9d6;Hnb`ON6fnoxTrm?lk-y{wdQO+Rc{x-sZCz)e>L0T6xQhdx_ zro#1pxYh|oUr{93B#8Zdl>R0SKj#^e!CUPM#^)!Cpi<-|_%qx{ z4ewOLWWfk~vS|D-r^E$z^Yxsq&OsTjc?%}tzKIAGiAmfxdRO=wRP-S1%**(DMdX9U zL{YamtRPU8@DCf(9uWx5zFx=k$lu?{eto~nmd~PtEwS~6;1LLhod5iu;tVUzc%q*l z`Acr#kEB74ZTdkgz2-EK)FbPBu_A4r3#MCeWZL$@^d~$cwtngPug{X2-;+tZZd*Y= zv&9O-CK~yIUY{S()=fgx+Ig#FhRNMcw9hO+A&f4i&nlZ$n24S(VTbv3N(dB1M2VBw8?Q1m+w>_M|t$I$r;61iMUiUQH|M zS%$x_l9^7j6IHjV-bCZ{7z!Q-C8E_xrW@rP+Y&Ki-LTif!0&oYwp>K=Py_3S-U$fJn@UZn(8zL zPxuX&hU13uvrbb|$!m|<7HG9b=Uk&FENxR3mO=ya!xSQ_*;7(=-TRFqBl7E%zdys_ zOm0b15-A}^FG5;0Z!~&PNwrdJtWv{r4VSCUBooV3eq9+pZm#Dfnxb0%(e@Q(_y`Qs z3Jd}?mwi`2l=^@=u*|;9No@=K-QlA99P$K`Gm^~NB6ZffK&36OgnbV%2AnK?+nC>B zi&POI-kt?*BeeS}G70Q8DHmW=80gsy_{qF)h53a$P6x3qmt^#Zs@M3Er2%O`oF}un zunR2n<9HHHe~LqiApfHd3l>Wxw<*Iy8Wer&Hd`Q{j)thz3x>iR&SFEmov{TDrTt+k3~QNK-Y*`wj{AG!WX6a~VGkD@%M8s=E`(^rD9!wTT7! zyutAI6wt9ul_H3%jvCN=4WQt$Fd)Ix6Pz{TTJp0Ox%OAGz7}a>f@m&g5E2_pONPCC zi0WTvkF2@dgKZ#*zgdJQ?Cw30Ixk_FrBe{qH)pyiR^7gZ71#d-X~8C2p#7cb6jfem z6sz23!Hbf)!7`Awb10903)f-sBO^Jb>#}>^NGMHK?KeE&rs4%Tc->D*1aAy(HFTeQ zIDdExWaZ27ThgMm>R03Bpl>_RR}N8Pbv5-_GJSY3%|zJnr zWf1EqpI@7+?H36p#&*$f-3X)`424L;{}v3II@GQ0J1tJulh2VBDc+rQ_C`RBzNJ##8Rod0|aWVcI>N%bt;%wDf__ zzfrOh$nGQso9KfGDpi_?_>uZpUHs$X=>njSbxGNX7KGQbWK!$@g^OvJX?Q8Q)NAwG1P z&bg4+DxN2$9pxN8l16Qf%caQtX$Zx3l?bb0Bx$%b)r2H)iQ+Jn@9^Y!6D^HROI&bz za}p#C4vVQ0@}k51J9DP3lI=^+dw%B?kIw0@>Uq$uMEF8BKM$WRY#0MRX&E{R&3QMD zA^*VeBuvQDCXOig}tUQ`&rA%FUpX4jbGI^6uH<31Iyj8lay zm^?oS=6nzC0M^_DAIBgyx=MDNnnJ^9d!oU?4WFncCwwM!8)~*4Be{A)XN__xj%iqr&4-UX|Cay(?d~ZB@EY1;tLn~IW1nlmMUSm8Y_gt| z?C#-|5vX~u0rP}W-znMC)AXcVRndvMxhhfr!oim{qd;g}83R8p9$6xNeT?G# zr+ug9UsUPy?~Jv0N!Z<(AET?X+f_+bnGcp5zy>s2d@^+7r|K%=NuOwxRGj+P+n`Nq zjvzQJt8+V62-gK1BsuLNV9^VHulDm zXu9m&z^Wtld^0#xJO2c-l(Uvg{I3BlBZVD&z_v9LUcAQnMAD{exbP)@-LXtv=}+)|0eu$wP~tPAp+sBvZ%G{GP;YtHDXni{YnL`dqim!@mhUUpgK-+~9#rdUMym&H$x7b~CuqjP;XRn8lnD zaA*xpe$nXR)ayU?gZBq}>R+3ffN0`*LV3KoE7(kDL8y_3Fiac=CWZ`q?nsbKCw}JfjZweP}fepND%=MAY4YyPlVw~*l0MRy}%hBWYCa&UK84Z@hL zI!TiV8bKM)FZ@ZlYP8LxNhH&~^$qZYv5gnXR=#FDA~A=;txwfu(b$jDUn|MWl@0&p0d z&`y^%7q9~Yir!BC;PC;ssB1f3ayV%8@OA!q9A^Db7{)OCv(mW9jgENI*gq-F%mVfg zOO2;SZE!#?+h^A#>2`&t%%0{XN8y3g}s!SYn1T+pfp@ z{5@q+f5l{$HxApu1e99)lV<VK}a2meY4;ilc)XMnAmw3!{8 zMjj0KM=ZS0M4i36Pw!2``I#Z?`>%1E%jYL^kE4h4C>n*9b9F~Xq<3s=^GVu3$3O`F z+&8)e_j~=s&-TCH(nlCK2?zHP5*ZtL9MEN&${pMSQQl}ybw&;)KbVbix$X<$0gLrJ zLUouF5a9nE{=U#a2vuEbSQFh;c`vay^0x;42gHPb#yD{R-;tOzf3MwgrOH@zvX(Cs z>~<=K(_Nh7lR5U+e~rFfKEI$!w-b`#p(4<%EG&AfH9C|(|5V3vFK>&OfF!rbgYn)* z2#O;M(<2|=T>tCh|M{h3`22yv_(=6dxAissEXZz;jVMp-mrS8i1Tv>FgQ&uHjviV zS5~f5tp9W7|9u_1f_OYzg=VeU^^+r;LaF=%sU|P4D?c=cr-^-^?EYWTa|8WG2K|FX z^^2OC`t$XUP?L(%X(_hqYMHr#w6NR#3x&02E9j@t^!53P6I5U)Td;+PzQ=(BbZKcx z4Gn}_sX_-RI%8Qh`?qu6%>C;SV3x;yz=_EK9&$)nUg${z3%fkr+rpJa>|+|oYfcx^ zE48q8u@Lt$DWd#fpe|N7I*_9x&`rc8wY9l7F{pw88b5nN*`D_(_~y!0NN+ALRYofx z;R6ZPN^H19PZz@Wed>=rsF*P(JUo!DF0K>~4;0wr!=V1qFQ>B=9E?lw{C`A7kx`bY z-FBl7IyI26!UCn{>k4%jhLN9O4k&_>2(0(^u6C>NEp#2xM@2^C1-%~HOKoJop>IMh zm@CVYU^S4g*N+0?TycNJOLV?_4y|VtSp_qoF<9V9q-Bt0!vcL|(_}XzkjAz(ic2J2 zEviYFiopbA0rP7^YkREejU0?ic;;-id0--9l@{1+bC4+zU4(1HsJv0I3+;Wp+^99W zftb*ykMNwzuon?@rcVH-wBK2%^UfE7NGwM5A}LoKudmNK)*UeyFuRc7fi4b*6U3w{ zRp_Sw;hMb>)cw)?G=6hc=De$C!=e+{^HXSdHunXk>`L@EA2b5TjoVH?5D5}CN26-Z zvbu>?v}0Ppoa{atvjY;Dh^mbk(;K(fD_r0uR=xQR#*(d;EU{HLgj=}(c1XVJ?9SQ7 zI8IZ<7cI&Al~%h07X5jvr6!9rQ-&g+<+$p^W+Od&Q6!0t~fNui8M;34CDH= z%jEFWof%NdoY+M&D0SKD>G@Q&UQ?OR0G_sLv|%L zC#_(`AhV|xI6w#uDAZa(ft)dAM$QHA@VDImvuKh<_>sM_k7(n79)h>tsXJfP1DXbf z`#k~pX|EgJBLexK6~f$p!dr0ox~3+ho{fk-s>PD~@c(zk-ur>}XVbdh4S3?M>}#(< zmiXZpu8;t&($sOCuhSOOn01k6M1gN%f zzWf>D@1kP+VUUxce7GO-?JN7gWk%W6_%2ofegweLrkYT|AK556jbdl(szq76h^d)3xEvQNVn$#>R^6;{#viEZ^;!SN<`32-O-(U5Xwcg zba4b9F|h224Dq${!Y=Dc@ygnoS^=gzEf~bZOJ9))$rQf) z945F{JvjCQBXFo3mN zJT~$(`P}_NqaWT)b5*BhS3LZ3XF#mssQC~7`t_sUq^8A@ky8f7n|Ey zsd_65H+yM8S6(u@LV9Yj56oagd*ytC3@2Odwtrh&mSv{U09{*+>)_;?KwVY`PGvT% zzI|^Apuvm!iPTYm0`o~+ii^QaZcmWR^wr*)DN_!&cfP|602rPap44nKiJtgSFaZZp zh{s+|1>4J5Ue=ri1pe} zQ&ytmV_AWJ^0T?992Nu+hQ0fm3Pds==P`8tsFgr$kGZP&U(tKQ6X;x zKkV4q;Egd>hF76RZ_!TBSgOGWv|0wflzr4>v;Ap|&j%w1o|+%*wTuGW5j6r}{i$z4 zzkXC?!|YT|vD}`aMZ8H27b|4Km{&zyXW#3_2=9gg<>kvPiypBN?%|(BRf9QYNE>Gk~)Z6^YKC8rihp5%_}i3=G>(*%^zftgMr*l_R$r~z4x6;<#r`avP!+=h zP{onKh>WOfwA{X1p5GLT#I`*Ory*@zyax}aUVLM2ae50Gvi%}2PJAj9j1@(J00Q(8 z!Ol=tLdw7%%$ML(5)|MpSTGNOOr<=nwWoa$d1}C(mD$mUkNpq4SCPwR4{W(wBmFS% z`frIGi_OjtQAZ+)%zrYKw$Sl(294K;OVxM1!C1=()Iof{#aa_(ml=0!Yl}^IJdZ!H z&U6;XrfR#xIS>F5Nl8o71K`0f6PU2^^&7ya^UGLhw%SAkXw!YbBr)4)A|DgC#2P#t zizjq@IDr<4#t4Qd6zKzC=-YJ}thO=!Am9!FDo6*b_MmUVbS4W(U#-nfe~u-;i~KQ= zGwZO@&@wne?8|w4Q_1|v;r>xc<1`)7fS>xVV zwMO3}`H`25f6m6qvK_-1W;I3*TaNe|rRpR%R6kQ)ZH@uq)HHhd_#-fl zRfn5Py$u+(S6KjaHD@f*4%c(#x}RL1NCVc^f5z9oVG2vNhp7{yO1mQ-tSmaPq=bVh zEf-H(8rsalY~bgPQHe&CIrC^K-_+yZaVkXzD#sQ5Uda8nbT5hkXk{7hEZ zrgTZ9j&HPka_ku10zmrdc~$-VMEHviyfKP$K`Q!&}W%l_%2X0RF7!?_mksp`objmhLv3`Z*g9sPCSa zdpw)~28A?>j&#ef`K2WT66y{*3Y(yjX$-#5;8OzqP$u^Sq3>ahR$H)2nUuBee&MX@ zC0B$RGt21LPRCOiu}o&?e@Zp5eJjd)#-U~<~ z?{s^H0+utwqp^4`%?8f`mj7V+kLQc_4=2)V8ygM*1;N2!9?zHK?1#kr06*NX=?vz< zgW*Va{wjb01k2lDNg@Doo|K=w&`-HUDxHZcEEd4WHmXHC^K0f1^~J{t3lugHAq4CB zuS!iRUIcweU&-kzM(_W?T7S>grJde{19oes=3?R}^dL{RH-d4kfYKz{tG|$; z71WlEc#W7_>eE|S?69<3ko6Gb@$`6Hwjc!kfsWO)1f`RxZV_ubpZAC3a7b#!S4m=gI7IBo?oFTE zzD<{=33=GJ4%*?9&VID^2OgQ(tB7B_PizY;$x^;k!c?~tY79Vemkjj`)5rU;s* zYDwm0pGkgmcY9@#NCfyTP5P0}w!cOn40JqM8csFEL+B??1YW=nuRHWkVi=-xjn?1~ z)_-ZsC#iY@&*mMVHGF`gv`FFl5b(@_M2<=i*0fsN|C_g1T1^rDD&Gb5oKdovnYn&)bhWsOZWh&M%TSPdp?tdr@hja6V29c1 zdK1#v$S9FQH4do;;F|i9$Q2i`x)KzP1O)|U>^gkDIC6M>XUvnu28A#P5JJQXnQgY% zWCDeO7h9}1?LJ@aLBYWttpFYT2&t(4Dp#r;HyP{pc<8+M_MRHccRpjL4I^0q$EQslAN|?85 za0Nhi4_?)J$^Qsz4Me0ls4W)#QScRK_)sziLR5AzS6B@Z|B^XdD-qAIdx}(aMoXdt z)8wo}23_1>aCSmste-c)E#OdBuKyXZ&3`3(T@I*WdR&u`)D9cky#xk*{XE93%#ZxD zB1r#mW*8KLE!x}F>qCSE&{MXEodtUF;T4e?)#D_=D%JW598`F(mrg!6c?10xU-5hI zHcahgm8ulJTR*87(@w}aR^YT(J^M$OA(tOW4iDd~*Kmghn1p--`@E>0vAF`Y11s^S zKIme$ViWPi5e03dc0EEe(b6XP3Z&{KGx*4sD>M;3V{Ae4x!6LDz?i&AX{I6KILw8BLWRbI*e{=AnNVd# z2?Q%&nuv5mY|L!NelcqBV~sQD9@g8jH~eK8Aum)}tct+bgf$%WqwxJ3ip7_@64a8! z>0}Nbn@L({^YKTfO@eZ2Bp4!pZK0k5L;6uDAcS7SxeCOaB*+4lA|PPP#Rd4<4P2mG z4|)lStuXXZxoOk$z;$1uPbT8W#HT#AB{KG4Xo7|;zJoajdgZ8Qqe0p+>zOO!DwNfJ zAbAVgQSNTBRMzeJ5*^TI5Cj$$77YyzkB$zmR5mw7Li*FQRA7N*zC{Pnyl`wKyh%C9 zj4~GC`008(bJJ$4%bjryEgNB2uUM&21R7wU4ql9Sy4p;gEaY{+FQS4JYby`}>&v zCOxu8NB@okxJ`J)Ul|z~4wzJ9akvG3zQ5S5QLF_O2>7WDDa_`uQ!zAvZ(RHCzQD8o zbU?-1YyN#7v_WUMPv!)pnpt6W4Af;-Ff1mC{?{zeUi%wi%_U3@Ej0IXD9;9rQAyU5 zfe}A!IodF!D;$HQY_%b(&HP^{355y;kh4BsO|CmE^mDM-AFuG%@$?5gaUr5l&{*`Z zInNew9%H!6w6@S#9LM?2RjRBnAsPH2H(#&c9#0v85Bg5DeVc`gbB!-tT7PBjG2m}n z5o4YO_dnx!_sPN^cH3h9#!9VLt9NUpHCl;uaJkzGw65hFbVIh~O#vGA?tu=vW=G6VSX>NP(EUBV!1=&t_|E>ef{W9Oht7LOwX%`kKZNy6i^b}!~g2+X*K& zNn&dSN}Vp&*V$VFaV8N=k}o&AvW=~T)oDfr)xr^%4GzZ=<5t)27pt{eeAXKhmy4ya z6iQ`%)~pG8E&wU=T6lvMVW9|IM>RD34o<653BnX-TJ#2%Z+GEo5rwAL`8AZmN1WWZ zY;dYme1udN)~)Atrb&R_&Hf7OlMRTd8M)Ec<87~V31iwC++gQQjb+|@c#srd%m{e?l+<S52Z2iX|+Xq$rHfjIOk$o1A zGw7+#T&hACUG1SA=X8&nj_w+E?l%gsE@3ajp_1&@ZRb?=SOf;TLE|LH3=HK57xA9) z7;6~Mm8lCe^;|yBr-w{^rE_1{kL+&H16s}6We>&d;*K#;ZXpq%0) zP&Y?2>@A(1(aoF9|oi2ARbVzz4s`&a& z04p`hw5}?T*|5CPW>0A9C>C;dAjG0!v?|;ZR^&}!ovOxYY-}71aJ3x(j7eE45gV-L z3!qKP;^FVl1}>@xLYk5B@$p!`2V8fZ+1c1SJznFG$^e)6K@hWpicJ5gd1Zo55hUR> zrZu;OY=@b=GZ~h1PlJ0e%{%7z^mIQG3_u87K>uEB&!DT>#Mb;=lhj!K_uH)Np2ffi z(a%}l^xmB7>f^60Z^iy+;{;iSpeD{S-TNcq6lFooORVnRXpX0v8o`r(jV2BZsZKwC zL2j{Y&fxjDq!6ZxFf<}bQ2mnAllD*zit4|x7m(rAzAxJc-|al;7CLZqMWJH(^b5^` zFl+9X-2=#6a)G4ks!?R8ofEL#bPFPr6`$uTCYwl2F>!|PkS)L2}it<#xSZ7py)_(LW3)%iXdt`XKF`lO-Etve{%r8>- zxj9BZdQW4iyH?yiV*?x}OEp~U2`7>AK}+Uk*<=u9 zy|svl!)NzD_%^$DhAIvnd6NNEECPtJ~ zY$7aA#eZY5gJ+uqkI+9bA;J=H89DPU!w|~w-z~}H3Xsm~Kfc$jsM;wviwM4fH#ssq zQ0^`{?sfL6%Havzi}Z*C$Q=g{hjkNX|r#5yZ1Wta=*t4zL=2-n5h( zB$L0ZgMgkhdb<1{xo8BlMHcXKi&eqptGk948m!57#13v7*+*dM*u`o%#Cz&}uuDcV z65@l(D=xoX<|G>!e5VQXj6mIhB**>kkG?S>wCqe{UOS3aLn}RuLIStoSk7K_5WI2- zYg{~==H>&=d3or|&Ygy0RCDpdz9MR^B1iM&aBMIiSJ9U9)LR%&F7sAz)Na|ifba3* zZr6~ZQgVmEPqH}NAG|Ip_JTaLQ~HbfpQUMv&_gE&1FJX`tfq~Xj~{;EfRETzki${y zl>8}VA96+a7Ds2#v?v^bBLC8W1w4`(E79MwO~a_Q?_!m#E9J$fo-5rzv71N+nqMIh zTd_JwkqHVycMs4>(Tg5QIB<--zPg>D(q{GI*Vj_-F)^E4J2ZPP;j0zoMs;F zmH~Xk6C~8hIWpBaYL^o_3X+Yf;v!q_X^hIoMF_plZ#{Qz)hNW z2g8z4)T{)jLPkwTU{vURx9dN>5MFLC@Wj%RMv#2T|B~zU&sAoCaM1fJDb6#OMP;}c z;!lqokSEE8QrrtW)R=v#T@}IvtjZ{kouZIBt3<9WDhwjuz$*K*BPOQU3nn;z27|;E z8&{2KEmmvGA9Jd!sOe2E)w1+Rx5rFQb;XUkxJ{;q70?4NqCCO_Js*IQk&E!di-?0k|JDpW;d zKuuQ52rz-zajdoz;d4}&6>Y%fowi$T{a)hOqdke7$p5OP*cF{_cA)|JTK&~&8tRsG z0ioo%v><;`pV&4v(+?SPmXWD7bVh5?)9GxU!;xr=9~00l^-bUMsHSY4)$ZFXPb8G8 z+;XQCQ=0Su$ib3)rEDh;CP&b9?kz@GO#COA|$PiK%Lt=p?(foxG9{j_H%zG^dGpl-~TAjmdWVZ|bb}U@TpP zTO6lCb;xp^*`d`Ls=~Iq<~^*-zYIzlES})KycdJW^DQ7hhITzuB966Y;Z&ULv|K{* z)0T7a0jW+^8%lgFd|tqWX#e^bU#h$9^!NjM8%`c%N)V6}`>)|av?blsi1jC@c&vX= zv4*zswT^})dgW(CD2@8!3k!X$H`-|Hw)9RYt>2WuDPp%MogvixpJhN{x%6n;x;Oyu zEb~Gf5ihw$)a-5VlqrUw@3{aWCJQ%{{a?((zV2^?PTwNU4T_@GAH#`MUD&VCr#D}$ zj)hUJtv0kX00!az0}iTG(>M+p2J9l`DU|)sjLQwtruVVGKmBWN+yQH5q!_C$Sw!)6 z9mMhdY4jjc9z7!QPVX3NtEL{5bG3SC0=koKw}R{;J12o@@_FQBka1o~Qg{Hlz}~oD z!R=E?F#g@eO#S`FJr%i!JH$KozR)L{B<<)90@d^rp+;zawa%^{+^2z2+c$IeFD{<% z4`^rG@2#_*s# z@Wi%%w8Ym*8E~^PRQ>sa1N`Si(ZrmkZ6HMi_7iWzIX z^k$IP;e@78<@(mO*Bo9^zuH<^_dM<(ceHi_y;*tehl&?af)4#RQL47x;97aY{vW>H zF)q^p+}E9KPqtlewrw}rwr$(CZM)t$VZvk^lbf21^RBhmI(zSP_Bp@jb02!@tsB?> zy2ZrGSjhuo;fL1qFPa@LWcU;0*E`Vr`s03GFzsEfPH^}6AuvUr3o~o1VEA=Mcry%%;k%g4{#rR%zByFurV9j`f^AN;nx)at9g(&n*AFe?3AB z)}t{3`;VESMWrN7-b@HPXl8gK;6W3bN*K_svn_K$TH*Sb-#u9vm zlJ{J}R>>8N*SAoci7B8MzzXd3Liwoj7bcg!nfLq(51q8H+q{IKG7(9A=~OXgZLjE{ z!r-G!g>KW@qT=Atq-Go)y78le18;JDBOIlCJ{XR7O0%^G3neC|g!{N(H7?$#8HV%& zA%p+@7HL9+U%fTaWT$NCqr@JR#N706!l5+}7Y61$`gjJB-a=nZ6_ahV1dR|1{Jert znxSV%3;6TlQO`b5z0Y$>7j6~0c05V2+?&QKj<2&WS8^p@&B&xSW=*G7v|?N(IcE2C zY0Or1LX@6CnhsQ88NV7MR!}!wTM3!dlN<&6blaG+6xK*|^8p14pS8Q{88vmK-Muv` z@l2-bR8q>VHvH9dEcu;A{r3$N+z6NNKmT?jr`|s5)1bGjhB0wqO*ZLOheXB4DU%fB zLKeyxqTw|}{JtXG(@Ibj`5YJCIQv|Vh4o7lthI$pLT_}--R~pqI>k#<%@q^U*JCguccE4Fj zqG$?Oo;a1zCz8ZPk-9%rJR($w_tC$)D!QRu&yT`&J~x^iwUHx*Dj=KENPzkkt3u^z zHP*|b+bw8ofC=1wl*F*fQpUB9b?A?ajPYsn!={FC1j$uUOw;{`ip6ihQYL$r3Rp{zjNPc4%)-a)x;P>Azs*4RVyyjv}J`cQY zYg}0FC0os(I<)exb|n2kp{zHuJtvCCxtjK3k7Hkw-^cY^3u?Ql@Ef*wtrRyUC>vr@q!+Px=}4k+F7L$X8DNsg%g8;Lte z?Y)xV~s( z*aW)}j*Gf185PL)SwZ2SoljjZ0;meHyF>M!3$T3wM2WU%hbF?@14jZVH}lm_8%?MK zp(iEo+3i*H##kQLH1P*`wM~DMAuQY*LNvx^_Pd;ge5B!jECqhzRjK9W8xS;M@g|Hp z-}|$NIOq+P+CXu?7QZkZ0_jP?9R|mqf?S>c*@uF;$Cf-R)3(WB4@YPO;BZf4rEB6A z2OsiY-q2z-RI#~#$M3_deUC8qdD?FRERM~+PB%|g0EQ5E=Lnw0PZKA(FFmj$1n@@G z4;EM1*u?fV=mAdtq?cznbm?=E$rvnVBdYv(+z9#f4Z+SI#af>83A~?k6!4fOwe-=| zm&I^viHpJeb>$5ZFp*owRu7=;^x3xaw^OS&(c`}@Fv${d)(!-hk)`S~YF$Jg=AS_I zQ&&0!&EJ@dM^Ae|2>Alm>_s|lPwU(UJ#DUmox~<5-dN+#RVa|n>WiAd?8NO-senHi zG)no4{WB6Q!QWL@(xz8}ErT>s)1WVqc=+9pmijVoDIqOSW)9-*`Y}t$9Zu8EpbYTu z$P5g6gQq0YA4RYC=tar6_;$~gG`P`laCl)7%@kT`c;|34eYP0SS6ODk^5^Ot2-Qd$ zv@(dxNeig8^4e`~>bwdj6^r;OMPte+!aycQj!I4!WkU6e>t_gu3W&6B(hQim-9%|B zNpKinc^60&=7j`p@y5J$2VWjsdfwx>9&TON1HK`XLz^+A_yjH-vT}k1-4ltfr?PH; zk0!ld9(h)-v`Etfy;&W7cE-p&q>m_Lm&+RtuE@QgC$&yNJ8!KX6Iw+BfxKpE+Nacr zIpO8M%gt8F2dTBHII1jY@Qv)==Vp1`-tyG~993%sosPSbpnz+-?2{5I0zAB^z`f6{ zx2xR#H9zzAP);~L+8%KcfUzLR1(k$!V{{a!vTw^sby$mPf!7@G5cIy+V@RtvkUX>N zEL4jA#|+_PV^ick@)v2#w|)&1lg2a=P8-6Pb{gw@NHmV?89UEE1HrsQNp?Jkn#HMiq1;sGSrt+LwKVQA=QyyS(*ii5KwMKHPX zuAdk$lThSFWsPgh*fAxQ#(lh^Be^9b+`eb1DbNph4CD^!v-@MNM zGK`o`O#RQ2aqkNpOy#edyq^=ih(zRpO}0rOzZ+5`aO$sZMx2tT-hSUP>b0Nh42<$a zZEe>c2B}&M7No&M^;@C3qZ(D*_(bB&Zh9W=i!XraO>iXhnShu(y|ogjug3{%CpmDo z=6Za26TqZ`NGY#ybd=oqXeb`kmgT2P5PWv?SLnyll>#bS5Pbtf%or1PS3!61>HZW8 zcV_2T-cl*XWZ(s#J4hy$*Yv)+2@46HClp9jC=>{)=s*&S$D3W|c6lEQ@VDRRlX&;U zQR9YVtGs$TI4^$}UZ2N;a5V;+$3p2uB7Wll;DDL8%WZR`5b}+{RGyve6EK6^0E(gL z*_`Wgy%~Y2!P4p8J4>%nD&aMybYa@E5s>)u3TIhxnV#XYj>Cq(dXlco;=aV-=U0es zCD7XAmu&|Lz0PNxvgnpH;W3tqYN!}5aV3ki-*il6@Juj1V_sR%|S~hMfOPJ zgdlt#H@{p9Yz3Z&w*Lcf4(NtaZw*)y@x~giDhQ5nb4ts*aK2>Nf$A4a{M>FSBf!Fn z04-dAH32XL&O8-OMW8D4|A?MZ&sC;S_VUyup%zAgdjH|6KJpWJou9+mVW&|ET`CR) zrCD6oH+OJw+#Y8X8ywehljIf$BfW{gYwoXD@IipW!hjgtdc2DaTs|P+`U{*KI!-JzW83-Vy|kXn8SAE2%X}{)D#4qYV~=sj3*Grc6D*vfe<785fFzKOt3$dFRrq7KMKK&;0cWQ zEn!a2aYU8X8WSj0HyyXyWuzzc0UmyCGkltWwee)1aIg#E^!SX~r^pnZ@tChdZ;$>7 z{nbkC^iP9KUU$9NwVn*sAC#c#k2Jb?8lPvMisTvDBzauM%O>R@lUr$8?+N!nP?UJl%XHvBkdf7A^Zoqv zN2v5O*jt1G5TSR3dnac2si`YU>cnKfTRE!@KR&x&(DWArm!Iufn77zWa^kWwy(W4@ za^@|ANSK95l>u~Ghlym`JO-ti|vV98piO@qSXZkjh8T+e+wy-t~44Ziw zgtI(jZd`}uC53t0VWSg>X!w}YehEs3huMv3@81)APJG@d`x}_^JyKHY%vBLtB+Pl< z=3$(aJ43(jVa^}Vk?+Z>3rC(#xqTndLva3Rt*9j<`q5l)g3xNNhVAkDN@nhQjrJUk z7KxlK?im6h>`kuA;U*U3YUe*Ky^?5TTg$jG1^%Hbg_CP-H5ewp+84H41sVRnLlJt= zf!kDr@%-8Mb9zGSA{Gc$)UQ`b3E!Fwn94UTb545tSR-bUmsYf3CVj$%*QTj$nxh3j zR>FEc8%s;9d3& z%nPm`{&Rw<_GM&_cwi3}ivqYc1|&FXY-wg(SD{ch6wE#?I^6yj_gFBMLdpYS&zUon zCLv>YM++b+y|JPJhFmN{bl5i!Zxu8g?=bqCF?~i_E7EZvE*+&5f)%@!wA`3A9SnO*Br=g~?1hn#n1mxuS?qwtIlst~oiQfM%x0$}x=)shjZ=Q&41E;+*y zF5aEAu3eg99(ORvr%_@AO#6udvegWO;5g+rE0r`F6>w0IED4f=vIM+)HGH<1Pv73| zi#h1oZ+1uw_`U3%oQ!uDMgrvEF94gmO zI{yRCK+Dyb; zz0P7p4<)D35YKkN&mOuZVq#!SZV)j-NB-9ZW!yzSOuD9n=xWl-8c_zKnscVCU!rkp zjsNDrJaZk&y`;Nyp#)eFwy%7}MYsa#pX=)PjMaQM@0*>0Q^@*_6ALTlPUD`II>-NQ zKK1|N2?@0~_+O|aKyIo_=N;x$MxV!~Q&t0k!2AZ}{o{ZoP$OeoB3)eaXR2i&*)(nj zouzhCAqzE;f>Hmh%pa!wNj}xr!ZgjU97r|B&rp%qQGvCz&hY|YD)aYWELX~Mmkeo^ zH8)#;&~FXZTgfy(l*RhGW^vo zuK=5%IzTbJOUgiZ5)?g>&pqc=X7M;AtYVRoA=YqAnDc*NlP%#oKvAgdQhfy~I{<;U zw+a?9t2|kjYjDVmO0_JM9^Xy+07m!q{}+CEYW@F<9~xMol#Kdn?btIUOk|4u$@0B@ zH;dV4y4G3~HMGE|t1{M%$N0uAlKNpKtyOs?)Kkgb00$Q>zZWMaFYm&OI4Vw?LD_%y zR1qUStN~IqcF =JS!*uYmlj-3?ZoPPco&sF;|8HZh8-nF4`aggf`AYyIn-zBI!g zH@KUfeh7(R93B^n)eK1T-!#dYf#_j0h?yY!K__=}Q~_5yr%DH_QTHjwh8~CytCr+C5Jx2-v3jV(PhetfiI6#WY3F@ ztb1X+Z6UPwBp8Xg8s8M9uoLb+WwMa84bc4)HzAN&Tn6pSP*cwT4fO*9&C|npI7L!5}ElG|_%&NM{s4}Yf6akhs6OI!_JTR3)mbJDvrpj*g%&)!x6i0>i(fkUU-xA@o zrhtYR)E%x7k{E=s32unB@(~%r$wcg;vHrp{?HBNdSb5`hez<1htmZCJ@Qf__j;M#t z=*94H6K`AGQDdO#it|A>u88?Jk(TX0ggH*z?Dao+-?wj<9N$ zKBir1Q{Zq5?YJW0*xv?n9x6uN>sY_-%oh7qqqr?^;>uGGiJ%)FJvY~&$q$Vt{KJ@M z_=>3RGgf({yG#yrZn({n(L?53QfR&d}gTF-}cme-^WUurvV=$#Qa+_z8ZxrVE zyqrFaEfmT#CP9p(h-RuT_%*{tILkypk!ZHoR?jOMrsQ?Ro2Xp&aCk*Ug`R2Jl$DLe z(Bi4r`RUwt_Hyl*;BGR*njZ$@1rJ@T1$=sSb*5DBH6d#8Hne%SWf-(PL+=ph9$Q^) zrQo**{GwccYIQ4h;y}8PLMq>D?GF3o(q>G0hfz4}2#MIHp@n`HFR#F($y^xS)r!vQ zV1)r~f3+ZeY+wCP>2&ddEa@=ICNWL6 zv-s^usgb>N_O_~Rh|PksLNq9s-}^W}Z>3N@@G9S@OH%b^6vZ%ov!R?n$Gp9^PQmlMzS#kY1WJB<-6}H*_Y+L62vZMMa>ov z&E1z&6-g^b;G{p%V;XJB=L3EMxLH6pgf~xPJ@RAen~*4aEmtiRQ!LLQk)8$V-UV^L zA!9pk>Y#85Z>w^fFm`EGa}-)x8+qR%U{q*vc1squg3w9$uL@t#j zAGly=CQNpLg5Mi9WR+ZBdZuhC6K?b&WGO~{tc8u?=!{Y9cY!0WGYpEEYV_XeFg7^G5GP1l%nMq0^sX#3P49o2)9EaqHJo87Yr#`LUR!GJsUeaUu$>bP;yugOy+!NxhmQzZY1HLdLrO zlk?-t0_?$T_8OVkgQE`FPaNez7myCPV(hs71`oLJ%_R_ zsu;o*AZRSv-od~PqbG&5!{;_tU2Gn1L6|*L%M+;8qNxU`7%y4t>sC0WH}~^Dl`Ye$ z{Kj$3{57Oe*ot$ul!(E4WSyJJ3sQv-w_45tEM{=Lyu5~j!UNzyHw!>FNXk=EAM_kd zO-Jl67CH#ld84liCF5{_?d5$SINo(LK9|e!B&0wM2*N?T|Kf4j7)@s|ZawPn4)}NT zzmg-LD-2^&Zsbsix%;e0%bAQkf|_Ax-MF%3--k1_*QcAr9`}oid?o^bk+6BSNpZb4{_8KMPyA*nDTc7Lh60`LexenusDP7T3txv}U|6aE z`d-J41Ft3rTk(h+H0`-zwFpW$EZ=o-&|iBa+p<}9&AI))Yh*g5EJ!&EIgAN{9}((y zF)P&@Fb#`0C{5l6SnmV_h#nh1n%W)Wz83O^DfNejOEQ%c0oq9Vg^M6=P`%Q z`YMgi9?_2o(9UDb2g6A~gm64#rK)u9QpoK=r7Q^AZG&-`=X8LkfhciRWPhF&AXU3D zNQ7RKhZBQ~-~4Tmy2)bGC5=u^)!cN(2eHR$0$J7WS`GZ%IT#jRNwoO2UOCeLNj;lp z=qY+Sa$%V6317Ju9Y(_1QK%2Mlp1BO*(iM_@W3|A4`!$PK_V23asGeBP1D8yPq-=V zuk7l{yX0TS#zOGOcFC3-PUsEuZ1iI2A6?Pi_e`Z(iyk2N!8NBBvoFB^SDZ-biU8qm8md{Q5MvXo$O39h7Ak|R%-VIehG3}s)AXMQDE3J=VWC$ zZAPWMxcT25NXTaaf;S87wz^m z*J-;kMT1VSiRTUpmlthJ$L5YfGn?g=>a{r`18DaKk*lE>rXBqa%*=X->1ZXuIHICIpoY11$F87w{!1s5_6_XEvcpW1y5neuGIYC z5&N0HM-u5Q3(92ggHj<+hp5rD-bnU-Kzs)X#O*HJM{$``wX`Q$hJJ0RjyFPMO@Vvf z!F~!NM0g6$_)pBjFkMriZPa)^CH;OqJ3JHLOjuJMmXeD%y4TDwJplz~&x%8f&HJV0M6(4|x z26hP#lYk(v`7ogeK9flw-o_jr1;w&XihV7==bUEYnNCAPW1Lbdbf;FMcAQ4dZ~f}> zGAoPTXL}~j4>{UA&N2IZxfa_Z$qDs5r)&x({;TBON|~RMw=**n8^nocyPj_#AR(E5 zW+D9f0-aNd*dK{xYyJV^KusHd^E^C6>E`wQjqYKpK927pYDJX-6Vp`~-)<^z>c^1Q zqp){s)?8-LpE?f{V_(a&sZaVHM`Vwpoj#St?W&i7cF0&cWLLyYun5c0FL#KnS`mQoSd9vmYQJP^K&sGiPv5CtWR^YDw*@5 zrG!=Za+;d#yO_RAZD)UA3yyK0dGO)whDk}WIrRaQGhnN;389){ zLVhFZ^A*ATJ9RgS!IfZs5c$xclmljKGHK&LeLHW617wjF ztQ^_OW_W%_$i zKx`6<0^b3hSnXAg`>RFnOmNf_F^FqRk(}As0koO)b;DamV=;&XOpg*pmbAQ-AW$N1 zzI=s!V?Kt8#^?!$Wr9)JH`^TOfwszsb*#iP{wEgj45{)PMG;t2IO2SH#=S)DtPA%e z186c!IU1Q1R*rwOf4I1f!)FGVx#~EureP*$IQ$Ar2iU5ps+xE#sy-DJxkyZ?W)c_{ zqw-@N$*J+ema?l{qyj!5y|GWUG$-*yAdpFG-1pOJhi6pQ2EaKphmk)#-&`RCoOxSl zq5o(CxzO<|)Ojkk^Ger<6)Q)ZkAFfzwzUaEpciH5du=s#b)6ER!#IpJbIS76VXnVr zF7Zf-B0zKMgqtv{{2*R&N8l*a^W>!H_G!rL9q&S%mYts$n{L|Qrw;Q`_Caml zCfNzz5!pi9;MY|c!B6N33YohXu6a*mfhC9|%-nrO@Zg%9?MbdZ$Fcb&J=;M@_D-2( zOv9@8He5atibs-vIooQbFrZd!V zWC&z`di(*NDm(q$2ku7v_L@6{YZ-^uaHg?B;-t)LS0 z*$AC4u2J-CyHY98mq&lLLM-?U>2$Rpo>F`#Nd9uRp3EQoZf7eY9;DGkcc5=-&V&Y; z!|o^_qce!*Q%RzriknH1Bv~w8Mkez{`D#2ux-Ht9Xg!PL9PCCcfu@vl6dy}EDJz!~ z(<9k45>3^|KwG0%j&9XZ7#MCfRtH$X6n(cFN-Fxh_3DoVd0n%x8Z4k zbqzj=O-aT6MFqQFqamDX*`gkRT9gh)Moz90zJ;#>5&%=$h-(L6Vg)O1^CewgCx&Mj zHmd1Tv9WwOWl&6ZE{QGH@aWG#{qhDK00S|mbwv&@9iAL#@*Va<(`!BXj;AKGqW#)^z6h_lN5NRdmt&M7Uo)SAwcCSkf4(a*F5j#fT38Z{l zWbRxK>(9UoSDM=;mu#oasZG{V|N1_5@NrxV>u4I1KxYmTz@j;QfWhHiio^g;)50e{ zAa4)?ixvk3Alc;9mo_Cjvtm>0_J&g&COWJ(i!pWD9pXtTE?9nTsuD#7SZ@V75gQH! z0Fx2kVkuK%YS2%Z3T15K{sl3ub28TG8;zmIAXcDI=X6Sm$`ZgR-q14j@=rXmpC<%$ z23ze82&AQ@XQ8urdo&CoeV^N<&~6TBY{=s0v*k~v$>7LVhmw|+m8_i2N5#g^OxD-? z-X|9WGZscgMZ{`CQZZXdkz!4!kpfigtj}WcU;@>yAvG-}ojFnB^MOHd{e%1a$;z75 z$d#eGa%N-0LJG2GuX*y`&56zCm+*QBFY%oN3xJJ_c!3J`_N_q?{%c}>o}1u38-KR% z3_13Cs;VvOV9ytaAyBzQV(VV~R6mbDIrp8+NO1u$7AFGxCxX!bK4F4?P92fH-cK+u z)oGLp_f?B_6qxZp4Fq6&%AoyOiu&8U7#|hR)W2a`JI0&r7o3|6gZRrtX&Q<$*~j)c z{5AJU5?ggvQqM@&Tju8pz@MGrYUYmTR@W!^CMB!<(tD*wAESmsUq%M?qi)846^RX7 ziSfrxamWfCB_pLePkqQUY{#M-fXh!{|BEVU>5xIFkXi)yE8TS6Mnkz1EPA@sFo#mO zTZl)_h|E~qMMkviF?Y_S2maPglRA$3A1oY&H{N>qn-Z#&9)f1s4pyO;N@tS^Xl2es zNZ|$#^|@9u5V5NV8H+TTGirQup&xWTeMAp!+xpqW{05jCQc*?hlW~oG_bmjW z(WN4x8E~Y)d?4_G#R{ZdV45O$_x6FNzWEFbX`r3yu}b3gUZ+KOWGamOVMOQlnpJG+ zZ!IRYua<*x;&vdPsAxMKRw5T<7pd2|Gx}4&w1{UBE^n=zxxTm!WI=v+@KFQAf+El9 z)$!Q#u=NTR|1xW2`lbJ6?<1Us@9by<+*LN>_0~-!i^Yi?BJUxP=Eo9}mkCf-555ts zbP~bn_7@t0fan4+`1_87TKrPHdU%8^`}F|jvC}6EJk|#kF%ecWrrKkiM9_adMMyESD>EpL(ne$oY9!~? zutu;uyS4uHA-_IVhvy=11D5{DVO3dq{24K&-V2jw*6Q$!x7Lp|C8lk|_O)hxWk8Y~ zTzGx%^ZQ1?UvhhT2PWT{8`T}Vx_5S;*wx|%OI1JuM+y~1u8BdIBXLl6?>|acTdF^y zuD9+C0msWh%G4M(&+KfqrHMGe)XFkt%%pk46Iy$HDm~vuPPOrCytva0Ml0R{$8rl* zJ@w>Y*#UxgKpUKP>+p;+HJ>|WN#$?geTCwugRf62YIK>ZT`=~eXJ*y@ppTd{u z*V0RpF@K0v0TfiBKB9&sb*XLsLrH3sOZdLjV!}fqvW{Ne+&S<38=`eo5R&a&a9m)7UbSM5j`6LlW5X}p|IU&!r@>GV!pU~saY1L z2^+Q00P=%Ieo>PtR@LJ&EQaeMW>BY3yLH#X0&bEGtQu&=|ti940yff(dD zumKn}BWL_;?dF1_Hc$MYzmaqRaP7jfV>?ol-x|eBeT2L}mBY!gz&aZaw52sOaWJ!c z4rin~#%kN;P=7m6TuS1T`OVq9+nIw~D8;Fpwl%yFzERtceausg269KY&uZTj%CtE^ z#qVKgojf20rbg4-O)Kg}$Ams@n0M9N#NP>|C8173ad99wQ*23dpJw3LWT8mi|G;lL zQ^H?cTQ8M*2CI8tM+}YHRK+O~DXylam5xscF&Ock^NyGiKW(QKd8Y@TpgOTAeG0nlVt2)So1NH#sZFUT<&$v zVspy*l52 zzBlYyD|3rB`W~v?3@DS|$;IHWj#%ir=jm|6C*Xe=wK|cwan_Vo*&FC;J@KRD$$91Jz5)k%1Hn56wM6Iw* z|Mk4H!8%E|)RTSb5esY(wa}%rI9q5ME_+SW0`RSneSP1CA+pgUPy)cB>U3Zf@Rp}i za)PjL!j~^f0rt@{85iM&M}`E+F*loW#ex1L*Mxsi3{*9$XntTuwa0|yRueHHdPk{| zW5PqytnXrCJsJQ6T)t)C9m^}tQ3)_EENmC?T@UmW)PqWBF(xUNuri7w88B@VG^77| ztm>E1KpbL3h$wJQN>JzcW|BsQaHEIt`LB83Usb`wjyEMYT+?y7AGlSpfx@Au+gpaB zn0d>E5H1*^4(@U6F;rL00CByX3UE?}*X?Ok?LUWXTk~bkIanb9>XLH)Q+N z@Kop&7lH~w8#06b(UkCIE|MSd!_}phcUN*rG^0$J{G&r6r6)+ntYyq)<*hnn5|n`Y zc<-clFYludNZPVbU5TP|*_fs_V(8a=64}30yV$*)y?avgD5~kgX?fZP;UYz>lS{js zZP?WYd&*=QcvFQT=`g$maEv%>R>uod2D8LqOo9gdu|GS?O8kyV(CszIe^6>HxY2Ko z5jsEkebdo7gVpMM-s*n#17c#MznS*>>Z)Ti#l;DA4fkkqMw8jFku!vxX zDklAqj+J7V8r9Xj5^f}laB_ENoVrCPWDETCW4B*H3}u8)8eUP55L_Kk-!!tb3(4p6 zjLGx#HthoWCmt>(rVOrl;^5fEW7PY6JBdi1HXy)Tuq-gl8qd# z&h5@GFJs~gWvGC`ileOtNIz~rW}AM;#FQcRG=HZFMn%yCPK{;@n4tX>F7j=3=yJU` z822`Sp`FxICy&;drq`&}(Xg`BGq9CbI8J$WKIYmh97im7VOxwmYELN9J57G>Nr{o6 zi{Ymfkd(~1L^+8Oq>1X;937M4JcO54zgGV%h)S8xW!p)!s=3kURZa2u@76g;{rjY1 z-)QF05YbT7HVOXjeOXEmF{UIi2EHqrUa4v2311a7Qq4Lg?OANrHUuRpUAAn7Rt!YWcWNMIr1Y zc`fCb=l>pRRxy_=dk|3eiwY6idz!Ka3l7d%y+#jrFfcqShH^Uczp8!VCD>6xFswO< z6~UP)9**yM)sI6kNJawA-kHH2YS3yHV?{%)gL6&3!?5x~uNXBsE@VKOeT>OdagUky zP9Er#vOH2`DXbK*TqBF?xJ26ZK%fNBa`n5GJYSp;ojb`7`6X3@%MmLKQ$%7eTx19x z(bjcty{Ix*7ZQ5RD2~tw;IZkfPDr4?9>DpTod7{4q(bj*_|%KGvDxzy6|u!-=zAoX zl}ay?gFW|{Bu!B(wi)+R_2-f>CW^18^AK&QvKIE<>OqBnyW9rJ6Hv4XlFld(F z=EfQg+>%3II|0#SIhE-O0-zlB$oXq+BCg)@oAS;2Z>YZ@9EFm&$Vr`#k+Yi3QAIl( zCq~RZ@iWXCv0pxnDyAQq9T8jH8Sd2Z;%c{>i2MqlSzIYGVoz5%_+S|5`Z71apb${!UyQRTcqAYjA4iM$-5gDA;YeOkW@ zZ)(PihmtZRyI$3(YwH}52O%0I1x@Nmuc<4hAb-y@{W@Ie>rYGEpJ;0T(NhY7dS@pb zvPm8k`JpYYaVMSiSe0>A!Gd8ueA~GEK%B;d!2Q2d%}s^-;JbRVi)O|J&SVJ@haN`L z0EH68v4j2QTs65l=m<1>EIBE3IKq})@M3%c00PdA2b zX8Z{J31bZlY>gbbAbkBxSuJSKgiDGAJRDdfbGWfZs0p3*U1%zFS7va=t5xu)jEQeE zmV(PDUf+^Ke!N*4g&o7^bzekyYla$&2@+U0Oy%L2!D>#y^uIn1rM`_a%6nwxal7Gx z0N{`}kwn#uko<)I{KN=zLpp5tQ0h&%tX+VNkX{dGD#(*oJ#irTAf%#jGQBM-#9Z7% zmRKLbuLmK{1B|^cDups<39oyWc67|8)<`1HAnf0uLbHYXwp(WrH#=W7>Dd$p1xD@g z4IdY375_Iid$;6v5_yqYMd0+|D%zqsAg+aS@u*z#02?l?mMj-5ZW2K#vfDNU;PgRb zRWV^1W7{67GjIP~b|(C`>aT~OJUV2}X_XrPLqs*GO6f*_!gxoQyz3et>Yv*uJ34bn z8Z23~)$dipU^Z7y61y{-HJgg;dhUaqv#LiJE@DctO2O(}uV2Jdp%p_t@wdPgb>p6^ z7HG9d+*DQW{N|@e&4-jjv#4LO&+&keIy;|z>xEquxFsxv;d&U9+rPdF>}s`;?GYQ& zYHY2hNu-+D1ZxUS5z`4DW%Bb_7M~okoO>VmUnbMTNdBSt*D!l6t4P<_H+%Z9hRw{J z7i275zHg)^C8n|LoQ)#K_VM16-WN64smkn)9*;dvu=8_cGAT(JRJ)kC&SH13uW*;1 z>vAcShNIh_fvf?3F?I>)Mn_$4?7&FB#_j#!8&Rnu@N zF)tT>NX|R>@Q2uZK$VK(7{mW)SX~8hUDUdsk6I9va^uIlBC@@8xlPV-j5nbMmF9v+ zW?{8>(4KT6{zdP|l?)MP!N8D(j)X0DJ&X*^)VifOY*4{BL3FKJXY6eoZpRgw`PX|{Yze(+ird#24kl-yWjtJ=J5 zrNCw0^+Ep}3j@v$zJH|w13x`l0{5lQcWC>#ChvA~?b=(9{VfTaLfaXA&AqkHRDjiX zrb>zRiAP3E#89DOaTcF5DdNA^!_oc4G0nEFL8hVLBd!apUWlg;enM?fI%(-owJjV##yia zZx=v=0464;UcV1!+bC|k7Nb^EkV;Pmqi)2_?%>E+xK^c0>1%25XN)l9-G@9$1Mnn! z?rN(ibDRV0Um1G^h6D07HA>#w7C44NS_|S{xwx>r3RrVM%9;5BZ)C`5v_eo=3-mc^ z96d(^N;+b#GF9&PUA3;qPKb=BDeI^IyGtj8ldR_UxF*+8x<(mR3{ph}CP)6>ok%(p zK&*GsYB)$DDTa*ar64h^Ksz_b2*G{>?pR3oE3-sV8>{O91cS%OHhYmTAYj$I=rovMj273az{C*XvNtrQI3QtYMR+l`{b9qY z-~0qlhAP6oi+u-fymyLGFcmzCTm{^#NaaCX!+<{N2G8bsJ|!^&7p@j@b1}g|Gv;nt zDZ^ue)SaMG@AVcJWObf(QCPz;I4ACMU~I*Ti%DDC8N-b~20`%%_Cb%A%UD=#h%>EY-e-kS?EQx9W$nO0ISFv#0*8JUn=i#C<%*?c>#* zn~#s}<*RD=@aifabgaS{Wbp#2e@S-+15t7!Z$Rd_sKmsHyTb`I(6LaEt!o$vy2`^y z)s+SMZvq@2kLW%+J`M*Nyx3yM+D8|>&CMsN*kxo?i967%5Yp&$*Q2rD5c8;yfYQI< z<_eNC$IV+@m6e%3F*I&-q!VKAsbU82adz?DB$Tv+Q5JS9*SgmT( zM%YA`7R@qj5vH~6?~{el5)<^q^iIRiG_-@baDj3kVIK_xR)P+UZPm*RuV6&}ze9{s z1~kH3;z8e!oO8?h|3%d|hesBC+fFb^Cbn(cwr$(CZQFJ-v2EMV#GKg1>$$&o-@CWZ zAN_qvS67`~UHk0Xwbt6?*0^8rPqZgDt*a0aPd_RIW*xyE91{Z3J|V2ER|s6QnX z`KOQ&^^SJ{@P0*;NJT&%gwKDqT69n$%W`M_#`gu|Q6d6Nqm<`zXJ@pt?l72n#0EF^ zeZ|XxGg00Du#r;V(GxX2B?`VPJZo(iJpus=oat5#=zV6_i;C*PI>AEf$d@~~dlqwG zr5?x#C5krS9l`a9IoIl;f5f@f?D($1uPzsFB@Dxq%N;nT>Dav8%-c&R!{6^ZONK%E+ir_c%s`z9Vx7Qu^yBiK=adO6% zozDFP`yQ1h^KnhFlRgP(Z@-x=4|9dc<#Mp1;O;jqP;=I@?Dwt6pQ8Xf_?iXe#|o3r z8A&Ysx9_G8q+v!KK7CU72glpm6W8ez1dqC3@G^7l_HT+!D>f0PGMW%g5KiQwloiVg zBGf;3W~4{!z!vN?yJ6TwUIZj$h$$?IP4IIb&zZvPAGnJ1pV7XZhVZ=qN@q3<@_m2Y zLS;t)>;1VV;_bbtRcnt9SV>Jt0cv9aBt#U3BPbo2>&axcDaAE4H8IMX@*8qP3*NH$hPS z%r_VIZT^T9CSjR&E+ItOdv=3Q+hiS*c?uP4Cc8Nn-dm|iX=BoFjPErLCpjHZz38xq zLX+(o)SEU^@9N}4jXBS(!7d?OG9rOBM2?M>H%i92%G(LAQijU>i%MzxBE-%q-B*)d zPpp?pi0gg}?`T4kx`X2{qf>$0rc5gS$CFr9_Y>r=R?IFauEH6n{@EY{#tMn;m_QAw*e#^-otb&R({Mby46QSO>FWoz)k167>9yFv`Cxj-Usgrs z$N1#wg*=oX;XfP9NL0Nt)Tx0YQB#tVsGrr_1y)?y&zu=lO*ZE)GmoBMXeI_N7EHLl zS)gb;QR{%OPeyJBQJOsA zuUjK(ELs;Ft%zVA+}`s(gS+V=@z}%T8rIDB|dA)%36vF@w%|t>z>N zQNDVCp@+IOF^7vR;R;cM4-}ur{~@C|R;i!|@Q|UDGYTP5o2eBlQFbcE zzU%ob83A(VjF6}>xMjk5t5T`d{Q!+WMoB}NOy*z8{e`JC70UB(2>`%8Xf=@J`vM-T z^6w&VG&l6?p7}zNOhOqH8q9RrUOx-GihKpD3;=qQ$;|efR<|pBcUNE3cSH7M zOm=BP)FqnwuoueyT{bbY!@pGm7hBIs|BEw!^NcZOaDvnXH6hAdQL2dafR6?fgn|4w zMZgKt_yu_q!C9DL**N=`qa!DmOItUNviQn=xfM4E*6B`ewouaNGeQZ;FQ!=OS$XE6 zWGjJ)=N}2uved9cC|vp*w`S?7yx!qic%f}ahbgNln4Km*Zy;6or+ret{GlSs_NTVP z{>Bui>tZvIlv6qEqGjZQDN6(8h|>Z}6brWPUdD5O%zI1`En0!-{;cL7$`cQ#Sn}5{ z-HVJC5BScdSR2oPZncnEH0luc<1^0G+`qw^#6~<04d}m8!^g`qOz7?9vs|iBpvEo4 zo{IU7Vn-tCq;q>P7A4DDAvQDS@cHqE zHIsKjEIs~{nqAx6{8!olI?rrqSai!3P`24PnbMNuY`zG*M!S6gxw%%Q64m$XBVwo5 zkLg4?p~$akAH)SpvG9sp(t9{|q(RZmNcZ7r5)GgqpXtD+$W~HRM1;vkYRfA~K|vf$ z@|n`2;PE2N)b4bg1m3AErvDp=oe}h;c}URoFRV3&Y~W7TZD;D6RcSI|c8AEhDe?Zd zLjuRde}HVnB_hfR9uREwsMkEH1D<F-RWtw_xPcIQcT2M(~f^8a&&ylx$ zLBpz02(w$2BoBNKV+2aEMR@+79aSU(qZYhW{AU(m%X>LmB3fPi{reyd*y!%VI@Dx-wk%hRBw?oo43m-UJ zLOM?bu(=O3sgEkSYZ+Qy7Oix=oCp45jqW*moBbwNi^e_#SH3O@vp4S%5KR|Yj#b)q z^`?=wD)B*cAjTZ@eP&?SgtB!Uw?Z-T)6o%v5u3Rlt&allQ>%%jT(&DCFrRbf>_Qrr zuN2d3N=*w*XO8fzPJKE`{+L4xpir4FCqaunKCoXkhm$o^s|o)}eBvwm>vV=E6yyLa z;Q24*mA7oj3cr&x@=G>~ifO-t5$Wds$f6Xwt&wWs1WHa4So2~VaRqq}yF;ME`TRM6 z3{*y^LEagacs0T0d~hKsM=^ML#lhe z*=m>14G?9#y2{-HKB5wlHoLn9l(Q-q$mPOfg@4vz{Rg*ogX&T)Z0L! z6HVdRkUh#zWjh-{CZ5@f>GL7H4kqz;d<&{1<6tGQNY6e7$#c@|2OUj2bO3}20JD+; z+2CRLM-p<2)$?V`P6r+yK|zcqsC1KKJ7EnDpFq{L#fu|`ci4q$bTd~fn9BW3XU7?2 zjyPxkCZM=)2#-_b^%^+R#s;(Ii;OYd_^VnpC4nr{I^ZK}qRo*frTRE!JIuN7(7Qj7 z>m-idhB7K{TjjYfeNN)T2RV?>9#y1=d`0$W-cLSnzYbw*tr+g*3+?61G{SXAdE4NR zg?*-m{xB+6x#<*Q440JZfR+Vvwrjcx>7b0=i7#J>MRyFtJ1Sbopv)ZOsAv?lXb5K= zCKIsDW-G-~x;9tZ0n0HZT|2vDlgM8RRgYh{d~G!vY9N; zzj-^{(4(aOfT4FdFkPtpHHtc%_jOSGjGd7ddnXRF7{ zT>9(Ju}ah+!S9IOF)MuS%d|&R5LQr{VbthCeP>ZrY78RxOSAeVP$49;1LQYX93?+< z=wHCjN}9<9g;u8{S;7D1R=NrAvq;$OTp?|zL@JHmU;uQ4P$X#XRh7(gxsY}g0D3l> z%NK~QwBPLum{Na>1M66P0&Kq+95h2iL!-|VTISndofhu-{J!w1siPC(a-kOD35580 zV3dCsGUxzkKZ5U1ml&-?%@y=F6h4qW-4zTgpKckoT^hj&(cB@IQ z1ZWzt-c+8v=;o?};W&h|j!;};=Zh{0Ditx1%wi||1CJpif64XL7mYeL)uX9{No%JU}Fd;d|aw}>u^Io zN-xMhTe1q5t@R!__TvVF77iX7ei^K7!ux!Q1*eqe7g|L%X zvuGv_^|COhNE)x^37qq`InN^j;sXT3VU*sj zi|Ei-$FS|HyBypw_$~;PKy}aDE>5?AqV(JX&PR*Tvu*B?E+tY>HRnDrXRCldGPEJx zmTK?uFI?6~$)Q59G#I-rvKImFKY`zJL9sPAFiZ&;Fv!_NTq|*X_84&A*Zip$( z{a1GD%Gnrs*!u0bT`pq1{`{{?(pbYLqY8}1Awqdn2>GOE%Kb$=rr{Rq(pfjzTgqjH zO4%Ii36yvM#-S>;ORP)*2jRd4qEY(7ft>PZ!L>I${D?^6D(M<+XRtkDR=K8DinX=4 zsG#$4o{Y(i$U zASVF^%u-^n=pEO+>|`|2+$!RB<7X|&kvYn1RN^xxhd8^OrH^NP^z;gfSIjF>UkfT; z05bPZJyvT(m@p+LcCWi}tJ}0>c;|V5h*j0Rg$J3b7)zO@i#dzkkC}C&H75gTXls=P zJ~1}~j2Cp-8IaxD+VEDm;yI&!Z)ntn%I=qQBMykL9}%q$uB3hgkF-&`B&GWE!_Ew+ z;W3@7{+Ty9Zd^tHnz>26sOyX$G|<^7+nEt4)Mf$6J}V^!JBNJe^ebK6I4cm{INH5Y zR=(TnM0`yn5+0a6g4CTdiHfQ}@D_to0C_FT!h!;?UEzF}9@-d8UCo?T6@P0{U?gRL zs#DGWYPR%kvmK~`;xWY+FHgke?>Blpz8K{d_NYPK+IV5bn=S)hF-Ua!I}$MXXCrra zL`pR^f9Z%mBOXuO2c3V$e%!Ds)!(Lczx!I2KX^U3L**;;zt!@1gRJe^Owi%M@8eO_=!R8 z;;`yaZ0}PWMFTYn6C-kP_=pOi>GG{&Ph0Z9)8Czg$hMLZ`V_RZBSVZSu=>BlJw2hF z#54>RnqJ$`S0j@#nfOcEtn-NWeU+7I)!Pf}B|e1H3gh#R4<{SV>&Z#=J&+-R-ac{A z(8LhaBvCE1>2qFp)&?GE9PuiQw;sBKnMUdK`wL&Cdth)7l~s#OVeUpyJ_{qM$+_wV z2>K2Zq*oF4h}|Z=_RCJ_k91OP9)q*JbtynZI)zbGOiJIlwnnrKi|4RCH4bD$G9xK* zuXkRJqL0Z7-cFV!Dtky9C&ks&*lKAB+i2rZFUMNW1QB8QOEU`5G7_n0Ncb}!nKO(d zKMo^E=tbHMJuWv?mKNs`+od~joE`L^l}$T&VxB295QS{Q7GD@?ffVL{mtUxJ__(>w z^NQ8f)R-?RJ6q^-p1Jjn$Y5o3DKdkjNSQ?S^#_4R;1Y(u!JJ2F$15t-l`k=mqgUIA zew$N-p77Ga-*HtT`(i|4Go>1E2>NTHA08Ulmp^Kon|tY_bZq=IYE< zxKO3TFc?WzgnJ_*Fcqi9xLxlskmx;OZ&5H_C7H$qeb)V&v(1k|oz4=9=KiBW=>vND z`FZ~i#~a_J#|38_(8hnKs34(aXn`qAf)6Y4JRG#F^)sSESD&a^@L)pUd^3n^Vzw&ke zUW;w>odq@YrlnVgD=OC_FpXH@883-z?q;_g_8jB2TWWSpd;LO@pc!tO8WhngB2rwa zNNpaoD2LZ$Oni*|7nGW*06v~%CkCu|rI!9tH7c<*K&39Y1IC){zmF~f?*V=&0fpr1 zRW^5eKCN}L5NUA?hGkdOc6pUH<2vlhJ_zR0hMj zhUovmwj{xUa4CGTpV`1}j~Xo%|N8gM3=qoCz>_yK!yH*}Rk2e8|GQ8yfpjO(+-tUD zh*ovu9T^yRKYs~|JF+mJ+Qc1cg%()QhD1?41ZwPQfKmc__*ogaMX##;pYekQnCoTR zx5w66sz-Xf&SC#r?5ZKB6<)S z^GrPVV;CQ=D5+qBCW!N;R@kTGEPp5uz*<SzY*yE8KrgnfVupApnOe5m@d?3Ppl z*2avszwKC&@L-zpAB2e<^WUwQx78Q@=P-TZCop$0;mepMbuWCxu@bBP=+H_ExqB-N(HgC?=7U_u4?c$XLfzp?dLDztsT!qJyZT9r&KiF<$Kl z^be3~?)&sXio%igx9S1rT1=^Y2@9*!pA;hga5ArhGDRTQgQ!ETx^v-ExA->-N_OBs zL+r#i#quJh#=W{8zYA`%TYswg7Uy@YnyR~QOUEA&o}IwichdF;LdFtN@ka5Bh~*pi z0#R~uR(w?%O^W5Eqvez3Ci-rr?MV4Xxsq+~i3}@0qz?<}xUFhcxe$3vOQ2 zB7evZ;h6zZxKDA4*p08Ze*#QzHbEby`q4#b-Tpt|nld-2`)>mFW_Pr!8?fD%d5Jks zU*vOdtO87pNOc(wak5hWWpm7yD(e98IGZ)P#U}L$6cO>lTlMTu-j+Z9Bndq;XFOB$ zYD$<1j*8-viDFz@MCOwAeKZ8v!9V@eGJ|`GV?v?}t?TR2TxDB!dSiZ6>VIxNfBQAQ z&Eoa$JMEft)b^OBKJox^*Or);^3?^|N1nc9JH4QLO1o*)Isk53zPb)> zOeR|(dRP*Gx|99FeJ8@gp{QDt;3|TUk&v_T^+O{V)Vx$S)^NtDx;<)IUy_nX2`(5aQL?-B)Z(-jR28yBnuY;6kny=R#6 z!Wh-*#V6*{<}_Yd=QNWSA-5{6-AF7xO~Vj?U`C}f17GXWDa>pnAt-Q(jR>kBaSK!w z{~mms{TZSvjXF35g~e1x+m99@fHp!b2uJ#oa&17y!STTRNH?~PhYkeHuhEufF7({w z+Q}!Mwa)t=(j8d^N(WbK&BG+M1xIghnZsFQ9odpPBLJV@cCkkH(6#

zE$547EB~ z4{bavXC9jlIhDG`Uwa1Ewz;l-uIEIxZ>m6cNMbc zesf-M6X<;bfeG?9RM5}OTm*vJ?b^<&BPe1W>OS~mUgg+Rt zh*BZ?jWtqufqCQd49+ui){Z+6ays4oTV%VrL&`NO!B{zJ2lp{1;ZMIAW_<%;s-3e2 zEg!<8rG#j}5Bx`QAF%X;`Fe%w9EMf9{0{QwbrE2$|Mk``wSUKki81Kky5t_rZ5c6PS8fIv#?(?3{A{yTJoqT&U ziuUZR5ZgNapG{BD*~AD5Di+|L!+u(oPBK*7KhewQ)pOOB!|xkM1gLpkbO%cj;1JQW zR0BzGTtA-ewg+><^599@W{PC(w``%Cl_iwz=7A-PqvbSq`-G5dYMh67mI>4iH?bL` zn=?RGTuCbAKqt%aLBIU?N+2~EHotlW{IfD6F^oN`=aJf2V;Gjy9)PNc-2qGmPcx z?`iBw+`m+$w4stI#{4>skdb#Be_ROs1Q8w{ z!7*9qMN1}9(ixUE+?;#4_DVw5(%I#!h9Wvs7V>rqjB9!|t^}UG5)~2CnJ0k&d)KgJ z=Te+Nzi-f683HSDVygn-(TIqMpnXw>)OdI@)D-1lZO4%}C~WjvyeQk{JKqaNpV{h> zmy~UQ2sRpGW%Fp&W~aUi73!>8b>GCMgSqV*4u~4<&>SaDL8Zno-djg~^^xA8`#yK& zt2ykLVw8lj`;o-ylsXGlsFB2hwpv}@;Oo^kZ%c5*G-srNR$I*2sY2n0xGUkVdRM3# zNeqMWDcbny5V}hzlMQF-{T%n7yZG$Y<@9^;5N$o1&qZ_wxYw@W`+-Syqtw79{s14B0z7ts zOcK;rP1TCfWf`xd$y_O z+teh7gF71aXvWbXmWa@-rn4?I?hJPQNLXq z@#PF#%J);^vAPmP^I^WmBa5+ zqJ_+K06adh12I`qnPyC4bvDRr-R@C9%I{|38phRjEu`ggb9Y=wPni^QdRgq`b>GAz zhDy8cKoBMV)fH^hVIqV9$dI z-J=*iy?&no9yj#g8?3u^_b)B8D|cwGa?Zc3k&?5 zAyvK071s&QX^`zujr>?fHyQx_aK6~c?dPOs#+H;^-g6!K$tiP2bi7QY-P;%R)WGmc zsXp`(sE(mC^~5|Pg9Lu@>=&VxHfX?73U;qLWFvS#O!f`Mi>DTERi-=XB zeuTEf?(puV%bmt=ZMeoe`MP*aaEs4z`_?Y-D4=(Qk3&*hVdyXy8<-{H$Qr{FUU@Fh z#409)eFlY>9QYr#JaO<64-cnx@cYEofj?FQ?cuduFpLPvA-$#u<)Qz@oag|3EC9Xs z`urIksDg^>7+$zc0tB1=D7rlfb?-1?aSg!kD>9ysU9K#9eC}`!Q>gnY?i}|8`=010 zdk;~g+s5an#x9G?BVy1~tiv#i7;nU}zl|naT>icUcfMG`b+g+dlbq9=a#Zgl4`7qK zGOe<{#Bv=hR$#LKTLX?=Rk}Y}R@?S|2F5+BJDe$kWp6*L@bU5m!W)-saQT#CzFvwE z&6#P~zhn|^h;nlUqH3!V$m&^5zcv?`;dCNuW>}J7SfiA-1A1<;K0L$D=kh@P;N=pa zX~ROD(G2Vt-r%%7h!hl-!RT{blb%qZKR!O$zjY4}_QV`U=P*2icx3TXULf7I8JCZz z-s6Cfu!z>#+OdBT0TmtC?0PGz{oEBn;Y6x`hYgHLp}&~INq(mgWVc=ppGvf}KlQrqft~W?Bje@yN3bqB4L@%bL7C5~ zXC-85F2*LgNE1w4Y^q(4b$19-TPewVO;KhuqXac5e$jHl{#su-+rr8e7u2Q_!M=kN zt@Fc2R7$>?BLjYLP#lsc^8=xDStpQ`o+K730geH zHyky#fle=hXOo3#s$sU}q-H?XSBsAg+JT6sdrOaSI1Jz|B$NExfTW%_I-x-Y>5%@uV zJ%Kn5(Xd(FGCQwa|A499hz<$koA8{1fTm|7T`46W$rPrsW@Ka>aw+R~F1Qs^{_U`F zF+r{d83|clqKYI$4MxbgPZt`44Q%vDAicvzw@nyG?vz{}XtY*K=45Tf8vD05WGHD_ zUP(k!lCjA6;V&~hl4%SHgW-&H@1HLZY>uGIcQ&DjHp$kS?SSxb2y6$Q)B=5paI+2Z zH@LGPfRdG9>;>zG)QV7wm~Urf1!M=7s|aLTs3S=K(1z%E61O_V?N%mcX7laEU#D#Q zd;U(^4(Y62z9fRrcxh^5z@h5RxmjtTRLayJ@#wba7h#$(k3FxgdsK&?r9RyXCdnP_ z=d9(E?Y$ z6?jI=Kj`G$)Q8T=+;Owx%RVE^+2dC&!dNpIUzw-`sZ?-Ln$LvD(Ty>S+FenyVZWkw z;}YCoE=UpIWc_uN=C8u3t)wG=5OD>I;Hv>V0F|2hTGTfh=u0H33VctzJ3-C}3}dS3 zB8_~ZWMWQ`%gEygY;?3&|4@})u1>{x8HE?~PeJ;Kf>&zM<&L5eVurcM;mt9c4dl%PHPfTdd5oxeATwe1M*}R{KJGw^sL~^m?N&FR z=WhG9THyggGIb8mn5NOYy(hf2>Ge^j)Mm6%^SefZbh+IFP&=!{x}_9Oxd(aizsbut z6E!c%C5lJ8Bqqg~PP|RhIS^jh?9&;H5U{xn6=Ttv18m!IkHhTdR`aoqg0)3s9RkrR zn&tyCr*yobJi}QkK@|K-p&owjx9$t|60s{{LkjoT(k8iStC7wkd*5*bmz@^~FHb}b z&Ckh#)38__+9MoWxlSXts=Pd&kV@aukH-*s z#faon!|S6|HSp);C=o*}Ew2QwK5|LpyXr#weZLGcD_Du*ApVf02pwWnuv{o)xPLoZ zWE-hAVJ!A5a5QEJx_>Q8#XZcGX0*Ad$QaU6C``qXT&j3F>{mj6STURVZ7BL(o=8$A z4`WJ<8-WOO+Ut8u%3XsOLM4IZ;9Sdvr02RnN3nlZ`7_Q+tlneURlD0gCxR}k)sVD; z%B;Bzn-J)-F>Qu_2PhO;4l|#3X)KXryLB?=V(3pFHrgXz^M!&)f+m2GY|l&Lo6oV; zpMfRQlpCwH+BH{ghGzYz8MC^H^KgY`)#`PxQdbP8?U{Mi;eoI@@q@bu_Fm3j|B4gK z2l?I&=?+7Jc}`%9GV$iQ7i34*$==#rM)dFA9$nW-?C3(-*=aQr^eh@p-J%t0W|F@h zE`?rFGK3+|bW_jmt3svv=K7|CCF{53ysib}bo!~WxC%)Yo}+9MnG&vp*R~b|&U;%> zy_1ET@=POHbCE>8ZWL!Z^(~}c4Y$au3phia>IPA9`T493r-H%2iOjW&yE}AIUnGez z5o5$qN4-bg%S~M^Nw~lL4kZcskI77t{Bnq!piPS#N7u-R2oqW#h?lXhC;Bs4Z5NA` zo*ah3AsPVv~S6IB1QIkJ>skG}SzpE3|C&3-RX#*m^#{Q8n(Cmoyy%8V*iU z5#O}Hh??r(*#3_jmW$OO2OdRTZxcEeA0FpX5|w(yp02Q+gVvYS-`@*YuSwD%%;hve z6D(W?3nK*ba;?CG4DhYBegcu%u+eHYLmSJc=EUPd!L{J17>^;hQ=Q_Qei?l415I~| z-@Ca7j%%0?LT=p!=nKEeHq4&~;KPjEam11#3djNHAfLAr&sp+!60DWrED{z=ryK}U zVk=33}d4##@!3Tu620G6M=myQ;a)=?r;^`S{ywt!P2!Pd@ zonEI#2?5MAE|hHBSq8KhXsllzr{-HL0U6fLE}NW8yTz68<7OB`=R;idIaUWi@Px98 z-?C4(EE%;iaV7>ok<|u;Av{_9(g~|u9gp#w=vvU0@pja{9%|b>qTLA;xFTzs~sj<>mTIaGeDD!fu+Ub2~}FjfV&VKCT0 zqki3#DREtwMNR?;u?>t*ms`O-AFd!{(ZN$kTGBa9S$W}C4}voe7?UL^!C}1Fn6UfR zheUIP$&kdZcqluh8}HZUDpn_={SXY3NTm zn?qmm4K&1D12Z#Px#mV2Pu@CID7GLBaX@bjiu|Ov0IqF+J2a^bV%_NkrglXw^epLj zyZ@qEweg&@HXV$AMd$sV0eP8=L!id*W*Gt1@s62b*5HXyfY>^dACv0mqk1eBe4rfAIa<8=NhUgfv+-9q$x z8okk~W-NT$$s;_~adrRS^U(|)Pi8mdX+Du^I_FUwnkb@E%A$#pxouQ-&wd&rHRL2Q ziW{giq%(Z{ZOZeMjS?mAu)g?@YX&58h#qBa)W^k6b)Vt;+p{Vw@Ml zF4%Xo)+zfk1+*o*+DUhih8y3|OFf~S)cp8sKWEYxN+m^C8#Mhwe_Q}X@l)W>X;u>P`h7_w;!r{GXbI%; z6=2lkXn;-t19c?X9n`R-%l+}1WTl=GzmiCjV6|PLj}l*!5#0)iTBRG4uQ3O&SV~TI zRDt))G;c(Z(9Fd8MM=<<19qTfxklW`my(k=mwfd?MJJr1DKMN$r>3$$QS28;9L^)- zL*TyScuvQjiktTEDfM(881^n;{Uk5&?jW_tOr+E%a zh@XRzCCsIOCy?+7`k`^>E8}q-kkSdLaCC_lUOo>MwH5x}kZZ0;Hs^nj>eMnYNcd|4 zhPfbla%QYS*Kd2Qj)of8Bg68Yuf>6-^IFT~^PFhHQj=$wIPt>1C1Ex7kbrKB( zb_ieQPG17*;VrKF-oLey{Octcv^aFG7^yfaoRCB@h#80?5}zCgVtCbOwKGyh&! zs0YhpXD?2ND5z@gRjPa^0et60<62&8oK{#Uu816GU9+z^B z-MaNsL|MBPM)O`eExay5apjz2>xt(3;m^gwsh&@oy@7*qq+tX^^NZ?t@aLm{SN-u= zZ+2WwN_x&nH}KUktY3f39)vs0S*>8D&7s{L51E3Hb^>h(WSOqSH^cw0-kJjXwEU8k zp171-``aXSjpagbbJz|_%IOt*El^C zVyY|L`wj_eY_JuVOi%OOFys3Z?dZ3|{Y1z6HJz9*&l;#RrYQqi4nle^!WiVrgZ6v- z#lP|Pis;W>SdBf`c5C*OX;E}N?pras?pHDH2s)Wgs`OURXWs^d^2Z@p33PzV6^I!d zPe2%hG6G`+SOophqJY-X7x=Av!ftbysx%D&jSRZ>-~qPPS_iufg)_`;y{3IPF3|Hq0;0=qE`Mvw9DQAa;? zxE0ep4}(Mu&2n$8){PfsOrF-w(rnz`9Jd~vUjAQ`f8F@N1~An25^&i)%3!q5h_AaJ zL<(^saa_;KUg4hbJpUoeva4^OV*BoDn1PEJdfqNiXVjK zlKy&~1q3elGnKGcI1#K$}e(kVOdWdR|MXm=~oRPIA2_1i*2YQj@&mSPHW2 zdNf+LT(w?TT6;*$ay?3IR_D%3W`F(nME;ktPDw!J8@}FFdnXFQa4F$RtUBj6)R zdbM+X-~LX#yg}7aclVxER32XYeRuA5O;Nt&F5J)Z(XHxw*$`!7*>x#w1BO20BS#k= zRUOYM@(nQOVbA-O6;+2UpbCNXLs48q;Bp)#l^F}soCGMT?@LQG4wLG3j%+N9h=jV| z=i`x--dkcc7F!LGJgQMm^5^P~JZ#&?va$F3b#HG~%Z8=nU<@1a;KMLl@7H7P(XJ3| zDphTlBJXqiY=)p>>a6TR$3ZOD#9kPlZmM#cnTvzSxCzUe5F8J&mTh~9_xt8f>SfEO z>>7Tj&f9XWFUfgTV=VT2k`Ya0S+1vK&r@^lxW`G(s%^4P4;G7Ma#h0wWs73(WE89I zNHhjf=T%3g_t$-18UXRO8e6bvQkfIr`y!xP_oiOr;#tgfdQd1j`U8cZuXtgFYPDgx8ksX<`Lb1u&Q+lhJ9T7CF1M56nk8nC}+O%ri0N^N_v2$NGR2Cyj z=?31oa%XJOe6eJYG9Mr7W_c*oxOlH)Vi#~xa^l?)=i_8~)+;6SJT-bAM<}m=rA$HB z`_a%{5p`Y^$GcGuuG&l2{jgv<{cF+T;^p!Dh9}ZC`foJT3_4hX4C^}aZ+zEZYlo%- z!M{?$z5T#oDARU?1#&0uG9!C5yuWm`N4f1_11bry z`k2^-p*Ng>@8zDDRX&{;!=uT286pG4zHUFkl{#)-)or|9Th5|?zv`?d+j;p}#5BV` zm<->li->DopY3}@rC0w@u624=S(^&Oy4JU_$Bz0ufB)h8d9H`%j%DQeajJ&d!u$>R-9)w7n8kDSvbkye&y|-*>(=7|hvaTwzP*&nqbWXa ze5BB7ns;ocU~IP$f`(~Jtns(JoMzPx=R~UNCdF1eR2}E}&_aK~5XQaWkmH6)TBdId zPLZ4U`8Z!1-LXj)N!lD)4wxjAG#znM?2ZcShFfl(z}=w1CfFRvMU4+Zw=I{n z*M;$_M7BGG ziW|V#yPxFyx`oqlT{q!rAGx0O`*P~sHc8U~g{oI#($3g0mbAZWKVik=uDKlsF-sPK zFxmZl-a$)l7D3ez{(2eZ%UnKPQj%+az5$M$sUO9ql9thRo);G%LV2yDoY zU)~Bn6!X;O6nN`o56oph8AUu6W!Nz6J|5a6lqPN{JO<;UoYC=h|(3`*Q+JVb=d?)3wGvwLr10VLglev={TI>mmIc(OL~6c)r9Tw@z-&TH816n6;lBd zNl9apx-oheA6;n+LBnzgY1Ha@P&NJt^&@k9! zSYhQncl}M)rq~0H>)w}4h0Hg~Ji;=gMb$R40}YVajzSdk%$5v99C08gA8wLMRxxkY zabt=iT_I8H*h`6*7gRt5gtW&aM#zlbx$?)q+0H^Z1%H>tEIh$+-|z$Zoq>>?()qng z6&UYvn(cgZL zi_;a`y4`FE`yQx>50rETxttQmo3M_1_c{o1v-0ZbK#%Nz(>Bx5uJ8K+a^op2jnoj_?{Yw4yt+aFUI`_?9aRVtYsnKm)yoLMl z%@%Su*lZ>#XLVf~wzF9}f_nB1nkhMo^IHD2yzb5W+H0o~7BUKLvE{eLhT&9<<0Btm zdiyluYq6Lu(5n#y40Pc zzz9f=;1~~UMsY44M-y8^hLLZQ9q=DZlrETRrHAPqKfOkqklI=_bLp`qL#W)MVDzS{ z{Fmaih%2jylMDlU(WC1r5k(S|TLy%c(*#Oo5GPJ%LCplQ511#t@QPq_Mg#>Eo>7z$8lKBpzdZG7@WqpuN$l`?1KWraNI8U%~jUjOc&PVyjompquX2Z)ZJC%-RiwBog6*9@*dyQr^=DP zl5AQKYj(Hk)K*X!(biJSb^$GXkon4TUS{b1DdVf# z?1!=@*HbsK#i-&{4>J1*hqfd$;&~}`e+6DQNQpTN@9icELeo}aeLMWaSah2F#O-7r zyPXOyyCv4L{ujRev|d~|<~{Diu=UbpR){Wj6kH}VaWH}Pi2q#7!{D($Vxv0n_L}Ru z%Se(EhkTJWcsMfnu}Lv2p=@Uc11I~D|9|ed<+)j%zY#N+i7xiN4 z(U+4UB*&>)WEwo~5@B&ipk>y{To8D5jAt;~fVjprOLwqA=-p8oN3FCGJslsaeuTby zld>imsA~|0AIEP8DaN8{j&j%wuP+d0?@MUw9quDY-5d1_YIfls+EZ^)G* z4Sr*f%XN4+6m%B|WBV6^aAoxMDXM4=fz50v@-i~3y-%lpf5SiFEQmP1NkC66bWAYl zm!PdZDSD!V_`A}RuFfN;pCRXa>$CXB^goouKE-&DoDK4B$k}=nI zC!@u);SUI?YRWKWH15aHDVF=)^~Okz3Vd4OO8p}t4cUBzxV6%QqHLGV)~FVTG@e*d zAh7eiTw!)r;F-?+8w&$-I&IIt3f45h<(t-IyeR|a zV$R7@eB%S0&m(D8d=wEh;8?DiQATP*!mQ4H)2`>o)1~{tZnW{AO>(ES3~OhZ1T|gn zRDcIsWceAF5wBvI&i6Z*~AhjIL7 z=_Z*zN*SU+0=VP(6qDFI-gQ6EYjMOo_5;)Gqz{c+r!nJ;3OI7O7p6@jd$D*mQ36OSJki2 z*du+`ph{at-MxMISqrCAa?e?qO z@rtt;W=yO8dO35iRvOZLl=%G|(MT_fVsTT~bg;z^YywfhT!*&Y4Y5!Q1A7oDm-fw6 zJgM0xIljX*aokUmreaxBPVovul8qmGE=PjOePuMCJnp=N%)66zVx92Kig$Apam?7! zSiS=@jXU&ceTcO^@bH(TtH*?PDKL8X$Rt1ZVad($niEL-wCr^(qkR3tr(snUU5!VD zI_v0i{P!PNb`F)ET6Ca$n^-bz=Jh}rkYp#~k5=W;I}%3^*&A-|kCsyb0ph*sY9BtL zL^g1)h0*Gs_Ye@gP8lOF^$h&tNa@jloA1)R8AkTEL=vGLF`qP@6CV1!i8WJrUr3|hIj1_zvus=tEh^t2T zBwfjUmnzsoMO!UEG|f!)+LsL?PKojNht85_F=P1p+h zNM8A#fGinUQu5Z6gGbpT|**<59C03XrzuQ8B2*c8~yTt_RGyzHM)ujl25c>NfSyY^EEzD(0(Z)TgE{V z%V#}_7xvg+*$!-XOF}!1O-M8F)Sw^Zk7E@aKLW_ir-Y*vj`N8*mJZvrXyKFV>_QtkrrGqK_$-`>SoDY?9*>F40*kC(gDQtWEP(kSYDAB7x(c)@XX~?YiCprS; zHv4pTrYnI5MO#kI|Dr&@L%fC8t}niA_W6&fksrM)T-DX#zc~nS{v}|L@Hp3{&VPGY zvxCoIOQV}3P(k55SA2fyIm*T(`l-9yFOD$q6%r9zvgeyuQ}tB3J$O-EQF5wpe>&Lm zS<}T*(9?4A;pVY%Rjm8!u~?a7fLR!uNX%gq#@s{E4kOtN)a<*|7x)YARSJHBAb6+j z7N{N`KVRmCJ|#4%WJtI(rn+0^0E%cfWTV06u}<%+dwNwIrs!WXk0=4$bh?+eH|;$>rB zKg8Ntuka9t(tZ0RcrWc&O2jUd9S|T9M_5Tk^5e~$KuM<1O0V!`ms107zgZ!(W{(p( zQ~^-~?zSsWPoK%VzCACm%2F|nzcgQd)(%hZchZm#Z!r8~$H{hW)VqP{)ek5LkE(z7 zb**4t6M!y7i9O{l5I9he`;D!-1b!V~pJuX;BVaA@Ob;FAMGm0(@@rwB+-Kz5{ykJY zwz~9`>`=Hk_e)5M>6la8uuYV&0M3!r-Q;q{)-o^UxbZMoF- zPZSI#qzceTYFx?NU*rEpYks0$5FuO;*9@tF&@&t`>|M1Pq;56$6Eo{}+IH z`ufABrqg^$V|^Bm`ihJ;(tncL+Wr|IbYX4yoeT%1|3jYg#A(0!3svYXNQDiM92r8= zym)2yBI?kQ=%RM6kB&7gNDGG&{Qvhfkdfw>;_-#%HQjQ3udHlRY~%=7Q4leRy!URZ zC!a)~&Hn>J{~xTj2QU6&8;5Q5EvZ*_L{oxlMJ6A<{r3!vykzs_#2`~Vj{ldu{?Cr6 zqnaYKged#L!R&v{_Y<5B&(J&GHvAX1<9^KqA5D>>7a#wB|MMN~1pRC69&7y82krma z<_n=EfR8HDB;{KC_cSf>11gQ0(Da^I{}*CKhmRJ=5C8A;&=?^Bu3bc8_fP%rnY{~v zkE;LN-u+)Ja`EFkc&z?9O?>pfXBIC7A7wH<|M$Ok!tQ-MVI_aSXs+3R&rF^MK1#rS z>-oQSf*($+!QX!DjKA#B|7@iQ7e0!f@gVZQcH$fSKvsNOKeza=|1?GV{}1nf4(R_E z4{r@4grn=`^S6cl&hb_M=f_wXF{jOsx9dSTWM^=w>m760<;BHzXu90$bUvxXaS5_a zy!HC`TVN;KeN~}roTH6aV#(unC#}8vr^+4W>&IgDgaMPm=zaH_s}#)ZG_MmDkC|aP zHsr(2D2`N+mOdy|&sjm+IE2vjxM{Ar8wA!K$f0zdR&$0Y3?qRU?2K?&ZIJDAp))W3 zxUW+Gi+l7Z#g|_+xewb6JMceS3Rhvs=`$v~mKUeiJNy6T$xVP~;{WO&&PQ7)cc*V0 zb!HUQ2+n^y_-S0NdOO_qGl2AEe zL)QqO1$d}l*)0TTuy+FUQ=HenojIj6yrL>P8I45QZi`kSB{~$o7k~7-g#Z2Gi4wh< zFH}`hJ1eAqEc1W|@NFHrL;^1vXWdc#C=ndrLU*kG$Plx=;K;yTSF12eH%9Yp*V)bO zBo%v~yFqdX0>g(*;;8R`Gb_$2ISiC50N>#(JMxA1zPsHOf7X6U%MYVqAO5*)gb))6 zEIUe{`x<&n`jsMw>_5S=DU07#L++<2TLNRT5`{N+mIYOt=MVqJiCdPPl?n-_>eXHSq|gi(n2de&Hdj)@L@^P_tUt`<>7Oc->pUVpI=Y}*M*Az-4uJS-_1tYK#aiZ z{LVsSmk{g+V!KN`oN!g|JidhCRh8Lr0pV&T$NOaAxE6+r!bGc*8i9#tx(sGpP$-go;gqjr(v(|?;edtJ2|NI}l^`QurX;lpPgi|^n& ze0t@*7Ibihs|H3qeZakqr~nf<*LMZ(S-dbWxQCMan(nY5ik-Q6{&Uo8rL-vx6;Tj{ zXxsyo-^4PQ1}T5kSu%E*`V4A&#rP`iiWjaGy`V&51h+= zJhF(2^F1;dxZTZNlb8Fn83ODkB)4l!@D+!5<_hdKd^%4H#UI&@WT}QwcO=fX#g(JB!|GK3k;L~0Bm-amW>~PpeT9lmxJ*=+2f>N3 z*oqNgpEdON9p&Kg7!xoGMb<3MkAkZ%)cOSup-RlyUH8K(exD!YR{n#kdJy1!RisuV zAZ9&eqvV&(a41UnG8uAD`JKlXs?-c$bzAx#v8Uz4o9=;Y-DfclB!)-rr(I92OB`k; z1}$rT-cg7-AuoQ(D7mNa-F|6Qjo)R3+jw56S}ibu_yv?<9fnbPm8q`PeSUtt=2)*A z_dWSCaMFHi;Ne}?7ghAzeVXsKK-nVKmp`t(_`aq4u}bi2-UK6CR-lU>^VTR&da2Z2 zMda)w6$8_Nzw4;VYZ%|{3Y;8cVL0qE*{=HC9|s1|I*0oFD1~j{_`!lU;itIfjP>}U zC<$o2-5uRR#!(ohxdLut%U=DDaxm@@5U_DhZMbSmBeI*kIm+E7KeOYuWM{#Ro$;!s zY~0jJzr+Iy#mpIp38ch0Ju@X0U#OUA(noz8 z4z6$8|6o|+Q8nXbU;O(Cey!KJ$=nyj8Nl(3GJQ`R=%)Ap(Y9iH06sg1DW2(-e)CbE zRMweXnq7Bz?kcq+QmJA-SRU=w8)bGDGQJ3|#g*cuSA0deKL^)PUccJW2uIanxCGT> zzedKCu`0)n+I{U}`;63^8-SB%uNheF} z)%ZVF7)$8&2h~SO$%Zxm8xtmW>g4UCBSaAiRnHsC8kp<4;!{GT`8p{bifuuRCRoXDsNudrtMv|3K{vp3-h9wzUH2FiFOlUEu78I9m zB6&61#Vm*n=LalUj|<|&b)KH(qIN%^c=j{xLZ(eUdr)MbrSSPpiNQ@`R+3nx*q z83siiRgV!P4KERyAUb0v;37=?p{FaVO58ZRZ0b9c5~QCo$z+oC8I`{)q$3}EyVQ?P z6iKN_7TKD_F2V>G>Fiq!Z%4E$V;}Sv90=KiA0aH85Jq!YV}vlK77Zixp&e3Xo05Xg z471Pi1yZS!X&3;Avdh0~&#u|ihlR~x-n$G0l-=xGY$^f|pJv#Mp^HSR(g0!xPC2fx znP+Bt;^l1@S4ZeXPGPS5t(s?=AkuJ-B26(=4UoZ`pnOb8KQ(bC!mxVR z7ii)Dd?B3nwONJMbi0dHZYpg8&vY8Rs+Y`3Z)p3n<14078USWDpcy&PBBl$;Hcj7c z;kZOBErZ4*sc}4Fqc`hBa%c47&Om+ZrbQDNkJ%Axl8NOHUPGA_O=*^M^R~Njyt6QWST`ZzFt*(LO?-b`(Oi1iVd<+RdzWzf-gqFYZ<=+ zrBl|VX*h3qug<1hFcNGTMh7=z3Tvl{&S9|`qR~%Q7B;R0k#&}|Eg(Up@+AK_aaM<9 zJ z>KG}$Z|&*o;a1{8!;8Gf+s14#ZKa`B3~^zW5IL>dR$@y|O+_D+&0McSOda&~Pz!tR zJ>9g~?tjEt=72wP8*ZfleHTV{+vWvxy+1B~Un*vfP#@5x5__;~vKRg2pgR^sMeMqB zdT?&m?z+HkpQAF~{tVtJgmyq7;X1H>E;131S~=~zCH$Oo9}gWfG|l17dRV!|@#f)K zh|h2X*zZy6UaU@wj4+9b54pUzZR17-Vb=r0RaYa};bn3byq@^Tdb5u^8WU z+zCJ;LH51WwPQw97G#Zn{k8a0u5XT(V>sii_CtqAnXBdL5m zk8_!ojxe9fClY!j64br2(O~L8dipNoiQXEZRpcFSza-hU@x#+ZZh6QYRe4Q*xAns; z>HLix>lRq4c<-WdFMNhtz)y5#Fld+4Kg%P4<1KnffigX|Whid=;abUfKOs47U%ZO-m2xa zbupWdbd)S$CV(UWNdiM|IQnZXl@Qlkt*NH6!gB58X>_H2LXp6Uf_eT+rblWNK$nw4-K=;84~UG`k{Bk@rizaNhOmk$h`F7Fkj68n4KX0w+^kN%z#sIWRr;_UBzjj1 zhWh7X9a#23{*dN@jJFYjpdFAC8vSMX_4LW6yYV|Nt#{co{P&*| z^cXkSGzahU@B|*-WX_PFFka!cX&4JLknO(i@HE(nVm9cz(*}*ZF0t<9Jv;h+jDTAUwKfW>4K<1Rwh}J3F_1;Jm~S*4X+Q>Ka@Wy>m%9bg-C{JBh9h=l#@R`f3HZr z|8GbvuDFPBmV4tla4lZeR1EP9j3ro5uO53H{6!q=ufFGBe?hA|F`dQz3`9;pAbMAo z#Anxc|5{FJV+wiyEyQQ`|0d0BZFC|yQb{?F`P{WSh3TKM74{Qxtnj9TmllkXkm|W0KLoVTjBVnwY8BuAZA**UQj7GSpT-V5|}P9l>ZPfjm@x!)PXReG!~#@|j*CJ4T+xeM zIl8GDvV0KJp~`T^Pbv7(Ms9|V=^(0Ya1GI)nK8wEwVc}umf zCy=8f=ivN(g@FRG$h=pO;Ud+3@P2&(u^1g|l(yrKuEVi*Ruz#x>i2qRx}lPh&!|y6 z))hNL<4P$C7%JpJrFy`k+6^H;_o96C%wgn_9#K~)dIekUh5;K=vtR+KgC}Xl0nx7O zCaaC^dQ=>pk4$HsqHbwN)#KJW7U;|jB0jAJ(%&?yoE86v)vLu`JxXA|onP%Ezk^7f z13jO_b}S2>b8-nC>WP&Rty^Qd4TO3y&pg+z8M7)e49s5oD-yNGvD&pL)8i`X2{8E_ zp3^<)ocI{f%Vo}d$>vuwEQl<)3}JmIRYGfVqqLy~qAqW=60giE!3mKocP9c%ls`rL z&>+UAG+Gd;Cabk~K!~&+Fsr(?B1U9?{)^JsV{2LdEJb8|7|aw(1vRh{Mir#JPY-yw4Ax7S_~UUF%*?RO_aTeLX*ta0!iWX6?;14c!n>W| zam!x<+!l}J>6&BW{<0jKE6@&qPd@M-QT7}~wRLsTxPo42GoP-P6tdoj-Z|Oc(qWN# zltrh>03mTf*Olhy9<2;^BT_z2dN4#!)CkVi=H6Y`8Y>2v*vG8o@p4vu{~?ZBtWkn4 zmiFq+?zSBUA4Tm`0)9{jVQP)`N~rPDci_u29>^7~o39YK0ZxfC_sr`o_#w@TU@Z$0 zfqzc|zeI~JN`t$nJMJ)aH~U8mX$@b+E_jJ`1WWM)32)TY=r0G5uV(Ql{h{+{SpW6M z>0Fu`h^kDxrL5@20JIh5OkED4Qx-b2*oZl0o^HZrw;A*W{!1RIBjm-l@DP zXfaaL%O0y${4ttyOgYkVd{4wW#V;Tjb|_{Xqg*x~W5+b%!I^Dfs#h``GIC4kv znh^}hUieA8#8F2=H${L{5*f%x_uVJuzr%T>b=rwTjo|D+T$XeIk{ffI&RJQXk=c#; zDU0~Y5J#z5_Jzs(o!;z>WYee{nJ;@#W`0-B5ZEjUu)SD+<;H2x7vr`W+pQ+IdGm+2 z(;_O<>J!AKtlQ1f`ix@3?K#9}_rW zZse7$Y~qj#1RF53KRhqc@i+C9`47b&{E|~~;AJ~7_3HDmPS%P;d@EXSW&L3WJ_)|K zqoGozjRE&sAYnXq6tZ|*)7M%qXG&|n7=TkrZ=wLdkLj1eE~z6F6njSOsla)a*cf4< zlpw1*__Sag_ZZIVy+;k>>_^;VR=7ugWA@KgZ+&y@<4ipcnVJ{V$?YKx&B&ZH_Jf3u8yK?|=YspGf~P5$u$rWtd@uN)6RGaE%3sg_q*{X1IuPup z&eAk17F3r1qde@NiKsoQk@Hi(hs-F@ZdU--n#kRtI<*j?pSKcKsuW;>!A6r$I=9XE z#`BjMCb%w1^(B9tz$x|=%xlY!-!ulINlm7BgDBB)ONE;bgFhP;SX^7Xv5fJe1pv~M zL(D9^V+{^Qr_=tW19Z6t8)h8O+?Ta{b-|@(PDMEb#{gGz-Gh6ec-+6BEnc}mW64Z# zzh_C6y7tSg;m0~ldbg4CS~NGyDBvRl4O3GK z87Y*%^95OjaSz>)n|Ug2>e?Rt@^yEJ7rExg`i^3M+b_mverLl0bZWhXfX&_l`_nXV z$~w7&FUxiGy8DkT`NG(Qtt#GdPpj%az18PFHM??e_4DPj=Gl9u>59CA7A$3PPVJ7u z$I22{M#_n^tNF56Yo87Y*!zd*OW`FHc9NnzP<6?JVV7IBEqo7T8^+XEL#R&^g=t4i z;6*T$)hxZpg9Urk{@Lpx@i_D-M82uY+zfWwx5mxtp>qHAS+GhnZiZ9Yqx$K+8+ss6 zn@qf6|FbsL0wjJ4=VFi2#scz9NIUSf!Vy^9^Dkcg_#Nif+%{jMQE{#ND1xxINy6vd zE8nYy(z*oJ(}k*R4HC=IB3AV^C#<;;O$iR5-`8?zGB`0A9_bBtA*B{(JOlXMTMwxYH) z2AH%N@RN@;u2TE~)`uYq7`&^B*Ba@vk(s5T-4MQG=4WOQH>d`}Qm((z@Ml-!Z&o)Y z2_qMsHC&3*>y{7-UcHe)>HRQVM^LTl6pfH_GG&{dlcD%!a}Og{BcQtTN(b=|3j-zw z3H!$mu0r{V{CeFAz#jrAy~dQC#m8B`P8DdsS26}dV5xDvdTuoaRQK!#hvkY~QndxA5681;a${w=SsCRS z{5u{b$t~%bXJ}aHbfaj+6LIzVTZkA)=t$!6hOyv4)vto=VcBEC5g1cbydKHSjvsJb zB)E!iki;%&)ku6M_yzyU;)pjfW;Rtp>-*DjjU!t&Yuol~m{z=~ObQvxn?n$w%+j=h z!^BLi9-o#Zt-cDri885r2e8uuSgkb<)Qh~=SU3Jz>YzHZ|L|?6juk74@kWd8TcD8} zDT_GMk$%tnwZ;^<=l|m-T#^N8jpt{uyO3Kjt)w2+pw&D|(({1pVx@1$nlt4n{tX3b z!27^NIq!3k_fbLzk-YKkKMCKUnZKL=g124uYzVLco>?v@3dyumN9p*7r zBFX&uK&4fyK0i78AWs#`q0W*y%;WMJs3Pw>r#wqz_=onQLwi&xi1ED7U6tWxZNc1)^tZ1mt~#j@fg>$j^(;!0dGc+i)Z2r=T**06bnyeKOu4#ud`Si0p#~ z6?T?N4}6w{$Z`T9f*!J2xR?!athm&K@s7EJP>P9`JsF=NI3kxrW~j$gJ33Rk_q^{P z7_BE@a>z1A4KWh;qlNNawGmg_x>&JMetHiDFczoKB>Z|c!zr%!?P?*x_kub5Y+Meo zxa=;FS3*{Zi*E~`l`KE7ILwNkkacXZPy`60vJjHW?hoM=w7@uIaSh>1{R2`4`2?rNPA+5XQ0wjfxrGN&Z!t z=nB)B&%H%JN3NLi*||Nb0+b&I5jwnu(q@OQ;Z+bq6o=1!8Gv=*(a73g-1NZ0IczmS zCSCHf&e_*PJ#8yrd@#2e!oO6_BKU1rx1ZFn!KlU5RVzkCL~8(nwjcY@wwW!134^?F zp>}0NsYzeG5Ui?-s{pPH%+YpDD1L9Mk~V6e>we%*QPDmy&?6YFPWmb zquK^L^arVd_3sN)^6n~F1d7{}b62IcuIs3JSd0AJRP|ExCb!4JO;vB;%y-y6T#sj6 z%wL4kb@j;odx^A~$|9IO3j%Y8QpLw*8){GA6e*cR?Y3#5 z#FgVY5)}`Y5p#A9VqOW%jf-848FE=S`zCmB!PaA~#;`<;rh&4!uR{e#Sz@{wvcl^s z<9ba#&X_v#(@GUP{N^G;oCw;&oY=>o=$ zE>T|l?B00bhSUO)p^uYVl6TtnOu_t~$Ab)u&!+L=!z>eqN7IO|6BS^g8Cq*J;R-hc zB6a9KvVpL63R}XaRKW)*D0)EydJGIJydj+QO&> zht3PwS3fZ+o4yqk*rUbVs8VoHxEgPD%|x{wM|o5%Leklt6|2PbZwyMgs@VG+w_kbA z+k5K##<7~U%01w))jj1DR3H7u)W@jhx~K&Kc2_68Tw#-?F~N#>amut{p(NiDR^7B7v4thn^0Q4CQN*2+H!uY~|UBDLtea|{Y&8|`n zVhi?v?!Nu$Mpz1GV2V^VcChKVQbY4hB_{aC@H4bFPa_ZLx)Fi}na1hwnq)H9n4U?P zCt$3(=8Qwqq^3WFlsjZAwmkaNGb{<>vxR(N#(K_KgSz2{Ze`wmSx}#;eiW{8Sfq1~ zbBi3nTnLmog%@ISuYTTHfDME#Ss0U1#XS$YqSt_jji1N-O+2M2g~f5-kfJz9{o>D~ zs8Fe78hl9(wS!&A)>iMk&3Dku)LG$o5 zX(xex*A^eQ5mNb0lzabSnjQg^C%RtHUKnHkud3Lp;jK znfp7;iYrY7!b(*|JnXXkAbvldUp7hpK5rdsr^jX<>cgeUp+Z4#v3ngNulqG*gWXXi z|F&ui$x}Iz3=9)8=$<%D&sDug=UdCv;1#o^;Wt9wxc?|}6LO|R67Hq`%2v#((yzxU zg?s&DxHQ|-yZ)F4<}M}Ca{Mg1vltq9Auu!7F$pX6SvP#ijISfYz&#eOqa&in%4}rd zw!NPtK5{T|;~mWFhYb?SyD?XEVh|J^)KO)-jrGLNq8q6K(^soV^i`n(IoZN^F$8t2 zbj#}z6vNSc^apf}Az2z$?l(pha0w~Y0cvL@r>6bpUhpQZ_cgrZufPxd)i*dE_>*HZ z9^t#pmcQ!=sbc?ycN6;j^@s7Ett08btTrW)8?71aK%?gQ*9KwN5Czao^#<`G%9Gp; zLft{6v-Pg28uYcpPLXn>G1yxAW;VoXzHiSEQrZ{+y(08ol*y@;KwvO44P@^53p=_7 zG+FQT4WVR6WhC?nR;8&yMSgF8VO(>ZHYaQ}>HSP$Dtawg)*yvr-Q-&jh#3&5^eZwn zy8NyI_1!3yF_g``d9(jxY|*BI8!E6w;=Ik1dmDWlR8K}?kng`qX5jK!M&$N-0R1V; z@Q(akp1goM4~V8jU8>=nGCvt>)d0A(#^Cuee7~WY&70&Z*s`_t#{UnPks5=oj05x( z^?%Tr&Bf3?>A}rcrl{kFgkrT$ z-gqmDR^gvx4OXc}s;JI(80^oRpx5o{lZat2G3kX9`+Q5Oid5O`nCYxxRti^u=jDAR z=M%uC$SppTBg*RAlt|Y42K2GuMGuRY&S6{LS_EoBXatcI7Y|f-7ijFO$ z<~|~1j6#Itj_Bv`!lQP zHlzog01pR@?+97{BeVj(B79@6pAX(gE%AC+l05U;Vv53AdryK{C^GuT4@m@Zxn(k! zCcnOnPW)mMYY;qCXX*s@3ruXnWzZzDnLkCG4sjAH3G6Kb;*pMgGR(IBQED$RUs19x zMH=`JxE}m^Kc&_vbU)Z7A*#gWpAG+3F>2XP^+w!G`P(`#HGTD8YiUW!E)~@OqK-5l z5NUEMjj)1sewqD5RSrYlBd9Oy4y{V={UfXaj|Mp}aev(8OtvL?gX^nI zknM{9kyKyYV@k=qD}F3W1wOf#woU(IBO4*mPT(O&{T`V9;)l>T-4Ye5G6JtzUsR|6 zqVF1pw(9z1wpbnwGbY(9s)&2zg?I0Jd*@AA$DnCwr<6~z#}b{GD3|?ZhB(;3nD$-w4|i}X23jUS~ZQBpJ_7prt(nLIE|A+EE^ zOX{ZyzsNP>5AQ;|T4MS~77=m>l4`^^eU%x0s!M?b0WEDw&=R2t!9?SCZ+FY~Bi8hJ z=eoea*th8>R|+>gfbECo_)Tui={EFHyHndv8i7B%mSPO3Z5Ps1Uq~!o4z)P*M$H zSv5fvOly;C}bxA?UDM1*tc29rPqF88LF~Wb}Qd zefykVQP}hWTcb7AlbMC(_RAtBW0}?kNk=H(3O|IJ86i!Uw@v{Zqq{wI`b%*NGcFF$PYq?A3;p1EBb0q5b}6Wd~0|ZEtjCWOey}aKm|x_ z@@K6t46y?F)$&POh9oHi$LUk(meh^nvH6ahH!QKgjpkf#*zDUORYJfJejJChionJR zED5&Adan)S_8JqWq7$LmtzHa{iw{-NPTm)3<@~AEN?kg)mIw&K$O<1NHMzPyMK&C? zv(Hw~zLqA)B}duF;}CN+YSaTvEso&c@a(o~YmggsC3VJKn{wcoI9f!-n6h~Ctk&$d zMe5`8NNv@q1B~)Kkv3>O1OBx`^0l&>=weyY`0}H8oZ_3T`aWWq(9oO&WN$+z_Aby% zy@g2f;9<3G= z{`DYNf`1XC!UwrsbVsKbKI(l*yA+?oj^1%Ju1vEZmEUOi5cgOgzg9}J-!M83@k?#) zKXX^*5&P7n(k0{9hw@W4ZfNPvtYGZvMmk-rn-ssN__qVOaGsJK@h+b;xDLK}rbdDG zaVgC5@Iv}`_#Y3p&zVT{PUk>!{JgQCOlL$5Nz-Ux4y_6!!7kBcz`#ekgZv?^foRUG zDdc>ZBX$B0LT~K92+juY#q={=7gW))n*u!AANvG*jjkxVCj88=2iHZcA<2o|raziy)Kp<|~4bR>~IlXMnOCjlMy?(Z2l8{2Km6~)so zNRIa^QsT^+W*!Zry8zO4ype?FSPZS9 zk^G|BrPS^Ssl5S58%ujgslHVC6=Y^`WKlhGt@q(eHk1ecN{fJ^^hO3oK#A-eUTzl5 zvMVsT!IQUrB({yIIb*+dR}C$!eN3HE0|7v{Ey{16)#qn-1QzK#5V|+JerV#Re@`TU z7oUhR*kna1>RGGRXT1-z8uwan7~bX9eoHXSuHi}^#?r2GW=!}z3A9AXnhIZ*v!We` z37;b&qmec}Ki%v4LuTB5YLpojAA~#7jsu9fe6k3-_3;);yX z^GXesIsQ!#DuNj*28KU{o^qe#VS;boGww3vTNdMw7x}UU(seVdppq+4>+>NfzK?7( z_zh0F3qKG-yT=%TGNAN3il5<^B^`E@RA@au51~|Q?MYBjuZ)N)NJtGAb?Cd*FY6^IgRk6M0?)mHuz5g+D-e(l)+U!Z1d*ZnLdSknb8?G)mabfHONP2MhqT?$F|x{ z%)K{J-Q!#}>X$gRoQ=cPl3#5(iND7 zgepyD>#8POwW)qpB@Qp@)(>UzlC9>M^dQU?ru^fndJh^^Nb@E%cdManDZ88>I|wSB z8LYI}{TVxB8ICQFbID(k(ugs3TCA%n>pl%h4QfaK~QN5m#u7~0B z(8>M8@YViXoHZ@u!^?w>iW@9aY<3!c>m6c{;^V(T2Ocr4o!+s4of|?Alj$&?gOt-D$zS*rw7D=!_iH&%VKZZk)Gj&K(w5VnLX{S zpGvntY%NOj<^rIdBChgFiC8$XGxlT{)ud_q8<80;Q$=4;=xg#-kaCF5Z`FEX>lAuY zF@ozSx3Z;*gZRa94_?4T)de(Y_=E)9q?o`cXIeQ(v$<5m=lecF<$XkBoly@*|L%`!>j& z#B{FCc?LDU)d;s>>bBsCs34W(kheFl+-JhvWtzd75+i8za?WsVr|am=PQJ;}z z;wDZvj79X^iZy5x-!^SFAX0F8$yN8|ge_IY^HFOyWrGjPpxqT2*`_eynLc;GFo)No znA*|IIg_<3FN0r`^#Vc@Q2Z@94kzXWalUxbwgsdb9&FOG1Ed?J_`ZD`R1QY{6u$eU z#4sYCo}v>C*~6UiS%#+ay&Y^40pBCP61Z_vFlad&rXi@Uo!6nA$oPH_tm zBoudt;_gLL+}+_#|9jp&yPJ>6dCv1;=49r+?wQ{OaO}OnYckd3wvNEChs$0ji&b8Q zZnIc+-TwMZ#tzfFBAPX4tleRhnXlhgD6jRNoPWo{x^%@s3^)%RCytB{?WxMB4{h^{5*co9E-kQ3YeftiLxxH{J`eYxvj1K^jGLI%#hZB zh#!aL@QAu~vvY%uy)@naI}hQB=7kX$s@iRO!R)f*~P@@wa6jfOju$pEUb#u!80 z*~hCR!t#f~viV_R(-uwJ(1(e#sL5G{gU=+soGC$e4e%2k&U_aHCgymZtSN?7?{G!@weLe}vm}5$ zc2aR2k*`qofqX26p-79gTZhQ{+I*2cQGw?08_~C-_{_XP5hP-*9qlkohMC_i4D=TP z(IlnL*=Av>$=urx4DZ}lrf|O{=F;=+&KAT~)`EK;0(sBzU`=yVlTZkB>3$!6rty_2 zbM#C@0Ls;m#2m^e*Om&i40<5*n?vI1AAhHRhAvciIydE+)~3xT2@>WXHk%>DmXmBXL1 zwu0X!0Cj~=u(uEnoNsX6SNNJ3!e3=SQEwF6ejdh9?k?ME>{kcNt}asmP(k%<;{Q~w z2JKf9%%xCPTc5)>P&1l0n0#e@fDYO?rC8^&{I-??2D@Z3TDPO&zR5)+a)Mkao$FIWpOHDCISQff zsr*mdbFX27k7pvdp3lj$M_*gyJFWNH7MtB<`Mw&ly=4$LzsV&ynDQsSH&5`tzR}k_aV2Pg@Hyqdr&Ma3?MI4q2qyvc&OH2 z{#?=7uZ}P z8Y6|4l_H7S7uJDodU)DNRV@tl!Tq%)+j)6J<0|DsmHuSTB}E_FhJoUFfAJ$+rRhfiv@rlm@Cxz#yIfm9NHpW`y-!Qw6cHpXR(W!$}i;4!;*mU*p7G zs8o+h0ZO~ULiorw=v#XHcoXM^BT!?FyN_$2D|gbO>!dyhTf^M<$KhGl4(QlV`+5AF zE(Um7w{DYnj&ZYyg<@_^pDWg`B+sHMZPZQCjs59WMh>ww^1oi~vCybNyRBGNrI9v7iVQ9Z`ebCoN{5_w z&td#m|No_Q=F1=fGU&w9{5zpz@)Y2w?t(lWCox;sn+(G zw-brPSK9p^^2z4XrBG0}T(l9QAK>;!a%}{5e|n!%vO|<7ZWqmu@;UO@$u{m~RqQ`+mGMFE!ces0KKI{tew@T?0ZX^4a81Ipm7j=_`o*%2nSPX8B7xg=1QG^6YTVu^?>2-4 z^98VezJ%4#1tz^>yPX`e&R+DYuEsdBC|Q{}~~E7mRiru(reR=c(X@C%A% zf;chhSQ;Z;CT8j4ux;a!(FXNx?IF`lQ0+<8yza%Fdpq0YYy|kt#IU6Cw&xHOHvDkN z2Wz)9IDDqy?Ve1Bb?n$znzJQr?#f9VzlKulMfGIPcTj!1_Cfd7L+GKTZ+gv}NEuvU zks3UYqy-5MNrU2H1R`L!wu0RM`0a!=3qOf>Drrd|faV3u z-HMHp#}zQ#YSK7X?V7MIyol{sK{8^?ug0%-I&Xu`wJ^*?_wGLu?W;cE3S>GEFG3fk zHOuCO{QuPY+|rs}&Afgw7o4xj}+bgn2dI&kI*BfYd0C}Q`vz*FZQGgTXluRQ)P zhdBQA0E37gdNC(L8j(12HuP#9-UmGm|; zaPOG&=)iye*@2m!n`#MkJres7jzB5GjcyL~yi^=Gsuk4dPMw1o*`);EM6CHCqZj-> zClQ$bg$i~Y{JQK}?SrPu#9+y2ha?<|&*BM7!_Qdo0aM($v2_rrdca5kdx zJ{*toRYq^MTS-B2$DkoWvRhK`VrCPU%ny6U7!(73dwD3$Rd*xvkARI354wb7h6=!` zCt4Fi&Vk63l8jbL*&ilVe1`>)Vy^=oWF-<_=(tzyINBC*T;niBC=K1i_JoXkkhjL0 zv{HV0y@8g}OIPyLZ+%x^S&x|>Ze`qWP!T-gO3J0a`=XCFhk6c%ALerENj)8_U+{m};^ zu}Bj;yerdBRG{e;E$;@CY21X*MYEZ@zFMDu~;j#jmbQo{0 zMEdCp^Cqn*5-t0fGdJ{huqsT9^&&fZmt@oV`klBqx>SR7zHERzW}&I^k1gpZ;5Bj@Ji54 z>W`z}#|<(`|CPPVs%1sej*QPVh3O7_{=p;i2>eej&t83$R|b$0P^X}adtv~|T%bZv z3hEB{IKh-Gd#M9|C!@6>DIHTLIUjt_d<0GCfXNhsp&s5AgcyF&9*R8S$GF+A7bQbG zU7c_o!HD9@;St$oG^N@U@wBiU{8WU3a4`47$UBim_~fDF#>(*f*qGbuZX-OZ?MZBrCy(b~vIo-&e1j4+Hopk*VbNd@B(ciWbC_VB(pGz}>Pc); z^Ch+^NHb=N13xVIQ^N1}BY`SkI@t0>*Ngw+oB!oy>Yc78_}vnWbESkr>7)e}lZZx~ z)V!`6+-PnFAvv*`Xp!7aid;fxoJ9#9r4gX8|JqcHdF>LQP6KwYkF< zn3T4~*7y^OkewkUDDjWZE0|KW3v(23SG**>rGIJ8#Kn8WEyRtG4O(lRD3P84U97yt z_K;8AMzH2N;B!1lKc@t7jRiyY{d~XlVQ)o1NnmFKLCMHq{=_ON9Jv{jYAyx}?n4O8 zrtPB(-v2wbYOGH>OM9+8^1@{n3mVI-_?j!L>zZ3Bmh>J!BspI?z0faLiIsD z<^&QIQcMN|?w_K#t2opoa9yCX*B&JsHt*@0W1w=JhhBeJqWL$}LQ@*X927}NPnNp0 zn$bQ(e`MeXV*6DAoxDP0c!~DvISjH?v=CJVc&fNLWM7Q$w3eJ&=MZ#iw*hR~51uA; z;#xCK6w<9>Rku#0&e37U_8C4m8C?LjU^S39LuJd1rnLy{ z(_X?X?+WTa#qQ^qDS8&gf)ksA2c9Zg(Yq1^%rwE>Wa5#{7@NX!Gz00OsKZBV(813l zR2;_bVp4K`FVEokcjC(c9LA=?71@`Fe9QA~HFaR`Vxrj(;-T-=qrUOz?u3wj~XF}PS` zm<^kr4E~Y?C7eO)PlTXam~t7dT8T_HMQX!Mb6xolir+?_{Q}}<)OYo98Hrrxs^;e$ zm&vXK^{oM)yEc2etD_I+yoePeiO#m^#Q#>Hn?PJuE;$q^)xhBJAoO2=Ldfp=LpV%I z8$}@9nI6H_?f!TEQx?$S@Krq&c*6NNLWJQ-J~i%j1c$NHG!c_t3;zo~t!O$Fvr|-Q zexZIr2e$*>kYw{%)(MkU&sP8bdDrXm-;6Nj2Akeg(=4uE{%Bib#bx>-VuhI^zK5rf zWM&brb1>%+1+O(CqshVVt%aFBf;tY`;{MUSq{UimC!8f?EBL6QZ;s z8^rbd+Pu>xY&bh%e{BgOC(lI?&bI?dErWNG+*$yg^ zCxs$Rq@B#Q+7>WVUg)p5*UMWL;x!8Z6#ytdGrOK*sl?s4aHY#>#Y3)gL6w5q<^aRK zk{!Clidq;f>6&6nGh+)$!Vzu>7QE=kD%4=t<-#4piZHbAw{gZ$E}3>ZxHY+i8!3;@ zt&*Icu6*&D81fLxhfx z6z~>JV8!zbu_J7GSn>3F?_HH%ggqC)5E1>2xpITRGJE=ORF|qDc3bFKPm5eUMNNGj zHVN{RrM_o)gfU^2Q9qCR_vAZ@h8cH|S>({p!<-Y8a^C|P1|0Dlr$)LbJboZf(XUOGO7&dV3@ z;CI**eK+50lh?o?GZ$9w-2Q`py)Zmajr+&s=wpzEnWA>6(0ZvDlZN|Pqz*OVMV$FyHt-R4LVa|Keb4Lz%^;D<|ELGw*I7E>7w*RFYdm3D4 zA`6#8&t*41XJ>gBR31Nga<$cJ%N<+cvJb0z5Y4?BFX{)=yHBJ`+rOo3XZneZU+3sD z;{)|)_&9}zNS!*l&3gg2SR4}T^jGCYwsJRRCHJL^mRan*mU-w!>GI(nse()^cfOKs zW!b>UJu!v8w7Qqnv@%0t@nOon>t~DSvW1Ob`)z`t?MV&yH49xdew~J19htsOOPfT~ zITX^9T^g#PYf&dv>7}$H92=Y>HShJFQ0s;jcY?9B;Y3#E>I255nD`I0pBuKdo}{(_ zUuCXgS?W%Z+vEPY@5BcEP#6vEd79&rF-jetY|Z=&P-&Qme%xX^qQplW`R&zkOfp$v z))4^rT6OJJl~$#lD`;qbQF5R1ikw-hVKmwJ;vKNKQ?MTD^<_YN@X1+l{B%Cp;v3E& zR{S?xx2J4YXbggi-mFqX%X~>D^3HXMl|_dQA+$_VI@F<~mvB=QgY;6KVJMEs^6oVz{KVKxEO+f! zXg;}LB^$r~QvFAzifpCGeKRG%`rD?T&sZ6HVX}CVIJwtK)4De(&VRnx4l%6|Y5L$$ zu%-m08mi$?iE6Wp>kpmLEudo~#wu{x<*qQBq+~?gL^SrI6q*?76AxAy?Uajk=c7&a zboB-+@BTW+RX#@_OyF5ob@6JN#!lnhM*Rw;kjs=Vh<|gg@|Q1ZhJ zrQZE(yC);FS_cb5W)uYzuVOryU*$7C_QLQKV(^I8K9dIp~M3IYr&hi<|CVn&VuXGFHIFygEyJIMF4>rDPd1t*H{w6Ne*81 zVK&BbfW;io(1@fAs>p*;zcYCUee8bX+|W#CAARDFpZ!=W+N%SP;U#d|;lKLGxu_%r ztJ4j?30y=uzD1( zd97aEzwl6ryy*H{Q=TU@Stg7=gt7um2(-$4LbSch1R~vKUiY7@EM1Qi6)cw>{}|2; zae-*ez~~uZ!qVF|+0-~;yxezie+m4gtuTv7nm;ML70j-QlcW!%+u#9?lWc$U0#Ilz zuaDLjDxAI!l&fB({;NHQF(QjF0H+t=>6ZoJS}ovxJ$*pAKRK0;;M$BKLb16CGZ_v0 zxE`L&m)*m!%UINa2dT1mwyI(VaRDuCExsCG% zyo*Eq#4J}QHr$*W)EbyspC3XvK!>{F$@d^0 zmNt*Bu!AAYHzojzTFp+UL9~4?D_p*$7vC=QEOyS|H*_^+pq zNIEp?d9L21Ixntr8Q(@55nu0-M8wpfK3#KOf*2c1=QI*7(0r#3i%93{{{U255ps(s zHkPV`jCRy~E3}7jB>ymF6f8KQxN|xz>0BNOltN-dFJbMKQAQNeIJ*Z>u&5mHX}tcl zY5R@;7vZ(00{gQ$()XX=XK}l~BC5n8u;+%qsfZZie`FZC5HJz@z>fGfnwU7WVjZ5) zss|NIn8}XhBZ;0v^^^uyHKe{L_=>M@)!<#?WP`B4b_@J}5t7-|fAZ;R(D3(tg=ahQd*7BBhTv#4`7z!M1f zhsXtMUKA9?QNkr1%;AAdoFn)6QleP~fcM(7O{(tOp=9b8c4lm0sGey9U!guS@u7l> zT?zy}0(xJh-dok2N-X|Xa!nc`6Yl-b;ZSngy+sOgjC2s$FL(v{eZBy^icsP-@+M)z zpWTkc^+*h*lv_r@TXzd~`i`gj^s`+<2tgjnmFP$m2Xk;4wYyWUlVlBV%6?)U-Po}oLU0gb#*XtfI>GliXO76JW>gD(Wm z{F0Pjt62Nr*_^lL!f)*A!sjxy{pIcU)NnCt`^&p~e96`ip6`;c;057X-UU%?Dvn^0 z8^C-P4Nj%(#7DyTWJFbiCl-!D0274Vg|;BhZlK#*ZJ0@TM_}=o&F3-EaR*WOdYAh! zI~z~|`uC~08DPr*ZPa=hp`9(ub@MMIwq57;%U;1H1?CS)lxm>a>&vG>aT{3YuGT`! zr>d(ps~X0fR{SG%SaAclGi7V74zD*tgWbBnD{LZ2;Di8?0aBw*6|8pOOJO!o+yF=> zoMfDr?u-gd5Uoy+10i`Dc@?G=MtbYNUg(a<+YZyGSQsc)Kn8e8p*B!g`&S_Bg8YKI z8H}F^Gf&qBLUTGSt~glVnkbq+k*HC97+bJz{tFu2#E4%obivM14`vPMZrD_%5!VnF zxxA#(v}5$sS~ef&XSHb4Sit1bVh~Jlvpo)M+7v+vpU?t-Z>W7XNi{q9uiCNx9^6{K zAZ@7EEmM!iKBpq#g$zsv!t!Hy5%{S!l|CIgP%D75+5^VF=p);``w|=2)M%(XK8OTP z;JrKG7=H-~4%C6fL+>$Z@xQ)iI|mJPYr|?B*BJsR(nkdBjKgJj);>X~M+#j3=v?4e z95fQ(b#rs2PvFKlN(X zMc+yj+UJ=b5JEkRjxAC-M|S3@$<7x;>yuWMIVGQ8022PJ?`Fl_c3MxD|85-{*jcB1S2ZbLNH&PGv+4P)lL zF)Gu}1fdp}|2vG^QLo!BfTy|V-z1eowk&9J(~8nxzWK`=8RIQ_Xi#G+H+;g}|G;VM zD9uB4=iJ>a(QrK;@QGkE24)tP1yA}p_5*5v4Du)WyKV1XA4{Gn8dy(bq_3FbR#Yv& zYdAo>_;$$RvaH((u~G!7O)tNswC(-7q`kLWs|@j{;7)&=z(o+-vU0ABIC&ylJAvN4#ToS(#nu_&vzz6GE=BXjTb4 z7eZ*r^6>d_PAL-b*kez!FKzH>elItRkXnJQ(w{@*(}Q*hJoA%0>L1$W4u*$ODahk= zRqHX%-u9jda16a^-t)2-mQYOQ&=IgX=~<{9or4Jtybns+^PYI@A$nPDmvZoT>4yGZ z{apB`&5 z0ENW5B$0>2TQ9SIs++vbIWsZ0^n+is6*I(GAzHz)u8iMW_;H=IKS=uwAgHOCijBpm zLUdO?i^;cq)7^=Q9=pMxDyS;aGUwN{$q@@EFy>tCg zWUCSQ>nJ3oIr609mb*qO0r$`pWtRK!XGv#{|50uuAOMq2_nB(tkEv;t^oJAYssxDf z4@ZJM1lb)_Ya?F*p8apGFRD)pZScuzrDe+_GA^na3=(F(=Dn%4pHP}+>D~z&RSt;T zj-hl66|@n|Z=JsM=UQI&lcVryAU1M zTh`GpD`iTAB|V4J@eim%d;?ohIxHC^m1){Jb+vCEm)D?oob7qyrY%|4q$lqLYg1G3 z+nqB&fqN((AkT-!=y|>=HbJm2Z=~PGAuf~17b|udrpnPx2cg1pIG-G`!@&Te7T^~=se?{Q$5G}<$3?(i|BX}b74plN+Ul0Z~tW{N-zDJC)I-1k+eCV5O29^;?(}gbYm4EdfN8UVY+8lQ6xwEZofGK zROQqj`qT72=g7k)nlF|eA(3Vi81G#4g-%IgbHk z!Y4DbDJ2@p*AH7N_vf7|8SFNQXKX8`gA6wKyH+IK0OsMkQsujGr^w?3)4w*cJkC12fVg_I)Qp~XJ%Oh=M9&P z@N%bF?b!5$%2@q0Er?GQ!SPo(fzAs7e9siARujnO z8^&Zj8X`A!2!Cf3C46aqRkQe}8f4?WVV3L{O-gU$W1WDkCA3nYM4OdPwuR@PFzt-4 zsk_x2@srel?@@pCI{g8S_q$?1YEIQd=`ltQOcV`)WxX0ahk#CZvvor&Fh2d+Gb8|H z4=zkX>|O`{+0OYf9`dPpSa4jYH-nr*$R|%rDt_aa08$i((2v6cPM7m>o(%v+4!Seo z>+3)TrQ%r{H+&EUKSc+tIIpZ?Y>QyL5Nh10*vGq_0j%C`3^{uw3%=)zt~XvmD2>^) z?SpXEA9%`1nz#EWkwG3inoWjA{y|GO7?H3VZ;h^=jBt#x!`n@s_lFp|eux3G)?U70 zfr6plUXgI@v!{pTbSSyVOkb=do+5XGTS^`e~{t>RHvKhvkx;r*wbis0yBUhwB zwl8hC3F;FwXZEwRgo2IsS}_p~*Zn_{`~M1cWBP~~l(Ts?&3|bV;ly5qVAS}q z-lpu%`erDYC#=L9)z{Lf9slzpv@%GGz}LqydIP)srm)!KC&~k(mtB+ zVIkEPM4IGK?5|a@rfKPy-8*bn12*{XOo}}7?ftCy+};E7&k~mn#P zhDz9qe;MxS_ow4MwwBbs&c1S|!07Y8LT!9o1w9#v^P+#!{R=ppgsgD?wgZx}r{_^s z(~NKiz=*hj%GF#i)QKPbbj|FpI1`_;=OUK6_d0i>vSDy{x6$Ffmc9L?)Pn|Pcsj7B zDn)s?Na~R*?f7a06$vcydfhN^f0-y6JtC$|&=v)GX9ek0aFmMC(>&qlV7V0(h%$MG(Fe zOl%@+S{Z_%8&h;Kz{Sic-J5H-nQ!DQ9cUX?JL2u2H{bSA#RpP2vrr>tVs)+fp(z;s zch}H=fN4u~O)Ge_02pT6$Qy;^!1fRxTKyVKO-1S`}A}=?< zT1EymhQ_70iAUX=Z)B^2U9=$>F$Z+Aa+witPpmW#q@<|~UAXA8{?@3i_TGnE&N^p> znqBFoKqUyUTc5K~4uaf=cy6mV_tZ;g4-AWI#Y4LPR>Y2IPqZ_Cw0i4v-3J-xtyE`R zH~q68pv1{bBI~Kr!NDD9Ipj1Vi>LLqXiN2J0=^=b8jdZg6YBavF^2IzD-T+0l{}Pg zDxbF7PLBxl&?vVl?6ziHG~VhV`6F3Gn$$D_k0=J#;%WDGJE~bHOQaJr|6DIR`0odR z(F(#3f0M6m(i;CI#S#(7_Ws*ai-mwG1JjW;+bfjmj;{Vj^~h_b_UHQX4cPT!OUkH| zz24-tEJj`|y`K-YwnM%ifL>D9eG%xCz#AkAN+QQxt(V9!a~)FtAOF=tNtkb*V}p&} z@lbOt{OH|TR{QeL5=(HUCNyJ#?83#*R+H2DU4#t6T7)c6oRfsj;ofd>ZN3Y0@If{J zz&)UkVtr4~scs!$+;s)r{%l7%U>8y@{9w&Efg0HH-iOZTvVBF|K98naGw3#+B=^9F zo2+`-wSil9MtZLpQ@66hMKKi(w#ZYdQGdkLbC?pQEbkJ(@t5_cHPGSKP2_RE2xL_0 zHR+kH%q>8DHsDl??l8$q9nATz!4Z_RVw$(;@5oxp>3Wvhw;pILfG)Pq&@Sw0#IE)? z6H7ABoShhrI$QeqP$~E0a$?>aLXufdi0CEq6S7;g-od&t8szsPT?}usJy{=b8@Zw zPtM*!-7E`_^d?^-RX=QI)Jwd5{y`mOCNaeb>lZFLm}~cIal~Z_Cm5vTG(Y zZ75=F5BqxBhQ1e4GaovKf`)ijV7%L1tK&OrTf=K;e?|xp8s$NF>itP+CeNm2s2?9T zRNj7aRBW7^yL|I8{D70HiIg`iaUt2BED_4L3tcK2zgQ?Dad+J1hOmYfUA#LZCAo8I zz==Qn8A*l6)*5v*@vm}2h2biWV3YU4oka5Tam0WyWjd}3n zBQe`Xd)1z|M+uEc+~iOU|MT7+ofYyu-5k6KcZN*OWYmh(|CbX71%p_89oA$(5Q_B4 zYe!X}v#dV7c`5ID&m`?h7s!CMt5SxvQ+=}0SpBFSaMWjQsLuaFz;Hgj3d~(RBNQY( zN^fa*h|cx)6x48R&H(vr?ayt{31z}>8f;+1$CkSUdQ>J-3uV|s={TLDs&gx?io=SYKE4uLWbnO_$9Q9JSy!`|sB0hraU({Mf6(7A& zWZUaChg$g6JAvn^-#{Wr^@~9UZ#CmhPNCJhgvp8t?u_7<@t&4u{=2ru4||)Xi)WV( ztH_Fne2BQg>v#D)26KM4_i+G;%hfH+9#JcI_9@DuP9}$(UU`4yndTEMYeH@kUjpM@3E*K0 zwx7AgN(lX=Lbx ## Examples diff --git a/docs/sections/pipeline_samples/papers/math_shepherd.md b/docs/sections/pipeline_samples/papers/math_shepherd.md new file mode 100644 index 0000000000..ca5b8e9653 --- /dev/null +++ b/docs/sections/pipeline_samples/papers/math_shepherd.md @@ -0,0 +1,299 @@ +--- +hide: toc +--- + +# Create datasets to train a Process Reward Model using Math-Shepherd + +This example will introduce [Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations](https://arxiv.org/abs/2312.08935), an innovative math process reward model (PRM) which assigns reward scores to each step of math problem solutions. Specifically, we will present a recipe to create datasets to train such models. The final sections contain 2 pipeline examples to run the pipeline depending with more or less resources. + +## Replica + +Unlike traditional models that only look at final answers (Output Reward Models or ORM), this system evaluates each step of a mathematical solution and assigns reward scores to individual solution steps. Let's see the Figure 2 from the paper, which makes a summary of the labelling approach presented in their work. + +![Math-Shepherd framework](../../../assets/tutorials-assets/math-sheperd.png) + +In the traditional ORM approach, the annotation was done depending on the final outcome, while the Process Reward Model (PRM) allows labelling the different steps that lead to a solution, making for a richer set of information. + +### Steps involved + +- [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/): This step is in charge of generating solutions for the instruction. Depending on the value set for the `M`, this step can be used to generate both the `golden_solution`, to be used as a reference for the labeller, or the set of `solutions` to be labelled. For the `solutions` column we want some diversity, to allow the model to reach both good and bad solutions, so we have a representative sample for the labeller, so it may be better to use a "weaker" model. + +- [`MathShepherdCompleter`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdcompleter/). This task does the job of the `completer` in the paper, generating completions as presented in Figure 2, section 3.3.2. It doesn't generate a column on it's own, but updates the steps generated in the `solutions` column from the [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/), using as reference to label the data, the `golden_solution`. So in order for this step to work, we need both of this columns in our dataset. Depending on the type of dataset, we may already have access to the `golden_solution`, even if it's with a different name, but it's not the same for the `solutions`. + +- [`FormatPRM`](https://distilabel.argilla.io/dev/components-gallery/task/formatprm/). This step does the auxiliary job of preparing the data to follow the format defined in the paper of having two columns `input` and `label`. After running the [`MathShepherdCompleter`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdcompleter/), we have raw data that can be formatted as the user want. Using [`ExpandColumns`](https://distilabel.argilla.io/latest/components-gallery/steps/expandcolumns/) and this step, one can directly obtain the same format presented in the dataset shared in the paper: [peiyi9979/Math-Shepherd](https://huggingface.co/datasets/peiyi9979/Math-Shepherd?row=0). + +## Data preparation + +For this example, just as the original paper, we are using the [openai/gsm8k](https://huggingface.co/datasets/openai/gsm8k) dataset. We only need a dataset with instructions to be solved (in this case it corresponds to the `question` column), and we can generate everything else using our predefined steps. + +## Building the pipeline + +The pipeline uses `openai/gsm8k` as reference, but the pipeline can be applied to different datasets, keep in mind the prompts can be modified with the current definition, by tweaking the `extra_rules` and `few_shots` in each task: + +```python +from datasets import load_dataset + +from distilabel.steps.tasks import MathShepherdCompleter, MathShepherdGenerator, FormatPRM +from distilabel.models import InferenceEndpointsLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import CombineOutputs, ExpandColumns + +ds_name = "openai/gsm8k" + +ds = load_dataset(ds_name, "main", split="test").rename_column("question", "instruction").select(range(3)) # (1) + +with Pipeline(name="Math-Shepherd") as pipe: + model_id_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct" + model_id_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct" + + llm_70B = InferenceEndpointsLLM( + model_id=model_id_70B, + tokenizer_id=model_id_70B, + generation_kwargs={"max_new_tokens": 1024, "temperature": 0.6}, + ) + llm_8B = InferenceEndpointsLLM( + model_id=model_id_8B, + tokenizer_id=model_id_8B, + generation_kwargs={"max_new_tokens": 2048, "temperature": 0.6}, + ) # (2) + + generator_golden = MathShepherdGenerator( + name="golden_generator", + llm=llm_70B, + ) # (3) + generator = MathShepherdGenerator( + name="generator", + llm=llm_8B, + use_default_structured_output=True, # (9) + M=5 + ) # (4) + completer = MathShepherdCompleter( + name="completer", + llm=llm_8B, + use_default_structured_output=True, + N=4 + ) # (5) + + combine = CombineOutputs() + + expand = ExpandColumns( + name="expand_columns", + columns=["solutions"], + split_statistics=True, + ) # (6) + formatter = FormatPRM(name="format_prm") # (7) + + [generator_golden, generator] >> combine >> completer >> expand >> formatter # (8) +``` + +1. Will use just 3 rows from the sample dataset, and rename the "question" to "instruction", to set the expected value for the [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/). + +2. We will use 2 different LLMs, `meta-llama/Meta-Llama-3.1-70B-Instruct` (a stronger model for the `golden_solution`) and `meta-llama/Meta-Llama-3.1-8B-Instruct` (a weaker one to generate candidate solutions, and the completions). + +3. This [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/) task, that uses the *stronger* model, will generate the `golden_solution` for us, the step by step solution for the task. + +4. Another [`MathShepherdGenerator`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdgenerator/) task, but in this case using the *weaker* model will generate candidate `solutions` (`M=5` in total). + +5. Now the [`MathShepherdCompleter`](https://distilabel.argilla.io/dev/components-gallery/task/mathshepherdcompleter/) task will generate `n=4` *completions* for each step of each candidate solution in the `solutions` column, and label them using the `golden_solution` as shown in Figure 2 in the paper. This step will add the label (it uses [+ and -] tags following the implementation in the paper, but these values can be modified) to the `solutions` column in place, instead of generating an additional column, but the intermediate completions won't be shown at the end. + +6. The [`ExpandColumns`](https://distilabel.argilla.io/latest/components-gallery/steps/expandcolumns/) step expands the solution to match the instruction, so if we had set M=5, we would now have 5x instruction-pair solutions. We set the `split_statistics` to True to ensure the `distilabel_metadata` is split accordingly, othwerwise the number of tokens for each solution would count as the tokens needed for the whole list of solutions generated. One can omit both this and the following step and process the data for training as preferred. + +7. And finally, the [`FormatPRM`](https://distilabel.argilla.io/dev/components-gallery/task/formatprm/) generates two columns: `input` and `label` which prepare the data for training as presented in the original Math-Shepherd dataset. + +8. Both the `generator_golden` and `generator` can be run in parallel as there's no dependency between them, and after that we combine the results and pass them to the `completer`. Finally, we use the `expand` and `formatter` prepare the data in the expected format to train the Process Reward Model as defined in the original paper. + +9. Generate structured outputs to ensure it's easier to parse them, otherwise the models can fail a lot of times with an easy to parse list. + +## Script and final dataset + +To see all the pieces in place, take a look at the full pipeline: + +??? Run + + ```python + python examples/pipe_math_shepherd.py + ``` + +??? "Full pipeline" + + ```python title="pipe_math_shepherd.py" + --8<-- "examples/pipe_math_shepherd.py" + ``` + + The resulting dataset can be seen at: [plaguss/test_math_shepherd_prm](https://huggingface.co/datasets/plaguss/test_math_shepherd_prm). + +### Pipeline with vLLM and ray + +This section contains an alternative way of running the pipeline with a bigger outcome. To showcase how to scale the pipeline, we are using for the 3 generating tasks [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct), highly improving the final quality as it follows much closer the prompt given. Also, we are using `vLLM` and 3 nodes (one per task in this case), to scale up the generation process. + +??? Tip "Math-Shepherd's bigger pipeline" + + ````python + from datasets import load_dataset + + from distilabel.models import vLLM + from distilabel.steps import StepResources + from distilabel.pipeline import Pipeline + from distilabel.steps import CombineOutputs, ExpandColumns + from distilabel.steps.tasks import ( + FormatPRM, + MathShepherdCompleter, + MathShepherdGenerator, + ) + + ds_name = "openai/gsm8k" + + ds = ( + load_dataset(ds_name, "main", split="test") + .rename_column("question", "instruction") + ) + + + with Pipeline(name="Math-Shepherd").ray() as pipe: # (1) + + model_id_72B = "Qwen/Qwen2.5-72B-Instruct" + + llm_72B = vLLM( + model=model_id_72B, + tokenizer=model_id_72B, + extra_kwargs={ + "tensor_parallel_size": 8, # Number of GPUs per node + "max_model_len": 2048, + }, + generation_kwargs={ + "temperature": 0.5, + "max_new_tokens": 4096, + }, + ) + + generator_golden = MathShepherdGenerator( + name="golden_generator", + llm=llm_72B, + input_batch_size=50, + output_mappings={"model_name": "model_name_golden_generator"}, + resources=StepResources(replicas=1, gpus=8) # (2) + ) + generator = MathShepherdGenerator( + name="generator", + llm=llm_72B, + input_batch_size=50, + M=5, + use_default_structured_output=True, + output_mappings={"model_name": "model_name_generator"}, + resources=StepResources(replicas=1, gpus=8) + ) + completer = MathShepherdCompleter( + name="completer", + llm=llm_72B, + N=8, + use_default_structured_output=True, + output_mappings={"model_name": "model_name_completer"}, + resources=StepResources(replicas=1, gpus=8) + ) + + combine = CombineOutputs() + + expand = ExpandColumns( + name="expand_columns", + columns=["solutions"], + split_statistics=True, + + ) + formatter = FormatPRM(name="format_prm", format="trl") # (3) + + [generator_golden, generator] >> combine >> completer >> expand >> formatter + + + if __name__ == "__main__": + distiset = pipe.run(use_cache=False, dataset=ds, dataset_batch_size=50) + if distiset: + distiset.push_to_hub("plaguss/test_math_shepherd_prm_ray") + + ```` + + 1. Transform the pipeline to run using `ray` backend. + + 2. Assign the resources: number of replicas 1 as we want a single instance of the task in a node, and number of GPUs equals to 8, using a whole node. Given that we defined the script in the slurm file to use 3 nodes, this will use all the 3 available nodes, with 8 GPUs for each of these tasks. + + 3. Prepare the columns in the format expected by `TRL` for training. + +Click to see the slurm file used to run the previous pipeline. It's our go to `slurm` file, using 3 8xH100 nodes. + +??? Tip "Slurm file" + + ```bash + #!/bin/bash + #SBATCH --job-name=math-shepherd-test-ray + #SBATCH --partition=hopper-prod + #SBATCH --qos=normal + #SBATCH --nodes=3 + #SBATCH --exclusive + #SBATCH --ntasks-per-node=1 + #SBATCH --gpus-per-node=8 + #SBATCH --output=./logs/%x-%j.out + #SBATCH --err=./logs/%x-%j.err + #SBATCH --time=48:00:00 + + set -ex + + module load cuda/12.1 + + echo "SLURM_JOB_ID: $SLURM_JOB_ID" + echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" + + source .venv/bin/activate + + # Getting the node names + nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") + nodes_array=($nodes) + + # Get the IP address of the head node + head_node=${nodes_array[0]} + head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + + # Start Ray head node + port=6379 + ip_head=$head_node_ip:$port + export ip_head + echo "IP Head: $ip_head" + + # Generate a unique Ray tmp dir for the head node + head_tmp_dir="/tmp/ray_tmp_${SLURM_JOB_ID}_head" + + echo "Starting HEAD at $head_node" + srun --nodes=1 --ntasks=1 -w "$head_node" \ + ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --dashboard-host=0.0.0.0 \ + --dashboard-port=8265 \ + --temp-dir="$head_tmp_dir" \ + --block & + + # Give some time to head node to start... + sleep 10 + + # Start Ray worker nodes + worker_num=$((SLURM_JOB_NUM_NODES - 1)) + + # Start from 1 (0 is head node) + for ((i = 1; i <= worker_num; i++)); do + node_i=${nodes_array[$i]} + worker_tmp_dir="/tmp/ray_tmp_${SLURM_JOB_ID}_worker_$i" + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" \ + ray start --address "$ip_head" \ + --temp-dir="$worker_tmp_dir" \ + --block & + sleep 5 + done + + # Give some time to the Ray cluster to gather info + sleep 60 + + # Finally submit the job to the cluster + RAY_ADDRESS="http://$head_node_ip:8265" ray job submit --working-dir pipeline -- python -u pipeline_math_shepherd_ray.py + ``` + +??? Tip "Final dataset" + + The resulting dataset can be seen at: [plaguss/test_math_shepherd_prm_ray](https://huggingface.co/datasets/plaguss/test_math_shepherd_prm_ray). + diff --git a/examples/pipe_math_shepherd.py b/examples/pipe_math_shepherd.py new file mode 100644 index 0000000000..5d1ec73789 --- /dev/null +++ b/examples/pipe_math_shepherd.py @@ -0,0 +1,74 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset + +from distilabel.models import InferenceEndpointsLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import CombineOutputs, ExpandColumns +from distilabel.steps.tasks import ( + FormatPRM, + MathShepherdCompleter, + MathShepherdGenerator, +) + +ds_name = "openai/gsm8k" + +ds = ( + load_dataset(ds_name, "main", split="test") + .rename_column("question", "instruction") + .select(range(3)) +) + + +with Pipeline(name="Math-Shepherd") as pipe: + model_id_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct" + model_id_8B = "meta-llama/Meta-Llama-3.1-8B-Instruct" + + llm_70B = InferenceEndpointsLLM( + model_id=model_id_8B, + tokenizer_id=model_id_8B, + generation_kwargs={"max_new_tokens": 1024, "temperature": 0.5}, + ) + llm_8B = InferenceEndpointsLLM( + model_id=model_id_8B, + tokenizer_id=model_id_8B, + generation_kwargs={"max_new_tokens": 2048, "temperature": 0.7}, + ) + + generator_golden = MathShepherdGenerator( + name="golden_generator", + llm=llm_70B, + ) + generator = MathShepherdGenerator( + name="generator", + llm=llm_8B, + M=5, + ) + completer = MathShepherdCompleter(name="completer", llm=llm_8B, N=4) + + combine = CombineOutputs() + + expand = ExpandColumns( + name="expand_columns", + columns=["solutions"], + split_statistics=True, + ) + formatter = FormatPRM(name="format_prm") + [generator_golden, generator] >> combine >> completer >> expand >> formatter + + +if __name__ == "__main__": + distiset = pipe.run(use_cache=False, dataset=ds) + distiset.push_to_hub("plaguss/test_math_shepherd_prm") diff --git a/mkdocs.yml b/mkdocs.yml index b892b18275..15c7e73a16 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -212,6 +212,7 @@ nav: - UltraFeedback: "sections/pipeline_samples/papers/ultrafeedback.md" - APIGen: "sections/pipeline_samples/papers/apigen.md" - CLAIR: "sections/pipeline_samples/papers/clair.md" + - Math Shepherd: "sections/pipeline_samples/papers/math_shepherd.md" - Examples: - Benchmarking with distilabel: "sections/pipeline_samples/examples/benchmarking_with_distilabel.md" - Structured generation with outlines: "sections/pipeline_samples/examples/llama_cpp_with_outlines.md" diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index eed9fa012b..7665ea4221 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -223,7 +223,7 @@ def unload(self) -> None: super().unload() def _cleanup_vllm_model(self) -> None: - import torch + import torch # noqa from vllm.distributed.parallel_state import ( destroy_distributed_environment, destroy_model_parallel, diff --git a/src/distilabel/steps/columns/expand.py b/src/distilabel/steps/columns/expand.py index 709ca4bc66..989924cf8a 100644 --- a/src/distilabel/steps/columns/expand.py +++ b/src/distilabel/steps/columns/expand.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from itertools import zip_longest from typing import TYPE_CHECKING, Any, Dict, List, Union -from pydantic import field_validator +from pydantic import field_validator, model_validator +from typing_extensions import Self from distilabel.steps.base import Step, StepInput @@ -34,6 +36,19 @@ class ExpandColumns(Step): columns: A dictionary that maps the column to be expanded to the new column name or a list of columns to be expanded. If a list is provided, the new column name will be the same as the column name. + encoded: A bool to inform Whether the columns are JSON encoded lists. If this value is + set to True, the columns will be decoded before expanding. Alternatively, to specify + columns that can be encoded, a list can be provided. In this case, the column names + informed must be a subset of the columns selected for expansion. + split_statistics: A bool to inform whether the statistics in the `distilabel_metadata` + column should be split into multiple rows. + If we want to expand some columns containing a list of strings that come from + having parsed the output of an LLM, the tokens in the `statistics_{step_name}` + of the `distilabel_metadata` column should be splitted to avoid multiplying + them if we aggregate the data afterwards. For example, with a task that is supposed + to generate a list of N instructions, and we want each of those N instructions in + different rows, we should split the statistics by N. + In such a case, set this value to True. Input columns: - dynamic (determined by `columns` attribute): The columns to be expanded into @@ -68,9 +83,66 @@ class ExpandColumns(Step): # >>> result # [{'instruction': 'instruction 1', 'generation': 'generation 1'}, {'instruction': 'instruction 1', 'generation': 'generation 2'}] ``` + + Expand the selected columns which are JSON encoded into multiple rows: + + ```python + from distilabel.steps import ExpandColumns + + expand_columns = ExpandColumns( + columns=["generation"], + encoded=True, # It can also be a list of columns that are encoded, i.e. ["generation"] + ) + expand_columns.load() + + result = next( + expand_columns.process( + [ + { + "instruction": "instruction 1", + "generation": '["generation 1", "generation 2"]'} + ], + ) + ) + # >>> result + # [{'instruction': 'instruction 1', 'generation': 'generation 1'}, {'instruction': 'instruction 1', 'generation': 'generation 2'}] + ``` + + Expand the selected columns and split the statistics in the `distilabel_metadata` column: + + ```python + from distilabel.steps import ExpandColumns + + expand_columns = ExpandColumns( + columns=["generation"], + split_statistics=True, + ) + expand_columns.load() + + result = next( + expand_columns.process( + [ + { + "instruction": "instruction 1", + "generation": ["generation 1", "generation 2"], + "distilabel_metadata": { + "statistics_generation": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + } + ], + ) + ) + # >>> result + # [{'instruction': 'instruction 1', 'generation': 'generation 1', 'distilabel_metadata': {'statistics_generation': {'input_tokens': [6], 'output_tokens': [6]}}}, {'instruction': 'instruction 1', 'generation': 'generation 2', 'distilabel_metadata': {'statistics_generation': {'input_tokens': [6], 'output_tokens': [6]}}}] + ``` """ columns: Union[Dict[str, str], List[str]] + encoded: Union[bool, List[str]] = False + split_statistics: bool = False @field_validator("columns") @classmethod @@ -88,6 +160,22 @@ def always_dict(cls, value: Union[Dict[str, str], List[str]]) -> Dict[str, str]: return value + @model_validator(mode="after") + def is_subset(self) -> Self: + """Ensure the "encoded" column names are a subset of the "columns" selected. + + Returns: + The "encoded" attribute updated to work internally. + """ + if isinstance(self.encoded, list): + if not set(self.encoded).issubset(set(self.columns.keys())): + raise ValueError( + "The 'encoded' columns must be a subset of the 'columns' selected for expansion." + ) + if isinstance(self.encoded, bool): + self.encoded = list(self.columns.keys()) if self.encoded else [] + return self + @property def inputs(self) -> "StepColumns": """The columns to be expanded.""" @@ -110,6 +198,11 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore Yields: The expanded rows. """ + if self.encoded: + for input in inputs: + for column in self.encoded: + input[column] = json.loads(input[column]) + yield [row for input in inputs for row in self._expand_columns(input)] def _expand_columns(self, input: Dict[str, Any]) -> List[Dict[str, Any]]: @@ -121,11 +214,75 @@ def _expand_columns(self, input: Dict[str, Any]) -> List[Dict[str, Any]]: Returns: The expanded rows. """ + metadata_visited = False expanded_rows = [] - for expand_column, new_column in self.columns.items(): # type: ignore + # Update the columns here to avoid doing the validation on the `inputs`, as the + # `distilabel_metadata` is not defined on Pipeline creation on the DAG. + columns = self.columns + if self.split_statistics: + columns["distilabel_metadata"] = "distilabel_metadata" + + for expand_column, new_column in columns.items(): # type: ignore data = input.get(expand_column) + input, metadata_visited = self._split_metadata( + input, len(data), metadata_visited + ) + rows = [] for item, expanded in zip_longest(*[data, expanded_rows], fillvalue=input): rows.append({**expanded, new_column: item}) expanded_rows = rows return expanded_rows + + def _split_metadata( + self, input: Dict[str, Any], n: int, metadata_visited: bool = False + ) -> None: + """Help method to split the statistics in `distilabel_metadata` column. + + Args: + input: The input data. + n: Number of splits to apply to the tokens (if we have 12 tokens and want to split + them 3 times, n==3). + metadata_visited: Bool to prevent from updating the data more than once. + + Returns: + Updated input with the `distilabel_metadata` updated. + """ + # - If we want to split the statistics, we need to ensure that the metadata is present. + # - Metadata can only be visited once per row to avoid successive splitting. + # TODO: For an odd number of tokens, this will miss 1, we have to fix it. + if ( + self.split_statistics + and (metadata := input.get("distilabel_metadata", {})) + and not metadata_visited + ): + for k, v in metadata.items(): + if ( + not v + ): # In case it wasn't found in the metadata for some error, skip it + continue + if k.startswith("statistics_") and ( + "input_tokens" in v and "output_tokens" in v + ): + # For num_generations>1 we assume all the tokens should be divided by n + # TODO: The tokens should always come as a list, but there can + # be differences + if isinstance(v["input_tokens"], list): + input_tokens = [value // n for value in v["input_tokens"]] + else: + input_tokens = [v["input_tokens"] // n] + if isinstance(v["input_tokens"], list): + output_tokens = [value // n for value in v["output_tokens"]] + else: + output_tokens = [v["output_tokens"] // n] + + input["distilabel_metadata"][k] = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + } + metadata_visited = True + # Once we have updated the metadata, Create a list out of it to let the + # following section to expand it as any other column. + if isinstance(input["distilabel_metadata"], dict): + input["distilabel_metadata"] = [input["distilabel_metadata"]] * n + return input, metadata_visited diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 98974b00db..8e96d59f0a 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -43,6 +43,9 @@ ) from distilabel.steps.tasks.magpie.base import Magpie from distilabel.steps.tasks.magpie.generator import MagpieGenerator +from distilabel.steps.tasks.math_shepherd.completer import MathShepherdCompleter +from distilabel.steps.tasks.math_shepherd.generator import MathShepherdGenerator +from distilabel.steps.tasks.math_shepherd.utils import FormatPRM from distilabel.steps.tasks.pair_rm import PairRM from distilabel.steps.tasks.prometheus_eval import PrometheusEval from distilabel.steps.tasks.quality_scorer import QualityScorer @@ -81,6 +84,9 @@ "InstructionBacktranslation", "Magpie", "MagpieGenerator", + "MathShepherdGenerator", + "MathShepherdCompleter", + "FormatPRM", "PairRM", "PrometheusEval", "QualityScorer", diff --git a/src/distilabel/steps/tasks/math_shepherd/__init__.py b/src/distilabel/steps/tasks/math_shepherd/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/steps/tasks/math_shepherd/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/steps/tasks/math_shepherd/completer.py b/src/distilabel/steps/tasks/math_shepherd/completer.py new file mode 100644 index 0000000000..3606c4ec98 --- /dev/null +++ b/src/distilabel/steps/tasks/math_shepherd/completer.py @@ -0,0 +1,613 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Union + +from jinja2 import Template +from pydantic import PositiveInt +from typing_extensions import override + +from distilabel.constants import DISTILABEL_METADATA_KEY +from distilabel.steps.base import StepInput +from distilabel.steps.tasks.base import Task +from distilabel.steps.tasks.math_shepherd.utils import ( + parse_json_response, + split_solution_steps, +) +from distilabel.utils.itertools import batched + +if TYPE_CHECKING: + from distilabel.models.llms.typing import LLMStatistics + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns, StepOutput + + +SYSTEM_PROMPT = """\ +You are a math teacher who helps students by breaking down word problems into clear, logical steps. +When given a problem statement and any number of initial step, generate the remaining steps needed to reach the final answer. +Each step should: + +- Build logically on previous steps +- Explain the reasoning in natural language +- Lead to the final answer +- Multiple solution paths are acceptable +- Steps should be concise but clear +- Each calculation should be shown explicitly +- The final answer must be clearly stated +- The number of steps may vary based on the solution approach + +# Format requirements: +- Each step should be numbered sequentially, continuing from the last given step +- The final step should clearly state "The answer is: [result]" +- Each step can use different approaches but must be mathematically valid + +{{ extra_rules }}{{ few_shots }}{{ structured_prompt }}""" + +SYSTEM_PROMPT_STRUCTURED: Final[str] = """ +Your answer must adhere to the following format, with each step by step solution in a separate object: +``` +[ + { + "solution": "Step i: The step i\nStep i+1: The step i+1\n...\nThe answer is: [Your final answer]", + }, + ... (more solutions as required) +] +``` +""" + +RULES_GSM8K: Final[str] = """\ +# Rules: +- All calculations must be shown within <<>> brackets +- Basic operations: use * for multiplication, / for division, + for addition, - for subtraction +- Write the full calculation and result, e.g., <<5*10=50>>50 +""" + +FEW_SHOTS_GSM8K: Final[str] = """ +# Examples: +## Input +Krystian works in the library. He borrows an average of 40 books every day. Every Friday, his number of borrowed books is about 40% higher than the daily average. How many books does he borrow in a week if the library is open from Monday to Friday? +Step 1: On Friday, Krystian borrows 40 * 0.4 = <<40*0.4=16>>16 more books than on a regular day. + +## Output 1 +Step 2: On Friday, Krystian borrows 40 + 16 = <<40+16=56>>56 books in total. +Step 3: For the other 4 days (Monday to Thursday), he borrows 40 * 4 = <<40*4=160>>160 books. +Step 4: The total books for the week is 160 + 56 = <<160+56=216>>216. The answer is: 216 + +## Output 2 +Step 2: In total, he borrows 40 + 16 = <<40+16=56>>56 books on Friday. +Step 3: For the whole week (4 regular days plus Friday), the total is (40 * 4) + 56 = <<(40*4)+56=216>>216. The answer is: 216 + +## Output 3 +Step 2: On Friday, he borrows 40 + 40/100 * 40 = <<40+40/100*40=56>>56 books. +Step 3: In a week, he borrows 5.7 * 7 = <<5.7*7=40>>40 books. The answer is: 40""" + + +TEMPLATE: str = """Generate {{ N }} example solutions to the same problem, separated by a single `---` and nothing else. +Response format: +``` +Step i: step i explanation. +Step i+1: step i+1 explanation. +The answer is: X + +--- + +Step i: step i explanation. +Step i+1: step i+1 explanation. +The answer is: Y +``` + +This is the problem: +{{ instruction }} +""" + + +TEMPLATE_STRUCTURED: str = """Generate {{ N }} example solutions for the following problem:\n{{ instruction }}""" + + +# Type aliases +StepSolution = List[str] +Completions = List[StepSolution] + + +class MathShepherdCompleter(Task): + """Math Shepherd Completer and auto-labeller task. + + This task is in charge of, given a list of solutions to an instruction, and a golden solution, + as reference, generate completions for the solutions, and label them according to the golden + solution using the hard estimation method from figure 2 in the reference paper, Eq. 3. + The attributes make the task flexible to be used with different types of dataset and LLMs, and + allow making use of different fields to modify the system and user prompts for it. Before modifying + them, review the current defaults to ensure the completions are generated correctly. + + Attributes: + system_prompt: The system prompt to be used in the completions. The default one has been + checked and generates good completions using Llama 3.1 with 8B and 70B, + but it can be modified to adapt it to the model and dataset selected. + extra_rules: This field can be used to insert extra rules relevant to the type of dataset. + For example, in the original paper they used GSM8K and MATH datasets, and this field + can be used to insert the rules for the GSM8K dataset. + few_shots: Few shots to help the model generating the completions, write them in the + format of the type of solutions wanted for your dataset. + N: Number of completions to generate for each step, correspond to N in the paper. + They used 8 in the paper, but it can be adjusted. + tags: List of tags to be used in the completions, the default ones are ["+", "-"] as in the + paper, where the first is used as a positive label, and the second as a negative one. + This can be updated, but it MUST be a list with 2 elements, where the first is the + positive one, and the second the negative one. + + Input columns: + - instruction (`str`): The task or instruction. + - solutions (`List[str]`): List of solutions to the task. + - golden_solution (`str`): The reference solution to the task, will be used + to annotate the candidate solutions. + + Output columns: + - solutions (`List[str]`): The same columns that were used as input, the "solutions" is modified. + - model_name (`str`): The name of the model used to generate the revision. + + Categories: + - text-generation + - labelling + + References: + - [`Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations`](https://arxiv.org/abs/2312.08935) + + Examples: + Annotate your steps with the Math Shepherd Completer using the structured outputs (the preferred way): + + ```python + from distilabel.steps.tasks import MathShepherdCompleter + from distilabel.models import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + generation_kwargs={ + "temperature": 0.6, + "max_new_tokens": 1024, + }, + ) + task = MathShepherdCompleter( + llm=llm, + N=3, + use_default_structured_output=True + ) + + task.load() + + result = next( + task.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "golden_solution": ["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", "The answer is: 18"], + "solutions": [ + ["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", "The answer is: 18"], + ['Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking.', 'Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day.', 'Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18.', 'The answer is: 18'], + ] + }, + ] + ) + ) + # [[{'instruction': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + # 'golden_solution': ["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\\u2019s market.", "The answer is: 18"], + # 'solutions': [["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. -", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\\u2019s market.", "The answer is: 18"], ["Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking. +", "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day. +", "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18.", "The answer is: 18"]]}]] + ``` + + Annotate your steps with the Math Shepherd Completer: + + ```python + from distilabel.steps.tasks import MathShepherdCompleter + from distilabel.models import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + generation_kwargs={ + "temperature": 0.6, + "max_new_tokens": 1024, + }, + ) + task = MathShepherdCompleter( + llm=llm, + N=3 + ) + + task.load() + + result = next( + task.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "golden_solution": ["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", "The answer is: 18"], + "solutions": [ + ["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", "The answer is: 18"], + ['Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking.', 'Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day.', 'Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18.', 'The answer is: 18'], + ] + }, + ] + ) + ) + # [[{'instruction': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + # 'golden_solution': ["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\\u2019s market.", "The answer is: 18"], + # 'solutions': [["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. -", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\\u2019s market.", "The answer is: 18"], ["Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking. +", "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day. +", "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18.", "The answer is: 18"]]}]] + ``` + + Citations: + + ``` + @misc{wang2024mathshepherdverifyreinforcellms, + title={Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations}, + author={Peiyi Wang and Lei Li and Zhihong Shao and R. X. Xu and Damai Dai and Yifei Li and Deli Chen and Y. Wu and Zhifang Sui}, + year={2024}, + eprint={2312.08935}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2312.08935}, + } + ``` + """ + + system_prompt: Optional[str] = SYSTEM_PROMPT + extra_rules: Optional[str] = RULES_GSM8K + few_shots: Optional[str] = FEW_SHOTS_GSM8K + N: PositiveInt = 1 + tags: list[str] = ["+", "-"] + + def load(self) -> None: + super().load() + + if self.system_prompt is not None: + self.system_prompt = Template(self.system_prompt).render( + extra_rules=self.extra_rules or "", + few_shots=self.few_shots or "", + structured_prompt=SYSTEM_PROMPT_STRUCTURED + if self.use_default_structured_output + else "", + ) + if self.use_default_structured_output: + self._template = Template(TEMPLATE_STRUCTURED) + else: + self._template = Template(TEMPLATE) + + @property + def inputs(self) -> "StepColumns": + return ["instruction", "solutions", "golden_solution"] + + @property + def outputs(self) -> "StepColumns": + return ["model_name"] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + messages = [ + { + "role": "user", + "content": self._template.render( + instruction=input["instruction"], N=self.N + ), + } + ] + if self.system_prompt: + messages.insert(0, {"role": "system", "content": self.system_prompt}) + return messages # type: ignore + + def _parse_output(self, output: Union[str, None]) -> list[list[str]]: + if output is None: + return [[""]] * self.N + + if self.N > 1: + output_transformed = ( # type: ignore + self._format_structured_output(output) + if self.use_default_structured_output + else output.split("---") + ) + examples = [split_solution_steps(o) for o in output_transformed] + # In case there aren't the expected number of completions, we fill it with "", or short the list. + # This shoulnd't happen if the LLM works as expected, but it's a safety measure as it can be + # difficult to debug if the completions don't match the solutions. + if len(examples) < self.N: + examples.extend([""] * (self.N - len(examples))) # type: ignore + elif len(examples) > self.N: + examples = examples[: self.N] + else: + output_transformed = ( + self._format_structured_output(output)[0] + if self.use_default_structured_output + else output + ) + examples = [split_solution_steps(output_transformed)] + return examples + + def _format_structured_output(self, output: str) -> list[str]: + default_output = [""] * self.N if self.N else [""] + if parsed_output := parse_json_response(output): + solutions = parsed_output["solutions"] + extracted_solutions = [solution["solution"] for solution in solutions] + if len(output) != self.N: + extracted_solutions = default_output + return extracted_solutions + return default_output + + def format_output( + self, + output: Union[str, None], + input: Union[Dict[str, Any], None] = None, + ) -> Dict[str, Any]: + """Does nothing.""" + return {} + + def process(self, inputs: StepInput) -> "StepOutput": + """Does the processing of generation completions for the solutions, and annotate + each step with the logic found in Figure 2 of the paper, with the hard estimation (Eq. (3)). + + Args: + inputs: Inputs to the step + + Yields: + Annotated inputs with the completions. + """ + + # A list with all the inputs to be passed to the LLM. Needs another structure to + # find them afterwards + prepared_inputs = [] + # Data structure with the indices of the elements. + # (i, j, k) where i is the input, j is the solution, and k is the completion + input_positions = [] + golden_answers = [] + for i, input in enumerate(inputs): + instruction = input["instruction"] + golden_solution = input["golden_solution"] # This is a single solution + golden_answers.append(golden_solution[-1]) + # This contains a list of solutions + solutions = input["solutions"] + for j, solution in enumerate(solutions): + # For each solution, that has K steps, we have to generate N completions + # for the first K-2 steps (-2 because the last 2 steps are the last step, and + # the answer itself, which can be directly compared against golden answer) + prepared_completions = self._prepare_completions(instruction, solution) + prepared_inputs.extend(prepared_completions) + input_positions.extend( + [(i, j, k) for k in range(len(prepared_completions))] + ) + + # Send the elements in batches to the LLM to speed up the process + final_outputs = [] + # Added here to simplify testing in case we don't have anything to process + # TODO: Ensure the statistics has the same shape as all the outputs, raw_outputs, and raw_inputs + statistics = [] + total_raw_outputs = [] + total_raw_inputs = [] + for inner_batch in batched(prepared_inputs, self.input_batch_size): # type: ignore + outputs = self.llm.generate_outputs( + inputs=inner_batch, + num_generations=1, + **self.llm.get_generation_kwargs(), # type: ignore + ) + + formatted_outputs = [] + stats = [] + raw_outputs = [] + raw_inputs = [] + for i, output in enumerate(outputs): + generation = output["generations"][0] + raw_inputs.append(inner_batch[i]) + raw_outputs.append(generation or "") + formatted_outputs.append(self._parse_output(generation)) + stats.append(output["statistics"]) + + final_outputs.extend(formatted_outputs) + statistics.extend(stats) + total_raw_outputs.extend(raw_outputs) + total_raw_inputs.extend(raw_inputs) + + yield self._auto_label( # type: ignore + inputs, + final_outputs, + input_positions, + golden_answers, + statistics, + total_raw_outputs, + total_raw_inputs, + ) + + def _prepare_completions( + self, instruction: str, steps: list[str] + ) -> List["ChatType"]: + """Helper method to create, given a solution (a list of steps), and a instruction, the + texts to be completed by the LLM. + + Args: + instruction: Instruction of the problem. + steps: List of steps that are part of the solution. + + Returns: + List of ChatType, where each ChatType is the prompt corresponding to one of the steps + to be completed. + """ + prepared_inputs = [] + # Use the number of completions that correspond to a given instruction/steps pair + # to find afterwards the input that corresponds to a given completion (to do the labelling) + num_completions = len(steps[:-2]) + for i in range(1, num_completions + 1): + to_complete = instruction + " " + "\n".join(steps[:i]) + prepared_inputs.append(self.format_input({"instruction": to_complete})) + + return prepared_inputs + + def _auto_label( + self, + inputs: StepInput, + final_outputs: list[Completions], + input_positions: list[tuple[int, int, int]], + golden_answers: list[str], + statistics: list["LLMStatistics"], + raw_outputs: list[str], + raw_inputs: list[str], + ) -> StepInput: + """Labels the steps inplace (in the inputs), and returns the inputs. + + Args: + inputs: The original inputs + final_outputs: List of generations from the LLM. + It's organized as a list where the elements sent to the LLM are + grouped together, then each element contains the completions, and + each completion is a list of steps. + input_positions: A list with tuples generated in the process method + that contains (i, j, k) where i is the index of the input, j is the + index of the solution, and k is the index of the completion. + golden_answers: List of golden answers for each input. + statistics: List of statistics from the LLM. + raw_outputs: List of raw outputs from the LLM. + raw_inputs: List of raw inputs to the LLM. + + Returns: + Inputs annotated. + """ + for i, (instruction_i, solution_i, step_i) in enumerate(input_positions): + input = inputs[instruction_i] + solutions = input["solutions"] + n_completions = final_outputs[i] + label = f" {self.tags[1]}" + for completion in n_completions: + if len(completion) == 0: + # This can be a failed generation + label = "" # Everyting stays the same + self._logger.info("Completer failed due to empty completion") + continue + if completion[-1] == golden_answers[instruction_i]: + label = f" { self.tags[0]}" + # If we found one, it's enough as we are doing Hard Estimation + continue + # In case we had no solutions from the previous step, otherwise we would have + # an IndexError + if not solutions[solution_i]: + continue + solutions[solution_i][step_i] += label + inputs[instruction_i]["solutions"] = solutions + + for i, input in enumerate(inputs): + solutions = input["solutions"] + new_solutions = [] + for solution in solutions: + if not solution or (len(solution) == 1): + # The generation may fail to generate the expected + # completions, or just added an extra empty completion, + # we skip it. + # Other possible error is having a list of solutions + # with a single item, so when we call .pop, we are left + # with an empty list, so we skip it too. + new_solutions.append(solution) + continue + + answer = solution.pop() + label = ( + f" {self.tags[0]}" + if answer == golden_answers[i] + else f" {self.tags[1]}" + ) + solution[-1] += " " + answer + label + new_solutions.append(solution) + + # Only add the solutions if the data was properly parsed + input["solutions"] = new_solutions if new_solutions else input["solutions"] + input = self._add_metadata( + input, statistics[i], raw_outputs[i], raw_inputs[i] + ) + + return inputs + + def _add_metadata(self, input, statistics, raw_output, raw_input): + """Adds the `distilabel_metadata` to the input. + + This method comes for free in the general Tasks, but as we have reimplemented the `process`, + we have to repeat it here. + + Args: + input: The input to add the metadata to. + statistics: The statistics from the LLM. + raw_output: The raw output from the LLM. + raw_input: The raw input to the LLM. + + Returns: + The input with the metadata added if applies. + """ + input["model_name"] = self.llm.model_name + + if DISTILABEL_METADATA_KEY not in input: + input[DISTILABEL_METADATA_KEY] = {} + # If the solutions are splitted afterwards, the statistics should be splitted + # to avoid counting extra tokens + input[DISTILABEL_METADATA_KEY][f"statistics_{self.name}"] = statistics + + # Let some defaults in case something failed and we had None, otherwise when reading + # the parquet files using pyarrow, the following error will appear: + # ArrowInvalid: Schema + if self.add_raw_input: + input[DISTILABEL_METADATA_KEY][f"raw_input_{self.name}"] = raw_input or [ + {"content": "", "role": ""} + ] + if self.add_raw_output: + input[DISTILABEL_METADATA_KEY][f"raw_output_{self.name}"] = raw_output or "" + return input + + @override + def get_structured_output(self) -> dict[str, Any]: + """Creates the json schema to be passed to the LLM, to enforce generating + a dictionary with the output which can be directly parsed as a python dictionary. + + The schema corresponds to the following: + + ```python + from pydantic import BaseModel, Field + + class Solution(BaseModel): + solution: str = Field(..., description="Step by step solution leading to the final answer") + + class MathShepherdCompleter(BaseModel): + solutions: list[Solution] = Field(..., description="List of solutions") + + MathShepherdCompleter.model_json_schema() + ``` + + Returns: + JSON Schema of the response to enforce. + """ + return { + "$defs": { + "Solution": { + "properties": { + "solution": { + "description": "Step by step solution leading to the final answer", + "title": "Solution", + "type": "string", + } + }, + "required": ["solution"], + "title": "Solution", + "type": "object", + } + }, + "properties": { + "solutions": { + "description": "List of solutions", + "items": {"$ref": "#/$defs/Solution"}, + "title": "Solutions", + "type": "array", + } + }, + "required": ["solutions"], + "title": "MathShepherdGenerator", + "type": "object", + } diff --git a/src/distilabel/steps/tasks/math_shepherd/generator.py b/src/distilabel/steps/tasks/math_shepherd/generator.py new file mode 100644 index 0000000000..d9ab565e54 --- /dev/null +++ b/src/distilabel/steps/tasks/math_shepherd/generator.py @@ -0,0 +1,372 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Final, Optional, Union + +from jinja2 import Template +from pydantic import PositiveInt +from typing_extensions import override + +from distilabel.steps.tasks.base import Task +from distilabel.steps.tasks.math_shepherd.utils import ( + parse_json_response, + split_solution_steps, +) + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + + +SYSTEM_PROMPT = """\ +You are a math tutor that helps students solve math problems by breaking them down into clear, logical steps. Follow these guidelines: + +# For each step: +- Clearly explain the reasoning +- Show the calculated result for any arithmetic calculation +- Present intermediate calculations clearly +- Use clear, concise language to explain the mathematical reasoning + +# Format requirements: +- Number each step starting with "Step 1:" +- The final step should clearly state "The answer is: [result]" +- Keep explanations clear and concise + +{{ extra_rules }}{{ few_shots }}{{ structured_prompt }}""" + + +SYSTEM_PROMPT_STRUCTURED: Final[str] = """ +Your answer must adhere to the following format, with each step by step solution in a separate object: +``` +[ + { + "solution": "Step 1: Your first step\nStep 2: Your second step\n...\nThe answer is: [Your final answer]", + }, + ... (more solutions as required) +] +``` +""" + + +RULES_GSM8K: Final[str] = """\ +# Rules: +- All calculations must be shown within <<>> brackets +- Basic operations: use * for multiplication, / for division, + for addition, - for subtraction +- Write the full calculation and result, e.g., <<5*10=50>>50 +""" + +FEW_SHOTS_GSM8K: Final[str] = """ +# Examples: +## Instruction +A store sells notebooks for $3 each. If you buy 5 or more, you get a 20% discount. How much would you pay for 6 notebooks? + +## Solution +Step 1: Calculate the regular price for 6 notebooks: 6 * $3 = <<63=18>>18 dollars +Step 2: Calculate the 20% discount amount: 18 * 20/100 = <<1820/100=3.6>>3.6 dollars +Step 3: Subtract the discount from the regular price: 18 - 3.6 = <<18-3.6=14.4>>14.4 dollars. The answer is: 14.4 + +## Instruction +A recipe calls for 2.5 cups of flour to make 12 cookies. How many cups of flour are needed to make 30 cookies? + +## Solution +Step 1: Find out how many cups of flour are needed per cookie: 2.5 ÷ 12 = <<2.5/12=0.208333>>0.208333 cups +Step 2: Calculate the flour needed for 30 cookies: 0.208333 * 30 = <<0.208333*30=6.25>>6.25 cups. The answer is: 6.25 +""" + +RULES_MATH: Final[str] = """\ +# Rules: +- Always wrap mathematical expressions in $ symbols +- Use LaTeX-style math notation with $ symbols for mathematical expressions +- Format operations and equations properly using LaTeX notation within $ symbols +- Keep explanations precise and mathematically rigorous +- Use $\boxed{}$ notation only in the final step +""" + +FEW_SHOTS_MATH: Final[str] = """ +# Examples +## Input +Find the sum of the first three perfect squares greater than 50. + +## Output +Step 1: The first perfect square greater than 50 is $8^2 = 64$. +Step 2: The second perfect square is $9^2 = 81$. +Step 3: The third perfect square is $10^2 = 100$. +Step 4: The sum is $64 + 81 + 100 = 245$. +Step 5: Therefore, the answer is $\boxed{245}$. The answer is: 245 + +## Input +What is the value of $2^5 + 3^3$? + +## Output +Step 1: Calculate $2^5 = 32$. +Step 2: Calculate $3^3 = 27$. +Step 3: Add the results: $32 + 27 = 59$. +Step 4: Therefore, the answer is $\boxed{59}$. The answer is: 59 +""" + +TEMPLATE: str = """{% if M %}Generate {{ M }} example solutions to the following problem, separated by a single `---`. This is your problem:{% endif %} +{{ instruction }}""" + +TEMPLATE_STRUCTURED: str = """{% if M %}Generate {{ M }} diverse solutions, even if they are incorrect. This is the problem:{% endif %} +{{ instruction }}""" + + +class MathShepherdGenerator(Task): + """Math Shepherd solution generator. + + This task is in charge of generating completions for a given instruction, in the format expected + by the Math Shepherd Completer task. The attributes make the task flexible to be used with different + types of dataset and LLMs, but we provide examples for the GSM8K and MATH datasets as presented + in the original paper. Before modifying them, review the current defaults to ensure the completions + are generated correctly. This task can be used to generate the golden solutions for a given problem if + not provided, as well as possible solutions to be then labeled by the Math Shepherd Completer. + Only one of `solutions` or `golden_solution` will be generated, depending on the value of M. + + Attributes: + system_prompt: The system prompt to be used in the completions. The default one has been + checked and generates good completions using Llama 3.1 with 8B and 70B, + but it can be modified to adapt it to the model and dataset selected. + Take into account that the system prompt includes 2 variables in the Jinja2 template, + {{extra_rules}} and {{few_shot}}. These variables are used to include extra rules, for example + to steer the model towards a specific type of responses, and few shots to add examples. + They can be modified to adapt the system prompt to the dataset and model used without needing + to change the full system prompt. + extra_rules: This field can be used to insert extra rules relevant to the type of dataset. + For example, in the original paper they used GSM8K and MATH datasets, and this field + can be used to insert the rules for the GSM8K dataset. + few_shots: Few shots to help the model generating the completions, write them in the + format of the type of solutions wanted for your dataset. + M: Number of completions to generate for each step. By default is set to 1, which will + generate the "golden_solution". In this case select a stronger model, as it will be used + as the source of true during labelling. If M is set to a number greater than 1, the task + will generate a list of completions to be labeled by the Math Shepherd Completer task. + + Input columns: + - instruction (`str`): The task or instruction. + + Output columns: + - golden_solution (`str`): The step by step solution to the instruction. + It will be generated if M is equal to 1. + - solutions (`List[List[str]]`): A list of possible solutions to the instruction. + It will be generated if M is greater than 1. + - model_name (`str`): The name of the model used to generate the revision. + + Categories: + - text-generation + + References: + - [`Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations`](https://arxiv.org/abs/2312.08935) + + Examples: + Generate the solution for a given instruction (prefer a stronger model here): + + ```python + from distilabel.steps.tasks import MathShepherdGenerator + from distilabel.models import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + generation_kwargs={ + "temperature": 0.6, + "max_new_tokens": 1024, + }, + ) + task = MathShepherdGenerator( + name="golden_solution_generator", + llm=llm, + ) + + task.load() + + result = next( + task.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + }, + ] + ) + ) + # [[{'instruction': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + # 'golden_solution': '["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\\u2019s market.", "The answer is: 18"]'}]] + ``` + + Generate M completions for a given instruction (using structured output generation): + + ```python + from distilabel.steps.tasks import MathShepherdGenerator + from distilabel.models import InferenceEndpointsLLM + + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + generation_kwargs={ + "temperature": 0.7, + "max_new_tokens": 2048, + }, + ) + task = MathShepherdGenerator( + name="solution_generator", + llm=llm, + M=2, + use_default_structured_output=True, + ) + + task.load() + + result = next( + task.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + }, + ] + ) + ) + # [[{'instruction': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + # 'solutions': [["Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. -", "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\\u2019s market.", "The answer is: 18"], ["Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking. +", "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day. +", "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18.", "The answer is: 18"]]}]] + ``` + """ + + system_prompt: Optional[str] = SYSTEM_PROMPT + extra_rules: Optional[str] = RULES_GSM8K + few_shots: Optional[str] = FEW_SHOTS_GSM8K + M: Optional[PositiveInt] = None + + def load(self) -> None: + super().load() + if self.system_prompt is not None: + self.system_prompt = Template(self.system_prompt).render( + extra_rules=self.extra_rules or "", + few_shots=self.few_shots or "", + structured_prompt=SYSTEM_PROMPT_STRUCTURED + if self.use_default_structured_output + else "", + ) + if self.use_default_structured_output: + self._template = Template(TEMPLATE_STRUCTURED) + else: + self._template = Template(TEMPLATE) + + @property + def inputs(self) -> "StepColumns": + return ["instruction"] + + @property + def outputs(self) -> "StepColumns": + if self.M: + return ["solutions", "model_name"] + return ["golden_solution", "model_name"] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + messages = [ + { + "role": "user", + "content": self._template.render( + instruction=input["instruction"], + M=self.M, + ), + } + ] + if self.system_prompt: + messages.insert(0, {"role": "system", "content": self.system_prompt}) + return messages + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + output_name = "solutions" if self.M else "golden_solution" + + if output is None: + input.update(**{output_name: None}) + return input + + if self.M: + output_parsed = ( + self._format_structured_output(output) + if self.use_default_structured_output + else output.split("---") + ) + solutions = [split_solution_steps(o) for o in output_parsed] + else: + output_parsed = ( + self._format_structured_output(output)[0] + if self.use_default_structured_output + else output + ) + solutions = split_solution_steps(output_parsed) + + input.update(**{output_name: solutions}) + return input + + @override + def get_structured_output(self) -> dict[str, Any]: + """Creates the json schema to be passed to the LLM, to enforce generating + a dictionary with the output which can be directly parsed as a python dictionary. + + The schema corresponds to the following: + + ```python + from pydantic import BaseModel, Field + + class Solution(BaseModel): + solution: str = Field(..., description="Step by step solution leading to the final answer") + + class MathShepherdGenerator(BaseModel): + solutions: list[Solution] = Field(..., description="List of solutions") + + MathShepherdGenerator.model_json_schema() + ``` + + Returns: + JSON Schema of the response to enforce. + """ + return { + "$defs": { + "Solution": { + "properties": { + "solution": { + "description": "Step by step solution leading to the final answer", + "title": "Solution", + "type": "string", + } + }, + "required": ["solution"], + "title": "Solution", + "type": "object", + } + }, + "properties": { + "solutions": { + "description": "List of solutions", + "items": {"$ref": "#/$defs/Solution"}, + "title": "Solutions", + "type": "array", + } + }, + "required": ["solutions"], + "title": "MathShepherdGenerator", + "type": "object", + } + + def _format_structured_output(self, output: str) -> list[str]: + default_output = [""] * self.M if self.M else [""] + if parsed_output := parse_json_response(output): + solutions = parsed_output["solutions"] + extracted_solutions = [o["solution"] for o in solutions] + if len(extracted_solutions) != self.M: + extracted_solutions = default_output + return extracted_solutions + return default_output diff --git a/src/distilabel/steps/tasks/math_shepherd/utils.py b/src/distilabel/steps/tasks/math_shepherd/utils.py new file mode 100644 index 0000000000..bed56e1ff9 --- /dev/null +++ b/src/distilabel/steps/tasks/math_shepherd/utils.py @@ -0,0 +1,318 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import TYPE_CHECKING, Any, Literal, Union + +import orjson + +from distilabel.steps.base import Step, StepInput + +if TYPE_CHECKING: + from distilabel.steps.typing import StepColumns, StepOutput + + +def split_solution_steps(text: str) -> list[str]: + """ + Split a step-by-step solution text into individual components. + Returns a list of steps and the final answer. + """ + # Pattern to match: + # 1. Steps starting with "Step N:" and capturing all content until the next step or answer + # 2. The final answer starting with "The answer is:" + pattern = r"Step \d+:.*?(?=Step \d+:|The answer is:|$)|The answer is:.*" + + # Find all matches, strip whitespace + matches = [match.strip() for match in re.findall(pattern, text, re.DOTALL)] + + return matches + + +class FormatPRM(Step): + """Helper step to transform the data into the format expected by the PRM model. + + This step can be used to format the data in one of 2 formats: + Following the format presented + in [peiyi9979/Math-Shepherd](https://huggingface.co/datasets/peiyi9979/Math-Shepherd?row=0), + in which case this step creates the columns input and label, where the input is the instruction + with the solution (and the tag replaced by a token), and the label is the instruction + with the solution, both separated by a newline. + Following TRL's format for training, which generates the columns prompt, completions, and labels. + The labels correspond to the original tags replaced by boolean values, where True represents + correct steps. + + Attributes: + format (Literal["math-shepherd", "trl"]): The format to use for the PRM model. + "math-shepherd" corresponds to the original paper, while "trl" is a format + prepared to train the model using TRL. + step_token (str): String that serves as a unique token denoting the position + for predicting the step score. + tags (list[str]): List of tags that represent the correct and incorrect steps. + This only needs to be informed if it's different than the default in + `MathShepherdCompleter`. + + Input columns: + - instruction (`str`): The task or instruction. + - solutions (`list[str]`): List of steps with a solution to the task. + + Output columns: + - input (`str`): The instruction with the solutions, where the label tags + are replaced by a token. + - label (`str`): The instruction with the solutions. + - prompt (`str`): The instruction with the solutions, where the label tags + are replaced by a token. + - completions (`List[str]`): The solution represented as a list of steps. + - labels (`List[bool]`): The labels, as a list of booleans, where True represents + a good response. + + Categories: + - text-manipulation + - columns + + References: + - [`Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations`](https://arxiv.org/abs/2312.08935) + - [peiyi9979/Math-Shepherd](https://huggingface.co/datasets/peiyi9979/Math-Shepherd?row=0) + + Examples: + Prepare your data to train a PRM model with the Math-Shepherd format: + + ```python + from distilabel.steps.tasks import FormatPRM + from distilabel.steps import ExpandColumns + + expand_columns = ExpandColumns(columns=["solutions"]) + expand_columns.load() + + # Define our PRM formatter + formatter = FormatPRM() + formatter.load() + + # Expand the solutions column as it comes from the MathShepherdCompleter + result = next( + expand_columns.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"]] + }, + ] + ) + ) + result = next(formatter.process(result)) + # result[0]["input"] + # "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. ки\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. ки\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 ки" + # result[0]["label"] + # "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +" + ``` + + Prepare your data to train a PRM model with the TRL format: + + ```python + from distilabel.steps.tasks import FormatPRM + from distilabel.steps import ExpandColumns + + expand_columns = ExpandColumns(columns=["solutions"]) + expand_columns.load() + + # Define our PRM formatter + formatter = FormatPRM(format="trl") + formatter.load() + + # Expand the solutions column as it comes from the MathShepherdCompleter + result = next( + expand_columns.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"], ["Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", "Step 2: Calculate the amount of white fiber needed: Since it\'s half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +"]] + }, + ] + ) + ) + + result = next(formatter.process(result)) + # { + # "instruction": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + # "solutions": [ + # "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + # "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", + # "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +" + # ], + # "prompt": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + # "completions": [ + # "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required.", + # "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber.", + # "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3" + # ], + # "labels": [ + # true, + # true, + # true + # ] + # } + ``` + + Citations: + + ``` + @misc{wang2024mathshepherdverifyreinforcellms, + title={Math-Shepherd: Verify and Reinforce LLMs Step-by-step without Human Annotations}, + author={Peiyi Wang and Lei Li and Zhihong Shao and R. X. Xu and Damai Dai and Yifei Li and Deli Chen and Y. Wu and Zhifang Sui}, + year={2024}, + eprint={2312.08935}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2312.08935}, + } + ``` + """ + + format: Literal["math-shepherd", "trl"] = "math-shepherd" + step_token: str = "ки" + tags: list[str] = ["+", "-"] + + def model_post_init(self, __context: Any) -> None: + super().model_post_init(__context) + if self.format == "math-shepherd": + self._formatter = self._format_math_shepherd + else: + self._formatter = self._format_trl + + @property + def inputs(self) -> "StepColumns": + return ["instruction", "solutions"] + + @property + def outputs(self) -> "StepColumns": + if self.format == "math-shepherd": + return ["input", "label"] + return ["prompt", "completions", "labels"] + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + """The process prepares the data for the `APIGenGenerator` task. + + If a single example is provided, it is copied to avoid raising an error. + + Args: + inputs: A list of dictionaries with the input data. + + Yields: + A list of dictionaries with the output data. + """ + for input in inputs: + self._formatter(input) + + yield inputs # type: ignore + + def _format_math_shepherd( + self, input: dict[str, str] + ) -> dict[str, Union[str, list[str]]]: + instruction = input["instruction"] + replaced = [] + # At this stage, the "solutions" column can only contain a single solution, + # and the last item of each solution is the tag. + solution = input["solutions"] + for step in solution: + # Check there's a string, because the step that generated + # the solutions could have failed, and we would have an empty list. + replaced.append(step[:-1] + self.step_token if len(step) > 1 else step) + + input["input"] = instruction + " " + "\n".join(replaced) + input["label"] = instruction + " " + "\n".join(solution) + + return input # type: ignore + + def _format_trl( + self, input: dict[str, str] + ) -> dict[str, Union[str, list[str], list[bool]]]: + input["prompt"] = input["instruction"] + completions: list[str] = [] + labels: list[bool] = [] + for step in input["solutions"]: + token = step[-1] + completions.append(step[:-1].strip()) + labels.append(True if token == self.tags[0] else False) + + input["completions"] = completions # type: ignore + input["labels"] = labels # type: ignore + + return input # type: ignore + + +def parse_json_response(json_str: str) -> Union[dict[str, Any], None]: + """Helper function to clean and parse JSON strings generated by LLMs. + Some common errors may appear (see the REPLACEMENTS dictionary) that need to be fixed before parsing, + but the JSON is valid otherwise. + """ + + try: + # First try parsing as-is + return orjson.loads(json_str) + except orjson.JSONDecodeError: + # Apply all replacements + for old, new in REPLACEMENTS.items(): + json_str = json_str.replace(old, new) + + try: + # Try parsing after replacements + return orjson.loads(json_str) + except orjson.JSONDecodeError: + # If still failing, try more aggressive cleaning + + # Remove any non-ASCII characters + json_str = re.sub(r"[^\x00-\x7F]+", "", json_str) + + # Remove any remaining escape sequences except valid ones + json_str = re.sub(r'\\([^"\\\/bfnrt])', r"\1", json_str) + + try: + return orjson.loads(json_str) + except orjson.JSONDecodeError: + # Failed to parse JSON after all cleaning attempts + return None + + +# Dictionary of common LLM JSON artifacts and their replacements +REPLACEMENTS: dict[str, str] = { + # Escape sequence issues + "\\)": ")", # Incorrectly escaped parentheses + "\\]": "]", # Incorrectly escaped brackets + "\\}": "}", # Incorrectly escaped braces + "\\`": "`", # Incorrectly escaped backticks + "\\'": "'", # Incorrectly escaped single quotes + '\\\\"': '\\"', + '\\"': '"', # Incorrectly escaped double quotes + "\\\\n": "\\n", # Double escaped newlines + "\\\\t": "\\t", # Double escaped tabs + "\\\\r": "\\r", # Double escaped carriage returns + # # Markdown artifacts + # '```json\n': '', # Markdown code block start + # '\n```': '', # Markdown code block end + # '`': '', # Inline code markers + # Common mathematical symbols that might be escaped + "\\<": "<", # Less than + "\\>": ">", # Greater than + "\\=": "=", # Equals + "\\+": "+", # Plus + "\\-": "-", # Minus + "\\*": "*", # Asterisk + "\\|": "|", # Pipe + # Unicode escaping issues + "\\u0022": '"', # Double quote + "\\u0027": "'", # Single quote + "\\u005C": "\\", # Backslash + # # Other common issues + # '\n\n': '\n', # Multiple newlines + # '\t\t': '\t', # Multiple tabs +} diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index 621f4b61dc..08877d3cb7 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -92,6 +92,7 @@ "load": ":material-file-download:", "execution": ":octicons-code-16:", "save": ":material-content-save:", + "labelling": ":label:", } _STEP_CATEGORY_TO_DESCRIPTION = { @@ -111,6 +112,7 @@ "load": "Load steps are used to load the data.", "execution": "Executes python functions.", "save": "Save steps are used to save the data.", + "labelling": "Labelling steps are used to label the data.", } assert list(_STEP_CATEGORY_TO_DESCRIPTION.keys()) == list( diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py index 873ff20721..0aaebf261b 100644 --- a/src/distilabel/utils/serialization.py +++ b/src/distilabel/utils/serialization.py @@ -206,7 +206,7 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: "_name": getattr(obj, k).__name__, "_values": {x.name: x.value for x in v}, # type: ignore } - elif isinstance(v, list): + elif isinstance(v, list) and len(v) > 0: obj_list = getattr(obj, k) if isinstance(obj_list, list) and isinstance( obj_list[0], RuntimeParametersMixin diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py index ec97dc5f71..f0cb377bbe 100644 --- a/tests/unit/steps/argilla/test_preference.py +++ b/tests/unit/steps/argilla/test_preference.py @@ -83,6 +83,7 @@ def test_process(self, mock_dataset) -> None: ) with patch.object(PreferenceToArgilla, "load"): step.load() + step._instruction = "instruction" step._generations = "generations" step._ratings = "ratings" diff --git a/tests/unit/steps/columns/test_expand.py b/tests/unit/steps/columns/test_expand.py index c88702a658..a257f4e096 100644 --- a/tests/unit/steps/columns/test_expand.py +++ b/tests/unit/steps/columns/test_expand.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +from typing import Union + +import pytest + from distilabel.pipeline.local import Pipeline from distilabel.steps.columns.expand import ExpandColumns @@ -26,19 +31,248 @@ def test_always_dict(self) -> None: assert expand_columns.columns == {"column1": "column1", "column2": "column2"} - def test_process(self) -> None: + @pytest.mark.parametrize( + "encoded, split_statistics, values, stats", + [ + ( + False, + False, + [ + { + "column1": [1, 2, 3], + "column2": ["a", "b", "c"], + "distilabel_metadata": { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + } + ], + [ + { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + ], + ), + ( + ["column1", "column2"], + False, + [ + { + "column1": json.dumps([1, 2, 3]), + "column2": json.dumps(["a", "b", "c"]), + "distilabel_metadata": { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + } + ], + [ + { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + ], + ), + ( + False, + True, + [ + { + "column1": [1, 2, 3], + "column2": ["a", "b", "c"], + "distilabel_metadata": { + "statistics_column1": { + "input_tokens": [12], + "output_tokens": [12], + }, + "statistics_column2": { + "input_tokens": [12], + "output_tokens": [12], + }, + }, + } + ], + [ + { + "statistics_column1": { + "input_tokens": [4], + "output_tokens": [4], + }, + "statistics_column2": { + "input_tokens": [4], + "output_tokens": [4], + }, + }, + { + "statistics_column1": { + "input_tokens": [4], + "output_tokens": [4], + }, + "statistics_column2": { + "input_tokens": [4], + "output_tokens": [4], + }, + }, + { + "statistics_column1": { + "input_tokens": [4], + "output_tokens": [4], + }, + "statistics_column2": { + "input_tokens": [4], + "output_tokens": [4], + }, + }, + ], + ), + ( + False, + True, + [ + { + "column1": [1, 2, 3], + "column2": ["a", "b", "c"], + "distilabel_metadata": { + "statistics_column1": { + "input_tokens": [793], + "output_tokens": [361], + }, + "statistics_column2": { + "input_tokens": [202], + "output_tokens": [100], + }, + }, + } + ], + [ + { + "statistics_column1": { + "input_tokens": [264], + "output_tokens": [120], + }, + "statistics_column2": { + "input_tokens": [67], + "output_tokens": [33], + }, + }, + { + "statistics_column1": { + "input_tokens": [264], + "output_tokens": [120], + }, + "statistics_column2": { + "input_tokens": [67], + "output_tokens": [33], + }, + }, + { + "statistics_column1": { + "input_tokens": [264], + "output_tokens": [120], + }, + "statistics_column2": { + "input_tokens": [67], + "output_tokens": [33], + }, + }, + ], + ), + ], + ) + def test_process( + self, + encoded: Union[bool, list[str]], + split_statistics: bool, + values: list[dict[str, Union[list[int], list[str], str]]], + stats: dict[str, dict[str, int]], + ) -> None: expand_columns = ExpandColumns( - name="expand_columns", columns=["column1", "column2"], - pipeline=Pipeline(name="unit-test"), + encoded=encoded, + split_statistics=split_statistics, ) - result = next( - expand_columns.process([{"column1": [1, 2, 3], "column2": ["a", "b", "c"]}]) - ) + result = next(expand_columns.process(values)) assert result == [ - {"column1": 1, "column2": "a"}, - {"column1": 2, "column2": "b"}, - {"column1": 3, "column2": "c"}, + { + "column1": 1, + "column2": "a", + "distilabel_metadata": stats[0], + }, + { + "column1": 2, + "column2": "b", + "distilabel_metadata": stats[1], + }, + { + "column1": 3, + "column2": "c", + "distilabel_metadata": stats[2], + }, ] diff --git a/tests/unit/steps/tasks/math_shepherd/__init__.py b/tests/unit/steps/tasks/math_shepherd/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/steps/tasks/math_shepherd/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/steps/tasks/math_shepherd/test_completer.py b/tests/unit/steps/tasks/math_shepherd/test_completer.py new file mode 100644 index 0000000000..c5e8092cd3 --- /dev/null +++ b/tests/unit/steps/tasks/math_shepherd/test_completer.py @@ -0,0 +1,476 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap +from typing import TYPE_CHECKING, Any, Dict, List + +import pytest + +from distilabel.steps.tasks.math_shepherd.completer import MathShepherdCompleter +from tests.unit.conftest import DummyLLM + +if TYPE_CHECKING: + from distilabel.models.llms.typing import GenerateOutput + + +class MathShepherdCompleterLLM(DummyLLM): + N: int = 3 + + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "math-shepherd-completer" + + def generate( # type: ignore + self, inputs: Dict[str, Any], num_generations: int = 1 + ) -> List["GenerateOutput"]: + if self.N == 1: + response = textwrap.dedent(""" + Step 1: Determine the total number of eggs Janet collects per day: Janet's ducks lay 16 eggs per day. + Step 2: Calculate the number of eggs Janet uses for herself per day: She eats three for breakfast and bakes muffins with four eggs, for a total of 3 + 4 = <<3+4=7>>7 eggs. + Step 3: Calculate the number of eggs Janet has left to sell per day: 16 - 7 = <<16-7=9>>9 eggs. + Step 4: Calculate the total amount Janet makes at the farmers' market per day: 9 * $2 = <<9*2=18>>18 dollars. + + The answer is: $18""") + else: + response = textwrap.dedent(""" + Step 2: Janet's ducks lay 16 eggs per day, and she uses 7 for eating and baking. So the number of eggs she has left is 16 - 7 = <<16-7=9>>9. + Step 3: Janet sells the remaining 9 eggs for $2 each, so she makes 9 * 2 = <<9*2=18>>18 dollars every day at the farmers' market. + The answer is: 18 + + --- + + Step 2: Janet's ducks lay 16 eggs per day, and she uses 3 for eating and bakes 4 for her friends, so she has 16 - 7 = <<16-7=9>>9 eggs left. + Step 3: Selling the 9 eggs at $2 each, she makes 9 * 2 = <<9*2=18>>18 dollars every day. + The answer is: 18 + + --- + + Step 2: Janets ducks lay 16 eggs per day. She eats 3 and bakes 4, so she has 16 - (3 + 4) = 16 - 7 = 9 eggs left. + Step 3: She sells the 9 eggs for $2 each, which means she makes 9 * $2 = $<<9*2=18>>18. + The answer is: 18""") + return [ + { + "generations": [response] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + for _ in range(len(inputs)) + ] + + +DUMMY_STEPS = [ + "Step 1: Determine the total number of eggs Janet collects per day: Janet's ducks lay 16 eggs per day.", + "Step 2: Calculate the number of eggs Janet uses for herself per day: She eats three for breakfast and bakes muffins with four eggs, for a total of 3 + 4 = <<3+4=7>>7 eggs.", + "Step 3: Calculate the number of eggs Janet has left to sell per day: 16 - 7 = <<16-7=9>>9 eggs.", + "Step 4: Calculate the total amount Janet makes at the farmers' market per day: 9 * $2 = <<9*2=18>>18 dollars.", + "The answer is: $18", +] + + +class TestMathShepherdCompleter: + @pytest.mark.parametrize( + "steps, num_completions", + [ + (DUMMY_STEPS, 3), + # This would be the same case as having the problem already solved in a single step, + # there's nothing else we have to do + (DUMMY_STEPS[-2:], 0), + # Check there aren't errors if no solutions were provided + ([DUMMY_STEPS[0]], 0), + ([], 0), + ], + ) + def test_prepare_completions(self, steps: List[str], num_completions: int) -> None: + task = MathShepherdCompleter(llm=MathShepherdCompleterLLM(N=1), N=1) + task.load() + instruction = "Krystian works in the library. He borrows an average of 40 books every day. Every Friday, his number of borrowed books is about 40% higher than the daily average. How many books does he borrow in a week if the library is open from Monday to Friday?" + prepared_inputs = task._prepare_completions(instruction, steps) + assert len(prepared_inputs) == num_completions + + def test_process(self) -> None: + task = MathShepherdCompleter( + llm=MathShepherdCompleterLLM(N=3), + N=3, + ) + task.load() + result = next( + task.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "golden_solution": [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + "solutions": [ + [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + [ + "Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking.", + "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day.", + "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18.", + "The answer is: 18", + ], + ], + }, + ] + ) + ) + + assert result == [ + { + "golden_solution": [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "model_name": "math-shepherd-completer", + "solutions": [ + [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. +", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market. The answer is: 18 +", + ], + [ + "Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking. +", + "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day. +", + "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18. The answer is: 18 +", + ], + ], + "distilabel_metadata": { + "statistics_math_shepherd_completer_0": { + "input_tokens": [12], + "output_tokens": [12], + }, + "raw_input_math_shepherd_completer_0": [ + { + "role": "system", + "content": 'You are a math teacher who helps students by breaking down word problems into clear, logical steps.\nWhen given a problem statement and any number of initial step, generate the remaining steps needed to reach the final answer.\nEach step should:\n\n- Build logically on previous steps\n- Explain the reasoning in natural language\n- Lead to the final answer\n- Multiple solution paths are acceptable\n- Steps should be concise but clear\n- Each calculation should be shown explicitly\n- The final answer must be clearly stated\n- The number of steps may vary based on the solution approach\n\n# Format requirements:\n- Each step should be numbered sequentially, continuing from the last given step\n- The final step should clearly state "The answer is: [result]"\n- Each step can use different approaches but must be mathematically valid\n\n# Rules:\n- All calculations must be shown within <<>> brackets\n- Basic operations: use * for multiplication, / for division, + for addition, - for subtraction\n- Write the full calculation and result, e.g., <<5*10=50>>50\n\n# Examples:\n## Input\nKrystian works in the library. He borrows an average of 40 books every day. Every Friday, his number of borrowed books is about 40% higher than the daily average. How many books does he borrow in a week if the library is open from Monday to Friday?\nStep 1: On Friday, Krystian borrows 40 * 0.4 = <<40*0.4=16>>16 more books than on a regular day.\n\n## Output 1\nStep 2: On Friday, Krystian borrows 40 + 16 = <<40+16=56>>56 books in total.\nStep 3: For the other 4 days (Monday to Thursday), he borrows 40 * 4 = <<40*4=160>>160 books.\nStep 4: The total books for the week is 160 + 56 = <<160+56=216>>216. The answer is: 216\n\n## Output 2\nStep 2: In total, he borrows 40 + 16 = <<40+16=56>>56 books on Friday.\nStep 3: For the whole week (4 regular days plus Friday), the total is (40 * 4) + 56 = <<(40*4)+56=216>>216. The answer is: 216\n\n## Output 3\nStep 2: On Friday, he borrows 40 + 40/100 * 40 = <<40+40/100*40=56>>56 books.\nStep 3: In a week, he borrows 5.7 * 7 = <<5.7*7=40>>40 books. The answer is: 40', + }, + { + "role": "user", + "content": "Generate 3 example solutions to the same problem, separated by a single `---` and nothing else.\nResponse format:\n```\nStep i: step i explanation.\nStep i+1: step i+1 explanation.\nThe answer is: X\n\n---\n\nStep i: step i explanation.\nStep i+1: step i+1 explanation.\nThe answer is: Y\n```\n\nThis is the problem:\nJanet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + }, + ], + "raw_output_math_shepherd_completer_0": """\nStep 2: Janet's ducks lay 16 eggs per day, and she uses 7 for eating and baking. So the number of eggs she has left is 16 - 7 = <<16-7=9>>9. +Step 3: Janet sells the remaining 9 eggs for $2 each, so she makes 9 * 2 = <<9*2=18>>18 dollars every day at the farmers' market. +The answer is: 18 + +--- + +Step 2: Janet's ducks lay 16 eggs per day, and she uses 3 for eating and bakes 4 for her friends, so she has 16 - 7 = <<16-7=9>>9 eggs left. +Step 3: Selling the 9 eggs at $2 each, she makes 9 * 2 = <<9*2=18>>18 dollars every day. +The answer is: 18 + +--- + +Step 2: Janets ducks lay 16 eggs per day. She eats 3 and bakes 4, so she has 16 - (3 + 4) = 16 - 7 = 9 eggs left. +Step 3: She sells the 9 eggs for $2 each, which means she makes 9 * $2 = $<<9*2=18>>18. +The answer is: 18""", + }, + } + ] + + def test_auto_label(self): + inputs = [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "golden_solution": [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + "solutions": [ + [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + [ + "Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking.", + "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day.", + "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18.", + "The answer is: 18", + ], + ], + }, + ] + N = 3 + task = MathShepherdCompleter( + llm=MathShepherdCompleterLLM(N=N), + N=N, + add_raw_input=False, + add_raw_output=False, + ) + task.load() + final_outputs = [ + [ + [ + "Step 2: Janet sells 9 duck eggs at the farmers' market, so she makes 9 * $1 = $<<9*1=9>>9 from selling the eggs.", + "The answer is: $9", + ], + [ + "Step 1: Janet lays 16 eggs per day, eats 3 for breakfast, uses 4 for baking, so she has 16 - 3 - 4 = 9 eggs left.", + "Step 2: Since Janet sells 9 eggs a day, and each egg is sold for $1, she makes 9 * $1 = $<<9*1=9>>9.", + "The answer is: $9", + ], + [ + "Step 1: Janet lays 16 eggs per day, eats 3, uses 4 for baking which leaves her with 16 - 3 - 4 = 9 eggs.", + "Step 2: Since she sells the eggs for $1 each, she makes 9 * $1 = $<<9*1=9>>9.", + "The answer is: $9", + ], + ], + [ + [ + "Step 3: To determine how many eggs Jan's sells at the market, we need to subtract the eggs she uses (7) from the total number of eggs laid (16), which is 16 - 7 = <<16-7=9>>9.", + "Step 4: Since she sells 9 eggs for $2 each, we multiply 9 * 2 = <<9*2=18>>18 to find out her daily earnings.", + "The answer is: 18", + ], + [ + "Step 2: Jan's ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking.", + "Step 3: To find the number of eggs Jan's sells at the market, we subtract the eggs she uses (7) from the total number of eggs laid (16), which is 16 - 7 = <<16-7=9>>9.", + "Step 4: Since she sells 9 eggs for $2 each, we multiply 9 * 2 = <<9*2=18>>18 to find out her daily earnings.", + "The answer is: 18", + ], + [ + "Step 2: Jan's ducks lay 16 eggs per day, and she uses 7 for eating and baking.", + "Step 3: To find the number of eggs Jan's sells at the market, we calculate 16 - 7 = <<16-7=9>>9.", + "Step 4: Since she sells 9 eggs for $2 each, we multiply 9 * 2 = <<9*2=18>>18 to find out her daily earnings.", + "The answer is: 18", + ], + ], + [ + [ + "Step 1: Janet's ducks lay 16 eggs per day. She eats 3 eggs and bakes 4 eggs.", + "Step 2: So, she uses 3 + 4 = <<3+4=7>>7 eggs for eating and baking.", + "Step 3: She sells the remaining eggs, which is 16 - 7 = <<16-7=9>>9 duck eggs every day.", + "Step 4: She sells each egg for $2, so the total amount she makes is 9 * 2 = <<9*2=18>>18 dollars every day.", + "The answer is: 18", + ], + [ + "Step 1: Janet's ducks lay 16 eggs per day.", + "Step 2: She eats 3 eggs and bakes 4 eggs, which is a total of 3 + 4 = <<3+4=7>>7 eggs.", + "Step 3: She sells the remaining eggs, which is 16 - 7 = <<16-7=9>>9 duck eggs every day.", + "Step 4: Since she sells each egg for $2, she makes 9 * 2 = <<9*2=18>>18 dollars every day.", + "The answer is: 18", + ], + [ + "Step 1: Janet's ducks lay 16 eggs per day.", + "Step 2: She consumes 7 eggs for eating and baking, which means she has 16 - 7 = <<16-7=9>>9 eggs left.", + "Step 3: She sells each egg for $2, so she makes 9 * 2 = <<9*2=18>>18 dollars every day.", + "The answer is: 18", + ], + ], + ] + + golden_answers = ["The answer is: 18", "The answer is: 18"] + input_positions = [(0, 0, 0), (0, 1, 0), (0, 1, 1)] + statistics = [ + {"input_tokens": [12], "output_tokens": [12]}, + {"input_tokens": [12], "output_tokens": [12]}, + {"input_tokens": [12], "output_tokens": [12]}, + ] + raw_outputs = [ + [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. -", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market. The answer is: 18 +", + ], + [ + "Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking. +", + "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day. +", + "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18. The answer is: 18 +", + ], + ] + raw_inputs = [ + { + "role": "system", + "content": 'You are a math teacher who helps students by breaking down word problems into clear, logical steps.\nWhen given a problem statement and any number of initial step, generate the remaining steps needed to reach the final answer.\nEach step should:\n\n- Build logically on previous steps\n- Explain the reasoning in natural language\n- Lead to the final answer\n- Multiple solution paths are acceptable\n- Steps should be concise but clear\n- Each calculation should be shown explicitly\n- The final answer must be clearly stated\n- The number of steps may vary based on the solution approach\n\n# Format requirements:\n- Each step should be numbered sequentially, continuing from the last given step\n- The final step should clearly state "The answer is: [result]"\n- Each step can use different approaches but must be mathematically valid\n\n# Rules:\n- All calculations must be shown within <<>> brackets\n- Basic operations: use * for multiplication, / for division, + for addition, - for subtraction\n- Write the full calculation and result, e.g., <<5*10=50>>50\n\n# Examples:\n## Input\nKrystian works in the library. He borrows an average of 40 books every day. Every Friday, his number of borrowed books is about 40% higher than the daily average. How many books does he borrow in a week if the library is open from Monday to Friday?\nStep 1: On Friday, Krystian borrows 40 * 0.4 = <<40*0.4=16>>16 more books than on a regular day.\n\n## Output 1\nStep 2: On Friday, Krystian borrows 40 + 16 = <<40+16=56>>56 books in total.\nStep 3: For the other 4 days (Monday to Thursday), he borrows 40 * 4 = <<40*4=160>>160 books.\nStep 4: The total books for the week is 160 + 56 = <<160+56=216>>216. The answer is: 216\n\n## Output 2\nStep 2: In total, he borrows 40 + 16 = <<40+16=56>>56 books on Friday.\nStep 3: For the whole week (4 regular days plus Friday), the total is (40 * 4) + 56 = <<(40*4)+56=216>>216. The answer is: 216\n\n## Output 3\nStep 2: On Friday, he borrows 40 + 40/100 * 40 = <<40+40/100*40=56>>56 books.\nStep 3: In a week, he borrows 5.7 * 7 = <<5.7*7=40>>40 books. The answer is: 40', + }, + { + "role": "user", + "content": "Generate 3 example solutions to the same problem, separated by a single `---` and nothing else.\nResponse format:\n```\nStep i: step i explanation.\nStep i+1: step i+1 explanation.\nThe answer is: X\n\n---\n\nStep 2: step i explanation.\nStep 3: step i+1 explanation.\nThe answer is: Y\n```\n\nThis is the problem:\nJanet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + }, + ] + results = task._auto_label( + inputs, + final_outputs, + input_positions, + golden_answers, + statistics, + raw_outputs, + raw_inputs, + ) + assert results == [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "golden_solution": [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + "solutions": [ + [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day. -", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market. The answer is: 18 +", + ], + [ + "Step 1: Janets ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking. +", + "Step 2: So she sells 16 - 7 = <<16-7=9>>9 duck eggs every day. +", + "Step 3: Those 9 eggs are worth 9 * $2 = $<<9*2=18>>18. The answer is: 18 +", + ], + ], + "model_name": "math-shepherd-completer", + "distilabel_metadata": { + "statistics_math_shepherd_completer_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, + } + ] + + def test_auto_label_with_errors(self): + inputs = [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "golden_solution": [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + "solutions": [[], [], []], + }, + ] + N = 3 + task = MathShepherdCompleter( + llm=MathShepherdCompleterLLM(N=N), + N=N, + add_raw_input=False, + add_raw_output=False, + ) + task.load() + final_outputs = [ + [ + [ + "Step 2: Janet sells 9 duck eggs at the farmers' market, so she makes 9 * $1 = $<<9*1=9>>9 from selling the eggs.", + "The answer is: $9", + ], + [ + "Step 1: Janet lays 16 eggs per day, eats 3 for breakfast, uses 4 for baking, so she has 16 - 3 - 4 = 9 eggs left.", + "Step 2: Since Janet sells 9 eggs a day, and each egg is sold for $1, she makes 9 * $1 = $<<9*1=9>>9.", + "The answer is: $9", + ], + [ + "Step 1: Janet lays 16 eggs per day, eats 3, uses 4 for baking which leaves her with 16 - 3 - 4 = 9 eggs.", + "Step 2: Since she sells the eggs for $1 each, she makes 9 * $1 = $<<9*1=9>>9.", + "The answer is: $9", + ], + ], + [ + [ + "Step 3: To determine how many eggs Jan's sells at the market, we need to subtract the eggs she uses (7) from the total number of eggs laid (16), which is 16 - 7 = <<16-7=9>>9.", + "Step 4: Since she sells 9 eggs for $2 each, we multiply 9 * 2 = <<9*2=18>>18 to find out her daily earnings.", + "The answer is: 18", + ], + [ + "Step 2: Jan's ducks lay 16 eggs per day, and she uses 3 + 4 = <<3+4=7>>7 for eating and baking.", + "Step 3: To find the number of eggs Jan's sells at the market, we subtract the eggs she uses (7) from the total number of eggs laid (16), which is 16 - 7 = <<16-7=9>>9.", + "Step 4: Since she sells 9 eggs for $2 each, we multiply 9 * 2 = <<9*2=18>>18 to find out her daily earnings.", + "The answer is: 18", + ], + [ + "Step 2: Jan's ducks lay 16 eggs per day, and she uses 7 for eating and baking.", + "Step 3: To find the number of eggs Jan's sells at the market, we calculate 16 - 7 = <<16-7=9>>9.", + "Step 4: Since she sells 9 eggs for $2 each, we multiply 9 * 2 = <<9*2=18>>18 to find out her daily earnings.", + "The answer is: 18", + ], + ], + [ + [ + "Step 1: Janet's ducks lay 16 eggs per day. She eats 3 eggs and bakes 4 eggs.", + "Step 2: So, she uses 3 + 4 = <<3+4=7>>7 eggs for eating and baking.", + "Step 3: She sells the remaining eggs, which is 16 - 7 = <<16-7=9>>9 duck eggs every day.", + "Step 4: She sells each egg for $2, so the total amount she makes is 9 * 2 = <<9*2=18>>18 dollars every day.", + "The answer is: 18", + ], + [ + "Step 1: Janet's ducks lay 16 eggs per day.", + "Step 2: She eats 3 eggs and bakes 4 eggs, which is a total of 3 + 4 = <<3+4=7>>7 eggs.", + "Step 3: She sells the remaining eggs, which is 16 - 7 = <<16-7=9>>9 duck eggs every day.", + "Step 4: Since she sells each egg for $2, she makes 9 * 2 = <<9*2=18>>18 dollars every day.", + "The answer is: 18", + ], + [ + "Step 1: Janet's ducks lay 16 eggs per day.", + "Step 2: She consumes 7 eggs for eating and baking, which means she has 16 - 7 = <<16-7=9>>9 eggs left.", + "Step 3: She sells each egg for $2, so she makes 9 * 2 = <<9*2=18>>18 dollars every day.", + "The answer is: 18", + ], + ], + ] + + golden_answers = ["The answer is: 18", "The answer is: 18"] + input_positions = [(0, 0, 0), (0, 1, 0), (0, 1, 1)] + statistics = [ + {"input_tokens": [12], "output_tokens": [12]}, + {"input_tokens": [12], "output_tokens": [12]}, + {"input_tokens": [12], "output_tokens": [12]}, + ] + raw_outputs = [[], []] + raw_inputs = [ + { + "role": "system", + "content": 'You are a math teacher who helps students by breaking down word problems into clear, logical steps.\nWhen given a problem statement and any number of initial step, generate the remaining steps needed to reach the final answer.\nEach step should:\n\n- Build logically on previous steps\n- Explain the reasoning in natural language\n- Lead to the final answer\n- Multiple solution paths are acceptable\n- Steps should be concise but clear\n- Each calculation should be shown explicitly\n- The final answer must be clearly stated\n- The number of steps may vary based on the solution approach\n\n# Format requirements:\n- Each step should be numbered sequentially, continuing from the last given step\n- The final step should clearly state "The answer is: [result]"\n- Each step can use different approaches but must be mathematically valid\n\n# Rules:\n- All calculations must be shown within <<>> brackets\n- Basic operations: use * for multiplication, / for division, + for addition, - for subtraction\n- Write the full calculation and result, e.g., <<5*10=50>>50\n\n# Examples:\n## Input\nKrystian works in the library. He borrows an average of 40 books every day. Every Friday, his number of borrowed books is about 40% higher than the daily average. How many books does he borrow in a week if the library is open from Monday to Friday?\nStep 1: On Friday, Krystian borrows 40 * 0.4 = <<40*0.4=16>>16 more books than on a regular day.\n\n## Output 1\nStep 2: On Friday, Krystian borrows 40 + 16 = <<40+16=56>>56 books in total.\nStep 3: For the other 4 days (Monday to Thursday), he borrows 40 * 4 = <<40*4=160>>160 books.\nStep 4: The total books for the week is 160 + 56 = <<160+56=216>>216. The answer is: 216\n\n## Output 2\nStep 2: In total, he borrows 40 + 16 = <<40+16=56>>56 books on Friday.\nStep 3: For the whole week (4 regular days plus Friday), the total is (40 * 4) + 56 = <<(40*4)+56=216>>216. The answer is: 216\n\n## Output 3\nStep 2: On Friday, he borrows 40 + 40/100 * 40 = <<40+40/100*40=56>>56 books.\nStep 3: In a week, he borrows 5.7 * 7 = <<5.7*7=40>>40 books. The answer is: 40', + }, + { + "role": "user", + "content": "Generate 3 example solutions to the same problem, separated by a single `---` and nothing else.\nResponse format:\n```\nStep i: step i explanation.\nStep i+1: step i+1 explanation.\nThe answer is: X\n\n---\n\nStep 2: step i explanation.\nStep 3: step i+1 explanation.\nThe answer is: Y\n```\n\nThis is the problem:\nJanet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + }, + ] + results = task._auto_label( + inputs, + final_outputs, + input_positions, + golden_answers, + statistics, + raw_outputs, + raw_inputs, + ) + assert results == [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "golden_solution": [ + "Step 1: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.", + "Step 2: She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.", + "The answer is: 18", + ], + "solutions": [[], [], []], + "model_name": "math-shepherd-completer", + "distilabel_metadata": { + "statistics_math_shepherd_completer_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, + } + ] diff --git a/tests/unit/steps/tasks/math_shepherd/test_generator.py b/tests/unit/steps/tasks/math_shepherd/test_generator.py new file mode 100644 index 0000000000..14ccc87533 --- /dev/null +++ b/tests/unit/steps/tasks/math_shepherd/test_generator.py @@ -0,0 +1,252 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import pytest + +from distilabel.steps.tasks.math_shepherd.generator import ( + FEW_SHOTS_GSM8K, + RULES_GSM8K, + SYSTEM_PROMPT, + MathShepherdGenerator, +) +from tests.unit.conftest import DummyLLM + +if TYPE_CHECKING: + from distilabel.models.llms.typing import GenerateOutput + + +class MathShepherdGeneratorLLM(DummyLLM): + M: Optional[int] = None + structured: bool = False + with_error: bool = False + + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "math-shepherd-generator" + + def generate( + self, inputs: Dict[str, Any], num_generations: int = 1 + ) -> List["GenerateOutput"]: + if self.structured: + if self.M: + solutions = [ + { + "solution": "Step 1: Calculate the number of eggs Janet lays per day: 16 eggs/day. She eats 3 for breakfast and bakes 4 for muffins, so she uses a total of 3 + 4 = 7 eggs per day. Calculate the number of eggs she has left: 16 - 7 = <<16-7=9>>9 eggs/day. Step 2: Calculate the amount of money Janet makes from selling eggs at the farmers' market: 9 eggs/day * $2/egg = <<9*2=18>>$18/day. Step 3: Calculate the total number of books borrowed from Monday to Thursday: 40 * 4 = <<40*4=160>>160 books. Step 4: Calculate the total number of books borrowed in the entire week: 160 + 56 = <<160+56=216>>216 books The answer is: 18" + } + for _ in range(self.M) + ] + else: + solutions = [ + { + "solution": "Step 1: Calculate the number of eggs Janet lays per day: 16 eggs/day. She eats 3 for breakfast and bakes 4 for muffins, so she uses a total of 3 + 4 = 7 eggs per day. Calculate the number of eggs she has left: 16 - 7 = <<16-7=9>>9 eggs/day. Step 2: Calculate the amount of money Janet makes from selling eggs at the farmers' market: 9 eggs/day * $2/egg = <<9*2=18>>$18/day. Step 3: Calculate the total number of books borrowed from Monday to Thursday: 40 * 4 = <<40*4=160>>160 books. Step 4: Calculate the total number of books borrowed in the entire week: 160 + 56 = <<160+56=216>>216 books The answer is: 18" + } + ] + response = json.dumps({"solutions": solutions}) + else: + response = """ +Step 1: Calculate the number of books borrowed on a regular day (Monday to Thursday): +40 books per day + +Step 2: Calculate the number of books borrowed on Friday, which is 40% higher than the daily average: +40 * 40/100 = <<40*40/100=16>>16 books +40 + 16 = <<40+16=56>>56 books + +Step 3: Calculate the total number of books borrowed from Monday to Thursday: +40 * 4 = <<40*4=160>>160 books + +Step 4: Calculate the total number of books borrowed in the entire week: +160 + 56 = <<160+56=216>>216 books + +The answer is: 216 books.""" + if self.M: + response = "---".join([response for _ in range(self.M)]) + return [ + { + "generations": [response] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + for _ in range(len(inputs)) + ] + + +class TestMathShepherdGenerator: + @pytest.mark.parametrize( + "system_prompt", + [ + None, # Use the default + SYSTEM_PROMPT, + ], + ) + @pytest.mark.parametrize( + "extra_rules, few_shots", + [ + (None, None), # Use the default + (RULES_GSM8K, FEW_SHOTS_GSM8K), + ], + ) + @pytest.mark.parametrize( + "M", + [None, 5], + ) + def test_format_input( + self, + system_prompt: Optional[str], + extra_rules: Optional[str], + few_shots: Optional[str], + M: Optional[int], + ) -> None: + task = MathShepherdGenerator( + llm=MathShepherdGeneratorLLM(), + system_prompt=system_prompt, + extra_rules=extra_rules, + few_shots=few_shots, + M=M, + ) + task.load() + + result = task.format_input( + input={ + "instruction": "Krystian works in the library. He borrows an average of 40 books every day. Every Friday, his number of borrowed books is about 40% higher than the daily average. How many books does he borrow in a week if the library is open from Monday to Friday?" + } + ) + rendered_system_prompt = "" + if system_prompt: + rendered_system_prompt = result[0]["content"] + if extra_rules: + assert RULES_GSM8K in rendered_system_prompt + if few_shots: + assert FEW_SHOTS_GSM8K in rendered_system_prompt + if M: + assert ( + "Generate 5 example solutions to the following problem," + in result[1]["content"] + ) + else: + if M: + assert ( + "Generate 5 example solutions to the following problem," + in result[0]["content"] + ) + + @pytest.mark.parametrize( + "M, output_name", + [ + (None, "golden_solution"), + (3, "solutions"), + ], + ) + def test_process(self, M: Optional[int], output_name: str) -> None: + task = MathShepherdGenerator(llm=MathShepherdGeneratorLLM(M=M), M=M) + task.load() + + result = next(task.process([{"instruction": ""}]))[0][output_name] + if M: + assert len(result) == 3 + assert all(len(r) == 5 for r in result) + else: + assert len(result) == 5 + + @pytest.mark.parametrize( + "M, output_name", + [ + # (None, "golden_solution"), # golden solution doesn't need structured output + (1, "solutions"), + (3, "solutions"), + ], + ) + def test_process_structured_output( + self, M: Optional[int], output_name: str + ) -> None: + task = MathShepherdGenerator( + llm=MathShepherdGeneratorLLM(M=M, structured=True), + use_default_structured_output=True, + M=M, + ) + task.load() + + result = next(task.process([{"instruction": ""}]))[0][output_name] + if M: + assert len(result) == M + assert all(len(r) == 5 for r in result) + else: + assert len(result) == 5 + + @pytest.mark.parametrize( + "M, output_name, expected", + [ + # (None, "golden_solution"), + ( + 2, + "solutions", + { + "instruction": "", + "solutions": [ + [ + "Step 1: Calculate the number of eggs Janet lays per day: 16 eggs/day. She eats 3 for breakfast and bakes 4 for muffins, so she uses a total of 3 + 4 = 7 eggs per day. Calculate the number of eggs she has left: 16 - 7 = <<16-7=9>>9 eggs/day.", + "Step 2: Calculate the amount of money Janet makes from selling eggs at the farmers' market: 9 eggs/day * $2/egg = <<9*2=18>>$18/day.", + "Step 3: Calculate the total number of books borrowed from Monday to Thursday: 40 * 4 = <<40*4=160>>160 books.", + "Step 4: Calculate the total number of books borrowed in the entire week: 160 + 56 = <<160+56=216>>216 books", + "The answer is: 18", + ], + [ + "Step 1: Calculate the number of eggs Janet lays per day: 16 eggs/day. She eats 3 for breakfast and bakes 4 for muffins, so she uses a total of 3 + 4 = 7 eggs per day. Calculate the number of eggs she has left: 16 - 7 = <<16-7=9>>9 eggs/day.", + "Step 2: Calculate the amount of money Janet makes from selling eggs at the farmers' market: 9 eggs/day * $2/egg = <<9*2=18>>$18/day.", + "Step 3: Calculate the total number of books borrowed from Monday to Thursday: 40 * 4 = <<40*4=160>>160 books.", + "Step 4: Calculate the total number of books borrowed in the entire week: 160 + 56 = <<160+56=216>>216 books", + "The answer is: 18", + ], + ], + }, + ), + ( + 3, + "solutions", + { + "instruction": "", + "solutions": [[], [], []], + }, + ), + ], + ) + def test_format_output(self, M: int, output_name: str, expected) -> None: + task = MathShepherdGenerator( + llm=MathShepherdGeneratorLLM(M=M), + M=M, + use_default_structured_output=True, + ) + task.load() + + # This is just to force an error in the generation so that the default is + # returned and we can test it easily + if M == 3: + M += 1 + + solutions = [ + { + "solution": "Step 1: Calculate the number of eggs Janet lays per day: 16 eggs/day. She eats 3 for breakfast and bakes 4 for muffins, so she uses a total of 3 + 4 = 7 eggs per day. Calculate the number of eggs she has left: 16 - 7 = <<16-7=9>>9 eggs/day. Step 2: Calculate the amount of money Janet makes from selling eggs at the farmers' market: 9 eggs/day * $2/egg = <<9*2=18>>$18/day. Step 3: Calculate the total number of books borrowed from Monday to Thursday: 40 * 4 = <<40*4=160>>160 books. Step 4: Calculate the total number of books borrowed in the entire week: 160 + 56 = <<160+56=216>>216 books The answer is: 18" + } + for _ in range(M) + ] + solutions = json.dumps({"solutions": solutions}) + result = task.format_output(solutions, {"instruction": ""}) + assert result == expected diff --git a/tests/unit/steps/tasks/math_shepherd/test_utils.py b/tests/unit/steps/tasks/math_shepherd/test_utils.py new file mode 100644 index 0000000000..8dff5da8ec --- /dev/null +++ b/tests/unit/steps/tasks/math_shepherd/test_utils.py @@ -0,0 +1,231 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.steps import ExpandColumns +from distilabel.steps.tasks.math_shepherd.utils import FormatPRM, parse_json_response + + +class TestFormatPRM: + def test_process(self) -> None: + # As the expected data format is a JSON encoded string of a list, will use + # ExpandColumns initially as that would be a prior step in the pipeline + expand_columns = ExpandColumns( + columns=["solutions"], + split_statistics=True, + ) + expand_columns.load() + result = next( + expand_columns.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + ], + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [793], + "output_tokens": [361], + } + }, + } + ] + ) + ) + + formatter = FormatPRM() + formatter.load() + result = next(formatter.process(result)) + + assert result == [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + "input": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. ки\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. ки\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 ки", + "label": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [264], + "output_tokens": [120], + } + }, + }, + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + "input": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. ки\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. ки\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 ки", + "label": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [264], + "output_tokens": [120], + } + }, + }, + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + "input": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. ки\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. ки\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 ки", + "label": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [264], + "output_tokens": [120], + } + }, + }, + ] + + def test_process_trl(self) -> None: + expand_columns = ExpandColumns( + columns=["solutions"], + split_statistics=True, + ) + expand_columns.load() + result = next( + expand_columns.process( + [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + ], + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [793], + "output_tokens": [361], + } + }, + } + ] + ) + ) + + formatter = FormatPRM(format="trl") + formatter.load() + result = next(formatter.process(result)) + + assert result == [ + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + "prompt": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "completions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required.", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber.", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3", + ], + "labels": [True, True, True], + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [264], + "output_tokens": [120], + } + }, + }, + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + "prompt": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "completions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required.", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber.", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3", + ], + "labels": [True, True, True], + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [264], + "output_tokens": [120], + } + }, + }, + { + "instruction": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "solutions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber. +", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +", + ], + "prompt": "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", + "completions": [ + "Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required.", + "Step 2: Calculate the amount of white fiber needed: Since it's half that much, we can multiply 2 by 0.5 (which is the same as dividing by 2): 2 * 0.5 = <<2*0.5=1>>1 bolt of white fiber.", + "Step 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3", + ], + "labels": [True, True, True], + "distilabel_metadata": { + "statistics_completer_math_shepherd_completer_0": { + "input_tokens": [264], + "output_tokens": [120], + } + }, + }, + ] + + +def test_parse_json_response(): + json_str = """{"solutions": [{"solution": "It\\'s 2 * 0.5 = 1 \\) bolt of white fiber"}]}""" + response = parse_json_response(json_str) + assert response == { + "solutions": [{"solution": "It's 2 * 0.5 = 1 ) bolt of white fiber"}] + } From 55d9e5d0927bfef2fb0b3a5b53689d9a0a374dc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 4 Dec 2024 15:27:35 +0100 Subject: [PATCH 05/30] Add `load_groups` argument to `run` (#1075) Co-authored-by: Agus --- .github/workflows/test.yml | 6 + .../load_groups_and_execution_stages.md | 122 +++++++++ mkdocs.yml | 1 + src/distilabel/pipeline/_dag.py | 73 +++-- src/distilabel/pipeline/base.py | 254 ++++++++++++++++-- src/distilabel/pipeline/local.py | 13 +- src/distilabel/pipeline/ray.py | 12 +- src/distilabel/pipeline/step_wrapper.py | 2 +- src/distilabel/pipeline/typing.py | 20 +- tests/integration/test_load_groups.py | 105 ++++++++ tests/unit/pipeline/test_base.py | 135 +++++++++- tests/unit/pipeline/test_dag.py | 48 ++++ tests/unit/pipeline/test_local.py | 1 + 13 files changed, 731 insertions(+), 61 deletions(-) create mode 100644 docs/sections/how_to_guides/advanced/load_groups_and_execution_stages.md create mode 100644 tests/integration/test_load_groups.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 31fe54301b..ed7e1df8be 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,6 +50,12 @@ jobs: if: steps.cache.outputs.cache-hit != 'true' run: ./scripts/install_dependencies.sh + - name: Setup tmate session + uses: mxschmitt/action-tmate@v3 + if: ${{ matrix.python-version == '3.12' && github.event_name == 'workflow_dispatch' && inputs.tmate_session }} + with: + limit-access-to-actor: true + - name: Lint run: make lint diff --git a/docs/sections/how_to_guides/advanced/load_groups_and_execution_stages.md b/docs/sections/how_to_guides/advanced/load_groups_and_execution_stages.md new file mode 100644 index 0000000000..4b13451952 --- /dev/null +++ b/docs/sections/how_to_guides/advanced/load_groups_and_execution_stages.md @@ -0,0 +1,122 @@ +# Load groups and execution stages + +By default, the `distilabel` architecture loads all steps of a pipeline at the same time, as they are all supposed to process batches of data in parallel. However, loading all steps at once can waste resources in two scenarios: when using `GlobalStep`s that must wait for upstream steps to complete before processing data, or when running on machines with limited resources that cannot execute all steps simultaneously. In these cases, steps need to be loaded and executed in distinct **load stages**. + +## Load stages + +A load stage represents a point in the pipeline execution where a group of steps are loaded at the same time to process batches in parallel. These stages are required because: + +1. There are some kind of steps like the `GlobalStep`s that needs to receive all the data at once from their upstream steps i.e. needs their upstream steps to have finished its execution. It would be wasteful to load a `GlobalStep` at the same time as other steps of the pipeline as that would take resources (from the machine or cluster running the pipeline) that wouldn't be used until upstream steps have finished. +2. When running on machines or clusters with limited resources, it may be not possible to load and execute all steps simultaneously as they would need to access the same limited resources (memory, CPU, GPU, etc.). + +Having that said, the first element that will create a load stage when executing a pipeline are the [`GlobalStep`][distilabel.steps.base.GlobalStep], as they mark and divide a pipeline in three stages: one stage with the upstream steps of the global step, one stage with the global step, and one final stage with the downstream steps of the global step. For example, the following pipeline will contain three stages: + +```python +from typing import TYPE_CHECKING + +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromDicts, StepInput, step + +if TYPE_CHECKING: + from distilabel.typing import StepOutput + + +@step(inputs=["instruction"], outputs=["instruction2"]) +def DummyStep(inputs: StepInput) -> "StepOutput": + for input in inputs: + input["instruction2"] = "miau" + yield inputs + + +@step(inputs=["instruction"], outputs=["instruction2"], step_type="global") +def GlobalDummyStep(inputs: StepInput) -> "StepOutput": + for input in inputs: + input["instruction2"] = "miau" + yield inputs + + +with Pipeline() as pipeline: + generator = LoadDataFromDicts(data=[{"instruction": "Hi"}] * 50) + dummy_step_0 = DummyStep() + global_dummy_step = GlobalDummyStep() + dummy_step_1 = DummyStep() + + generator >> dummy_step_0 >> global_dummy_step >> dummy_step_1 + +if __name__ == "__main__": + load_stages = pipeline.get_load_stages() + + for i, steps_stage in enumerate(load_stages[0]): + print(f"Stage {i}: {steps_stage}") + + # Output: + # Stage 0: ['load_data_from_dicts_0', 'dummy_step_0'] + # Stage 1: ['global_dummy_step_0'] + # Stage 2: ['dummy_step_1'] +``` + +As we can see, the `GlobalStep` divided the pipeline execution in three stages. + +## Load groups + +While `GlobalStep`s automatically divide pipeline execution into stages, we many need fine-grained control over how steps are loaded and executed within each stage. This is where **load groups** come in. + +Load groups allows to specify which steps of the pipeline have to be loaded together within a stage. This is particularly useful when running on resource-constrained environments where all the steps cannot be executed in parallel. + +Let's see how it works with an example: + +```python +from datasets import load_dataset + +from distilabel.llms import vLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import StepResources +from distilabel.steps.tasks import TextGeneration + +dataset = load_dataset( + "distilabel-internal-testing/instruction-dataset-mini", split="test" +).rename_column("prompt", "instruction") + +with Pipeline() as pipeline: + text_generation_0 = TextGeneration( + llm=vLLM( + model="HuggingFaceTB/SmolLM2-1.7B-Instruct", + extra_kwargs={"max_model_len": 1024}, + ), + resources=StepResources(gpus=1), + ) + + text_generation_1 = TextGeneration( + llm=vLLM( + model="HuggingFaceTB/SmolLM2-1.7B-Instruct", + extra_kwargs={"max_model_len": 1024}, + ), + resources=StepResources(gpus=1), + ) + +if __name__ == "__main__": + load_stages = pipeline.get_load_stages(load_groups=[[text_generation_1.name]]) + + for i, steps_stage in enumerate(load_stages[0]): + print(f"Stage {i}: {steps_stage}") + + # Output: + # Stage 0: ['text_generation_0'] + # Stage 1: ['text_generation_1'] + + distiset = pipeline.run(dataset=dataset, load_groups=[[text_generation_0.name]]) +``` + +In this example, we're working with a machine that has a single GPU, but the pipeline includes two instances of [TextGeneration]() tasks both using [vLLM]() and requesting 1 GPU. We cannot execute both steps in parallel. To fix that, +we specify in the `run` method using the `load_groups` argument that the `text_generation_0` step has to be executed in isolation in a stage. This way, we can run the pipeline on a single GPU machine by executing the steps in different stages (sequentially) instead of in parallel. + +Some key points about load groups: + +1. Load groups are specified as a list of lists, where each inner list represents a group of steps that should be loaded together. +2. Same as `GlobalSteps`s, the load groups creates a new load stage dividing the pipeline in 3 stages: one for the upstream steps, one for the steps in the load group, and one for the downstream steps. + +### Load groups modes + +In addition, `distilabel` allows passing some modes to the `load_groups` argument that will handle the creation of the load groups: + +- `"sequential_step_execution"`: when passed, it will create a load group for each step i.e. the execution of the steps of the pipeline will be sequential. diff --git a/mkdocs.yml b/mkdocs.yml index 15c7e73a16..9654d03530 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -191,6 +191,7 @@ nav: - Structured data generation: "sections/how_to_guides/advanced/structured_generation.md" - Offline Batch Generation: "sections/how_to_guides/advanced/offline_batch_generation.md" - Specifying requirements for pipelines and steps: "sections/how_to_guides/advanced/pipeline_requirements.md" + - Load groups and execution stages: "sections/how_to_guides/advanced/load_groups_and_execution_stages.md" - Using CLI to explore and re-run existing Pipelines: "sections/how_to_guides/advanced/cli/index.md" - Using a file system to pass data of batches between steps: "sections/how_to_guides/advanced/fs_to_pass_data.md" - Assigning resources to a step: "sections/how_to_guides/advanced/assigning_resources_to_step.md" diff --git a/src/distilabel/pipeline/_dag.py b/src/distilabel/pipeline/_dag.py index 5962ecc4f0..e8f930413f 100644 --- a/src/distilabel/pipeline/_dag.py +++ b/src/distilabel/pipeline/_dag.py @@ -22,6 +22,7 @@ Generator, Iterable, List, + Optional, Set, Tuple, Type, @@ -290,13 +291,24 @@ def get_total_replica_count(self) -> int: """ return sum([self.get_step_replica_count(step_name) for step_name in self.G]) - def get_steps_load_stages(self) -> Tuple[List[List[str]], List[List[str]]]: + def get_steps_load_stages( # noqa: C901 + self, load_groups: Optional[List[List[str]]] = None + ) -> Tuple[List[List[str]], List[List[str]]]: """Gets the stages in which the `Step`s of the `Pipeline` should be loaded. Stages - are determined by `GlobalStep`s as they receive all the data at once, which means + are determined by: + + - `GlobalStep`s as they receive all the data at once and the provided which means that a `GlobalStep` is not required to be loaded until all their previous steps have finished their execution, and the successors of the global step are not required to be loaded until the global has finished. + - `load_groups` which determine which steps has to be loaded together and in isolation + with respect to the rest. + + Args: + load_groups: a list containing list of steps that have to be loaded together + in a stage. Defaults to `None`. + Returns: A tuple with the first element containing asorted list by stage containing lists with the names of the steps of the stage, and the second element a list @@ -309,35 +321,54 @@ def _get_stage_last_steps(stage_steps: List[str]) -> List[str]: [node for node in subgraph.nodes() if subgraph.out_degree(node) == 0] ) - stages = [] - current_stage = [] - stages_last_steps = [] + if load_groups is None: + load_groups = [] - steps_sorted = list(nx.topological_sort(self.G)) - for i, step_name in enumerate(steps_sorted): + # Create a load group for each global step + for step_name in self.G: step: "_Step" = self.get_step(step_name)[STEP_ATTR_NAME] - if not step.is_global: - current_stage.append(step_name) - else: - previous_step = None - if i > 0: - previous_step_name = steps_sorted[i - 1] - previous_step = self.get_step(previous_step_name)[STEP_ATTR_NAME] - if not previous_step or not previous_step.is_global: + if step.is_global: + load_groups.append([step_name]) + + # Sort load groups by steps position in the DAG + topological_sort = list(nx.topological_sort(self.G)) + load_groups = sorted(load_groups, key=lambda x: topological_sort.index(x[0])) + + # Create load groups for the rest of the steps that don't belong to any load group + stages: List[List[str]] = [] + current_stage: List[str] = [] + grouped_steps: List[str] = [step for group in load_groups for step in group] + for step_name in topological_sort: + if step_name in grouped_steps: + # If a stage was being created, finish it as we've reached a step belonging + # to another load stage + if current_stage: stages.append(current_stage) - stages_last_steps.append(_get_stage_last_steps(current_stage)) - stages.append([step_name]) - stages_last_steps.append([step_name]) - current_stage = [] + current_stage = [] + + # Append the load group of this step + for group in load_groups: + if step_name in group and group not in stages: + stages.append(group) + break + else: + current_stage.append(step_name) if current_stage: stages.append(current_stage) - stages_last_steps.append(_get_stage_last_steps(current_stage)) + + # No stage was created, so we have a single stage with all the steps of the pipeline + if not stages: + stages.append(topological_sort) + + stages_last_steps = [] + for stage in stages: + stages_last_steps.append(_get_stage_last_steps(stage)) return stages, stages_last_steps def validate(self) -> None: - """Validates that the `Step`s included in the pipeline are correctly connected and + """Validates that the `Step`s included in the pipeline are correctly connected, and have the correct inputs and outputs. Raises: diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index b9a3d4bdc2..978f9be7d1 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -47,7 +47,7 @@ from distilabel.pipeline.batch import _Batch from distilabel.pipeline.batch_manager import _BatchManager from distilabel.pipeline.write_buffer import _WriteBuffer -from distilabel.steps.base import GeneratorStep +from distilabel.steps.base import GeneratorStep, _Step from distilabel.steps.generators.utils import make_generator_step from distilabel.utils.logging import setup_logging, stop_logging from distilabel.utils.notebook import in_notebook @@ -70,10 +70,11 @@ from distilabel.pipeline.routing_batch_function import RoutingBatchFunction from distilabel.pipeline.typing import ( InputDataset, + LoadGroups, PipelineRuntimeParametersInfo, StepLoadStatus, ) - from distilabel.steps.base import Step, _Step + from distilabel.steps.base import Step class _CacheLocation(TypedDict): """Dictionary to store the filenames and directories of a cached pipeline. @@ -125,6 +126,7 @@ def get_pipeline(cls) -> Union["BasePipeline", None]: _STEP_LOAD_FAILED_CODE = -666 _STEP_NOT_LOADED_CODE = -999 +_STEP_UNLOADED_CODE = -1000 _PIPELINE_DEFAULT_NAME = "__default_pipeline_name__" @@ -201,6 +203,7 @@ def __init__( self._batch_manager: Optional["_BatchManager"] = None self._write_buffer: Optional["_WriteBuffer"] = None + self._steps_input_queues: Dict[str, "Queue"] = {} self._steps_load_status: Dict[str, int] = {} self._steps_load_status_lock = threading.Lock() @@ -220,6 +223,7 @@ def __init__( self._current_stage = 0 self._stages_last_batch: List[List[str]] = [] + self._load_groups = [] self.requirements = requirements or [] @@ -277,6 +281,7 @@ def signature(self) -> str: def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, + load_groups: Optional["LoadGroups"] = None, use_cache: bool = True, storage_parameters: Optional[Dict[str, Any]] = None, use_fs_to_pass_data: bool = False, @@ -293,6 +298,14 @@ def run( Args: parameters: A dictionary with the step name as the key and a dictionary with the runtime parameters for the step as the value. Defaults to `None`. + load_groups: A list containing lists of steps that have to be loaded together + and in isolation with respect to the rest of the steps of the pipeline. + This argument also allows passing the following modes: + + - "sequential_step_execution": each step will be executed in a stage i.e. + the execution of the steps will be sequential. + + Defaults to `None`. use_cache: Whether to use the cache from previous pipeline runs. Defaults to `True`. storage_parameters: A dictionary with the storage parameters (`fsspec` and path) @@ -345,8 +358,10 @@ def run( self._set_pipeline_name() # Validate the pipeline DAG to check that all the steps are chainable, there are - # no missing runtime parameters, batch sizes are correct, etc. - self.dag.validate() + # no missing runtime parameters, batch sizes are correct, load groups are valid, + # etc. + self._load_groups = self._built_load_groups(load_groups) + self._validate() self._set_pipeline_artifacts_path_in_steps() @@ -430,6 +445,24 @@ def dry_run( self._dry_run = False return distiset + def get_load_stages( + self, load_groups: Optional["LoadGroups"] = None + ) -> Tuple[List[List[str]], List[List[str]]]: + """Convenient method to get the load stages of a pipeline. + + Args: + load_groups: A list containing list of steps that has to be loaded together + and in isolation with respect to the rest of the steps of the pipeline. + Defaults to `None`. + + Returns: + A tuple with the first element containing asorted list by stage containing + lists with the names of the steps of the stage, and the second element a list + sorted by stage containing lists with the names of the last steps of the stage. + """ + load_groups = self._built_load_groups(load_groups) + return self.dag.get_steps_load_stages(load_groups) + def _add_dataset_generator_step( self, dataset: "InputDataset", batch_size: int = 50 ) -> None: @@ -473,6 +506,108 @@ def get_runtime_parameters_info(self) -> "PipelineRuntimeParametersInfo": runtime_parameters[step_name] = step.get_runtime_parameters_info() return runtime_parameters + def _built_load_groups( + self, load_groups: Optional["LoadGroups"] = None + ) -> List[List[str]]: + if load_groups is None: + return [] + + if load_groups == "sequential_step_execution": + return [[step_name] for step_name in self.dag] + + return [ + [ + step.name if isinstance(step, _Step) else step + for step in steps_load_group + ] # type: ignore + for steps_load_group in load_groups + ] + + def _validate(self) -> None: + """Validates the pipeline DAG to check that all the steps are chainable, there are + no missing runtime parameters, batch sizes are correct and that load groups are + valid (if any).""" + self.dag.validate() + self._validate_load_groups(self._load_groups) + + def _validate_load_groups(self, load_groups: List[List[Any]]) -> None: # noqa: C901 + """Checks that the provided load groups are valid and that the steps can be scheduled + to be loaded in different stages without any issue. + + Args: + load_groups: the load groups to be checked. + + Raises: + DistilabelUserError: if something is not OK when checking the load groups. + """ + + def check_predecessor_in_load_group( + step_name: str, load_group: List[str], first: bool + ) -> Union[str, None]: + if not first and step_name in load_group: + return step_name + + for predecessor_step_name in self.dag.get_step_predecessors(step_name): + # Immediate predecessor is in the same load group. This is OK. + if first and predecessor_step_name in load_group: + continue + + # Case: A -> B -> C, load_group=[A, C] + # If a non-immediate predecessor is in the same load group and an immediate + # predecessor is not , then it's not OK because we cannot load `step_name` + # before one immediate predecessor. + if step_name_in_load_group := check_predecessor_in_load_group( + predecessor_step_name, load_group, False + ): + return step_name_in_load_group + + return None + + steps_included_in_load_group = [] + for load_group_num, steps_load_group in enumerate(load_groups): + for step_name in steps_load_group: + if step_name not in self.dag.G: + raise DistilabelUserError( + f"Step with name '{step_name}' included in group {load_group_num} of" + " the `load_groups` is not an step included in the pipeline. Please," + " check that you're passing the correct step name and run again.", + page="sections/how_to_guides/advanced/load_groups_and_execution_stages", + ) + + node = self.dag.get_step(step_name) + step: "_Step" = node[constants.STEP_ATTR_NAME] + + if step_name_in_load_group := check_predecessor_in_load_group( + step_name, steps_load_group, True + ): + # Improve this user error message + raise DistilabelUserError( + f"Step with name '{step_name}' cannot be in the same load group" + f" as the step with name '{step_name_in_load_group}'. '{step_name_in_load_group}'" + f" is not an immediate predecessor of '{step_name}' and there are" + " immediate predecessors that have not been included.", + page="sections/how_to_guides/advanced/load_groups_and_execution_stages", + ) + + if step.is_global and len(steps_load_group) > 1: + raise DistilabelUserError( + f"Global step '{step_name}' has been included in a load group along" + " more steps. Global steps cannot be included in a load group with" + " more steps as they will be loaded in a different stage to the" + " rest of the steps in the pipeline by default.", + page="sections/how_to_guides/advanced/load_groups_and_execution_stages", + ) + + if step_name in steps_included_in_load_group: + raise DistilabelUserError( + f"Step with name '{step_name}' in load group {load_group_num} has" + " already been included in a previous load group. A step cannot be in more" + " than one load group.", + page="sections/how_to_guides/advanced/load_groups_and_execution_stages", + ) + + steps_included_in_load_group.append(step_name) + def _init_steps_load_status(self) -> None: """Initialize the `_steps_load_status` dictionary assigning 0 to every step of the pipeline.""" @@ -752,6 +887,9 @@ def _save_stages_status(self) -> None: }, ) + def _get_steps_load_stages(self) -> Tuple[List[List[str]], List[List[str]]]: + return self.dag.get_steps_load_stages(self._load_groups) + def _load_stages_status(self, use_cache: bool = True) -> None: """Try to load the stages status from cache, or initialize it if cache file doesn't exist or cache is not going to be used.""" @@ -762,7 +900,7 @@ def _load_stages_status(self, use_cache: bool = True) -> None: else: self._current_stage = 0 self._stages_last_batch = [ - [] for _ in range(len(self.dag.get_steps_load_stages()[0])) + [] for _ in range(len(self._get_steps_load_stages()[0])) ] def _refresh_pipeline_from_cache(self) -> None: @@ -887,17 +1025,25 @@ def _setup_write_buffer(self, use_cache: bool = True) -> None: def _print_load_stages_info(self) -> None: """Prints the information about the load stages.""" - stages, _ = self.dag.get_steps_load_stages() + stages, _ = self._get_steps_load_stages() msg = "" for stage, steps in enumerate(stages): steps_to_be_loaded = self._steps_to_be_loaded_in_stage(stage) msg += f"\n * Stage {stage}:" - for step in steps: - msg += f"\n - '{step}'" - if step not in steps_to_be_loaded: + for step_name in steps: + step: "Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME] + if step.is_generator: + emoji = "🚰" + elif step.is_global: + emoji = "🌐" + else: + emoji = "🔄" + msg += f"\n - {emoji} '{step_name}'" + if step_name not in steps_to_be_loaded: msg += " (results cached, won't be loaded and executed)" + legend = "\n * Legend: 🚰 GeneratorStep 🌐 GlobalStep 🔄 Step" self._logger.info( - f"⌛ The steps of the pipeline will be loaded in stages:{msg}" + f"⌛ The steps of the pipeline will be loaded in stages:{legend}{msg}" ) def _run_output_queue_loop_in_thread(self) -> threading.Thread: @@ -911,6 +1057,8 @@ def _run_output_queue_loop_in_thread(self) -> threading.Thread: def _output_queue_loop(self) -> None: """Loop to receive the output batches from the steps and manage the flow of the batches through the pipeline.""" + self._create_steps_input_queues() + if not self._initialize_pipeline_execution(): return @@ -939,6 +1087,7 @@ def _output_queue_loop(self) -> None: # If there is another load stage and all the `last_batch`es from the stage # have been received, then load the next stage. if self._should_load_next_stage(): + self._wait_current_stage_to_finish() if not self._update_stage(): break @@ -946,6 +1095,13 @@ def _output_queue_loop(self) -> None: self._finalize_pipeline_execution() + def _create_steps_input_queues(self) -> None: + """Creates the input queue for all the steps in the pipeline.""" + for step_name in self.dag: + self._logger.debug(f"Creating input queue for '{step_name}' step...") + input_queue = self._create_step_input_queue(step_name) + self._steps_input_queues[step_name] = input_queue + def _initialize_pipeline_execution(self) -> bool: """Load the steps of the required stage to initialize the pipeline execution, and requests the initial batches to trigger the batch flowing in the pipeline. @@ -1046,7 +1202,7 @@ def _register_stages_last_batch(self, batch: "_Batch") -> None: Args: batch: The last batch received from a step. """ - _, stages_last_steps = self.dag.get_steps_load_stages() + _, stages_last_steps = self._get_steps_load_stages() stage_last_steps = stages_last_steps[self._current_stage] if batch.step_name in stage_last_steps: self._stages_last_batch[self._current_stage].append(batch.step_name) @@ -1072,7 +1228,7 @@ def _should_load_next_stage(self) -> bool: Returns: `True` if the next stage should be loaded, `False` otherwise. """ - _, stage_last_steps = self.dag.get_steps_load_stages() + _, stage_last_steps = self._get_steps_load_stages() there_is_next_stage = self._current_stage + 1 < len(stage_last_steps) stage_last_batches_received = ( self._stages_last_batch[self._current_stage] @@ -1132,6 +1288,8 @@ def _run_load_queue_loop(self) -> None: self._steps_load_status[step_name] += 1 elif status == "unloaded": self._steps_load_status[step_name] -= 1 + if self._steps_load_status[step_name] == 0: + self._steps_load_status[step_name] = _STEP_UNLOADED_CODE else: # load failed self._steps_load_status[step_name] = _STEP_LOAD_FAILED_CODE @@ -1164,7 +1322,7 @@ def _steps_to_be_loaded_in_stage(self, stage: int) -> List[str]: """ assert self._batch_manager, "Batch manager is not set" - steps_stages, _ = self.dag.get_steps_load_stages() + steps_stages, _ = self._get_steps_load_stages() return [ step @@ -1172,6 +1330,36 @@ def _steps_to_be_loaded_in_stage(self, stage: int) -> List[str]: if not self._batch_manager.step_has_finished(step) ] + def _get_steps_load_status(self, steps: List[str]) -> Dict[str, int]: + """Gets the a dictionary containing the load status of the provided steps. + + Args: + steps: a list containing the names of the steps to get their load status. + + Returns: + A dictionary containing the load status of the provided steps. + """ + return { + step_name: replicas + for step_name, replicas in self._steps_load_status.items() + if step_name in steps + } + + def _wait_current_stage_to_finish(self) -> None: + """Waits for the current stage to finish.""" + stage = self._current_stage + steps = self._steps_to_be_loaded_in_stage(stage) + self._logger.info(f"⏳ Waiting for stage {stage} to finish...") + with self._stop_called_lock: + while not self._stop_called: + filtered_steps_load_status = self._get_steps_load_status(steps) + if all( + replicas == _STEP_UNLOADED_CODE + for replicas in filtered_steps_load_status.values() + ): + self._logger.info(f"✅ Stage {stage} has finished!") + break + def _run_stage_steps_and_wait(self, stage: int) -> bool: """Runs the steps of the specified stage and waits for them to be ready. @@ -1195,11 +1383,7 @@ def _run_stage_steps_and_wait(self, stage: int) -> bool: with self._stop_called_lock: while not self._stop_called: with self._steps_load_status_lock: - filtered_steps_load_status = { - step_name: replicas - for step_name, replicas in self._steps_load_status.items() - if step_name in steps - } + filtered_steps_load_status = self._get_steps_load_status(steps) self._logger.debug( f"Steps from stage {stage} loaded: {filtered_steps_load_status}" ) @@ -1217,7 +1401,14 @@ def _run_stage_steps_and_wait(self, stage: int) -> bool: replicas_message = "" for step_name, replicas in filtered_steps_load_status.items(): step_replica_count = self.dag.get_step_replica_count(step_name) - if replicas == step_replica_count: + # It can happen that the step is very fast and it has done all the + # work and have finished its execution before checking if it has + # been loaded, that's why we also considered the step to be loaded + # if `_STEP_UNLOADED_CODE`. + if ( + replicas == step_replica_count + or replicas == _STEP_UNLOADED_CODE + ): num_steps_loaded += 1 replicas_message += f"\n * '{step_name}' replicas: {max(0, replicas)}/{step_replica_count}" @@ -1246,14 +1437,19 @@ def _handle_stop(self) -> None: # Wait for the input queue to be empty, which means that all the steps finished # processing the batches that were sent before the stop flag. - for step_name in self.dag: - self._wait_step_input_queue_empty(step_name) + self._wait_steps_input_queues_empty() self._consume_output_queue() if self._should_load_next_stage(): self._current_stage += 1 + def _wait_steps_input_queues_empty(self) -> None: + self._logger.debug("Waiting for steps input queues to be empty...") + for step_name in self.dag: + self._wait_step_input_queue_empty(step_name) + self._logger.debug("Steps input queues are empty!") + def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]: """Waits for the input queue of a step to be empty. @@ -1286,7 +1482,7 @@ def _check_step_not_loaded_or_finished(self, step_name: str) -> bool: num_replicas = self._steps_load_status[step_name] # The step has finished (replicas = 0) or it has failed to load - if num_replicas in [0, _STEP_LOAD_FAILED_CODE]: + if num_replicas in [0, _STEP_LOAD_FAILED_CODE, _STEP_UNLOADED_CODE]: return True return False @@ -1330,7 +1526,7 @@ def _run_steps(self, steps: Iterable[str]) -> None: """ for step_name in steps: step: "Step" = self.dag.get_step(step_name)[constants.STEP_ATTR_NAME] - input_queue = self._create_step_input_queue(step_name=step_name) + input_queue = self._steps_input_queues[step.name] # type: ignore # Set `pipeline` to `None` as in some Python environments the pipeline is not # picklable and it will raise an error when trying to send the step to the process. @@ -1358,6 +1554,9 @@ def _run_steps(self, steps: Iterable[str]) -> None: def _add_batches_back_to_batch_manager(self) -> None: """Add the `Batch`es that were sent to a `Step` back to the `_BatchManager`. This method should be used when the pipeline has been stopped prematurely.""" + self._logger.debug( + "Adding batches from step input queues back to the batch manager..." + ) for step_name in self.dag: node = self.dag.get_step(step_name) step: "_Step" = node[constants.STEP_ATTR_NAME] @@ -1376,7 +1575,10 @@ def _add_batches_back_to_batch_manager(self) -> None: self._logger.debug( f"Adding batch back to the batch manager: {batch}" ) - input_queue.put(None) + if self._check_step_not_loaded_or_finished(step_name): + # Notify the step to stop + input_queue.put(None) + self._logger.debug("Finished adding batches back to the batch manager.") def _consume_output_queue(self) -> None: """Consumes the `Batch`es from the output queue until it's empty. This method should @@ -1442,12 +1644,12 @@ def _manage_batch_flow(self, batch: "_Batch") -> None: # Step ("this", the one from which the batch was received) has enough data on its # buffers to create a new batch while new_batch := self._batch_manager.get_batch(step.name): # type: ignore - # if new_batch := self._batch_manager.get_batch(step.name): # type: ignore self._send_batch_to_step(new_batch) - else: self._request_more_batches_if_needed(step) else: + # Case in which the pipeline only contains a `GeneratorStep` so we constanly keep + # requesting batch after batch as there is no downstream step to consume it if len(self.dag) == 1: self._request_batch_from_generator(step.name) # type: ignore diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index a35100e156..e8716f1ade 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -44,10 +44,9 @@ from queue import Queue from distilabel.distiset import Distiset - from distilabel.pipeline.typing import InputDataset + from distilabel.pipeline.typing import InputDataset, LoadGroups from distilabel.steps.base import _Step - _SUBPROCESS_EXCEPTION: Union[Exception, None] = None @@ -148,6 +147,7 @@ def ray( def run( self, parameters: Optional[Dict[Any, Dict[str, Any]]] = None, + load_groups: Optional["LoadGroups"] = None, use_cache: bool = True, storage_parameters: Optional[Dict[str, Any]] = None, use_fs_to_pass_data: bool = False, @@ -160,6 +160,14 @@ def run( Args: parameters: A dictionary with the step name as the key and a dictionary with the runtime parameters for the step as the value. Defaults to `None`. + load_groups: A list containing lists of steps that have to be loaded together + and in isolation with respect to the rest of the steps of the pipeline. + This argument also allows passing the following modes: + + - "sequential_step_execution": each step will be executed in a stage i.e. + the execution of the steps will be sequential. + + Defaults to `None`. use_cache: Whether to use the cache from previous pipeline runs. Defaults to `True`. storage_parameters: A dictionary with the storage parameters (`fsspec` and path) @@ -203,6 +211,7 @@ def run( if distiset := super().run( parameters=parameters, + load_groups=load_groups, use_cache=use_cache, storage_parameters=storage_parameters, use_fs_to_pass_data=use_fs_to_pass_data, diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py index 4b8ff509e3..30d2e5a47e 100644 --- a/src/distilabel/pipeline/ray.py +++ b/src/distilabel/pipeline/ray.py @@ -32,7 +32,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from distilabel.distiset import Distiset - from distilabel.pipeline.typing import InputDataset + from distilabel.pipeline.typing import InputDataset, LoadGroups from distilabel.steps.base import _Step @@ -79,6 +79,7 @@ def __init__( def run( self, parameters: Optional[Dict[str, Dict[str, Any]]] = None, + load_groups: Optional["LoadGroups"] = None, use_cache: bool = True, storage_parameters: Optional[Dict[str, Any]] = None, use_fs_to_pass_data: bool = False, @@ -91,6 +92,14 @@ def run( Args: parameters: A dictionary with the step name as the key and a dictionary with the runtime parameters for the step as the value. Defaults to `None`. + load_groups: A list containing lists of steps that have to be loaded together + and in isolation with respect to the rest of the steps of the pipeline. + This argument also allows passing the following modes: + + - "sequential_step_execution": each step will be executed in a stage i.e. + the execution of the steps will be sequential. + + Defaults to `None`. use_cache: Whether to use the cache from previous pipeline runs. Defaults to `True`. storage_parameters: A dictionary with the storage parameters (`fsspec` and path) @@ -129,6 +138,7 @@ def run( if distiset := super().run( parameters=parameters, + load_groups=load_groups, use_cache=use_cache, storage_parameters=storage_parameters, use_fs_to_pass_data=use_fs_to_pass_data, diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py index 8b33da933d..1caa3a3e38 100644 --- a/src/distilabel/pipeline/step_wrapper.py +++ b/src/distilabel/pipeline/step_wrapper.py @@ -177,7 +177,7 @@ def _generator_step_process_loop(self) -> None: offset = batch.seq_no * step.batch_size # type: ignore self.step._logger.info( - f"🧬 Starting yielding batches from generator step '{self.step.name}'." + f"🚰 Starting yielding batches from generator step '{self.step.name}'." f" Offset: {offset}" ) diff --git a/src/distilabel/pipeline/typing.py b/src/distilabel/pipeline/typing.py index 690acecaae..3e796948aa 100644 --- a/src/distilabel/pipeline/typing.py +++ b/src/distilabel/pipeline/typing.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + TypedDict, + TypeVar, + Union, +) if TYPE_CHECKING: import pandas as pd @@ -53,3 +62,12 @@ class StepLoadStatus(TypedDict): InputDataset = Union["Dataset", "pd.DataFrame", List[Dict[str, str]]] """Alias for the types we can process as input dataset.""" + +LoadGroups = Union[List[List[Any]], Literal["sequential_step_execution"]] +"""Alias for the types that can be used as load groups. + +- if `List[List[Any]]`, it's a list containing lists of steps that have to be loaded in +isolation. +- if "sequential_step_execution", each step will be loaded in a different stage i.e. only +one step will be executed at a time. +""" diff --git a/tests/integration/test_load_groups.py b/tests/integration/test_load_groups.py new file mode 100644 index 0000000000..a30efeb877 --- /dev/null +++ b/tests/integration/test_load_groups.py @@ -0,0 +1,105 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING +from unittest import mock + +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromDicts, StepInput, step + +if TYPE_CHECKING: + from distilabel.typing import StepOutput + + +@step(inputs=["instruction"], outputs=["instruction2"]) +def DummyStep(inputs: StepInput) -> "StepOutput": + for input in inputs: + input["instruction2"] = "miau" + yield inputs + + +@step(inputs=["instruction"], outputs=["instruction2"]) +def DummyStep2(*inputs: StepInput) -> "StepOutput": + outputs = [] + for rows in zip(*inputs): + combined = {} + for row in rows: + combined.update(row) + outputs.append(combined) + yield outputs + + +@step(inputs=["instruction"], outputs=["instruction2"], step_type="global") +def GlobalDummyStep(inputs: StepInput) -> "StepOutput": + for input in inputs: + input["instruction2"] = "miau" + yield inputs + + +def test_load_groups() -> None: + with Pipeline() as pipeline: + generator = LoadDataFromDicts(data=[{"instruction": "Hi"}] * 50) + dummy_step_0 = DummyStep() + dummy_step_1 = DummyStep() + dummy_step_2 = DummyStep2() + global_dummy_step = GlobalDummyStep() + dummy_step_3 = DummyStep() + dummy_step_4 = DummyStep() + dummy_step_5 = DummyStep() + + ( + generator + >> [dummy_step_0, dummy_step_1] + >> dummy_step_2 + >> global_dummy_step + >> dummy_step_3 + >> [dummy_step_4, dummy_step_5] + ) + + with mock.patch.object( + pipeline, "_run_stage_steps_and_wait", wraps=pipeline._run_stage_steps_and_wait + ) as run_stage_mock: + # `dummy_step_0` should be executed in isolation + pipeline.run(load_groups=[[dummy_step_0.name], [dummy_step_3.name]]) + + assert run_stage_mock.call_count == 6 + + +def test_load_groups_sequential_step_execution() -> None: + with Pipeline() as pipeline: + generator = LoadDataFromDicts(data=[{"instruction": "Hi"}] * 50) + dummy_step_0 = DummyStep() + dummy_step_1 = DummyStep() + dummy_step_2 = DummyStep2() + global_dummy_step = GlobalDummyStep() + dummy_step_3 = DummyStep() + dummy_step_4 = DummyStep() + dummy_step_5 = DummyStep() + + ( + generator + >> [dummy_step_0, dummy_step_1] + >> dummy_step_2 + >> global_dummy_step + >> dummy_step_3 + >> [dummy_step_4, dummy_step_5] + ) + + with mock.patch.object( + pipeline, "_run_stage_steps_and_wait", wraps=pipeline._run_stage_steps_and_wait + ) as run_stage_mock: + # `dummy_step_0` should be executed in isolation + pipeline.run(load_groups="sequential_step_execution") + + assert run_stage_mock.call_count == 8 diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 77faf25d14..3cb680eb06 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -97,6 +97,70 @@ def test_get_pipeline(self) -> None: class TestBasePipeline: + def test_get_load_stages(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + load_stages = pipeline.get_load_stages(load_groups=[[step2.name]]) + + assert load_stages == ( + [[generator.name, step.name], [step2.name], [step3.name]], + [[step.name], [step2.name], [step3.name]], + ) + + def test_get_load_stages_sequential_step_execution(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + load_stages = pipeline.get_load_stages(load_groups="sequential_step_execution") + + assert load_stages == ( + [[generator.name], [step.name], [step2.name], [step3.name]], + [[generator.name], [step.name], [step2.name], [step3.name]], + ) + + @pytest.mark.parametrize( + "load_groups, expected", + [ + ([["step_0", "step_1"], ["step_2"]], [["step_0", "step_1"], ["step_2"]]), + ("sequential_step_execution", [["step_0"], ["step_1"], ["step_2"]]), + ], + ) + def test_built_load_groups( + self, load_groups: Any, expected: List[List[str]] + ) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep(name="step_0") + step = DummyStep1(name="step_1") + step2 = DummyStep1(name="step_2") + + generator >> [step, step2] + + assert pipeline._built_load_groups(load_groups) == expected + + def test_built_load_groups_with_step_class(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + + generator >> [step, step2] + + assert pipeline._built_load_groups([[generator], [step, step2]]) == [ + [generator.name], + [step.name, step2.name], + ] + def test_aggregated_steps_signature(self) -> None: with DummyPipeline(name="dummy") as pipeline_0: generator = DummyGeneratorStep() @@ -372,6 +436,7 @@ def test_run_stage_steps_and_wait(self, caplog) -> None: generator >> [step, step2] >> step3 >> step4 pipeline._load_batch_manager() + pipeline._create_steps_input_queues() pipeline._steps_load_status = { # type: ignore generator.name: 1, step.name: 1, @@ -396,6 +461,7 @@ def test_run_stage_steps_and_wait_with_failing_step(self, caplog) -> None: pipeline._init_steps_load_status() pipeline._load_batch_manager() + pipeline._create_steps_input_queues() pipeline._steps_load_status[generator.name] = _STEP_LOAD_FAILED_CODE # type: ignore caplog.set_level(logging.INFO) @@ -414,6 +480,7 @@ def test_run_stage_steps_and_wait_stop_called(self) -> None: pipeline._init_steps_load_status() pipeline._load_batch_manager() + pipeline._create_steps_input_queues() pipeline._stop_called = True assert pipeline._run_stage_steps_and_wait(stage=0) is False @@ -489,6 +556,16 @@ def test_create_step_input_queue(self) -> None: pipeline.dag.get_step(generator_name)[INPUT_QUEUE_ATTR_NAME], Queue ) + def test_create_steps_input_queues(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + steps = [DummyStep1() for _ in range(5)] + + generator >> steps + + pipeline._create_steps_input_queues() + assert len(pipeline._steps_input_queues) == 6 + def test_run_steps(self) -> None: with DummyPipeline(name="unit-test-pipeline") as pipeline: generator = DummyGeneratorStep() @@ -497,18 +574,10 @@ def test_run_steps(self) -> None: generator >> step >> global_step - pipeline._create_step_input_queue = mock.MagicMock() pipeline._run_step = mock.MagicMock() + pipeline._create_steps_input_queues() pipeline._run_steps(steps=[generator.name, step.name]) # type: ignore - pipeline._create_step_input_queue.assert_has_calls( - [ - mock.call(step_name=step.name), - mock.call(step_name=generator.name), - ], - any_order=True, - ) - pipeline._run_step.assert_has_calls( [ mock.call(step=mock.ANY, input_queue=mock.ANY, replica=0), @@ -528,6 +597,7 @@ def test_add_batches_back_to_batch_manager(self) -> None: step_name: str = step.name # type: ignore pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) + pipeline._init_steps_load_status() generator_queue = Queue() pipeline.dag.set_step_attr( generator_name, INPUT_QUEUE_ATTR_NAME, generator_queue @@ -1259,6 +1329,53 @@ def test_optional_name(self): assert pipeline.name == "pipeline_dummy_generator_step_0_dummy_step1_0" + def test_validate_load_groups_step_not_in_pipeline(self) -> None: + pipeline = DummyPipeline() + + with pytest.raises( + ValueError, + match="Step with name 'random' included in group 0 of the `load_groups` is not an step included in the pipeline.", + ): + pipeline._validate_load_groups(load_groups=[["random"]]) + + def test_validate_load_groups_including_global_step(self) -> None: + pipeline = DummyPipeline() + step = DummyGlobalStep(pipeline=pipeline) + step_0 = DummyStep1() + with pytest.raises( + ValueError, + match=f"Global step '{step.name}' has been included in a load group.", + ): + pipeline._validate_load_groups(load_groups=[[step.name, step_0.name]]) + + def test_validate_load_groups_duplicate_step(self) -> None: + pipeline = DummyPipeline() + dummy_step_1 = DummyStep1(pipeline=pipeline) + + with pytest.raises( + ValueError, + match=f"Step with name '{dummy_step_1.name}' in load group 1 has already been included in a previous load group.", + ): + pipeline._validate_load_groups( + load_groups=[[dummy_step_1.name], [dummy_step_1.name]] + ) + + def test_validate_load_groups_non_immediate_predecessor(self) -> None: + pipeline = DummyPipeline() + generator_step_1 = DummyGeneratorStep(pipeline=pipeline) + dummy_step_1 = DummyStep1(pipeline=pipeline) + dummy_step_2 = DummyStep1(name="demon", pipeline=pipeline, input_batch_size=7) + + generator_step_1 >> dummy_step_1 >> dummy_step_2 + + with pytest.raises( + ValueError, + match=f"Step with name '{dummy_step_2.name}' cannot be in the same load group as the step with name '{generator_step_1.name}'.", + ): + pipeline._validate_load_groups( + load_groups=[[generator_step_1.name, dummy_step_2.name]] + ) + class TestPipelineSerialization: @pytest.mark.parametrize( diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py index a5b55520f4..6a6163b75e 100644 --- a/tests/unit/pipeline/test_dag.py +++ b/tests/unit/pipeline/test_dag.py @@ -337,6 +337,54 @@ def test_get_steps_load_stages_simple(self) -> None: ], ) + def test_get_steps_load_stages_with_load_groups(self) -> None: + with Pipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep(name="dummy_generator_step") + dummy_step_0 = DummyStep1() + dummy_step_1 = DummyStep1() + dummy_step_2 = DummyStep1() + dummy_step_3 = DummyStep1() + dummy_step_4 = DummyStep1() + dummy_step_5 = DummyStep1() + dummy_step_6 = DummyStep1() + dummy_step_7 = DummyStep1() + + ( + generator + >> dummy_step_0 + >> [dummy_step_1, dummy_step_2] + >> dummy_step_3 + >> dummy_step_4 + >> dummy_step_5 + >> dummy_step_6 + >> dummy_step_7 + ) + + assert pipeline.dag.get_steps_load_stages( + load_groups=[ + [dummy_step_5.name], + [dummy_step_0.name, dummy_step_1.name], + [dummy_step_4.name], + ] + ) == ( + [ + [generator.name], + [dummy_step_0.name, dummy_step_1.name], + [dummy_step_2.name, dummy_step_3.name], + [dummy_step_4.name], + [dummy_step_5.name], + [dummy_step_6.name, dummy_step_7.name], + ], + [ + [generator.name], + [dummy_step_1.name], + [dummy_step_3.name], + [dummy_step_4.name], + [dummy_step_5.name], + [dummy_step_7.name], + ], + ) + def test_validate_first_step_not_generator( self, dummy_step_1: "Step", dummy_step_2: "Step" ) -> None: diff --git a/tests/unit/pipeline/test_local.py b/tests/unit/pipeline/test_local.py index d661f21a95..9ac14f6309 100644 --- a/tests/unit/pipeline/test_local.py +++ b/tests/unit/pipeline/test_local.py @@ -40,6 +40,7 @@ def test_run_steps(self, step_wrapper_mock: mock.MagicMock) -> None: pipeline._manager = mock.MagicMock() pipeline._output_queue = mock.MagicMock() pipeline._load_queue = mock.MagicMock() + pipeline._create_steps_input_queues() pipeline._run_steps( steps=[dummy_generator.name, dummy_step_1.name, dummy_step_2.name] # type: ignore ) From 63c75c59f0cbef109305a40b8c20cce0a9e3428d Mon Sep 17 00:00:00 2001 From: Agus Date: Mon, 9 Dec 2024 11:27:32 +0100 Subject: [PATCH 06/30] Add TextGenerationWithImage task (#1066) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../examples/text_generation_with_image.md | 116 ++++++++++ docs/sections/pipeline_samples/index.md | 7 + examples/text_generation_with_image.py | 41 ++++ mkdocs.yml | 3 + src/distilabel/models/llms/openai.py | 6 + src/distilabel/models/llms/vertexai.py | 14 +- src/distilabel/pipeline/base.py | 7 +- src/distilabel/steps/tasks/__init__.py | 2 + .../steps/tasks/math_shepherd/completer.py | 8 +- .../steps/tasks/math_shepherd/utils.py | 10 +- src/distilabel/steps/tasks/text_generation.py | 19 +- .../steps/tasks/text_generation_with_image.py | 215 ++++++++++++++++++ src/distilabel/steps/tasks/typing.py | 31 ++- src/distilabel/steps/typing.py | 2 +- src/distilabel/utils/image.py | 26 +++ .../utils/mkdocs/components_gallery.py | 1 + src/distilabel/utils/template.py | 47 ++++ tests/unit/models/llms/test_vertexai.py | 28 --- tests/unit/steps/argilla/test_preference.py | 9 +- .../tasks/test_text_generation_with_image.py | 103 +++++++++ 20 files changed, 629 insertions(+), 66 deletions(-) create mode 100644 docs/sections/pipeline_samples/examples/text_generation_with_image.md create mode 100644 examples/text_generation_with_image.py create mode 100644 src/distilabel/steps/tasks/text_generation_with_image.py create mode 100644 src/distilabel/utils/image.py create mode 100644 src/distilabel/utils/template.py create mode 100644 tests/unit/steps/tasks/test_text_generation_with_image.py diff --git a/docs/sections/pipeline_samples/examples/text_generation_with_image.md b/docs/sections/pipeline_samples/examples/text_generation_with_image.md new file mode 100644 index 0000000000..0978edb325 --- /dev/null +++ b/docs/sections/pipeline_samples/examples/text_generation_with_image.md @@ -0,0 +1,116 @@ +--- +hide: toc +--- + +# Text generation with images in `distilabel` + +Answer questions about images using `distilabel`. + +Image-text-to-text models take in an image and text prompt and output text. In this example we will use an LLM [`InferenceEndpointsLLM`](https://distilabel.argilla.io/dev/components-gallery/llms/inferenceendpointsllm/) with [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) to ask a question about an image, and [`OpenAILLM`](https://distilabel.argilla.io/dev/components-gallery/llms/openaillm/) with `gpt-4o-mini`. We will ask a simple question to showcase how the [`TextGenerationWithImage`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgenerationwithimage/) task can be used in a pipeline. + +=== "Inference Endpoints - meta-llama/Llama-3.2-11B-Vision-Instruct" + + ```python + from distilabel.models.llms import InferenceEndpointsLLM + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage + from distilabel.steps import LoadDataFromDicts + + + with Pipeline(name="vision_generation_pipeline") as pipeline: + loader = LoadDataFromDicts( + data=[ + { + "instruction": "What’s in this image?", + "image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + ], + ) + + llm = InferenceEndpointsLLM( + model_id="meta-llama/Llama-3.2-11B-Vision-Instruct", + ) + + vision = TextGenerationWithImage( + name="vision_gen", + llm=llm, + image_type="url" # (1) + ) + + loader >> vision + ``` + + 1. The *image_type* can be a url pointing to the image, the base64 string representation, or a PIL image, take a look at the [`TextGenerationWithImage`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgenerationwithimage/) for more information. + + Image: + + ![Image](https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg) + + Question: + + > What’s in this image? + + Response: + + > This image depicts a wooden boardwalk weaving its way through a lush meadow, flanked by vibrant green grass that stretches towards the horizon under a calm and inviting sky. The boardwalk runs straight ahead, away from the viewer, forming a clear pathway through the tall, lush green grass, crops or other plant types or an assortment of small trees and shrubs. This meadow is dotted with trees and shrubs, appearing to be healthy and green. The sky above is a beautiful blue with white clouds scattered throughout, adding a sense of tranquility to the scene. While this image appears to be of a natural landscape, because grass is... + +=== "OpenAI - gpt-4o-mini" + + ```python + from distilabel.models.llms import OpenAILLM + from distilabel.pipeline import Pipeline + from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage + from distilabel.steps import LoadDataFromDicts + + + with Pipeline(name="vision_generation_pipeline") as pipeline: + loader = LoadDataFromDicts( + data=[ + { + "instruction": "What’s in this image?", + "image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + ], + ) + + llm = OpenAILLM( + model="gpt-4o-mini", + ) + + vision = TextGenerationWithImage( + name="vision_gen", + llm=llm, + image_type="url" # (1) + ) + + loader >> vision + ``` + + 1. The *image_type* can be a url pointing to the image, the base64 string representation, or a PIL image, take a look at the [`VisionGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/visiongeneration/) for more information. + + Image: + + ![Image](https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg) + + Question: + + > What’s in this image? + + Response: + + > The image depicts a scenic landscape featuring a wooden walkway or path that runs through a lush green marsh or field. The area is surrounded by tall grass and various shrubs, with trees likely visible in the background. The sky is blue with some wispy clouds, suggesting a beautiful day. Overall, it presents a peaceful natural setting, ideal for a stroll or nature observation. + + +The full pipeline can be run at the following example: + +??? Note "Run the full pipeline" + + ```python + python examples/text_generation_with_image.py + ``` + + ```python title="text_generation_with_image.py" + --8<-- "examples/text_generation_with_image.py" + ``` + +A sample dataset can be seen at [plaguss/test-vision-generation-Llama-3.2-11B-Vision-Instruct](https://huggingface.co/datasets/plaguss/test-vision-generation-Llama-3.2-11B-Vision-Instruct). diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md index a4789283e2..1c95b60b18 100644 --- a/docs/sections/pipeline_samples/index.md +++ b/docs/sections/pipeline_samples/index.md @@ -161,6 +161,13 @@ hide: toc [:octicons-arrow-right-24: Example](examples/exam_questions.md) +- __Text generation with images in distilabel__ + + --- + + Ask questions about images using distilabel. + + [:octicons-arrow-right-24: Example](examples/text_generation_with_image.md) diff --git a/examples/text_generation_with_image.py b/examples/text_generation_with_image.py new file mode 100644 index 0000000000..0d5b837b18 --- /dev/null +++ b/examples/text_generation_with_image.py @@ -0,0 +1,41 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.models.llms import InferenceEndpointsLLM +from distilabel.pipeline import Pipeline +from distilabel.steps import LoadDataFromDicts +from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage + +with Pipeline(name="vision_generation_pipeline") as pipeline: + loader = LoadDataFromDicts( + data=[ + { + "instruction": "What’s in this image?", + "image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + } + ], + ) + + llm = InferenceEndpointsLLM( + model_id="meta-llama/Llama-3.2-11B-Vision-Instruct", + ) + + vision = TextGenerationWithImage(name="vision_gen", llm=llm, image_type="url") + + loader >> vision + + +if __name__ == "__main__": + distiset = pipeline.run(use_cache=False) + distiset.push_to_hub("plaguss/test-vision-generation-Llama-3.2-11B-Vision-Instruct") diff --git a/mkdocs.yml b/mkdocs.yml index 9654d03530..f5a98be65d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -94,6 +94,8 @@ theme: watch: - src/distilabel +strict: true + # Extensions markdown_extensions: - attr_list @@ -220,6 +222,7 @@ nav: - Structured generation with instructor: "sections/pipeline_samples/examples/mistralai_with_instructor.md" - Create a social network with FinePersonas: "sections/pipeline_samples/examples/fine_personas_social_network.md" - Create questions and answers for a exam: "sections/pipeline_samples/examples/exam_questions.md" + - Text generation with images in distilabel: "sections/pipeline_samples/examples/text_generation_with_image.md" - API Reference: - Step: - "api/step/index.md" diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index a9ccc90dfa..e58c0b42ce 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -293,6 +293,12 @@ async def agenerate( # type: ignore "top_p": top_p, "stop": stop, } + # Check if it's a vision generation task, in that case "stop" cannot be used or raises + # an error in the API. + if isinstance( + [row for row in input if row["role"] == "user"][0]["content"], list + ): + kwargs.pop("stop") if response_format is not None: kwargs["response_format"] = response_format diff --git a/src/distilabel/models/llms/vertexai.py b/src/distilabel/models/llms/vertexai.py index c617b7bcf2..8f5dc28bbd 100644 --- a/src/distilabel/models/llms/vertexai.py +++ b/src/distilabel/models/llms/vertexai.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from pydantic import PrivateAttr, validate_call +from typing_extensions import TypedDict from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput @@ -27,6 +28,15 @@ from distilabel.llms.typing import LLMStatistics +class VertexChatItem(TypedDict): + role: Literal["user", "model"] + content: str + + +VertexChatType = List[VertexChatItem] +"""VertexChatType is a type alias for a `list` of `dict`s following the VertexAI conversational format.""" + + class VertexAILLM(AsyncLLM): """VertexAI LLM implementation running the async API clients for Gemini. @@ -121,7 +131,7 @@ def _chattype_to_content(self, input: "StandardInput") -> List["Content"]: @validate_call async def agenerate( # type: ignore self, - input: StandardInput, + input: VertexChatType, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 978f9be7d1..168599f782 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -95,6 +95,9 @@ class _CacheLocation(TypedDict): stages_file: Path +LoadStages = tuple[list[list[str]], list[list[str]]] + + class _GlobalPipelineManager: """Class to manage the global pipeline instance that will be used by the steps when created within a pipeline context. @@ -445,9 +448,7 @@ def dry_run( self._dry_run = False return distiset - def get_load_stages( - self, load_groups: Optional["LoadGroups"] = None - ) -> Tuple[List[List[str]], List[List[str]]]: + def get_load_stages(self, load_groups: Optional["LoadGroups"] = None) -> LoadStages: """Convenient method to get the load stages of a pipeline. Args: diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 8e96d59f0a..aa0460c3e1 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -54,6 +54,7 @@ from distilabel.steps.tasks.structured_generation import StructuredGeneration from distilabel.steps.tasks.text_classification import TextClassification from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration +from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback from distilabel.steps.tasks.urial import URIAL @@ -101,4 +102,5 @@ "CLAIR", "UltraFeedback", "URIAL", + "TextGenerationWithImage", ] diff --git a/src/distilabel/steps/tasks/math_shepherd/completer.py b/src/distilabel/steps/tasks/math_shepherd/completer.py index 3606c4ec98..5d3fdd7e15 100644 --- a/src/distilabel/steps/tasks/math_shepherd/completer.py +++ b/src/distilabel/steps/tasks/math_shepherd/completer.py @@ -528,7 +528,13 @@ def _auto_label( return inputs - def _add_metadata(self, input, statistics, raw_output, raw_input): + def _add_metadata( + self, + input: dict[str, Any], + statistics: list["LLMStatistics"], + raw_output: Union[str, None], + raw_input: Union[list[dict[str, Any]], None], + ) -> dict[str, Any]: """Adds the `distilabel_metadata` to the input. This method comes for free in the general Tasks, but as we have reimplemented the `process`, diff --git a/src/distilabel/steps/tasks/math_shepherd/utils.py b/src/distilabel/steps/tasks/math_shepherd/utils.py index bed56e1ff9..978496996f 100644 --- a/src/distilabel/steps/tasks/math_shepherd/utils.py +++ b/src/distilabel/steps/tasks/math_shepherd/utils.py @@ -53,12 +53,12 @@ class FormatPRM(Step): correct steps. Attributes: - format (Literal["math-shepherd", "trl"]): The format to use for the PRM model. + format: The format to use for the PRM model. "math-shepherd" corresponds to the original paper, while "trl" is a format prepared to train the model using TRL. - step_token (str): String that serves as a unique token denoting the position + step_token: String that serves as a unique token denoting the position for predicting the step score. - tags (list[str]): List of tags that represent the correct and incorrect steps. + tags: List of tags that represent the correct and incorrect steps. This only needs to be informed if it's different than the default in `MathShepherdCompleter`. @@ -110,10 +110,6 @@ class FormatPRM(Step): ) ) result = next(formatter.process(result)) - # result[0]["input"] - # "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. ки\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. ки\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 ки" - # result[0]["label"] - # "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Step 1: Determine the amount of blue fiber needed: 2 bolts of blue fiber are required. +\nStep 2: Calculate the amount of white fiber needed: Since it's half that much, we can divide 2 by 2: 2 / 2 = <<2/2=1>>1 bolt of white fiber. +\nStep 3: Add the amount of blue and white fiber: 2 (blue) + 1 (white) = <<2+1=3>>3 bolts of fiber in total. The answer is: 3 +" ``` Prepare your data to train a PRM model with the TRL format: diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index daabe5525b..b6620430cc 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from jinja2 import Template @@ -21,6 +20,7 @@ from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.base import Task from distilabel.utils.chat import is_openai_format +from distilabel.utils.template import check_column_in_template if TYPE_CHECKING: from distilabel.steps.tasks.typing import ChatType @@ -218,23 +218,6 @@ def model_post_init(self, __context: Any) -> None: def load(self) -> None: super().load() - def check_column_in_template(column, template): - pattern = ( - r"(?:{%.*?\b" - + re.escape(column) - + r"\b.*?%}|{{\s*" - + re.escape(column) - + r"\s*}})" - ) - if not re.search(pattern, template): - raise DistilabelUserError( - ( - f"You required column name '{column}', but is not present in the template, " - "ensure the 'columns' match with the 'template' to avoid errors." - ), - page="components-gallery/tasks/textgeneration/", - ) - for column in self.columns: check_column_in_template(column, self.template) diff --git a/src/distilabel/steps/tasks/text_generation_with_image.py b/src/distilabel/steps/tasks/text_generation_with_image.py new file mode 100644 index 0000000000..8494afc9db --- /dev/null +++ b/src/distilabel/steps/tasks/text_generation_with_image.py @@ -0,0 +1,215 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Literal, Union + +from jinja2 import Template +from PIL import Image +from pydantic import Field + +from distilabel.steps.tasks.base import Task +from distilabel.steps.tasks.text_generation import ( + TextGeneration, + check_column_in_template, +) +from distilabel.utils.image import image_to_str + +if TYPE_CHECKING: + from PIL.Image import Image + + from distilabel.steps.tasks.typing import ChatType + from distilabel.steps.typing import StepColumns + + +class TextGenerationWithImage(TextGeneration): + """Text generation with images with an `LLM` given a prompt. + + `TextGenerationWithImage` is a pre-defined task that allows passing a custom prompt using the + Jinja2 syntax. By default, a `instruction` is expected in the inputs, but the using + `template` and `columns` attributes one can define a custom prompt and columns expected + from the text. Additionally, an `image` column is expected containing one of the + url, base64 encoded image or PIL image. This task inherits from `TextGeneration`, + so all the functionality available in that task related to the prompt will be available + here too. + + Attributes: + system_prompt: The system prompt to use in the generation. + If not, then no system prompt will be used. Defaults to `None`. + template: The template to use for the generation. It must follow the Jinja2 template + syntax. If not provided, it will assume the text passed is an instruction and + construct the appropriate template. + columns: A string with the column, or a list with columns expected in the template. + Take a look at the examples for more information. Defaults to `instruction`. + image_type: The type of the image provided, this will be used to preprocess if necessary. + Must be one of "url", "base64" or "PIL". + + Input columns: + - dynamic (determined by `columns` attribute): By default will be set to `instruction`. + The columns can point both to a `str` or a `list[str]` to be used in the template. + - image: The column containing the image URL, base64 encoded image or PIL image. + + Output columns: + - generation (`str`): The generated text. + - model_name (`str`): The name of the model used to generate the text. + + Categories: + - text-generation + + References: + - [Jinja2 Template Designer Documentation](https://jinja.palletsprojects.com/en/3.1.x/templates/) + - [Image-Text-to-Text](https://huggingface.co/tasks/image-text-to-text) + - [OpenAI Vision](https://platform.openai.com/docs/guides/vision) + + Examples: + Answer questions from an image: + + ```python + from distilabel.steps.tasks import TextGenerationWithImage + from distilabel.models.llms import InferenceEndpointsLLM + + vision = TextGenerationWithImage( + name="vision_gen", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Llama-3.2-11B-Vision-Instruct", + ), + image_type="url" + ) + + vision.load() + + result = next( + vision.process( + [ + { + "instruction": "What’s in this image?", + "image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + ] + ) + ) + # result + # [ + # { + # "instruction": "What\u2019s in this image?", + # "image": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + # "generation": "Based on the visual cues in the image...", + # "model_name": "meta-llama/Llama-3.2-11B-Vision-Instruct" + # ... # distilabel_metadata would be here + # } + # ] + # result[0]["generation"] + # "Based on the visual cues in the image, here are some possible story points:\n\n* The image features a wooden boardwalk leading through a lush grass field, possibly in a park or nature reserve.\n\nAnalysis and Ideas:\n* The abundance of green grass and trees suggests a healthy ecosystem or habitat.\n* The presence of wildlife, such as birds or deer, is possible based on the surroundings.\n* A footbridge or a pathway might be a common feature in this area, providing access to nearby attractions or points of interest.\n\nAdditional Questions to Ask:\n* Why is a footbridge present in this area?\n* What kind of wildlife inhabits this region" + ``` + + Answer questions from an image stored as base64: + + ```python + # For this example we will assume that we have the string representation of the image + # stored, but will just take the image and transform it to base64 to ilustrate the example. + import requests + import base64 + + image_url ="https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + img = requests.get(image_url).content + base64_image = base64.b64encode(img).decode("utf-8") + + from distilabel.steps.tasks import TextGenerationWithImage + from distilabel.models.llms import InferenceEndpointsLLM + + vision = TextGenerationWithImage( + name="vision_gen", + llm=InferenceEndpointsLLM( + model_id="meta-llama/Llama-3.2-11B-Vision-Instruct", + ), + image_type="base64" + ) + + vision.load() + + result = next( + vision.process( + [ + { + "instruction": "What’s in this image?", + "image": base64_image + } + ] + ) + ) + ``` + """ + + image_type: Literal["url", "base64", "PIL"] = Field( + default="url", + description="The type of the image provided, this will be used to preprocess if necessary.", + ) + + @property + def inputs(self) -> "StepColumns": + columns = super().inputs + columns["image"] = True + return columns + + def load(self) -> None: + Task.load(self) + + for column in self.columns: + check_column_in_template( + column, self.template, page="components-gallery/tasks/visiongeneration/" + ) + + self._template = Template(self.template) + + def _transform_image(self, image: Union[str, "Image"]) -> str: + """Transforms the image based on the `image_type` attribute.""" + if self.image_type == "url": + return image + + if self.image_type == "base64": + return f"data:image/jpeg;base64,{image}" + + # Othwerwise, it's a PIL image + return f"data:image/jpeg;base64,{image_to_str(image)}" + + def _prepare_message_content(self, input: dict[str, Any]) -> "ChatType": + """Prepares the content for the template and returns the formatted messages.""" + fields = {column: input[column] for column in self.columns} + img_url = self._transform_image(input["image"]) + return [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": self._template.render(**fields), + }, + { + "type": "image_url", + "image_url": { + "url": img_url, + }, + }, + ], + } + ] + + def format_input(self, input: dict[str, Any]) -> "ChatType": + """The input is formatted as a `ChatType` assuming that the instruction + is the first interaction from the user within a conversation.""" + messages = self._prepare_message_content(input) + + if self.system_prompt: + messages.insert(0, {"role": "system", "content": self.system_prompt}) + + return messages # type: ignore diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py index 920a94c3b9..d0d22a6811 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/steps/tasks/typing.py @@ -15,12 +15,32 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union from pydantic import BaseModel -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict + + +class TextContent(TypedDict, total=False): + type: Required[Literal["text"]] + text: Required[str] + + +class ImageUrl(TypedDict): + url: Required[str] + """Either a URL of the image or the base64 encoded image data.""" + + +class ImageContent(TypedDict, total=False): + """Type alias for the user's message in a conversation that can include text or an image. + It's the standard type for vision language models: + https://platform.openai.com/docs/guides/vision + """ + + type: Required[Literal["image_url"]] + image_url: Required[ImageUrl] class ChatItem(TypedDict): - role: str - content: str + role: Literal["system", "user", "assistant"] + content: Union[str, list[Union[TextContent, ImageContent]]] ChatType = List[ChatItem] @@ -69,5 +89,6 @@ class InstructorStructuredOutputType(TypedDict, total=False): """StandardInput is an alias for ChatType that defines the default / standard input produced by `format_input`.""" StructuredInput = Tuple[StandardInput, Union[StructuredOutputType, None]] """StructuredInput defines a type produced by `format_input` when using either `StructuredGeneration` or a subclass of it.""" -FormattedInput = Union[StandardInput, StructuredInput] -"""FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated by `format_input` and expected by the `LLM`s.""" +FormattedInput = Union[StandardInput, StructuredInput, ChatType] +"""FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated +by `format_input` and expected by the `LLM`s, as well as `ConversationType` for the vision language models.""" diff --git a/src/distilabel/steps/typing.py b/src/distilabel/steps/typing.py index 9a3e5bb586..720037a74f 100644 --- a/src/distilabel/steps/typing.py +++ b/src/distilabel/steps/typing.py @@ -15,7 +15,7 @@ from typing import Any, Dict, Iterator, List, Tuple, Union StepOutput = Iterator[List[Dict[str, Any]]] - +"""`StepOutput` is an alias of the typing `Iterator[List[Dict[str, Any]]]`""" GeneratorStepOutput = Iterator[Tuple[List[Dict[str, Any]], bool]] """`GeneratorStepOutput` is an alias of the typing `Iterator[Tuple[List[Dict[str, Any]], bool]]`""" diff --git a/src/distilabel/utils/image.py b/src/distilabel/utils/image.py new file mode 100644 index 0000000000..aa9d09089c --- /dev/null +++ b/src/distilabel/utils/image.py @@ -0,0 +1,26 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io + +from PIL import Image + + +# TODO: Once we merge the image generation, this function can be reused +def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str: + """Converts a PIL Image to a base64 encoded string.""" + buffered = io.BytesIO() + image.save(buffered, format=image_format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index 08877d3cb7..005f74748e 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -115,6 +115,7 @@ "labelling": "Labelling steps are used to label the data.", } + assert list(_STEP_CATEGORY_TO_DESCRIPTION.keys()) == list( _STEPS_CATEGORY_TO_ICON.keys() ) diff --git a/src/distilabel/utils/template.py b/src/distilabel/utils/template.py new file mode 100644 index 0000000000..df825f852c --- /dev/null +++ b/src/distilabel/utils/template.py @@ -0,0 +1,47 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from distilabel.errors import DistilabelUserError + + +def check_column_in_template( + column: str, template: str, page: str = "components-gallery/tasks/textgeneration/" +) -> None: + """Checks if a column is present in the template, and raises an error if it isn't. + + Args: + column: The column name to check in the template. + template: The template of the Task to be checked, the input from the user. + page: The page to redirect the user for help . Defaults to "components-gallery/tasks/textgeneration/". + + Raises: + DistilabelUserError: Custom error if the column is not present in the template. + """ + pattern = ( + r"(?:{%.*?\b" + + re.escape(column) + + r"\b.*?%}|{{\s*" + + re.escape(column) + + r"\s*}})" + ) + if not re.search(pattern, template): + raise DistilabelUserError( + ( + f"You required column name '{column}', but is not present in the template, " + "ensure the 'columns' match with the 'template' to avoid errors." + ), + page=page, + ) diff --git a/tests/unit/models/llms/test_vertexai.py b/tests/unit/models/llms/test_vertexai.py index 529fbf332a..a40f8df33f 100644 --- a/tests/unit/models/llms/test_vertexai.py +++ b/tests/unit/models/llms/test_vertexai.py @@ -46,19 +46,6 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: ) llm._aclient.generate_content_async = AsyncMock(return_value=mocked_completion) - with pytest.raises( - ValueError, match="`VertexAILLM only supports the roles 'user' or 'model'." - ): - await llm.agenerate( - input=[ - {"role": "system", "content": ""}, - { - "role": "test", - "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - }, - ] - ) - result = await llm.agenerate( input=[ {"role": "model", "content": ""}, @@ -89,21 +76,6 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: nest_asyncio.apply() - with pytest.raises( - ValueError, match="`VertexAILLM only supports the roles 'user' or 'model'." - ): - llm.generate( - inputs=[ - [ - {"role": "system", "content": ""}, - { - "role": "test", - "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - }, - ], - ] - ) - result = llm.generate( inputs=[ [ diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py index f0cb377bbe..1c99f2f5c4 100644 --- a/tests/unit/steps/argilla/test_preference.py +++ b/tests/unit/steps/argilla/test_preference.py @@ -92,7 +92,14 @@ def test_process(self, mock_dataset) -> None: step._dataset.records.log = lambda x: x # type: ignore assert list( - step.process([{"instruction": "test", "generations": ["test", "test"]}]) + step.process( + [ + { + "instruction": "test", + "generations": ["test", "test"], + } + ] + ) ) == [[{"instruction": "test", "generations": ["test", "test"]}]] assert step._dataset.records # type: ignore diff --git a/tests/unit/steps/tasks/test_text_generation_with_image.py b/tests/unit/steps/tasks/test_text_generation_with_image.py new file mode 100644 index 0000000000..7ce0b7742f --- /dev/null +++ b/tests/unit/steps/tasks/test_text_generation_with_image.py @@ -0,0 +1,103 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import numpy as np +import pytest +from PIL import Image + +from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage +from tests.unit.conftest import DummyAsyncLLM + +img_str = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABkAGQDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDw9whjujGGK7EOS3fv2HfJxz0/ixuDrgqv2jciofJjUKiZG7A7jAxgE55z1+b74jkfzBcMWZfkVRsQYbHZsdM4JzzkjJz94OMg23hIALxIACevKnPBGemed3rz98EU1Z+n/toSVtwupVZ7krEQsipyeMcA/rjPJPqdx+anTiZVuMNhfJi38bdwIBHpnse+cbvmxupJ3mfz2YhGaKMsB8u5cA9Mc9j7/e5+9SzFSt0QikGNCGckEZ5yPc+nPBz82N4UI2S+X/to7p6jZB5guGwqkRIdu7bxgdBgbucHuep55YOdVjS9VlCsYkOHbnJIOVPGQevfg5wcbwXEnNyvmAkxRqSp4bgE5wBnnnvkjPzffBJuj+2fMwV4EHQrnJVgCMjPTP8AFnrz98NO6VvL/wBsJd0guFmVrkSGNXMUZI4XKkAjA/i/hOec/e5+8ImQQpOrFWLImDg55w2ePYd8g57/AHg0fvBc7AmwIDk4U4BGMDPJ9ue57bhPdSNFJOiKcSQxAnGM/KrZ4AzkjPcd8scPRH7Kt2/9tDrYZcghrk4VwVX5mzkEnOQc8/rnJPON1LO/k/aEZXBkjRQTxkcNk465wD3Hfk4YJNcEtdBGwHVVbDY3Ac8468gHqeRnk/NS3BZmuHkVlLQpgMNpOcEHqOo57k5zz96iG135f+2lT313FddqXXlFoovLTcrH72ecc9s8gc9AecbhGw2LchDLGrRoGCtuDngkE8cZBYdfujr96pJyE+1hGbY6ISS2ck84JPqecc9P4sbgXAAM5VQo8tBwSwyQCRnj39emfm+/RFp2v5f+2hJakWprtvTwfmVW5HJyAc/jnPfPq33iUmpGM3f7oKEEaYCjA+6PYf1+rfeJQvhXovyFr1HSqI3mV42jYxhlXHY4Pr0IOQefx+9Trpjvm+980UYJVQA3yg88DrjOeckZ+b71E5K+cjRlWaNMBlwcYznj1GD75zz96iSIJHcAExnyo229mzg45wSOc8Z6DqPmD/lfp/7aLrqx7xLEt4AQFEaMu3ockEDk579t3TPI+cMnLYnADIAiBjlQG/Lrn73Gc4zz96lmMkbXQlRgXRcZXkg8g9ehHPfPB5+8JJpDKL0kBT5UY5KksQQCQRjOeT/ET1O4guFFtJddv/bP6/4cp7tlZyCbk9cjjAyMk5xnPpn16d/vCaYQr9pGN37mMRsq9+Cc4xg4B5+b/gX3ws6uFuAsiriGLftYKGGBx0G7nB4znG75vv0XOGa4fzMbo4yFVcbs4POcfU9ckZ+b79EW218v/bRO0nd7iTOyPdqJAQ8S5IGNwyDg88+vfJGefv0l1E/mXG/ch2I5BGd2Rnr6EHPfPB5HzUt15ckkxMQVvJjKg8Y+UcgYGc/jwSfm+/THLSJcuVVcovYjvkd/T6568/eDgtE/T/20E73aZNKFCXuPLKmKMAoNoHIwByMn1+9nBPzffEM2VWdVLKdqbg7glvUg45BOG4Pp97G4SSOVF2GwzPEgyhO0ZIYjtnp1OQcZ5++GGQf6YTnEiDBOSSSwPPP167v/AGYKC27af+2jva7X9LXoPv40SSUNlSsUW0CIfMSo74GARk5GcnHLffpJPMk+1tIqqxjVum3IyMdTk5BB756nP3gtzJGrXScx7o4wqgdeh7Y4PXvnj733w102R3IYKxMMbDdlWGQGyMgZ689c5zzjeFCXw38v/bRN293+v61ItRwbrIXb8i9gM8Dn8evvnq33iVHdtun6AYUDAxjge3+T6nqSn0XovyC1ieUxgzqkLhWRdu49OhyPr178ev3qU7hHcfvEBEKIVjOAw44wMA8gHvkjPP3gtwrJ9o8xOqpgsuDzyD+I56nOc8/eEcsiuZmlTLmNVUgZweOeMdgeTnPuTuFQtZfL/wBtCUetgl8orOYgEXahCk5Oe+D6Z7c9vvY3VJcqm6cLJjbFHjhRu4A9vrxnnn5vv0+7jiWW4DZV/JjaMYPOQCeuOxzn5v8AgWd9RvJs+1AzmTzEAyu7nJDYPPbHOcgkcZ4YTDo15f8AtoPVXW6/IddkLNO2XHmQocKOCSFODnHuc4OcdW+/TDII1ulVsCWFAR8wzyre2enfP44DB8zf8fO503NEnCdDyDj3x685Izz98I4DLdvGoCKijBI457c8+uOT1PONwIpWSfl/7aLlbGkGGO5T513RrkjO05IbB9u46jjv94OuJHL3DvECZI0BIUgDIBz2zwOpznk8n5qW4WWRrmQblXy037zgsDgg++SN2OT35wWpSSsd4QkiGSFAd7HnJDe2c4yM545wcbwR6S9P/bRsjuVkBkEiEErGRiMLkbflJwO45z368/eoeWKQXDPFtcxIqYXhSMemOoB5Oe+ck7wk5Iln3xuHaNcbhjIIBz75HOefXn71EiCMzq2Y90alVC43A4Izz0xg988dfvBws0reX/tvYTa+4SVFiMyyqDKUTZgcDIBz27d+c9ec7hPO7RC5HQyQxA4yAQQrdMDPQHnOevzffEckZ2XAE0bBUTJTjd7e5B64zkjPI+YNmj8nzkEuRsXJTo2ecH+fGRxkZHzUoxvbXt/7b9w7EF0rLOQxJOAcnvkZz+v/ANc9aKffBVnXZ90xocemVBPYf57t94lGtlfsvyC99SxIUl+2Nt4WNACVUEsMDPBHUZPG4nqc8uC4VnFw8igNsQrmPaSD0P4rz3z15+8FkQbbvzV2usUZH3eTx9M5BzxnPXn74Jnmf7W7ps3xoW+XZkHBX3ORg9843HP3hNO1l8v/AG0aa6fd9/4ELSMEuQCRvRc5G0kZBHGec8Hv68/eDn3wi6KHfHJGoZiWX7xDDr1PHQ56ZGcBqddkrJOWiYEoi5kPOSAdwIwDuxkZzwc8n5qUMXhvSZAT5a5OfvHcCe4z69+mcHG8ONnZry/9tB/3thbgSMblxLuxFGJGBChgccYwNxyAe+SCfm5an3XzLdMgXBiiLEnBPAPoMknnHPr82N4jcu8dyVYQr5KExqMbxwQOcEjv3JIB5wWEc6+Z58iMGUBGYkgnJHOCR6knHJ7/ADY3URitL+X/ALaEbD3XfHcsFgZRFHkj5dpwOnAyeCCOc8nnG8SOyyR3zFSpMaYBI9R05Gc9f4j3wfvhk4ljW4wzorQxeYrHBfIDDsMgnDY5zwfmxuolCzfa5FbywiICqsMMeMjPfkZ7njPPLgglovT/ANtEr8um3/DiHe6Xsmcfu1Dcj5vmHvz0z3PGcHG4LLIifahCWMbxKhGWOTwx6YGMqeDn8cb6hYvtnwDgqFJDcYznHHXJGe/rz1Fi4heL7UqoI08qMlSexwRjpkHqBzkc/NjeHHRr5f8Ato2rt3RFOhLT+ZF5TiNHClgMggcjuc5B4zkc8/eC+ZF5N0Akg3RKoJbcNwIJ5BHXBI6/Qn5wtxIy/aSCCskaKdoKDBwwGO54HXOeTz96mu8aJPsLfPEinDZGeCQencZ79O/3gR2Sfl/7aS09mRXylbgZUqTGhORjOVBz0HXrnvnqepKbeYFwQIzGAB8pIPbqD6HqOvXqepKFsvRfkNK2jJ59xM7AkAxoOm3cMA8gYz0B7+vP3qdOjkzGRgHEEbjK7SwIHY4ycHORnPXn71SXkSiS4LblxDFs+XAOVB54HXk55z1y2d9Muv8AWXB3lB5SDCLgNwCAfyznnJGct96lTa0a8v8A20Vno0EzjfeFVkTeiqfmyG5BOeeQcbh16d/vBJSMTmf7xiQoDEQSTgg+3y5Oec5zz98LKix/ahHuAESLkEbT0yO2c4yOvTPP3wyRpnS5Z5OSqq2xR8+PUjjtnvkgHn7wdPZW8v8A23+mU022xHIk89mIjxEoRUUAEccY47DPcnqc5LCSVN4uS8TRlYUYByM545B4yCCWHXjnnlxG7F47hn2SMQvzkYOfUe/r1zknB+8HXChXmSUMsgiQrkg54HPQcHOcjOffO4OO6Xp/7b+AmreQyVWQzKyr/q1IyoU44wR+H1znPP3qklkj3XSgAb4xxncdwIJII7dfXt1++Gyq7NOcGMCFTjaE3LxtyO+Rhu5OM88tT5MTx3MnlgERxk7mGc9yDxnPXHJwcnOC4ILZvy/9tEno1f7iM7IFuYzuO6JVDZOM5DdiM5x7j68MFaI+XctISHCq43Dlt3156NnjOcZwR8wGuiY7hUVB5kaodvyAKCOw6nheue5OT8wdNNHIbpiisXRNrHsRjJ4xyffPcnJ+cKPMmvl/7aNe7ewsgaL7ZkH95EuSSe7K3qM9M/xevP3wSSlVuwn3ZI0XhSvHDe3pnnOcZ5OGBcwFWuMHGI42fLZyxAJwSBkZ57+vzAb6JYoVjuticCOMpkngnBPp78c8f3vviY2aT9P/AG0N3fuV74g3TEDAIB785Gf89fqepKZdFjMN6hTtXAC44xx+nfv1yc5JVdF6L8gvfUtMUiW8WN1KsiqAhbGCQxHvgj3HGRn7wbMXj+0Isi7SiK21Qu8cEA+vY98kZ5+9T5lIa7KloV8lAVBHzn5ep4yDjcOp4B55emyuyfagNzCWNdxyW5JDHnI44J5yPrgNUxTaXfT/ANtDvpqOnhRGuYyCNsaMmV5JODnORgEEnjdn3++ImfCTKcfMibcrg4xnsP8A9fXn7wmbYsd55bAhok7EdSGx29Pf15xvC3K83J3YYwxsRnGQQDjkDPOD39fm++Kg3dX8v/bQvqRkmNbxUKlWjUMVfjqDjnG7ntz0zzjcCUtH542OokjTrxkY3Z6d8A859efvBd8ckV2zMGby12HHJOefx656/Q/fV1wgie4XlB5EYUEY3AhTnAwOevf1+b79ELJq/l/7aJ6PQSZuLqR0kRnjQDd3zg5PTrjcM5P1+8HTRqgu8jIEUeM+pIPByPc/xZ68/fEMyhDNhtxZFJJ3fxDceo5/H8M/eqbywkF6EkkVfKjJHA8zJBwc44/iwM/dHUDeEla1n2/9tKdnqNuUSJ7hQxBMaFFUcMCAec9u+eeg+998RSW7qs7OHBUIx3HltwznJHOev055HzCQEvHeuspQNGpYZyZDuHBJI4745PAODgsGjYYbx4htXaoO5iOCc/jyBxk/jjcCN1a77f8Ato1u7f1uFwFd7iRF3DC/MT0J6/U9fXv1+9Sygj7Qdu3EaBsEYPT884z36Z5+8GuBG10sqksYwIzs6HIIPBxyuTn5s5/4EJphJGbxRKCjQpkjjIJVgOoz6/xZIzzjeHDpby/9tFJ6u6Kt+E+1EoSVZVbJzkkgE5z7/X6t94lO1IMLw7sZKIeFwMFQfx69ec9ctncSkvhXovyEWLlFSGViNzFIBlh03Rlyfz4/HJyearGdtkxCgb1VMAkAD73rz0HXPr15ooqruz+X/tgb0035fqKHzZzuVXJ8uPgYwME547/KP59eaex+0RzzygGT5FBAxj5Sc8dT8vU9cknJ5oooiv3n3f8AtpSXu/15iXyLBOUQYV4o5MHnBZAxAPpkn9Op5p8qho5myRlY+B05Qvj8wP65PNFFFLVxv5f+2lLr/XRi3LmBrgLyJ4oi2WPG5Q5788jvn16gEJeILe5eNCxWW3jc5Y8FkWQ/UZ9c/nzRRWNFtyin/XwmM3rL1H3Ci3inCE4kjhzkn+JPMP6jofr1ANMv/luinUPBE5OBnJjDfzP49Tk80UVvT+Nei/KA2yO7fbKQFX5oY+gxj5VPb+vXqcnmpLqT7O8saKu2aCInPUZVX4x7+ufU5IBooqdvuX/tpD0Wncr3pzc7j1ZEY/UqD/X6+uTRRRSWy9Eay3Z//9k=" +np.random.seed(42) +img_pil = Image.fromarray(np.random.randint(0, 255, (100, 100, 3)), "RGB") + + +class TestTextGenerationWithImage: + def test_format_input(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithImage(llm=llm, image_type="url") + task.load() + + assert task.format_input({"instruction": "test", "image": "123kjh123"}) == [ + { + "role": "user", + "content": [ + {"text": "test", "type": "text"}, + {"type": "image_url", "image_url": {"url": "123kjh123"}}, + ], + } + ] + + def test_format_input_with_system_prompt(self) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithImage(llm=llm, system_prompt="test", image_type="url") + task.load() + + assert task.format_input({"instruction": "test", "image": "123kjh123"}) == [ + {"role": "system", "content": "test"}, + { + "role": "user", + "content": [ + {"text": "test", "type": "text"}, + {"type": "image_url", "image_url": {"url": "123kjh123"}}, + ], + }, + ] + + @pytest.mark.parametrize( + "image_type, image, expected", + [ + ("url", "123kjh123", "123kjh123"), + ("base64", img_str, f"data:image/jpeg;base64,{img_str}"), + ("PIL", img_pil, f"data:image/jpeg;base64,{img_str}"), + ], + ) + def test_process( + self, image_type: str, image: Union[str, "Image.Image"], expected: str + ) -> None: + llm = DummyAsyncLLM() + task = TextGenerationWithImage(llm=llm, image_type=image_type) + task.load() + result = next(task.process([{"instruction": "test", "image": image}])) + + assert result == [ + { + "instruction": "test", + "image": image, + "generation": "output", + "distilabel_metadata": { + "raw_output_text_generation_with_image_0": "output", + "raw_input_text_generation_with_image_0": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "test"}, + { + "type": "image_url", + "image_url": {"url": expected}, + }, + ], + } + ], + "statistics_text_generation_with_image_0": { + "input_tokens": 12, + "output_tokens": 12, + }, + }, + "model_name": "test", + } + ] From a8588fd404e6883d570019dde78e46ad28d1ef48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 12 Dec 2024 16:31:19 +0100 Subject: [PATCH 07/30] Create columns with `LLM` returned extra keys (#1078) --- .../advanced/structured_generation.md | 2 +- src/distilabel/distiset.py | 13 +- src/distilabel/models/llms/base.py | 36 +-- .../llms/huggingface/inference_endpoints.py | 119 +++++++-- src/distilabel/models/llms/openai.py | 70 +++++- src/distilabel/models/llms/typing.py | 18 ++ src/distilabel/models/llms/utils.py | 10 +- src/distilabel/models/llms/vllm.py | 232 +++++++----------- src/distilabel/steps/tasks/base.py | 57 +++-- .../tasks/structured_outputs/outlines.py | 12 +- .../steps/tasks/structured_outputs/utils.py | 14 +- .../huggingface/test_inference_endpoints.py | 169 ++++++++++++- tests/unit/models/llms/test_openai.py | 48 +++- tests/unit/models/llms/test_vllm.py | 67 +++-- 14 files changed, 606 insertions(+), 261 deletions(-) diff --git a/docs/sections/how_to_guides/advanced/structured_generation.md b/docs/sections/how_to_guides/advanced/structured_generation.md index 6d6ed034eb..3eb1da99af 100644 --- a/docs/sections/how_to_guides/advanced/structured_generation.md +++ b/docs/sections/how_to_guides/advanced/structured_generation.md @@ -21,7 +21,7 @@ The [`LLM`][distilabel.models.llms.LLM] has an argument named `structured_output We will start with a JSON example, where we initially define a `pydantic.BaseModel` schema to guide the generation of the structured output. !!! NOTE - Take a look at [`StructuredOutputType`][distilabel.steps.tasks.structured_outputs.outlines.StructuredOutputType] to see the expected format + Take a look at [`StructuredOutputType`][distilabel.steps.tasks.typing.StructuredOutputType] to see the expected format of the `structured_output` dict variable. ```python diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index 8e52c667d3..e934f9d340 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -184,9 +184,20 @@ def _get_card( """ sample_records = {} for name, dataset in self.items(): - sample_records[name] = ( + record = ( dataset[0] if not isinstance(dataset, dict) else dataset["train"][0] ) + for key, value in record.items(): + # If list is too big, the `README.md` generated will be huge so we truncate it + if isinstance(value, list): + length = len(value) + if length < 10: + continue + record[key] = value[:10] + record[key].append( + f"... (truncated - showing 10 of {length} elements)" + ) + sample_records[name] = record readme_metadata = {} if repo_id and token: diff --git a/src/distilabel/models/llms/base.py b/src/distilabel/models/llms/base.py index 4657360afb..785668cbee 100644 --- a/src/distilabel/models/llms/base.py +++ b/src/distilabel/models/llms/base.py @@ -334,7 +334,7 @@ def get_last_hidden_states( ) def _prepare_structured_output( - self, structured_output: Optional["StructuredOutputType"] = None + self, structured_output: "StructuredOutputType" ) -> Union[Any, None]: """Method in charge of preparing the structured output generator. @@ -431,7 +431,7 @@ def event_loop(self) -> "asyncio.AbstractEventLoop": @abstractmethod async def agenerate( self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any - ) -> List[Union[str, None]]: + ) -> "GenerateOutput": """Method to generate a `num_generations` responses for a given input asynchronously, and executed concurrently in `generate` method. """ @@ -591,8 +591,8 @@ def _prepare_kwargs( def merge_responses( - responses: List[Dict[str, Any]], n: int = 1 -) -> List[Dict[str, Any]]: + responses: List["GenerateOutput"], n: int = 1 +) -> List["GenerateOutput"]: """Helper function to group the responses from `LLM.agenerate` method according to the number of generations requested. @@ -612,19 +612,27 @@ def chunks(lst, n): for i in range(0, len(lst), n): yield list(islice(lst, i, i + n)) - # Split responses into groups of size n - grouped_responses = list(chunks(responses, n)) + extra_keys = [ + key for key in responses[0].keys() if key not in ("generations", "statistics") + ] result = [] - for group in grouped_responses: - first = group[0] + for group in chunks(responses, n): merged = { - "generations": sum((r["generations"] for r in group), []), - "statistics": { - key: sum((r["statistics"][key] for r in group), []) - for key in first["statistics"] - }, + "generations": [], + "statistics": {"input_tokens": [], "output_tokens": []}, } + for response in group: + merged["generations"].append(response["generations"][0]) + # Merge statistics + for key in response["statistics"]: + if key not in merged["statistics"]: + merged["statistics"][key] = [] + merged["statistics"][key].append(response["statistics"][key][0]) + # Merge extra keys returned by the `LLM` + for extra_key in extra_keys: + if extra_key not in merged: + merged[extra_key] = [] + merged[extra_key].append(response[extra_key][0]) result.append(merged) - return result diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index c60199452b..d4e53f1ed2 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -16,10 +16,20 @@ import random import sys import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) from pydantic import ( Field, + PositiveInt, PrivateAttr, SecretStr, ValidationError, @@ -31,7 +41,7 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.typing import GenerateOutput, Logprob from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import ( @@ -45,12 +55,15 @@ from huggingface_hub import AsyncInferenceClient from huggingface_hub.inference._generated.types.chat_completion import ( ChatCompletionOutput, + ChatCompletionOutputComplete, ) from huggingface_hub.inference._generated.types.text_generation import ( TextGenerationOutput, ) from transformers import PreTrainedTokenizer + from distilabel.models.llms.typing import Logprob + class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): """InferenceEndpoints LLM implementation running the async API client. @@ -338,15 +351,15 @@ def prepare_input(self, input: "StandardInput") -> str: def _get_structured_output( self, input: FormattedInput - ) -> Union[Dict[str, Any], None]: + ) -> Tuple["StandardInput", Union[Dict[str, Any], None]]: """Gets the structured output (if any) for the given input. Args: input: a single input in chat format to generate responses for. Returns: - The structured output that will be passed as `grammer` to the inference endpoint - or `None` if not required. + The input and the structured output that will be passed as `grammar` to the + inference endpoint or `None` if not required. """ structured_output = None @@ -377,7 +390,7 @@ def _get_structured_output( "value" ].model_json_schema() - return structured_output + return input, structured_output async def _generate_with_text_generation( self, @@ -387,26 +400,28 @@ async def _generate_with_text_generation( frequency_penalty: Optional[float] = None, temperature: float = 1.0, do_sample: bool = False, - top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, top_p: Optional[float] = None, + top_k: Optional[int] = None, typical_p: Optional[float] = None, stop_sequences: Union[List[str], None] = None, return_full_text: bool = False, seed: Optional[int] = None, watermark: bool = False, ) -> GenerateOutput: - structured_output = self._get_structured_output(input) - - completion = None + input, structured_output = self._get_structured_output(input) + prompt = self.prepare_input(input) + generation: Union["TextGenerationOutput", None] = None try: - completion: "TextGenerationOutput" = await self._aclient.text_generation( # type: ignore - prompt=self.prepare_input(input), # type: ignore + generation = await self._aclient.text_generation( # type: ignore + prompt=prompt, max_new_tokens=max_new_tokens, do_sample=do_sample, typical_p=typical_p, repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, temperature=temperature, + top_n_tokens=top_n_tokens, top_p=top_p, top_k=top_k, stop_sequences=stop_sequences, @@ -423,25 +438,42 @@ async def _generate_with_text_generation( f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return prepare_output( - [completion.generated_text], - input_tokens=[ - compute_tokens(self.prepare_input(input), self._tokenizer.encode) - if self._tokenizer - else 0 - ], + generations=[generation.generated_text] if generation else [None], + input_tokens=[compute_tokens(prompt, self._tokenizer.encode)], # type: ignore output_tokens=[ - completion.details.generated_tokens if completion.details else 0 + generation.details.generated_tokens + if generation and generation.details + else 0 ], + logprobs=self._get_logprobs_from_text_generation(generation) + if generation + else None, # type: ignore ) + def _get_logprobs_from_text_generation( + self, generation: "TextGenerationOutput" + ) -> Union[List[List[List["Logprob"]]], None]: + if generation.details is None or generation.details.top_tokens is None: + return None + + return [ + [ + [ + {"token": top_logprob["text"], "logprob": top_logprob["logprob"]} + for top_logprob in token_logprobs + ] + for token_logprobs in generation.details.top_tokens + ] + ] + async def _generate_with_chat_completion( self, input: "StandardInput", max_new_tokens: int = 128, frequency_penalty: Optional[float] = None, logit_bias: Optional[List[float]] = None, + logprobs: bool = False, presence_penalty: Optional[float] = None, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, @@ -449,15 +481,19 @@ async def _generate_with_chat_completion( tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[PositiveInt] = None, top_p: Optional[float] = None, ) -> GenerateOutput: message = None + completion: Union["ChatCompletionOutput", None] = None + output_logprobs = None try: - completion: "ChatCompletionOutput" = await self._aclient.chat_completion( # type: ignore + completion = await self._aclient.chat_completion( # type: ignore messages=input, # type: ignore max_tokens=max_new_tokens, frequency_penalty=frequency_penalty, logit_bias=logit_bias, + logprobs=logprobs, presence_penalty=presence_penalty, # NOTE: here to ensure that the cache is not used and a different response is # generated every time @@ -467,25 +503,43 @@ async def _generate_with_chat_completion( tool_choice=tool_choice, # type: ignore tool_prompt=tool_prompt, tools=tools, # type: ignore + top_logprobs=top_logprobs, top_p=top_p, ) - choice = completion.choices[0] + choice = completion.choices[0] # type: ignore if (message := choice.message.content) is None: self._logger.warning( # type: ignore f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {choice.finish_reason}" ) + if choice_logprobs := self._get_logprobs_from_choice(choice): + output_logprobs = [choice_logprobs] except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) return prepare_output( - [message], - input_tokens=[completion.usage.prompt_tokens], - output_tokens=[completion.usage.completion_tokens], + generations=[message], + input_tokens=[completion.usage.prompt_tokens] if completion else None, + output_tokens=[completion.usage.completion_tokens] if completion else None, + logprobs=output_logprobs, ) + def _get_logprobs_from_choice( + self, choice: "ChatCompletionOutputComplete" + ) -> Union[List[List["Logprob"]], None]: + if choice.logprobs is None: + return None + + return [ + [ + {"token": top_logprob.token, "logprob": top_logprob.logprob} + for top_logprob in token_logprobs.top_logprobs + ] + for token_logprobs in choice.logprobs.content + ] + def _check_stop_sequences( self, stop_sequences: Optional[Union[str, List[str]]] = None, @@ -517,6 +571,7 @@ async def agenerate( # type: ignore max_new_tokens: int = 128, frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, logit_bias: Optional[List[float]] = None, + logprobs: bool = False, presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None, seed: Optional[int] = None, stop_sequences: Optional[List[str]] = None, @@ -524,6 +579,8 @@ async def agenerate( # type: ignore tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None, tool_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[PositiveInt] = None, + top_n_tokens: Optional[PositiveInt] = None, top_p: Optional[float] = None, do_sample: bool = False, repetition_penalty: Optional[float] = None, @@ -549,6 +606,9 @@ async def agenerate( # type: ignore This argument is exclusive to the `chat_completion` method and will be used only if `tokenizer_id` is `None`. Defaults to `None`. + logprobs: whether to return the log probabilities or not. This argument is exclusive + to the `chat_completion` method and will be used only if `tokenizer_id` + is `None`. Defaults to `False`. presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model likelihood to talk about new topics. This argument is exclusive to @@ -569,6 +629,12 @@ async def agenerate( # type: ignore tools: a list of tools definitions that the LLM can use. This argument is exclusive to the `chat_completion` method and will be used only if `tokenizer_id` is `None`. Defaults to `None`. + top_logprobs: the number of top log probabilities to return per output token + generated. This argument is exclusive to the `chat_completion` method and + will be used only if `tokenizer_id` is `None`. Defaults to `None`. + top_n_tokens: the number of top log probabilities to return per output token + generated. This argument is exclusive of the `text_generation` method and + will be only used if `tokenizer_id` is not `None`. Defaults to `None`. top_p: the top-p value to use for the generation. Defaults to `1.0`. do_sample: whether to use sampling for the generation. This argument is exclusive of the `text_generation` method and will be only used if `tokenizer_id` is not @@ -602,6 +668,7 @@ async def agenerate( # type: ignore max_new_tokens=max_new_tokens, frequency_penalty=frequency_penalty, logit_bias=logit_bias, + logprobs=logprobs, presence_penalty=presence_penalty, seed=seed, stop_sequences=stop_sequences, @@ -609,6 +676,7 @@ async def agenerate( # type: ignore tool_choice=tool_choice, tool_prompt=tool_prompt, tools=tools, + top_logprobs=top_logprobs, top_p=top_p, ) @@ -620,6 +688,7 @@ async def agenerate( # type: ignore repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, temperature=temperature, + top_n_tokens=top_n_tokens, top_p=top_p, top_k=top_k, stop_sequences=stop_sequences, diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index e58c0b42ce..c53122fa63 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import orjson -from pydantic import Field, PrivateAttr, SecretStr, validate_call +from pydantic import Field, PositiveInt, PrivateAttr, SecretStr, validate_call from distilabel import envs from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException @@ -32,8 +32,11 @@ from openai.types import Batch as OpenAIBatch from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion + from openai.types.chat.chat_completion import Choice as OpenAIChoice + from openai.types.completion import Completion as OpenAICompletion from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import Logprob _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" @@ -233,6 +236,8 @@ async def agenerate( # type: ignore input: FormattedInput, num_generations: int = 1, max_new_tokens: int = 128, + logprobs: bool = False, + top_logprobs: Optional[PositiveInt] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, temperature: float = 1.0, @@ -249,6 +254,9 @@ async def agenerate( # type: ignore `1`. max_new_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`. + logprobs: whether to return the log probabilities or not. Defaults to `False`. + top_logprobs: the number of top log probabilities to return per output token + generated. Defaults to `None`. frequency_penalty: the repetition penalty to use for the generation. Defaults to `0.0`. presence_penalty: the presence penalty to use for the generation. Defaults to @@ -285,6 +293,8 @@ async def agenerate( # type: ignore kwargs = { "messages": input, # type: ignore "model": self.model, + "logprobs": logprobs, + "top_logprobs": top_logprobs, "max_tokens": max_new_tokens, "n": num_generations, "frequency_penalty": frequency_penalty, @@ -307,10 +317,22 @@ async def agenerate( # type: ignore kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore + if structured_output: + # NOTE: `instructor` doesn't work with `n` parameter, so it will always return + # only 1 choice. + statistics = self._get_llm_statistics(completion._raw_response) + if choice_logprobs := self._get_logprobs_from_choice( + completion._raw_response.choices[0] + ): + output_logprobs = [choice_logprobs] + else: + output_logprobs = None return prepare_output( - [completion.model_dump_json()], - **self._get_llm_statistics(completion._raw_response), + generations=[completion.model_dump_json()], + input_tokens=statistics["input_tokens"], + output_tokens=statistics["output_tokens"], + logprobs=output_logprobs, ) return self._generations_from_openai_completion(completion) @@ -327,6 +349,7 @@ def _generations_from_openai_completion( A list of strings containing the generated responses for the input. """ generations = [] + logprobs = [] for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( # type: ignore @@ -334,14 +357,38 @@ def _generations_from_openai_completion( f" Finish reason was: {choice.finish_reason}" ) generations.append(content) + if choice_logprobs := self._get_logprobs_from_choice(choice): + logprobs.append(choice_logprobs) + + statistics = self._get_llm_statistics(completion) + return prepare_output( + generations=generations, + input_tokens=statistics["input_tokens"], + output_tokens=statistics["output_tokens"], + logprobs=logprobs, + ) + + def _get_logprobs_from_choice( + self, choice: "OpenAIChoice" + ) -> Union[List[List["Logprob"]], None]: + if choice.logprobs is None or choice.logprobs.content is None: + return None - return prepare_output(generations, **self._get_llm_statistics(completion)) + return [ + [ + {"token": top_logprob.token, "logprob": top_logprob.logprob} + for top_logprob in token_logprobs.top_logprobs + ] + for token_logprobs in choice.logprobs.content + ] def offline_batch_generate( self, inputs: Union[List["FormattedInput"], None] = None, num_generations: int = 1, max_new_tokens: int = 128, + logprobs: bool = False, + top_logprobs: Optional[PositiveInt] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, temperature: float = 1.0, @@ -359,6 +406,9 @@ def offline_batch_generate( `1`. max_new_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`. + logprobs: whether to return the log probabilities or not. Defaults to `False`. + top_logprobs: the number of top log probabilities to return per output token + generated. Defaults to `None`. frequency_penalty: the repetition penalty to use for the generation. Defaults to `0.0`. presence_penalty: the presence penalty to use for the generation. Defaults to @@ -388,6 +438,8 @@ def offline_batch_generate( inputs=inputs, **{ "model": self.model, + "logprobs": logprobs, + "top_logprobs": top_logprobs, "max_tokens": max_new_tokens, "n": num_generations, "frequency_penalty": frequency_penalty, @@ -684,8 +736,12 @@ def _name_for_openai_files(self, file_no: int) -> str: return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl" @staticmethod - def _get_llm_statistics(completion: "OpenAIChatCompletion") -> "LLMStatistics": + def _get_llm_statistics( + completion: Union["OpenAIChatCompletion", "OpenAICompletion"], + ) -> "LLMStatistics": return { - "input_tokens": [completion.usage.prompt_tokens if completion else 0], - "output_tokens": [completion.usage.completion_tokens if completion else 0], + "output_tokens": [ + completion.usage.completion_tokens if completion.usage else 0 + ], + "input_tokens": [completion.usage.prompt_tokens if completion.usage else 0], } diff --git a/src/distilabel/models/llms/typing.py b/src/distilabel/models/llms/typing.py index 512c76b471..cfa4ec382f 100644 --- a/src/distilabel/models/llms/typing.py +++ b/src/distilabel/models/llms/typing.py @@ -14,9 +14,26 @@ from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, TypeVar, Union +from typing_extensions import NotRequired + LLMOutput = List[Union[str, None]] +class Logprob(TypedDict): + token: str + logprob: float + + +LLMLogprobs = List[List[List[Logprob]]] +"""A type alias representing the probability distributions output by an `LLM`. + +Structure: + - Outermost list: contains multiple generation choices when sampling (`n` sequences) + - Middle list: represents each position in the generated sequence + - Innermost list: contains the log probabilities for each token in the vocabulary at that position +""" + + class TokenCount(TypedDict): input_tokens: List[int] output_tokens: List[int] @@ -31,6 +48,7 @@ class TokenCount(TypedDict): class GenerateOutput(TypedDict): generations: LLMOutput statistics: LLMStatistics + logprobs: NotRequired[LLMLogprobs] if TYPE_CHECKING: diff --git a/src/distilabel/models/llms/utils.py b/src/distilabel/models/llms/utils.py index 6a5ae78a1e..9cf6590c78 100644 --- a/src/distilabel/models/llms/utils.py +++ b/src/distilabel/models/llms/utils.py @@ -17,11 +17,11 @@ from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput, LLMOutput + from distilabel.models.llms.typing import GenerateOutput, LLMLogprobs, LLMOutput def compute_tokens( - text_or_messages: Union[str, ChatType], tokenizer: Callable[[str], List[int]] + text_or_messages: Union[str, ChatType], tokenizer: Callable[..., List[int]] ) -> int: """Helper function to count the number of tokens in a text or list of messages. @@ -42,6 +42,7 @@ def prepare_output( generations: "LLMOutput", input_tokens: Optional[List[int]] = None, output_tokens: Optional[List[int]] = None, + logprobs: Optional["LLMLogprobs"] = None, ) -> "GenerateOutput": """Helper function to prepare the output of the LLM. @@ -53,10 +54,13 @@ def prepare_output( Returns: Output generation from an LLM. """ - return { + output: "GenerateOutput" = { "generations": generations, "statistics": { "input_tokens": input_tokens or [], "output_tokens": output_tokens or [], }, } + if logprobs: + output["logprobs"] = logprobs + return output diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index 7665ea4221..9a75cd47c2 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -28,13 +28,12 @@ Union, ) -import numpy as np -from pydantic import Field, PrivateAttr, SecretStr, validate_call +from pydantic import Field, PositiveInt, PrivateAttr, SecretStr, validate_call from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.openai import OpenAILLM -from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.typing import GenerateOutput, Logprob from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin @@ -44,10 +43,12 @@ from openai import OpenAI # noqa from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM - from vllm.outputs import RequestOutputs, CompletionOutput + from vllm.outputs import RequestOutput, CompletionOutput from distilabel.steps.tasks.typing import StandardInput - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics + from distilabel.steps.tasks.typing import StructuredInput + from distilabel.models.llms.typing import LLMLogprobs, LLMOutput LogitsProcessorFn = Union[ @@ -270,8 +271,8 @@ def prepare_input(self, input: "StandardInput") -> str: return super().apply_magpie_pre_query_template(prompt, input) def _prepare_batches( - self, inputs: List[FormattedInput] - ) -> Tuple[List[List[FormattedInput]], List[int]]: + self, inputs: List["StructuredInput"] + ) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]: """Prepares the inputs by grouping them by the structured output. When we generate structured outputs with schemas obtained from a dataset, we need to @@ -289,16 +290,10 @@ def _prepare_batches( Each new tuple will contain instead of the single instruction, a list of instructions """ instruction_order = {} - batches = {} + batches: Dict[str, List[str]] = {} for i, (instruction, structured_output) in enumerate(inputs): instruction = self.prepare_input(instruction) - - # We need to convert the instruction to a string to make it hashable - str_instruction = instruction - if not isinstance(instruction, str): - str_instruction = json.dumps(instruction) - - instruction_order[str_instruction] = i + instruction_order[instruction] = i structured_output = json.dumps(structured_output) if structured_output not in batches: @@ -306,20 +301,22 @@ def _prepare_batches( else: batches[structured_output].append(instruction) - # Flatten the instructions in prepared_data + # Built a list with instructions sorted by structured output flat_instructions = [ instruction for _, group in batches.items() for instruction in group ] + # Generate the list of indices based on the original order sorted_indices = [ - instruction_order[str_instruction] for instruction in flat_instructions + instruction_order[instruction] for instruction in flat_instructions ] + return [ (batch, json.loads(schema)) for schema, batch in batches.items() ], sorted_indices @validate_call - def generate( # type: ignore + def generate( # noqa: C901 # type: ignore self, inputs: List[FormattedInput], num_generations: int = 1, @@ -331,6 +328,7 @@ def generate( # type: ignore top_p: float = 1.0, top_k: int = -1, min_p: float = 0.0, + logprobs: Optional[PositiveInt] = None, stop: Optional[List[str]] = None, stop_token_ids: Optional[List[int]] = None, include_stop_str_in_output: bool = False, @@ -355,6 +353,8 @@ def generate( # type: ignore top_p: the top-p value to use for the generation. Defaults to `1.0`. top_k: the top-k value to use for the generation. Defaults to `0`. min_p: the minimum probability to use for the generation. Defaults to `0.0`. + logprobs: number of log probabilities to return per output token. If `None`, + then no log probability won't be returned. Defaults to `None`. stop: a list of strings that will be used to stop the generation when found. Defaults to `None`. stop_token_ids: a list of token ids that will be used to stop the generation @@ -380,22 +380,28 @@ def generate( # type: ignore structured_output = None if isinstance(inputs[0], tuple): - prepared_batches, sorted_indices = self._prepare_batches(inputs) + # Prepare the batches for structured generation + prepared_batches, sorted_indices = self._prepare_batches(inputs) # type: ignore else: # Simulate a batch without the structured output content - prepared_batches = [([self.prepare_input(input) for input in inputs], None)] + prepared_batches = [([self.prepare_input(input) for input in inputs], None)] # type: ignore sorted_indices = None + # Case in which we have a single structured output for the dataset if self._structured_output_logits_processor: logits_processors.append(self._structured_output_logits_processor) - batched_outputs = [] + batched_outputs: List["LLMOutput"] = [] generations = [] for prepared_inputs, structured_output in prepared_batches: - if structured_output: + if self.structured_output is not None and structured_output is not None: + # TODO: warning + pass + + if structured_output is not None: logits_processors.append( - self._prepare_structured_output(structured_output) + self._prepare_structured_output(structured_output) # type: ignore ) sampling_params = SamplingParams( # type: ignore @@ -408,6 +414,7 @@ def generate( # type: ignore top_k=top_k, min_p=min_p, max_tokens=max_new_tokens, + logprobs=logprobs, stop=stop, stop_token_ids=stop_token_ids, include_stop_str_in_output=include_stop_str_in_output, @@ -415,40 +422,57 @@ def generate( # type: ignore **extra_sampling_params, ) - batch_outputs: List["RequestOutputs"] = self._model.generate( - prepared_inputs, - sampling_params, - use_tqdm=False, # type: ignore + batch_outputs: List["RequestOutput"] = self._model.generate( + prompts=prepared_inputs, + sampling_params=sampling_params, + use_tqdm=False, ) - # TODO: This is repeated in prepare_output, but for simplicity we extract - # the batched_outputs as we did when there wasn't statistics and we just - # return the str generations - batched_outputs += [ - [output.text for output in outputs.outputs] for outputs in batch_outputs - ] + # Remove structured output logit processor to avoid stacking structured output + # logits processors that leads to non-sense generations + if structured_output is not None: + logits_processors.pop(-1) + for input, outputs in zip(prepared_inputs, batch_outputs): + texts, statistics, outputs_logprobs = self._process_outputs( + input, outputs + ) + batched_outputs.append(texts) generations.append( prepare_output( - [output.text for output in outputs.outputs], - **self._get_llm_statistics(input, outputs), + generations=texts, + input_tokens=statistics["input_tokens"], + output_tokens=statistics["output_tokens"], + logprobs=outputs_logprobs, ) ) - # If logits_processor is set, we need to sort the outputs back to the original order - # (would be needed only if we have multiple structured outputs in the dataset) if sorted_indices is not None: - # Sort the batched outputs together with the statistics - generations = self._prepare_sorted_results( - batched_outputs, - sorted_indices, - generations, - num_generations=num_generations, - ) + pairs = list(enumerate(sorted_indices)) + pairs.sort(key=lambda x: x[1]) + generations = [generations[original_idx] for original_idx, _ in pairs] + return generations + def _process_outputs( + self, input: str, outputs: "RequestOutput" + ) -> Tuple["LLMOutput", "LLMStatistics", "LLMLogprobs"]: + texts = [] + outputs_logprobs = [] + statistics = { + "input_tokens": [compute_tokens(input, self._tokenizer.encode)] + * len(outputs.outputs), + "output_tokens": [], + } + for output in outputs.outputs: + texts.append(output.text) + statistics["output_tokens"].append(len(output.token_ids)) + if output.logprobs is not None: + outputs_logprobs.append(self._get_llm_logprobs(output)) + return texts, statistics, outputs_logprobs + def _prepare_structured_output( - self, structured_output: Optional[OutlinesStructuredOutputType] = None + self, structured_output: "OutlinesStructuredOutputType" ) -> Union[Callable, None]: """Creates the appropriate function to filter tokens to generate structured outputs. @@ -462,69 +486,23 @@ def _prepare_structured_output( prepare_guided_output, ) + assert structured_output is not None, "`structured_output` cannot be `None`" + result = prepare_guided_output(structured_output, "vllm", self._model) if (schema := result.get("schema")) and self.structured_output: self.structured_output["schema"] = schema return result["processor"] - def _get_llm_statistics( - self, input: "FormattedInput", outputs: "CompletionOutput" - ) -> "LLMStatistics": - output_tokens = [len(output.token_ids) for output in outputs.outputs] - return { - "input_tokens": [compute_tokens(input, self._tokenizer.encode)] - * len(output_tokens), - "output_tokens": output_tokens, - } - - @staticmethod - def _prepare_sorted_results( - batched_outputs: List[List[FormattedInput]], - sorted_indices: List[int], - generations: List[GenerateOutput], - num_generations: int = 1, - ) -> List[GenerateOutput]: - """Helper method to sort the results in case of multiple structured outputs in the dataset. - - Args: - batched_outputs: The mini-batches generated by the model. - sorted_indices: The indices that would sort the mini-batches back to the original order. - generations: The prepared outputs that would be returned in the general case, - from which the statistics will be extracted and sorted. - num_generations: The number of generations requested to vLLM. Defaults to 1. - - Returns: - The list of GenerateOutput sorted back to the original order. - """ - - # This was the only required sort back with only the generations - batched_outputs = _sort_batches( - batched_outputs, sorted_indices, num_generations=num_generations - ) - # Prepare the statistics to be sorted - # Loop over all the variables in the statistics - # Get the keys from the LLMStatistics - statistic_fields = list(generations[0]["statistics"].keys()) - statistics = {} - for field in statistic_fields: - batched_field = _sort_batches( - [g["statistics"][field] for g in generations], - sorted_indices, - num_generations=num_generations, - ) - statistics[field] = batched_field - - # Regenerates the outputs as they are returned by `prepare_output` - sorted_results = [] - for i, batched_output in enumerate(batched_outputs): - generation = {"generations": batched_output} - statistics = { - field: batched_field[i] for field, batched_field in statistics.items() - } - generation.update({"statistics": statistics}) - sorted_results.append(generation) - - return sorted_results + def _get_llm_logprobs(self, output: "CompletionOutput") -> List[List["Logprob"]]: + logprobs = [] + for token_logprob in output.logprobs: # type: ignore + token_logprobs = [] + for logprob in token_logprob.values(): + token_logprobs.append( + {"token": logprob.decoded_token, "logprob": logprob.logprob} + ) + logprobs.append(token_logprobs) + return logprobs class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin): @@ -703,7 +681,8 @@ async def agenerate( # type: ignore generations = [] for choice in completion.choices: - if (text := choice.text) == "": + text = choice.text + if text == "": self._logger.warning( # type: ignore f"Received no response from vLLM server (model: '{self.model_name}')." f" Finish reason was: {choice.finish_reason}" @@ -711,48 +690,3 @@ async def agenerate( # type: ignore generations.append(text) return prepare_output(generations, **self._get_llm_statistics(completion)) - - -def _sort_batches( - batches: List[List[FormattedInput]], indices: List[int], num_generations: int = 1 -) -> List[str]: - """Helper function to sort back the mini-batches generated by the model. - - It must take into account the number of `num_generations` to repeat the indices - accordingly. - - Args: - batches: The mini-batches generated by the model. - indices: The indices that would sort the mini-batches back to the original order. - num_generations: The number of generations requested to vLLM. Defaults to 1. - - Returns: - Sorted batched_outputs. - """ - batch_sizes = [len(batch) for batch in batches] - flattened_batches = np.array([b for batch in batches for b in batch]) - sorted_batches = np.take_along_axis( - flattened_batches, - np.argsort(np.repeat(indices, num_generations)), - axis=0, - ).tolist() - sorted_batches = _batchify(sorted_batches, batch_sizes) - return sorted_batches - - -def _batchify(sorted_batches: List[str], batch_sizes: List[int]) -> List[List[str]]: - """Helper function to regenerate the sorted batches from the flattened sorted ones. - - Args: - sorted_batches: Output obtained from the `_sort_batches` function. - batch_sizes: The batch sizes to be used to split the sorted batches. - - Returns: - Batched sorted batches in the original shape. - """ - batches = [] - idx = 0 - for bs in batch_sizes: - batches.append(sorted_batches[idx : idx + bs]) - idx += bs - return batches diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index 9be8f8ee1a..dba92588cd 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -14,7 +14,7 @@ import importlib from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union from pydantic import Field, PrivateAttr from typing_extensions import override @@ -173,7 +173,7 @@ def _format_outputs( repeate_inputs = len(outputs.get("generations")) outputs = normalize_statistics(outputs) - for (output, stats), input in zip( + for (output, stats, extra), input in zip( iterate_generations_with_stats(outputs), inputs * repeate_inputs ): # type: ignore try: @@ -181,13 +181,16 @@ def _format_outputs( # to keep everything clean formatted_output = self.format_output(output, input) formatted_output = self._create_metadata( - formatted_output, - output, - input, + output=formatted_output, + raw_output=output, + input=input, add_raw_output=self.add_raw_output, # type: ignore add_raw_input=self.add_raw_input, # type: ignore statistics=stats, ) + formatted_output = self._create_extra( + output=formatted_output, extra=extra + ) formatted_outputs.append(formatted_output) except Exception as e: self._logger.warning( # type: ignore @@ -217,8 +220,8 @@ def _output_on_failure( def _create_metadata( self, output: Dict[str, Any], - raw_output: List[Union[str, None]], - input: Union[str, None], + raw_output: Union[str, None], + input: Union[Dict[str, Any], None] = None, add_raw_output: bool = True, add_raw_input: bool = True, statistics: Optional["LLMStatistics"] = None, @@ -230,8 +233,8 @@ def _create_metadata( output: The output dictionary after formatting the output from the LLM, to add the raw output and or raw input. - raw_output: The raw output of the LLM (the list of generations). - input: The raw input of the LLM. + raw_output: The raw output of the `LLM`. + input: The input used to generate the output. add_raw_output: Whether to add the raw output to the output dictionary. add_raw_input: Whether to add the raw input to the output dictionary. statistics: The statistics generated by the LLM, which should contain at least @@ -241,15 +244,27 @@ def _create_metadata( if add_raw_output: meta[f"raw_output_{self.name}"] = raw_output + if add_raw_input: meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None + if statistics: meta[f"statistics_{self.name}"] = statistics + if meta: output[DISTILABEL_METADATA_KEY] = meta return output + def _create_extra( + self, output: Dict[str, Any], extra: Dict[str, Any] + ) -> Dict[str, Any]: + column_name_prefix = f"llm_{self.name}_" + for key, value in extra.items(): + column_name = column_name_prefix + key + output[column_name] = value + return output + def _set_default_structured_output(self) -> None: """Prepares the structured output to be set in the selected `LLM`. @@ -520,18 +535,24 @@ def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": return output -def iterate_generations_with_stats(output: "GenerateOutput") -> "GenerateOutput": - """Helper function to iterate together generations and statistics while - processing them inside _format_outputs. +def iterate_generations_with_stats( + outputs: "GenerateOutput", +) -> Generator[Tuple[Union[str, None], "LLMStatistics", Dict[str, Any]], None, None]: + """Helper function to iterate together generations and statistics while processing + them inside `_format_outputs`. Args: - output: Output from the LLM.generate_outputs method. + outputs: outputs from the `LLM.generate_outputs` method. Yields: - Iterator of generation and statistics paired. + Iterator of generation, generation statistics and extra data generated by the `LLM`. """ - for i, generation in enumerate(output["generations"]): + extra_keys = [ + key for key in outputs.keys() if key not in ("generations", "statistics") + ] + for i, generation in enumerate(outputs["generations"]): # Create a new dictionary with the statistics for this index - stats = {key: values[i] for key, values in output["statistics"].items()} - - yield generation, stats + stats = {key: values[i] for key, values in outputs["statistics"].items()} # type: ignore + # Extra keys returned by the `LLM` + extra = {key: outputs[key][i] for key in extra_keys} + yield generation, stats, extra diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index 62419d37b0..fe561d11af 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -17,13 +17,13 @@ import inspect import json from typing import ( + TYPE_CHECKING, Any, Callable, Dict, Literal, Tuple, Type, - Union, get_args, ) @@ -31,7 +31,9 @@ from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict -from distilabel.steps.tasks.typing import StructuredOutputType + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import OutlinesStructuredOutputType Frameworks = Literal["transformers", "llamacpp", "vllm"] """Available frameworks for the structured output configuration. """ @@ -72,10 +74,10 @@ def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: def prepare_guided_output( - structured_output: StructuredOutputType, + structured_output: "OutlinesStructuredOutputType", framework: Frameworks, llm: Any, -) -> Dict[str, Union[Callable, None]]: +) -> Dict[str, Any]: """Prepares the `LLM` to generate guided output using `outlines`. It allows to generate JSON or Regex structured outputs for the integrated @@ -105,6 +107,8 @@ def prepare_guided_output( format = structured_output.get("format") schema = structured_output.get("schema") + assert schema is not None, "schema cannot be `None`" + # If schema not informed (may be forgotten), try infering it if not format: if isinstance(schema, dict) or inspect.isclass(schema): diff --git a/src/distilabel/steps/tasks/structured_outputs/utils.py b/src/distilabel/steps/tasks/structured_outputs/utils.py index b041281989..46676fe5a6 100644 --- a/src/distilabel/steps/tasks/structured_outputs/utils.py +++ b/src/distilabel/steps/tasks/structured_outputs/utils.py @@ -19,14 +19,16 @@ def schema_as_dict( - schema: Union[str, Type[BaseModel], Dict[str, Any]], + schema: Union[str, Dict[str, Any], Type[BaseModel]], ) -> Dict[str, Any]: """Helper function to obtain the schema and simplify serialization.""" - if type(schema) is type(BaseModel): - return schema.model_json_schema() - elif isinstance(schema, str): + if isinstance(schema, str): return json.loads(schema) - return schema # type: ignore + + if isinstance(schema, dict): + return schema + + return schema.model_json_schema() # NOTE: The following functions were copied from: @@ -47,7 +49,7 @@ def json_schema_to_model(json_schema: Dict[str, Any]) -> Type[BaseModel]: """ # Extract the model name from the schema title. - model_name = json_schema.get("title") + model_name = json_schema["title"] if defs := json_schema.get("$defs", None): # This is done to grab the content of nested classes that need to dereference # the objects (those should be in a higher level). diff --git a/tests/unit/models/llms/huggingface/test_inference_endpoints.py b/tests/unit/models/llms/huggingface/test_inference_endpoints.py index 874cd9a595..f1dcd5e028 100644 --- a/tests/unit/models/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/models/llms/huggingface/test_inference_endpoints.py @@ -23,7 +23,10 @@ from huggingface_hub import ( ChatCompletionOutput, ChatCompletionOutputComplete, + ChatCompletionOutputLogprob, + ChatCompletionOutputLogprobs, ChatCompletionOutputMessage, + ChatCompletionOutputTopLogprob, ChatCompletionOutputUsage, ) @@ -134,6 +137,7 @@ async def test_agenerate_with_text_generation( generated_text="Aenean hendrerit aliquam velit...", details=MagicMock( generated_tokens=66, + top_tokens=None, ), ) ) @@ -146,12 +150,72 @@ async def test_agenerate_with_text_generation( }, ] ) + + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": { + "input_tokens": [31], + "output_tokens": [66], + }, + } + + @pytest.mark.asyncio + async def test_agenerate_with_text_generation_and_top_n_tokens( + self, mock_inference_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + tokenizer_id="distilabel-internal-testing/tiny-random-mistral", + ) + llm.load() + + llm._aclient.text_generation = AsyncMock( + return_value=MagicMock( + generated_text="Aenean hendrerit aliquam velit...", + details=MagicMock( + generated_tokens=66, + top_tokens=[ + [ + {"logprob": 0, "text": "Aenean"}, + {"logprob": -2, "text": "Hello"}, + ], + [ + {"logprob": 0, "text": " "}, + {"logprob": -2, "text": ","}, + ], + ], + ), + ) + ) + + result = await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ], + top_n_tokens=2, + ) + assert result == { "generations": ["Aenean hendrerit aliquam velit..."], "statistics": { "input_tokens": [31], "output_tokens": [66], }, + "logprobs": [ + [ + [ + {"logprob": 0, "token": "Aenean"}, + {"logprob": -2, "token": "Hello"}, + ], + [ + {"logprob": 0, "token": " "}, + {"logprob": -2, "token": ","}, + ], + ] + ], } @pytest.mark.asyncio @@ -201,6 +265,107 @@ async def test_agenerate_with_chat_completion( }, } + @pytest.mark.asyncio + async def test_agenerate_with_chat_completion_and_logprobs_and_top_logprobs( + self, mock_inference_client: MagicMock + ) -> None: + llm = InferenceEndpointsLLM( + model_id="distilabel-internal-testing/tiny-random-mistral", + ) + llm.load() + + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="length", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=" Aenean hendrerit aliquam velit. ...", + ), + logprobs=ChatCompletionOutputLogprobs( + content=[ + ChatCompletionOutputLogprob( + logprob=0, + token=" ", + top_logprobs=[ + ChatCompletionOutputTopLogprob( + logprob=0, token=" " + ), + ChatCompletionOutputTopLogprob( + logprob=-1, token="Hello" + ), + ], + ), + ChatCompletionOutputLogprob( + logprob=0, + token="Aenean", + top_logprobs=[ + ChatCompletionOutputTopLogprob( + logprob=0, token="Aenean" + ), + ChatCompletionOutputTopLogprob( + logprob=-1, token="miau" + ), + ], + ), + ] + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) + ) + + result = await llm.agenerate( + input=[ + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, + ], + logprobs=True, + top_logprobs=2, + ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": { + "input_tokens": [18], + "output_tokens": [66], + }, + "logprobs": [ + [ + [ + { + "logprob": 0, + "token": " ", + }, + { + "logprob": -1, + "token": "Hello", + }, + ], + [ + { + "logprob": 0, + "token": "Aenean", + }, + { + "logprob": -1, + "token": "miau", + }, + ], + ] + ], + } + @pytest.mark.asyncio async def test_agenerate_with_chat_completion_fails( self, mock_inference_client: MagicMock @@ -338,9 +503,7 @@ async def test_agenerate_with_structured_output( llm._aclient.text_generation = AsyncMock( return_value=MagicMock( generated_text="Aenean hendrerit aliquam velit...", - details=MagicMock( - generated_tokens=66, - ), + details=MagicMock(generated_tokens=66, top_tokens=None), ) ) # Since there's a pseudo-random number within the generation kwargs, we set the seed diff --git a/tests/unit/models/llms/test_openai.py b/tests/unit/models/llms/test_openai.py index b0c242f690..1e2a2d31d2 100644 --- a/tests/unit/models/llms/test_openai.py +++ b/tests/unit/models/llms/test_openai.py @@ -66,7 +66,15 @@ async def test_agenerate( mocked_completion = Mock( choices=[ - Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + Mock( + message=Mock(content=" Aenean hendrerit aliquam velit. ..."), + logprobs=Mock( + content=[ + Mock(top_logprobs=[Mock(token=" ", logprob=-1)]), + Mock(top_logprobs=[Mock(token="Aenean", logprob=-2)]), + ] + ), + ) ], usage=Mock(prompt_tokens=100, completion_tokens=100), ) @@ -84,6 +92,9 @@ async def test_agenerate( assert result == { "generations": [" Aenean hendrerit aliquam velit. ..."], "statistics": {"input_tokens": [100], "output_tokens": [100]}, + "logprobs": [ + [[{"token": " ", "logprob": -1}], [{"token": "Aenean", "logprob": -2}]] + ], } @pytest.mark.asyncio @@ -100,9 +111,6 @@ async def test_agenerate_structured( }, ) # type: ignore llm._aclient = async_openai_mock - import tiktoken - - llm._tokenizer = tiktoken.encoding_for_model(self.model_id) mocked_usage = MagicMock( usage=MagicMock(prompt_tokens=100, completion_tokens=100), @@ -139,6 +147,12 @@ async def test_agenerate_structured( { "generations": [" Aenean hendrerit aliquam velit. ..."], "statistics": {"input_tokens": [100], "output_tokens": [100]}, + "logprobs": [ + [ + [{"token": " ", "logprob": -1}], + [{"token": "Aenean", "logprob": -2}], + ] + ], } ], ), @@ -148,6 +162,13 @@ async def test_agenerate_structured( { "generations": [" Aenean hendrerit aliquam velit. ..."] * 2, "statistics": {"input_tokens": [100], "output_tokens": [100]}, + "logprobs": [ + [ + [{"token": " ", "logprob": -1}], + [{"token": "Aenean", "logprob": -2}], + ] + ] + * 2, } ], ), @@ -165,7 +186,17 @@ async def test_generate( llm._aclient = async_openai_mock mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock( + message=Mock(content=" Aenean hendrerit aliquam velit. ..."), + logprobs=Mock( + content=[ + Mock(top_logprobs=[Mock(token=" ", logprob=-1)]), + Mock(top_logprobs=[Mock(token="Aenean", logprob=-2)]), + ] + ), + ) + ] * num_generations, usage=Mock(prompt_tokens=100, completion_tokens=100), ) @@ -186,6 +217,13 @@ async def test_generate( ) assert result == expected_result + @pytest.mark.asyncio + async def test_generate_raises_value_error_if_unknown_response_format( + self, async_openai_mock: MagicMock, _: MagicMock + ) -> None: + llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore + llm._aclient = async_openai_mock + with pytest.raises(ValueError): llm.generate( inputs=[ diff --git a/tests/unit/models/llms/test_vllm.py b/tests/unit/models/llms/test_vllm.py index dda129fd8b..6babb6232c 100644 --- a/tests/unit/models/llms/test_vllm.py +++ b/tests/unit/models/llms/test_vllm.py @@ -102,29 +102,11 @@ class Animal(BaseModel): ] -class DummyTokenizer: - # chat_template = None - chat_template = "template" - vocabulary = {"I'm": 1, "fine": 2, "thank": 3, "you": 4, "sir": 5} - - def __init__(self) -> None: - pass - - def apply_chat_template(self, input, **kwargs): - return input - - def encode(self, text: str): - return [1, 2, 3, 4, 5] - - def convert_token_to_string(self, token: str) -> str: - return "token" - - def get_vocab(self): - return self.vocabulary - - class TestvLLM: - @pytest.mark.parametrize("multi_structured_output", (False, True)) + @pytest.mark.parametrize( + "multi_structured_output", + (True, False), + ) @pytest.mark.parametrize( "num_generations, expected_result", [ @@ -133,7 +115,19 @@ class TestvLLM: [ { "generations": ["I'm fine thank you"], - "statistics": {"input_tokens": [10], "output_tokens": [6]}, + "statistics": {"input_tokens": [21], "output_tokens": [6]}, + "logprobs": [ + [ + [ + {"token": "I'm", "logprob": -1}, + {"token": "Hello", "logprob": -3}, + ], + [ + {"token": "I'm", "logprob": -1}, + {"token": "Hello", "logprob": -3}, + ], + ] + ], } ], ), @@ -143,9 +137,22 @@ class TestvLLM: { "generations": ["I'm fine thank you"] * 2, "statistics": { - "input_tokens": [10, 10], + "input_tokens": [21, 21], "output_tokens": [6, 6], }, + "logprobs": [ + [ + [ + {"token": "I'm", "logprob": -1}, + {"token": "Hello", "logprob": -3}, + ], + [ + {"token": "I'm", "logprob": -1}, + {"token": "Hello", "logprob": -3}, + ], + ] + ] + * 2, } ], ), @@ -161,7 +168,7 @@ def test_generate( tokenizer = AutoTokenizer.from_pretrained( "distilabel-internal-testing/tiny-random-mistral" ) - llm._tokenizer = DummyTokenizer() + llm._tokenizer = tokenizer vllm_mock = mock.MagicMock() vllm_mock.get_tokenizer = mock.MagicMock(return_value=tokenizer) # mock the import by hacking sys.modules @@ -178,6 +185,16 @@ def test_generate( mock.Mock( # CompletionOutput text="I'm fine thank you", token_ids=[1, 2, 3, 4, 5, 7], + logprobs=[ + { + 1: mock.Mock(decoded_token="I'm", logprob=-1), + 2: mock.Mock(decoded_token="Hello", logprob=-3), + }, + { + 1: mock.Mock(decoded_token="I'm", logprob=-1), + 2: mock.Mock(decoded_token="Hello", logprob=-3), + }, + ], ) ] * num_generations, From cf28976323acc45d15f6bb58470ac4c9a0d874b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 13 Dec 2024 11:36:24 +0100 Subject: [PATCH 08/30] Enable `RUF022` to automatically sort `__all__` --- pyproject.toml | 1 + src/distilabel/constants.py | 22 ++++++------- src/distilabel/llms.py | 12 +++---- src/distilabel/models/__init__.py | 16 +++++----- src/distilabel/models/llms/__init__.py | 12 +++---- src/distilabel/pipeline/__init__.py | 2 +- src/distilabel/steps/__init__.py | 44 +++++++++++++------------- src/distilabel/steps/tasks/__init__.py | 36 ++++++++++----------- src/distilabel/typing.py | 20 ++++++------ 9 files changed, 83 insertions(+), 82 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e5bcc8399a..b203f7edf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,7 @@ exclude = ["docs"] [tool.ruff.lint] select = ["E", "W", "F", "I", "C", "B"] ignore = ["E501", "B905", "B008"] +extend-select = ["RUF022"] [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/distilabel/constants.py b/src/distilabel/constants.py index 44554f8423..20f644dffc 100644 --- a/src/distilabel/constants.py +++ b/src/distilabel/constants.py @@ -50,21 +50,21 @@ __all__ = [ - "DISTILABEL_METADATA_KEY", "BASE_CACHE_DIR", - "PIPELINES_CACHE_DIR", - "STEP_ATTR_NAME", + "CONVERGENCE_STEP_ATTR_NAME", + "DISTILABEL_DOCS_URL", + "DISTILABEL_METADATA_KEY", + "DISTISET_ARTIFACTS_FOLDER", + "DISTISET_CONFIG_FOLDER", "INPUT_QUEUE_ATTR_NAME", + "LAST_BATCH_SENT_FLAG", + "PIPELINES_CACHE_DIR", + "PIPELINE_CONFIG_FILENAME", + "PIPELINE_LOG_FILENAME", "RECEIVES_ROUTED_BATCHES_ATTR_NAME", "ROUTING_BATCH_FUNCTION_ATTR_NAME", - "CONVERGENCE_STEP_ATTR_NAME", - "LAST_BATCH_SENT_FLAG", "SIGINT_HANDLER_CALLED_ENV_NAME", - "STEPS_OUTPUTS_PATH", "STEPS_ARTIFACTS_PATH", - "DISTISET_CONFIG_FOLDER", - "DISTISET_ARTIFACTS_FOLDER", - "PIPELINE_CONFIG_FILENAME", - "PIPELINE_LOG_FILENAME", - "DISTILABEL_DOCS_URL", + "STEPS_OUTPUTS_PATH", + "STEP_ATTR_NAME", ] diff --git a/src/distilabel/llms.py b/src/distilabel/llms.py index e4970992ce..b00d891407 100644 --- a/src/distilabel/llms.py +++ b/src/distilabel/llms.py @@ -43,26 +43,26 @@ from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin __all__ = [ + "LLM", "AnthropicLLM", "AnyscaleLLM", - "AzureOpenAILLM", - "LLM", "AsyncLLM", + "AzureOpenAILLM", + "ClientvLLM", "CohereLLM", + "CudaDevicePlacementMixin", + "GenerateOutput", "GroqLLM", + "HiddenState", "InferenceEndpointsLLM", "LiteLLM", "LlamaCppLLM", "MistralLLM", - "CudaDevicePlacementMixin", "MixtureOfAgentsLLM", "OllamaLLM", "OpenAILLM", "TogetherLLM", "TransformersLLM", - "GenerateOutput", - "HiddenState", "VertexAILLM", - "ClientvLLM", "vLLM", ] diff --git a/src/distilabel/models/__init__.py b/src/distilabel/models/__init__.py index 45807302f0..b84a2c4690 100644 --- a/src/distilabel/models/__init__.py +++ b/src/distilabel/models/__init__.py @@ -38,29 +38,29 @@ from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin __all__ = [ + "LLM", "AnthropicLLM", "AnyscaleLLM", - "AzureOpenAILLM", - "LLM", "AsyncLLM", + "AzureOpenAILLM", + "ClientvLLM", "CohereLLM", + "CudaDevicePlacementMixin", + "Embeddings", + "GenerateOutput", "GroqLLM", + "HiddenState", "InferenceEndpointsLLM", "LiteLLM", "LlamaCppLLM", "MistralLLM", - "CudaDevicePlacementMixin", "MixtureOfAgentsLLM", "OllamaLLM", "OpenAILLM", + "SentenceTransformerEmbeddings", "TogetherLLM", "TransformersLLM", - "GenerateOutput", - "HiddenState", "VertexAILLM", - "ClientvLLM", "vLLM", - "Embeddings", - "SentenceTransformerEmbeddings", "vLLMEmbeddings", ] diff --git a/src/distilabel/models/llms/__init__.py b/src/distilabel/models/llms/__init__.py index 2ae3119832..cca70d64c3 100644 --- a/src/distilabel/models/llms/__init__.py +++ b/src/distilabel/models/llms/__init__.py @@ -32,26 +32,26 @@ from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin __all__ = [ + "LLM", "AnthropicLLM", "AnyscaleLLM", - "AzureOpenAILLM", - "LLM", "AsyncLLM", + "AzureOpenAILLM", + "ClientvLLM", "CohereLLM", + "CudaDevicePlacementMixin", + "GenerateOutput", "GroqLLM", + "HiddenState", "InferenceEndpointsLLM", "LiteLLM", "LlamaCppLLM", "MistralLLM", - "CudaDevicePlacementMixin", "MixtureOfAgentsLLM", "OllamaLLM", "OpenAILLM", "TogetherLLM", "TransformersLLM", - "GenerateOutput", - "HiddenState", "VertexAILLM", - "ClientvLLM", "vLLM", ] diff --git a/src/distilabel/pipeline/__init__.py b/src/distilabel/pipeline/__init__.py index 34400288da..b5e8b2c278 100644 --- a/src/distilabel/pipeline/__init__.py +++ b/src/distilabel/pipeline/__init__.py @@ -23,9 +23,9 @@ ) __all__ = [ + "InstructionResponsePipeline", "Pipeline", "RayPipeline", - "InstructionResponsePipeline", "routing_batch_function", "sample_n_steps", ] diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index cc1be59f92..58875bbec3 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -58,42 +58,42 @@ from distilabel.steps.typing import GeneratorStepOutput, StepOutput __all__ = [ - "PreferenceToArgilla", - "TextGenerationToArgilla", - "GeneratorStep", - "GlobalStep", - "Step", - "StepInput", - "StepResources", - "CombineOutputs", - "ExpandColumns", - "CombineColumns", - "GroupColumns", - "KeepColumns", - "MergeColumns", "DBSCAN", "UMAP", - "TextClustering", - "step", + "CombineColumns", + "CombineOutputs", + "ConversationTemplate", + "DataSampler", "DeitaFiltering", + "EmbeddingDedup", "EmbeddingGeneration", + "ExpandColumns", "FaissNearestNeighbour", - "ConversationTemplate", "FormatChatGenerationDPO", - "FormatTextGenerationDPO", "FormatChatGenerationSFT", + "FormatTextGenerationDPO", "FormatTextGenerationSFT", + "GeneratorStep", + "GeneratorStepOutput", + "GlobalStep", + "GroupColumns", + "KeepColumns", "LoadDataFromDicts", - "DataSampler", "LoadDataFromDisk", "LoadDataFromFileSystem", "LoadDataFromHub", - "EmbeddingDedup", + "MergeColumns", "MinHashDedup", - "make_generator_step", + "PreferenceToArgilla", "PushToHub", "RewardModelScore", - "TruncateTextColumn", - "GeneratorStepOutput", + "Step", + "StepInput", "StepOutput", + "StepResources", + "TextClustering", + "TextGenerationToArgilla", + "TruncateTextColumn", + "make_generator_step", + "step", ] diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index aa0460c3e1..f542aea232 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -60,47 +60,47 @@ from distilabel.steps.tasks.urial import URIAL __all__ = [ - "GeneratorTask", - "Task", - "ArgillaLabeller", + "CLAIR", + "URIAL", "APIGenExecutionChecker", "APIGenGenerator", "APIGenSemanticChecker", + "ArgillaLabeller", + "BitextRetrievalGenerator", + "ChatGeneration", + "ChatItem", + "ChatType", "ComplexityScorer", - "task", - "EvolInstruct", + "EmbeddingTaskGenerator", "EvolComplexity", "EvolComplexityGenerator", + "EvolInstruct", "EvolInstructGenerator", "EvolQuality", + "FormatPRM", "GenerateEmbeddings", - "Genstruct", - "BitextRetrievalGenerator", - "EmbeddingTaskGenerator", "GenerateLongTextMatchingData", + "GenerateSentencePair", "GenerateShortTextMatchingData", "GenerateTextClassificationData", "GenerateTextRetrievalData", - "MonolingualTripletGenerator", + "GeneratorTask", + "Genstruct", "InstructionBacktranslation", "Magpie", "MagpieGenerator", - "MathShepherdGenerator", "MathShepherdCompleter", - "FormatPRM", + "MathShepherdGenerator", + "MonolingualTripletGenerator", "PairRM", "PrometheusEval", "QualityScorer", "SelfInstruct", - "GenerateSentencePair", "StructuredGeneration", + "Task", "TextClassification", - "ChatGeneration", "TextGeneration", - "ChatItem", - "ChatType", - "CLAIR", - "UltraFeedback", - "URIAL", "TextGenerationWithImage", + "UltraFeedback", + "task", ] diff --git a/src/distilabel/typing.py b/src/distilabel/typing.py index 28bfd57fc5..a3d65d5d75 100644 --- a/src/distilabel/typing.py +++ b/src/distilabel/typing.py @@ -34,22 +34,22 @@ from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput __all__ = [ - "GenerateOutput", - "DownstreamConnectable", - "DownstreamConnectableSteps", - "InputDataset", - "PipelineRuntimeParametersInfo", - "StepLoadStatus", - "UpstreamConnectableSteps", "ChatItem", "ChatType", + "DownstreamConnectable", + "DownstreamConnectableSteps", "FormattedInput", + "GenerateOutput", + "GeneratorStepOutput", + "InputDataset", "InstructorStructuredOutputType", "OutlinesStructuredOutputType", + "PipelineRuntimeParametersInfo", "StandardInput", - "StructuredInput", - "StructuredOutputType", - "GeneratorStepOutput", "StepColumns", + "StepLoadStatus", "StepOutput", + "StructuredInput", + "StructuredOutputType", + "UpstreamConnectableSteps", ] From c2ae3f1f0543e0b7a6089779044da63ed60cd966 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 18 Dec 2024 11:34:55 +0100 Subject: [PATCH 09/30] Fix `vLLM` unload logic when model is `None` (#1080) --- src/distilabel/models/llms/_dummy.py | 70 ------------------- src/distilabel/models/llms/vllm.py | 3 + .../integration/test_generator_and_sampler.py | 26 ++++++- tests/unit/models/llms/test_vllm.py | 4 +- .../tasks/structured_outputs/test_outlines.py | 3 + 5 files changed, 34 insertions(+), 72 deletions(-) delete mode 100644 src/distilabel/models/llms/_dummy.py diff --git a/src/distilabel/models/llms/_dummy.py b/src/distilabel/models/llms/_dummy.py deleted file mode 100644 index de89356d0f..0000000000 --- a/src/distilabel/models/llms/_dummy.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, List - -from distilabel.models.llms.base import LLM, AsyncLLM -from distilabel.models.mixins.magpie import MagpieChatTemplateMixin - -if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput - - -class DummyAsyncLLM(AsyncLLM): - structured_output: Any = None - - def load(self) -> None: - pass - - @property - def model_name(self) -> str: - return "test" - - async def agenerate( # type: ignore - self, input: "FormattedInput", num_generations: int = 1 - ) -> "GenerateOutput": - return ["output" for _ in range(num_generations)] - - -class DummySyncLLM(LLM): - structured_output: Any = None - - def load(self) -> None: - super().load() - - @property - def model_name(self) -> str: - return "test" - - def generate( # type: ignore - self, inputs: "FormattedInput", num_generations: int = 1 - ) -> "GenerateOutput": - return [["output" for _ in range(num_generations)] for _ in range(len(inputs))] - - -class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): - def load(self) -> None: - pass - - @property - def model_name(self) -> str: - return "test" - - def generate( - self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any - ) -> List["GenerateOutput"]: - return [ - ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs)) - ] diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index 9a75cd47c2..401bc66d09 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -224,6 +224,9 @@ def unload(self) -> None: super().unload() def _cleanup_vllm_model(self) -> None: + if self._model is None: + return + import torch # noqa from vllm.distributed.parallel_state import ( destroy_distributed_environment, diff --git a/tests/integration/test_generator_and_sampler.py b/tests/integration/test_generator_and_sampler.py index cdbeb5703a..5c53346f48 100644 --- a/tests/integration/test_generator_and_sampler.py +++ b/tests/integration/test_generator_and_sampler.py @@ -12,12 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from distilabel.models.llms._dummy import DummyAsyncLLM +from typing import TYPE_CHECKING, Any + +from distilabel.models.llms.base import AsyncLLM from distilabel.pipeline import Pipeline from distilabel.steps import CombineOutputs, LoadDataFromDicts from distilabel.steps.generators.data_sampler import DataSampler from distilabel.steps.tasks import TextGeneration +if TYPE_CHECKING: + from distilabel.typing import FormattedInput, GenerateOutput + + +class DummyAsyncLLM(AsyncLLM): + structured_output: Any = None + + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + async def agenerate( # type: ignore + self, input: "FormattedInput", num_generations: int = 1 + ) -> "GenerateOutput": + return { + "generations": ["output" for _ in range(num_generations)], + "statistics": {}, + } + def get_pipeline(): with Pipeline() as pipe: diff --git a/tests/unit/models/llms/test_vllm.py b/tests/unit/models/llms/test_vllm.py index 6babb6232c..2230186bf3 100644 --- a/tests/unit/models/llms/test_vllm.py +++ b/tests/unit/models/llms/test_vllm.py @@ -105,7 +105,9 @@ class Animal(BaseModel): class TestvLLM: @pytest.mark.parametrize( "multi_structured_output", - (True, False), + # TODO: uncomment once with update our code to work with `outlines>0.1.0` + # (True, False), + (False,), ) @pytest.mark.parametrize( "num_generations, expected_result", diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index fc6f9a2f7c..e4eb2025c8 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -100,6 +100,9 @@ class DummyUserTest(BaseModel): } +@pytest.mark.skip( + reason="won't work until we update our code to work with `outlines>0.1.0`" +) class TestOutlinesIntegration: @pytest.mark.parametrize( "format, schema, prompt", From 925d259c50f6ab8646335a22989e0f7192c3380f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 18 Dec 2024 15:24:30 +0100 Subject: [PATCH 10/30] Fix `merge_distilabel_metadata` function when handling outputs from `Task` with `group_generations==True` (#1082) --- src/distilabel/steps/columns/group.py | 6 +-- src/distilabel/steps/columns/utils.py | 28 ++++++++++---- tests/unit/steps/columns/test_utils.py | 53 ++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 10 deletions(-) create mode 100644 tests/unit/steps/columns/test_utils.py diff --git a/src/distilabel/steps/columns/group.py b/src/distilabel/steps/columns/group.py index 876af1f0ad..4cc77b50f0 100644 --- a/src/distilabel/steps/columns/group.py +++ b/src/distilabel/steps/columns/group.py @@ -21,7 +21,7 @@ from distilabel.steps.columns.utils import group_columns if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.steps.typing import StepOutput class GroupColumns(Step): @@ -96,12 +96,12 @@ class GroupColumns(Step): output_columns: Optional[List[str]] = None @property - def inputs(self) -> "StepColumns": + def inputs(self) -> List[str]: """The inputs for the task are the column names in `columns`.""" return self.columns @property - def outputs(self) -> "StepColumns": + def outputs(self) -> List[str]: """The outputs for the task are the column names in `output_columns` or `grouped_{column}` for each column in `columns`.""" return ( diff --git a/src/distilabel/steps/columns/utils.py b/src/distilabel/steps/columns/utils.py index 7b3efe2262..58bcae6139 100644 --- a/src/distilabel/steps/columns/utils.py +++ b/src/distilabel/steps/columns/utils.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from distilabel.constants import DISTILABEL_METADATA_KEY @@ -21,22 +21,36 @@ from distilabel.steps.base import StepInput -def merge_distilabel_metadata(*output_dicts: Dict[str, Any]) -> Dict[str, Any]: +def merge_distilabel_metadata( + *output_dicts: Dict[str, Any], +) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """ - Merge the `DISTILABEL_METADATA_KEY` from multiple output dictionaries. + Merge the `DISTILABEL_METADATA_KEY` from multiple output dictionaries. `DISTILABEL_METADATA_KEY` + can be either a dictionary containing metadata keys or a list containing dictionaries + of metadata keys. Args: - *output_dicts: Variable number of dictionaries containing distilabel metadata. + *output_dicts: Variable number of dictionaries or lists containing distilabel metadata. Returns: - A merged dictionary containing all the distilabel metadata from the input dictionaries. + A merged dictionary or list containing all the distilabel metadata. """ merged_metadata = defaultdict(list) for output_dict in output_dicts: metadata = output_dict.get(DISTILABEL_METADATA_KEY, {}) - for key, value in metadata.items(): - merged_metadata[key].append(value) + # If `distilabel_metadata_key` is a `list` then it contains dictionaries with + # the metadata per `num_generations` created when `group_generations==True` + if isinstance(metadata, list): + if not isinstance(merged_metadata, list): + merged_metadata = [] + merged_metadata.extend(metadata) + else: + for key, value in metadata.items(): + merged_metadata[key].append(value) + + if isinstance(merged_metadata, list): + return merged_metadata final_metadata = {} for key, value_list in merged_metadata.items(): diff --git a/tests/unit/steps/columns/test_utils.py b/tests/unit/steps/columns/test_utils.py new file mode 100644 index 0000000000..790cf73fb5 --- /dev/null +++ b/tests/unit/steps/columns/test_utils.py @@ -0,0 +1,53 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.constants import DISTILABEL_METADATA_KEY +from distilabel.steps.columns.utils import merge_distilabel_metadata + + +def test_merge_distilabel_metadata() -> None: + rows = [ + {DISTILABEL_METADATA_KEY: {"a": 1, "b": 1}}, + {DISTILABEL_METADATA_KEY: {"a": 2, "b": 2}}, + ] + result = merge_distilabel_metadata(*rows) + assert result == {"a": [1, 2], "b": [1, 2]} + + +def test_merge_distilabel_metadata_list() -> None: + rows = [ + { + DISTILABEL_METADATA_KEY: [ + {"a": 1.0, "b": 1.0}, + {"a": 1.1, "b": 1.1}, + {"a": 1.2, "b": 1.2}, + ] + }, + { + DISTILABEL_METADATA_KEY: [ + {"a": 2.0, "b": 2.0}, + {"a": 2.1, "b": 2.1}, + {"a": 2.2, "b": 2.2}, + ] + }, + ] + result = merge_distilabel_metadata(*rows) + assert result == [ + {"a": 1.0, "b": 1.0}, + {"a": 1.1, "b": 1.1}, + {"a": 1.2, "b": 1.2}, + {"a": 2.0, "b": 2.0}, + {"a": 2.1, "b": 2.1}, + {"a": 2.2, "b": 2.2}, + ] From bfc84458593b892204e9a4ea9e8f742e9353b57d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 18 Dec 2024 17:25:51 +0100 Subject: [PATCH 11/30] Fix chat template not applied in `TransformersLLM` (#1083) --- .github/workflows/docs-pr-close.yml | 4 ++++ .github/workflows/docs-pr.yml | 4 ++++ .github/workflows/docs.yml | 4 ++++ src/distilabel/__init__.py | 2 +- src/distilabel/llms/huggingface/transformers.py | 2 +- tests/unit/llms/huggingface/test_transformers.py | 15 +++++++++++++++ tests/unit/steps/argilla/test_preference.py | 12 +++++++++++- .../tasks/structured_outputs/test_outlines.py | 3 +++ 8 files changed, 43 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs-pr-close.yml b/.github/workflows/docs-pr-close.yml index 71f4e5ff93..61008bcee1 100644 --- a/.github/workflows/docs-pr-close.yml +++ b/.github/workflows/docs-pr-close.yml @@ -8,6 +8,10 @@ concurrency: group: distilabel-docs cancel-in-progress: false +permissions: + contents: write + pull-requests: write + jobs: cleanup: runs-on: ubuntu-latest diff --git a/.github/workflows/docs-pr.yml b/.github/workflows/docs-pr.yml index 48c7236a58..ec963ccf98 100644 --- a/.github/workflows/docs-pr.yml +++ b/.github/workflows/docs-pr.yml @@ -10,6 +10,10 @@ concurrency: group: distilabel-docs cancel-in-progress: false +permissions: + contents: write + pull-requests: write + jobs: publish: runs-on: ubuntu-latest diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index dd59a5129d..93a17408e8 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -12,6 +12,10 @@ concurrency: group: distilabel-docs cancel-in-progress: false +permissions: + contents: write + pull-requests: write + jobs: publish: runs-on: ubuntu-latest diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py index f6ca72cd10..11a8378250 100644 --- a/src/distilabel/__init__.py +++ b/src/distilabel/__init__.py @@ -14,6 +14,6 @@ from rich import traceback as rich_traceback -__version__ = "1.4.1" +__version__ = "1.4.2" rich_traceback.install(show_locals=True) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 27ab00e5b9..a0582b4155 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -174,7 +174,7 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ - if self._pipeline.tokenizer.chat_template: # type: ignore + if self._pipeline.tokenizer.chat_template is None: # type: ignore return input[0]["content"] prompt: str = ( diff --git a/tests/unit/llms/huggingface/test_transformers.py b/tests/unit/llms/huggingface/test_transformers.py index 97214ef5fc..79d6089f79 100644 --- a/tests/unit/llms/huggingface/test_transformers.py +++ b/tests/unit/llms/huggingface/test_transformers.py @@ -40,6 +40,21 @@ def test_model_name(self, transformers_llm: TransformersLLM) -> None: == "distilabel-internal-testing/tiny-random-mistral" ) + def test_prepare_input(self, transformers_llm: TransformersLLM) -> None: + assert ( + transformers_llm.prepare_input([{"role": "user", "content": "Hello"}]) + == " [INST] Hello [/INST]" + ) + + def test_prepare_input_no_chat_template( + self, transformers_llm: TransformersLLM + ) -> None: + transformers_llm._pipeline.tokenizer.chat_template = None + assert ( + transformers_llm.prepare_input([{"role": "user", "content": "Hello"}]) + == "Hello" + ) + def test_generate(self, transformers_llm: TransformersLLM) -> None: responses = transformers_llm.generate( inputs=[ diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py index ab63ee5419..1c99f2f5c4 100644 --- a/tests/unit/steps/argilla/test_preference.py +++ b/tests/unit/steps/argilla/test_preference.py @@ -83,13 +83,23 @@ def test_process(self, mock_dataset) -> None: ) with patch.object(PreferenceToArgilla, "load"): step.load() + step._instruction = "instruction" step._generations = "generations" + step._ratings = "ratings" + step._rationales = "rationales" step._dataset = mock_dataset # type: ignore step._dataset.records.log = lambda x: x # type: ignore assert list( - step.process([{"instruction": "test", "generations": ["test", "test"]}]) + step.process( + [ + { + "instruction": "test", + "generations": ["test", "test"], + } + ] + ) ) == [[{"instruction": "test", "generations": ["test", "test"]}]] assert step._dataset.records # type: ignore diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index d2be053aa5..ecdfd04240 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -101,6 +101,9 @@ class DummyUserTest(BaseModel): } +@pytest.mark.skip( + reason="won't work until we update our code to work with `outlines>0.1.0`" +) class TestOutlinesIntegration: @pytest.mark.parametrize( "format, schema, prompt", From e65894c2f2328d4cb267768332cb838ff7cd6227 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Mon, 23 Dec 2024 17:07:40 +0900 Subject: [PATCH 12/30] chore: update base.py (#1085) --- src/distilabel/llms/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index ced6a8e041..f061af32e0 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -220,7 +220,7 @@ def _offline_batch_generate_polling( f" for {self.offline_batch_generation_block_until_done} seconds before" " trying to get the results again." ) - # When running a `Step` in a child process, SIGINT is overriden so the child + # When running a `Step` in a child process, SIGINT is overridden so the child # process doesn't stop when the parent process receives a SIGINT signal. # The new handler sets an environment variable that is checked here to stop # the polling. From 344cce7a89cc60953f6b6dae0659bd2de312d7ed Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Wed, 8 Jan 2025 17:19:40 +0100 Subject: [PATCH 13/30] Add magpie support llama cpp ollama (#1086) Co-authored-by: burtenshaw --- src/distilabel/models/llms/anthropic.py | 2 +- src/distilabel/models/llms/cohere.py | 2 +- src/distilabel/models/llms/groq.py | 2 +- src/distilabel/models/llms/llamacpp.py | 162 ++++++++++++++++-- src/distilabel/models/llms/mistral.py | 2 +- src/distilabel/models/llms/ollama.py | 131 ++++++++++++-- src/distilabel/models/llms/openai.py | 3 +- src/distilabel/models/llms/vertexai.py | 2 +- .../steps/tasks/evol_instruct/base.py | 2 +- .../steps/tasks/evol_instruct/generator.py | 2 +- tests/unit/models/llms/test_llamacpp.py | 27 ++- tests/unit/models/llms/test_ollama.py | 14 ++ 12 files changed, 307 insertions(+), 44 deletions(-) diff --git a/src/distilabel/models/llms/anthropic.py b/src/distilabel/models/llms/anthropic.py index c6c79a9141..0eefc092dc 100644 --- a/src/distilabel/models/llms/anthropic.py +++ b/src/distilabel/models/llms/anthropic.py @@ -42,7 +42,7 @@ from anthropic import AsyncAnthropic from anthropic.types import Message - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" diff --git a/src/distilabel/models/llms/cohere.py b/src/distilabel/models/llms/cohere.py index 043ac4214c..8b081a762e 100644 --- a/src/distilabel/models/llms/cohere.py +++ b/src/distilabel/models/llms/cohere.py @@ -40,7 +40,7 @@ from pydantic import BaseModel from tokenizers import Tokenizer - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" diff --git a/src/distilabel/models/llms/groq.py b/src/distilabel/models/llms/groq.py index 2977c513f3..8000211936 100644 --- a/src/distilabel/models/llms/groq.py +++ b/src/distilabel/models/llms/groq.py @@ -30,7 +30,7 @@ from groq import AsyncGroq from groq.types.chat.chat_completion import ChatCompletion - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics _GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL" diff --git a/src/distilabel/models/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py index 77e2707c1c..822e5cea77 100644 --- a/src/distilabel/models/llms/llamacpp.py +++ b/src/distilabel/models/llms/llamacpp.py @@ -14,19 +14,22 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from pydantic import Field, FilePath, PrivateAttr, validate_call +from pydantic import Field, FilePath, PrivateAttr, model_validator, validate_call from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output +from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType if TYPE_CHECKING: from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList + from distilabel.steps.tasks.typing import FormattedInput, StandardInput -class LlamaCppLLM(LLM): + +class LlamaCppLLM(LLM, MagpieChatTemplateMixin): """llama.cpp LLM implementation running the Python bindings for the C++ code. Attributes: @@ -44,6 +47,15 @@ class LlamaCppLLM(LLM): fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. extra_kwargs: additional dictionary of keyword arguments that will be passed to the `Llama` class of `llama_cpp` library. Defaults to `{}`. + tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing + the tokenizer config files. If not provided, the one associated to the `model` + will be used. Defaults to `None`. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. _model: the Llama model instance. This attribute is meant to be used internally and should not be accessed directly. It will be set in the `load` method. @@ -140,10 +152,27 @@ class User(BaseModel): default=None, description="The structured output format to use across all the generations.", ) - + tokenizer_id: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The Hugging Face Hub repo id or a path to a directory containing" + " the tokenizer config files. If not provided, the one associated to the `model`" + " will be used.", + ) _logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None) _model: Optional["Llama"] = PrivateAttr(...) + @model_validator(mode="after") + def validate_magpie_usage( + self, + ) -> "LlamaCppLLM": + """Validates that magpie usage is valid.""" + + if self.use_magpie_template and self.tokenizer_id is None: + raise ValueError( + "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please," + " set a `tokenizer_id` and try again." + ) + def load(self) -> None: """Loads the `Llama` model from the `model_path`.""" try: @@ -154,7 +183,7 @@ def load(self) -> None: ) from ie self._model = Llama( - model_path=self.model_path.as_posix(), # type: ignore + model_path=self.model_path.as_posix(), seed=self.seed, n_ctx=self.n_ctx, n_batch=self.n_batch, @@ -169,6 +198,27 @@ def load(self) -> None: self.structured_output ) + if self.use_magpie_template or self.magpie_pre_query_template: + if not self.tokenizer_id: + raise ValueError( + "The Hugging Face Hub repo id or a path to a directory containing" + " the tokenizer config files is required when using the `use_magpie_template`" + " or `magpie_pre_query_template` runtime parameters." + ) + + if self.tokenizer_id: + try: + from transformers import AutoTokenizer + except ImportError as ie: + raise ImportError( + "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`." + ) from ie + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) + if self._tokenizer.chat_template is None: + raise ValueError( + "The tokenizer does not have a chat template. Please use a tokenizer with a chat template." + ) + # NOTE: Here because of the custom `logging` interface used, since it will create the logging name # out of the model name, which won't be available until the `Llama` instance is created. super().load() @@ -178,6 +228,70 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self._model.model_path # type: ignore + def _generate_chat_completion( + self, + input: FormattedInput, + max_new_tokens: int = 128, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + extra_generation_kwargs: Optional[Dict[str, Any]] = None, + ) -> "CreateChatCompletionResponse": + return self._model.create_chat_completion( # type: ignore + messages=input, # type: ignore + max_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + logits_processor=self._logits_processor, + **(extra_generation_kwargs or {}), + ) + + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + prompt: str = ( + self._tokenizer.apply_chat_template( # type: ignore + conversation=input, # type: ignore + tokenize=False, + add_generation_prompt=True, + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + def _generate_with_text_generation( + self, + input: FormattedInput, + max_new_tokens: int = 128, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + extra_generation_kwargs: Optional[Dict[str, Any]] = None, + ) -> "CreateChatCompletionResponse": + prompt = self.prepare_input(input) + return self._model.create_completion( + prompt=prompt, + max_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + logits_processor=self._logits_processor, + **(extra_generation_kwargs or {}), + ) + @validate_call def generate( # type: ignore self, @@ -230,24 +344,36 @@ def generate( # type: ignore self._logits_processor = self._prepare_structured_output( structured_output ) - chat_completions: "CreateChatCompletionResponse" = ( - self._model.create_chat_completion( # type: ignore - messages=input, # type: ignore - max_tokens=max_new_tokens, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - top_p=top_p, - logits_processor=self._logits_processor, - **(extra_generation_kwargs or {}), + if self.tokenizer_id is None: + completion = self._generate_chat_completion( + input, + max_new_tokens, + frequency_penalty, + presence_penalty, + temperature, + top_p, + extra_generation_kwargs, ) - ) - outputs.append(chat_completions["choices"][0]["message"]["content"]) - output_tokens.append(chat_completions["usage"]["completion_tokens"]) + outputs.append(completion["choices"][0]["message"]["content"]) + output_tokens.append(completion["usage"]["completion_tokens"]) + else: + completion: "CreateChatCompletionResponse" = ( + self._generate_with_text_generation( # type: ignore + input, + max_new_tokens, + frequency_penalty, + presence_penalty, + temperature, + top_p, + extra_generation_kwargs, + ) + ) + outputs.append(completion["choices"][0]["text"]) + output_tokens.append(completion["usage"]["completion_tokens"]) batch_outputs.append( prepare_output( outputs, - input_tokens=[chat_completions["usage"]["prompt_tokens"]] + input_tokens=[completion["usage"]["prompt_tokens"]] * num_generations, output_tokens=output_tokens, ) diff --git a/src/distilabel/models/llms/mistral.py b/src/distilabel/models/llms/mistral.py index 873565091b..9fe9f357da 100644 --- a/src/distilabel/models/llms/mistral.py +++ b/src/distilabel/models/llms/mistral.py @@ -30,7 +30,7 @@ from mistralai import Mistral from mistralai.models.chatcompletionresponse import ChatCompletionResponse - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics _MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY" diff --git a/src/distilabel/models/llms/ollama.py b/src/distilabel/models/llms/ollama.py index f704627487..ff0779d881 100644 --- a/src/distilabel/models/llms/ollama.py +++ b/src/distilabel/models/llms/ollama.py @@ -14,19 +14,22 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Union -from pydantic import Field, PrivateAttr, validate_call +from pydantic import Field, PrivateAttr, model_validator, validate_call from typing_extensions import TypedDict from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output +from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput if TYPE_CHECKING: from ollama import AsyncClient + from ollama._types import ChatResponse, GenerateResponse - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics + from distilabel.steps.tasks.typing import StandardInput # Copied from `ollama._types.Options` @@ -69,13 +72,25 @@ class Options(TypedDict, total=False): stop: Sequence[str] -class OllamaLLM(AsyncLLM): +class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin): """Ollama LLM implementation running the Async API client. Attributes: model: the model name to use for the LLM e.g. "notus". host: the Ollama server host. timeout: the timeout for the LLM. Defaults to `120`. + follow_redirects: whether to follow redirects. Defaults to `True`. + structured_output: a dictionary containing the structured output configuration or if more + fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing + the tokenizer config files. If not provided, the one associated to the `model` + will be used. Defaults to `None`. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. _aclient: the `AsyncClient` to use for the Ollama API. It is meant to be used internally. Set in the `load` method. @@ -112,10 +127,26 @@ class OllamaLLM(AsyncLLM): description="The structured output format to use across all the generations.", ) ) - + tokenizer_id: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The Hugging Face Hub repo id or a path to a directory containing" + " the tokenizer config files. If not provided, the one associated to the `model`" + " will be used.", + ) _num_generations_param_supported = False + _aclient: Optional["AsyncClient"] = PrivateAttr(...) # type: ignore - _aclient: Optional["AsyncClient"] = PrivateAttr(...) + @model_validator(mode="after") # type: ignore + def validate_magpie_usage( + self, + ) -> "OllamaLLM": + """Validates that magpie usage is valid.""" + + if self.use_magpie_template and self.tokenizer_id is None: + raise ValueError( + "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please," + " set a `tokenizer_id` and try again." + ) def load(self) -> None: """Loads the `AsyncClient` to use Ollama async API.""" @@ -135,13 +166,80 @@ def load(self) -> None: " `pip install ollama`." ) from e + if self.tokenizer_id: + try: + from transformers import AutoTokenizer + except ImportError as ie: + raise ImportError( + "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`." + ) from ie + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) + if self._tokenizer.chat_template is None: + raise ValueError( + "The tokenizer does not have a chat template. Please use a tokenizer with a chat template." + ) + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model + async def _generate_chat_completion( + self, + input: "StandardInput", + format: Literal["", "json"] = "", + options: Union[Options, None] = None, + keep_alive: Union[bool, None] = None, + ) -> "ChatResponse": + return await self._aclient.chat( + model=self.model, + messages=input, + stream=False, + format=format, + options=options, + keep_alive=keep_alive, + ) + + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + prompt: str = ( + self._tokenizer.apply_chat_template( + conversation=input, + tokenize=False, + add_generation_prompt=True, + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + async def _generate_with_text_generation( + self, + input: "StandardInput", + format: Literal["", "json"] = None, + options: Union[Options, None] = None, + keep_alive: Union[bool, None] = None, + ) -> "GenerateResponse": + input = self.prepare_input(input) + return await self._aclient.generate( + model=self.model, + prompt=input, + format=format, + options=options, + keep_alive=keep_alive, + raw=True, + ) + @validate_call - async def agenerate( # type: ignore + async def agenerate( self, input: StandardInput, format: Literal["", "json"] = "", @@ -163,15 +261,18 @@ async def agenerate( # type: ignore """ text = None try: - completion: Dict[str, Any] = await self._aclient.chat( # type: ignore - model=self.model, - messages=input, # type: ignore - stream=False, - format=format, - options=options, - keep_alive=keep_alive, - ) - text = completion["message"]["content"] + if not format: + format = None + if self.tokenizer_id is None: + completion = await self._generate_chat_completion( + input, format, options, keep_alive + ) + text = completion["message"]["content"] + else: + completion = await self._generate_with_text_generation( + input, format, options, keep_alive + ) + text = completion.response except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Ollama client (model: '{self.model_name}')." diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index c53122fa63..37bb5bb6be 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -35,8 +35,7 @@ from openai.types.chat.chat_completion import Choice as OpenAIChoice from openai.types.completion import Completion as OpenAICompletion - from distilabel.llms.typing import LLMStatistics - from distilabel.models.llms.typing import Logprob + from distilabel.models.llms.typing import LLMStatistics, Logprob _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" diff --git a/src/distilabel/models/llms/vertexai.py b/src/distilabel/models/llms/vertexai.py index 8f5dc28bbd..62235dd321 100644 --- a/src/distilabel/models/llms/vertexai.py +++ b/src/distilabel/models/llms/vertexai.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from vertexai.generative_models import Content, GenerationResponse, GenerativeModel - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics class VertexChatItem(TypedDict): diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py index 3f2ba5da4f..f1a44d6a84 100644 --- a/src/distilabel/steps/tasks/evol_instruct/base.py +++ b/src/distilabel/steps/tasks/evol_instruct/base.py @@ -27,7 +27,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics from distilabel.steps.typing import StepOutput diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index fa568d392e..6f985464eb 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -33,7 +33,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: - from distilabel.llms.typing import LLMStatistics + from distilabel.models.llms.typing import LLMStatistics from distilabel.steps.tasks.typing import ChatType from distilabel.steps.typing import GeneratorStepOutput diff --git a/tests/unit/models/llms/test_llamacpp.py b/tests/unit/models/llms/test_llamacpp.py index 94bf008f19..f5b9f51cec 100644 --- a/tests/unit/models/llms/test_llamacpp.py +++ b/tests/unit/models/llms/test_llamacpp.py @@ -23,14 +23,18 @@ from .utils import DummyUserDetail -@pytest.fixture(scope="module") -def llm() -> Generator[LlamaCppLLM, None, None]: +def download_tinyllama() -> None: if not os.path.exists("tinyllama.gguf"): urllib.request.urlretrieve( "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q2_K.gguf", "tinyllama.gguf", ) + +@pytest.fixture(scope="module") +def llm() -> Generator[LlamaCppLLM, None, None]: + download_tinyllama() + llm = LlamaCppLLM(model_path="tinyllama.gguf", n_gpu_layers=0) # type: ignore llm.load() @@ -38,6 +42,19 @@ def llm() -> Generator[LlamaCppLLM, None, None]: class TestLlamaCppLLM: + def test_no_tokenizer_magpie_raise_value_error(self) -> None: + download_tinyllama() + + with pytest.raises( + ValueError, + match="`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`", + ): + LlamaCppLLM( + model_path="tinyllama.gguf", + use_magpie_template=True, + magpie_pre_query_template="llama3", + ) + def test_model_name(self, llm: LlamaCppLLM) -> None: assert llm.model_name == "tinyllama.gguf" @@ -83,6 +100,9 @@ def test_generate(self, llm: LlamaCppLLM) -> None: "name": "LlamaCppLLM", }, "verbose": False, + "magpie_pre_query_template": None, + "tokenizer_id": None, + "use_magpie_template": False, }, ), ( @@ -110,6 +130,9 @@ def test_generate(self, llm: LlamaCppLLM) -> None: "name": "LlamaCppLLM", }, "verbose": False, + "magpie_pre_query_template": None, + "tokenizer_id": None, + "use_magpie_template": False, }, ), ], diff --git a/tests/unit/models/llms/test_ollama.py b/tests/unit/models/llms/test_ollama.py index 167ec6a1dc..3d80846370 100644 --- a/tests/unit/models/llms/test_ollama.py +++ b/tests/unit/models/llms/test_ollama.py @@ -22,6 +22,17 @@ @patch("ollama.AsyncClient") class TestOllamaLLM: + def test_no_tokenizer_magpie_raise_value_error(self, _: MagicMock) -> None: + with pytest.raises( + ValueError, + match="`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`", + ): + OllamaLLM( + model="llama3.1", + use_magpie_template=True, + magpie_pre_query_template="llama3", + ) + def test_ollama_llm(self, _: MagicMock) -> None: llm = OllamaLLM(model="notus") # type: ignore assert isinstance(llm, OllamaLLM) @@ -97,6 +108,9 @@ def test_serialization(self, _: MagicMock) -> None: "generation_kwargs": {}, "structured_output": None, "jobs_ids": None, + "magpie_pre_query_template": None, + "tokenizer_id": None, + "use_magpie_template": False, "offline_batch_generation_block_until_done": None, "use_offline_batch_generation": False, "type_info": { From 99c24485263fdbab8ec21bfa6874c1320ec6f2e1 Mon Sep 17 00:00:00 2001 From: bikash119 Date: Thu, 9 Jan 2025 14:11:00 +0530 Subject: [PATCH 14/30] Feat/954 llama cpp (#1000) Co-authored-by: David Berenstein --- .gitignore | 1 - src/distilabel/models/embeddings/__init__.py | 2 + src/distilabel/models/embeddings/llamacpp.py | 237 ++++++++++++++++++ tests/unit/conftest.py | 35 +++ tests/unit/models/embeddings/test_llamacpp.py | 185 ++++++++++++++ 5 files changed, 459 insertions(+), 1 deletion(-) create mode 100644 src/distilabel/models/embeddings/llamacpp.py create mode 100644 tests/unit/models/embeddings/test_llamacpp.py diff --git a/.gitignore b/.gitignore index 42967a7edb..d8337200af 100644 --- a/.gitignore +++ b/.gitignore @@ -77,4 +77,3 @@ venv.bak/ # Other *.log *.swp -.DS_Store \ No newline at end of file diff --git a/src/distilabel/models/embeddings/__init__.py b/src/distilabel/models/embeddings/__init__.py index 9177298748..573ba72266 100644 --- a/src/distilabel/models/embeddings/__init__.py +++ b/src/distilabel/models/embeddings/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from distilabel.models.embeddings.base import Embeddings +from distilabel.models.embeddings.llamacpp import LlamaCppEmbeddings from distilabel.models.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) @@ -22,4 +23,5 @@ "Embeddings", "SentenceTransformerEmbeddings", "vLLMEmbeddings", + "LlamaCppEmbeddings", ] diff --git a/src/distilabel/models/embeddings/llamacpp.py b/src/distilabel/models/embeddings/llamacpp.py new file mode 100644 index 0000000000..6596bb45ea --- /dev/null +++ b/src/distilabel/models/embeddings/llamacpp.py @@ -0,0 +1,237 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from pydantic import Field, PrivateAttr + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.embeddings.base import Embeddings +from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin + +if TYPE_CHECKING: + from llama_cpp import Llama + + +class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin): + """`LlamaCpp` library implementation for embedding generation. + + Attributes: + model_name: contains the name of the GGUF quantized model, compatible with the + installed version of the `llama.cpp` Python bindings. + model_path: contains the path to the GGUF quantized model, compatible with the + installed version of the `llama.cpp` Python bindings. + repo_id: the Hugging Face Hub repository id. + verbose: whether to print verbose output. Defaults to `False`. + n_gpu_layers: number of layers to run on the GPU. Defaults to `-1` (use the GPU if available). + disable_cuda_device_placement: whether to disable CUDA device placement. Defaults to `True`. + normalize_embeddings: whether to normalize the embeddings. Defaults to `False`. + seed: RNG seed, -1 for random + n_ctx: Text context, 0 = from model + n_batch: Prompt processing maximum batch size + extra_kwargs: additional dictionary of keyword arguments that will be passed to the + `Llama` class of `llama_cpp` library. Defaults to `{}`. + + Runtime parameters: + - `n_gpu_layers`: the number of layers to use for the GPU. Defaults to `-1`. + - `verbose`: whether to print verbose output. Defaults to `False`. + - `normalize_embeddings`: whether to normalize the embeddings. Defaults to `False`. + - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to the + `Llama` class of `llama_cpp` library. Defaults to `{}`. + + References: + - [Offline inference embeddings](https://llama-cpp-python.readthedocs.io/en/stable/#embeddings) + + Examples: + Generate sentence embeddings using a local model: + + ```python + from pathlib import Path + from distilabel.models.embeddings import LlamaCppEmbeddings + + # You can follow along this example downloading the following model running the following + # command in the terminal, that will download the model to the `Downloads` folder: + # curl -L -o ~/Downloads/all-MiniLM-L6-v2-Q2_K.gguf https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-Q2_K.gguf + + model_path = "Downloads/" + model = "all-MiniLM-L6-v2-Q2_K.gguf" + embeddings = LlamaCppEmbeddings( + model=model, + model_path=str(Path.home() / model_path), + ) + + embeddings.load() + + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + print(results) + embeddings.unload() + ``` + + Generate sentence embeddings using a HuggingFace Hub model: + + ```python + from distilabel.models.embeddings import LlamaCppEmbeddings + # You need to set environment variable to download private model to the local machine + + repo_id = "second-state/All-MiniLM-L6-v2-Embedding-GGUF" + model = "all-MiniLM-L6-v2-Q2_K.gguf" + embeddings = LlamaCppEmbeddings(model=model,repo_id=repo_id) + + embeddings.load() + + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + print(results) + embeddings.unload() + # [ + # [-0.05447685346007347, -0.01623094454407692, ...], + # [4.4889533455716446e-05, 0.044016145169734955, ...], + # ] + ``` + + Generate sentence embeddings with cpu: + + ```python + from pathlib import Path + from distilabel.models.embeddings import LlamaCppEmbeddings + + # You can follow along this example downloading the following model running the following + # command in the terminal, that will download the model to the `Downloads` folder: + # curl -L -o ~/Downloads/all-MiniLM-L6-v2-Q2_K.gguf https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-Q2_K.gguf + + model_path = "Downloads/" + model = "all-MiniLM-L6-v2-Q2_K.gguf" + embeddings = LlamaCppEmbeddings( + model=model, + model_path=str(Path.home() / model_path), + n_gpu_layers=0, + disable_cuda_device_placement=True, + ) + + embeddings.load() + + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + print(results) + embeddings.unload() + # [ + # [-0.05447685346007347, -0.01623094454407692, ...], + # [4.4889533455716446e-05, 0.044016145169734955, ...], + # ] + ``` + + + """ + + model: str = Field( + description="The name of the model to use for embeddings.", + ) + + model_path: RuntimeParameter[str] = Field( + default=None, + description="The path to the GGUF quantized model, compatible with the installed version of the `llama.cpp` Python bindings.", + ) + + repo_id: RuntimeParameter[str] = Field( + default=None, description="The Hugging Face Hub repository id.", exclude=True + ) + + n_gpu_layers: RuntimeParameter[int] = Field( + default=-1, + description="The number of layers that will be loaded in the GPU.", + ) + + n_ctx: int = 512 + n_batch: int = 512 + seed: int = 4294967295 + + normalize_embeddings: RuntimeParameter[bool] = Field( + default=False, + description="Whether to normalize the embeddings.", + ) + verbose: RuntimeParameter[bool] = Field( + default=False, + description="Whether to print verbose output from llama.cpp library.", + ) + extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field( + default_factory=dict, + description="Additional dictionary of keyword arguments that will be passed to the" + " `Llama` class of `llama_cpp` library. See all the supported arguments at: " + "https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__", + ) + _model: Optional["Llama"] = PrivateAttr(...) + + def load(self) -> None: + """Loads the `gguf` model using either the path or the Hugging Face Hub repository id.""" + super().load() + CudaDevicePlacementMixin.load(self) + + try: + from llama_cpp import Llama + except ImportError as ie: + raise ImportError( + "`llama-cpp-python` package is not installed. Please install it using" + " `pip install llama-cpp-python`." + ) from ie + + if self.repo_id is not None: + # use repo_id to download the model + from huggingface_hub.utils import validate_repo_id + + validate_repo_id(self.repo_id) + self._model = Llama.from_pretrained( + repo_id=self.repo_id, + filename=self.model, + n_gpu_layers=self.n_gpu_layers, + seed=self.seed, + n_ctx=self.n_ctx, + n_batch=self.n_batch, + verbose=self.verbose, + embedding=True, + kwargs=self.extra_kwargs, + ) + elif self.model_path is not None: + self._model = Llama( + model_path=str(Path(self.model_path) / self.model), + n_gpu_layers=self.n_gpu_layers, + seed=self.seed, + n_ctx=self.n_ctx, + n_batch=self.n_batch, + verbose=self.verbose, + embedding=True, + kwargs=self.extra_kwargs, + ) + else: + raise ValueError("Either 'model_path' or 'repo_id' must be provided") + + def unload(self) -> None: + """Unloads the `gguf` model.""" + CudaDevicePlacementMixin.unload(self) + self._model.close() + super().unload() + + @property + def model_name(self) -> str: + """Returns the name of the model.""" + return self.model + + def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]: + """Generates embeddings for the provided inputs. + + Args: + inputs: a list of texts for which an embedding has to be generated. + + Returns: + The generated embeddings. + """ + return self._model.embed(inputs, normalize=self.normalize_embeddings) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9aa4ea3361..32f70133a2 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit +import os from typing import TYPE_CHECKING, Any, Dict, List, Union +from urllib.request import urlretrieve import pytest from pydantic import PrivateAttr @@ -126,3 +129,35 @@ class DummyTaskOfflineBatchGeneration(DummyTask): @pytest.fixture def dummy_llm() -> AsyncLLM: return DummyAsyncLLM() + + +@pytest.fixture(scope="session") +def local_llamacpp_model_path(tmp_path_factory): + """ + Session-scoped fixture that provides the local model path for LlamaCpp testing. + + Download a small test model to a temporary directory. + The model is downloaded once per test session and cleaned up after all tests. + + Args: + tmp_path_factory: Pytest fixture providing a temporary directory factory. + + Returns: + str: The path to the local LlamaCpp model file. + """ + model_name = "all-MiniLM-L6-v2-Q2_K.gguf" + model_url = f"https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/{model_name}" + tmp_path = tmp_path_factory.getbasetemp() + model_path = tmp_path / model_name + + if not model_path.exists(): + urlretrieve(model_url, model_path) + + def cleanup(): + if model_path.exists(): + os.remove(model_path) + + # Register the cleanup function to be called at exit + atexit.register(cleanup) + + return str(tmp_path) diff --git a/tests/unit/models/embeddings/test_llamacpp.py b/tests/unit/models/embeddings/test_llamacpp.py new file mode 100644 index 0000000000..b219ac7798 --- /dev/null +++ b/tests/unit/models/embeddings/test_llamacpp.py @@ -0,0 +1,185 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pytest + +from distilabel.models.embeddings import LlamaCppEmbeddings + + +class TestLlamaCppEmbeddings: + @pytest.fixture(autouse=True) + def setup_embeddings(self, local_llamacpp_model_path): + """ + Fixture to set up embeddings for each test, considering CPU usage. + """ + self.model_name = "all-MiniLM-L6-v2-Q2_K.gguf" + self.repo_id = "second-state/All-MiniLM-L6-v2-Embedding-GGUF" + self.disable_cuda_device_placement = True + self.n_gpu_layers = 0 + self.embeddings = LlamaCppEmbeddings( + model=self.model_name, + model_path=local_llamacpp_model_path, + n_gpu_layers=self.n_gpu_layers, + disable_cuda_device_placement=self.disable_cuda_device_placement, + ) + + self.embeddings.load() + + @pytest.fixture + def test_inputs(self): + """ + Fixture that provides a list of test input strings. + + Returns: + list: A list of strings to be used as test inputs for embeddings. + """ + return [ + "Hello, how are you?", + "What a nice day!", + "I hear that llamas are very popular now.", + ] + + def test_model_name(self) -> None: + """ + Test if the model name is correctly set. + """ + assert self.embeddings.model_name == self.model_name + + def test_encode(self, test_inputs) -> None: + """ + Test if the model can generate embeddings. + """ + results = self.embeddings.encode(inputs=test_inputs) + + for result in results: + assert len(result) == 384 + + def test_load_model_from_local(self, test_inputs): + """ + Test if the model can be loaded from a local file and generate embeddings. + + Args: + local_llamacpp_model_path (str): Fixture providing the local model path. + """ + + results = self.embeddings.encode(inputs=test_inputs) + + for result in results: + assert len(result) == 384 + + def test_load_model_from_repo(self, test_inputs): + """ + Test if the model can be loaded from a Hugging Face repository. + """ + embeddings = LlamaCppEmbeddings( + repo_id=self.repo_id, + model=self.model_name, + normalize_embeddings=True, + n_gpu_layers=self.n_gpu_layers, + disable_cuda_device_placement=self.disable_cuda_device_placement, + ) + embeddings.load() + results = embeddings.encode(inputs=test_inputs) + + for result in results: + assert len(result) == 384 + + def test_normalize_embeddings(self, test_inputs): + """ + Test if embeddings are normalized when normalize_embeddings is True. + """ + + embeddings = LlamaCppEmbeddings( + repo_id=self.repo_id, + model=self.model_name, + normalize_embeddings=True, + n_gpu_layers=self.n_gpu_layers, + disable_cuda_device_placement=self.disable_cuda_device_placement, + ) + embeddings.load() + results = embeddings.encode(inputs=test_inputs) + + for result in results: + # Check if the embedding is normalized (L2 norm should be close to 1) + norm = np.linalg.norm(result) + assert np.isclose( + norm, 1.0, atol=1e-6 + ), f"Norm is {norm}, expected close to 1.0" + + def test_normalize_embeddings_false(self, test_inputs): + """ + Test if embeddings are not normalized when normalize_embeddings is False. + """ + + results = self.embeddings.encode(inputs=test_inputs) + + for result in results: + # Check if the embedding is not normalized (L2 norm should not be close to 1) + norm = np.linalg.norm(result) + assert not np.isclose( + norm, 1.0, atol=1e-6 + ), f"Norm is {norm}, expected not close to 1.0" + + # Additional check: ensure that at least one embedding has a norm significantly different from 1 + norms = [np.linalg.norm(result) for result in results] + assert any( + not np.isclose(norm, 1.0, atol=0.1) for norm in norms + ), "Expected at least one embedding with norm not close to 1.0" + + def test_encode_batch(self) -> None: + """ + Test if the model can generate embeddings for batches of inputs. + """ + # Test with different batch sizes + batch_sizes = [1, 2, 5, 10] + for batch_size in batch_sizes: + inputs = [f"This is test sentence {i}" for i in range(batch_size)] + results = self.embeddings.encode(inputs=inputs) + + assert ( + len(results) == batch_size + ), f"Expected {batch_size} results, got {len(results)}" + for result in results: + assert ( + len(result) == 384 + ), f"Expected embedding dimension 384, got {len(result)}" + + # Test with a large batch to ensure it doesn't cause issues + large_batch = ["Large batch test" for _ in range(100)] + large_results = self.embeddings.encode(inputs=large_batch) + assert ( + len(large_results) == 100 + ), f"Expected 100 results for large batch, got {len(large_results)}" + + def test_encode_batch_consistency(self) -> None: + """ + Test if the model produces consistent embeddings for the same input in different batch sizes. + + Args: + local_llamacpp_model_path (str): Fixture providing the local model path. + """ + input_text = "This is a test sentence for consistency" + + # Generate embedding individually + single_result = self.embeddings.encode([input_text])[0] + + # Generate embedding as part of a batch + batch_result = self.embeddings.encode([input_text, "Another sentence"])[0] + + # Compare the embeddings + assert np.allclose( + single_result, batch_result, atol=1e-5 + ), "Embeddings are not consistent between single and batch processing" From 8ad48387dfa4d7bd5639065661f1975dcb44c16a Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 10 Jan 2025 08:44:29 +0100 Subject: [PATCH 15/30] fix import by replacing GeneratorOutput with GeneratorStepOutput (#1093) --- docs/sections/how_to_guides/basic/task/generator_task.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sections/how_to_guides/basic/task/generator_task.md b/docs/sections/how_to_guides/basic/task/generator_task.md index 613d8deb17..2bd84c1aa8 100644 --- a/docs/sections/how_to_guides/basic/task/generator_task.md +++ b/docs/sections/how_to_guides/basic/task/generator_task.md @@ -13,14 +13,13 @@ from typing_extensions import override from distilabel.steps.tasks.base import GeneratorTask from distilabel.steps.tasks.typing import ChatType -from distilabel.steps.typing import GeneratorOutput - +from distilabel.steps.typing import GeneratorStepOutput class MyCustomTask(GeneratorTask): instruction: str @override - def process(self, offset: int = 0) -> GeneratorOutput: + def process(self, offset: int = 0) -> GeneratorStepOutput: output = self.llm.generate( inputs=[ [ @@ -79,11 +78,12 @@ from typing import Any, Dict, List, Union from distilabel.steps.tasks.base import GeneratorTask from distilabel.steps.tasks.typing import ChatType +from distilabel.steps.typing import GeneratorStepOutput class MyCustomTask(GeneratorTask): @override - def process(self, offset: int = 0) -> GeneratorOutput: + def process(self, offset: int = 0) -> GeneratorStepOutput: output = self.llm.generate( inputs=[ [{"role": "user", "content": "Tell me a joke."}], From 2c893c15131be49215426d14faabdf5c9b29f47b Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 10 Jan 2025 12:10:44 +0100 Subject: [PATCH 16/30] add mlx support (#1089) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 1 + docs/sections/getting_started/installation.md | 2 + pyproject.toml | 1 + src/distilabel/models/__init__.py | 4 + src/distilabel/models/embeddings/__init__.py | 2 +- src/distilabel/models/llms/__init__.py | 2 + src/distilabel/models/llms/mlx.py | 288 ++++++++++++++++++ src/distilabel/models/llms/utils.py | 1 + tests/unit/models/llms/test_mlx.py | 124 ++++++++ 9 files changed, 424 insertions(+), 1 deletion(-) create mode 100644 src/distilabel/models/llms/mlx.py create mode 100644 tests/unit/models/llms/test_mlx.py diff --git a/README.md b/README.md index 7a7dfc8d3d..d10e5d924a 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ In addition, the following extras are available: - `vertexai`: for using [Google Vertex AI](https://cloud.google.com/vertex-ai) proprietary models via the `VertexAILLM` integration. - `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration. - `sentence-transformers`: for generating sentence embeddings using [sentence-transformers](https://github.com/UKPLab/sentence-transformers). +- `mlx`: for using [MLX](https://github.com/ml-explore/mlx) models via the `MlxLLM` integration. ### Structured generation diff --git a/docs/sections/getting_started/installation.md b/docs/sections/getting_started/installation.md index c11392e3f1..14e08e1ff9 100644 --- a/docs/sections/getting_started/installation.md +++ b/docs/sections/getting_started/installation.md @@ -57,6 +57,8 @@ Additionally, as part of `distilabel` some extra dependencies are available, mai - `sentence-transformers`: for generating sentence embeddings using [sentence-transformers](https://github.com/UKPLab/sentence-transformers). +- `mlx`: for using [MLX](https://github.com/ml-explore/mlx) models via the `MlxLLM` integration. + ### Data processing - `ray`: for scaling and distributing a pipeline with [Ray](https://github.com/ray-project/ray). diff --git a/pyproject.toml b/pyproject.toml index b203f7edf5..3123d56b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ text-clustering = [ "scikit-learn >= 1.4.1", "matplotlib >= 3.8.3", # For the figure (even though it's optional) ] +mlx = ["mlx >= 0.21.0", "mlx-lm"] # minhash minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"] diff --git a/src/distilabel/models/__init__.py b/src/distilabel/models/__init__.py index b84a2c4690..86ea2023e4 100644 --- a/src/distilabel/models/__init__.py +++ b/src/distilabel/models/__init__.py @@ -14,6 +14,7 @@ from distilabel.models.embeddings.base import Embeddings +from distilabel.models.embeddings.llamacpp import LlamaCppEmbeddings from distilabel.models.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) @@ -28,6 +29,7 @@ from distilabel.models.llms.litellm import LiteLLM from distilabel.models.llms.llamacpp import LlamaCppLLM from distilabel.models.llms.mistral import MistralLLM +from distilabel.models.llms.mlx import MlxLLM from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM @@ -52,9 +54,11 @@ "HiddenState", "InferenceEndpointsLLM", "LiteLLM", + "LlamaCppEmbeddings", "LlamaCppLLM", "MistralLLM", "MixtureOfAgentsLLM", + "MlxLLM", "OllamaLLM", "OpenAILLM", "SentenceTransformerEmbeddings", diff --git a/src/distilabel/models/embeddings/__init__.py b/src/distilabel/models/embeddings/__init__.py index 573ba72266..65eb00c469 100644 --- a/src/distilabel/models/embeddings/__init__.py +++ b/src/distilabel/models/embeddings/__init__.py @@ -21,7 +21,7 @@ __all__ = [ "Embeddings", + "LlamaCppEmbeddings", "SentenceTransformerEmbeddings", "vLLMEmbeddings", - "LlamaCppEmbeddings", ] diff --git a/src/distilabel/models/llms/__init__.py b/src/distilabel/models/llms/__init__.py index cca70d64c3..0b0f3a7a9c 100644 --- a/src/distilabel/models/llms/__init__.py +++ b/src/distilabel/models/llms/__init__.py @@ -22,6 +22,7 @@ from distilabel.models.llms.litellm import LiteLLM from distilabel.models.llms.llamacpp import LlamaCppLLM from distilabel.models.llms.mistral import MistralLLM +from distilabel.models.llms.mlx import MlxLLM from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM @@ -48,6 +49,7 @@ "LlamaCppLLM", "MistralLLM", "MixtureOfAgentsLLM", + "MlxLLM", "OllamaLLM", "OpenAILLM", "TogetherLLM", diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py new file mode 100644 index 0000000000..4ffcceddab --- /dev/null +++ b/src/distilabel/models/llms/mlx.py @@ -0,0 +1,288 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Union, +) + +from pydantic import ( + Field, + PrivateAttr, + validate_call, +) + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.llms.base import LLM +from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.llms.utils import compute_tokens, prepare_output +from distilabel.models.mixins.magpie import MagpieChatTemplateMixin +from distilabel.steps.tasks.typing import ( + OutlinesStructuredOutputType, + StandardInput, +) + +if TYPE_CHECKING: + import mlx.nn as nn + from mlx_lm.tokenizer_utils import TokenizerWrapper + + +class MlxLLM(LLM, MagpieChatTemplateMixin): + """Apple MLX LLM implementation. + + Attributes: + path_or_hf_repo: the path to the model or the Hugging Face Hub repo id. + tokenizer_config: the tokenizer configuration. + model_config: the model configuration. + adapter_path: the path to the adapter. + structured_output: a dictionary containing the structured output configuration or if more + fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. + + Icon: + `:apple:` + + Examples: + Generate text: + + ```python + from distilabel.models.llms import MlxLLM + + llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit") + + llm.load() + + # Call the model + output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` + """ + + path_or_hf_repo: str + tokenizer_config: Dict[str, Any] = {} + model_config: Dict[str, Any] = {} + adapter_path: Optional[str] = None + structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) + + _mlx_generate: Optional[Callable] = PrivateAttr(default=None) + _model: Optional["nn.Module"] = PrivateAttr(...) + _tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...) + _structured_output_logits_processor: Union[Callable, None] = PrivateAttr( + default=None + ) + + def load(self) -> None: + """Loads the model and tokenizer and creates the text generation pipeline. In addition, + it will configure the tokenizer chat template.""" + try: + import mlx # noqa + from mlx_lm import generate, load + except ImportError as ie: + raise ImportError( + "MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`." + ) from ie + + self._model, self._tokenizer = load( + self.path_or_hf_repo, + tokenizer_config=self.tokenizer_config, + model_config=self.model_config, + adapter_path=self.adapter_path, + ) + + if self.structured_output: + self._structured_output_logits_processor = self._prepare_structured_output( + self.structured_output + ) + + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + + self._mlx_generate = generate + + super().load() + + @property + def model_name(self) -> str: + """Returns the model name used for the LLM.""" + return self.path_or_hf_repo + + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + if self._tokenizer.chat_template is None: + return input[0]["content"] + + prompt: str = ( + self._tokenizer.apply_chat_template( + input, + tokenize=False, + add_generation_prompt=True, + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + @validate_call + def generate( + self, + inputs: List[StandardInput], + num_generations: int = 1, + max_tokens: int = 256, + sampler: Optional[Callable] = None, + logits_processors: Optional[List[Callable]] = None, + max_kv_size: Optional[int] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, + prompt_progress_callback: Optional[Callable[[int, int], None]] = None, + temp: Optional[float] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: Optional[float] = None, + min_p: Optional[float] = None, + min_tokens_to_keep: Optional[int] = None, + ) -> List[GenerateOutput]: + """Generates `num_generations` responses for each input using the text generation + pipeline. + + Args: + inputs: the inputs to generate responses for. + num_generations: the number of generations to create per input. Defaults to + `1`. + max_tokens: the maximum number of new tokens that the model will generate. + Defaults to `128`. + sampler: the sampler to use for the generation. Defaults to `None`. + logits_processors: the logits processors to use for the generation. Defaults to + `None`. + max_kv_size: the maximum size of the key-value cache. Defaults to `None`. + prompt_cache: the prompt cache to use for the generation. Defaults to `None`. + prefill_step_size: the prefill step size. Defaults to `512`. + kv_bits: the number of bits to use for the key-value cache. Defaults to `None`. + kv_group_size: the group size for the key-value cache. Defaults to `64`. + quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`. + prompt_progress_callback: the callback to use for the generation. Defaults to + `None`. + temp: the temperature to use for the generation. Defaults to `None`. + repetition_penalty: the repetition penalty to use for the generation. Defaults to + `None`. + repetition_context_size: the context size for the repetition penalty. Defaults to + `None`. + top_p: the top-p value to use for the generation. Defaults to `None`. + min_p: the minimum p value to use for the generation. Defaults to `None`. + min_tokens_to_keep: the minimum number of tokens to keep. Defaults to `None`. + + Returns: + A list of lists of strings containing the generated responses for each input. + """ + logits_processors = [] + if self._structured_output_logits_processor: + logits_processors.append(self._structured_output_logits_processor) + + structured_output = None + result = [] + for input in inputs: + if isinstance(input, tuple): + input, structured_output = input + + output: List[str] = [] + for _ in range(num_generations): + if structured_output: + additional_logits_processors = self._prepare_structured_output( + structured_output + ) + logits_processors.append(additional_logits_processors) + prompt = self.prepare_input(input) + + generation = self._mlx_generate( + prompt=prompt, + model=self._model, + tokenizer=self._tokenizer, + logits_processors=logits_processors, + max_tokens=max_tokens, + sampler=sampler, + max_kv_size=max_kv_size, + prompt_cache=prompt_cache, + prefill_step_size=prefill_step_size, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + quantized_kv_start=quantized_kv_start, + prompt_progress_callback=prompt_progress_callback, + temp=temp, + repetition_penalty=repetition_penalty, + repetition_context_size=repetition_context_size, + top_p=top_p, + min_p=min_p, + min_tokens_to_keep=min_tokens_to_keep, + ) + + output.append(generation) + + result.append( + prepare_output( + output, + input_tokens=[compute_tokens(input, self._tokenizer.encode)], + output_tokens=[ + compute_tokens( + text_or_messages=generation, + tokenizer=self._tokenizer.encode, + ) + for generation in output + ], + ) + ) + return result + + def _prepare_structured_output( + self, structured_output: Optional[OutlinesStructuredOutputType] = None + ) -> Union[Callable, None]: + """Creates the appropriate function to filter tokens to generate structured outputs. + + Args: + structured_output: the configuration dict to prepare the structured output. + + Returns: + The callable that will be used to guide the generation of the model. + """ + from distilabel.steps.tasks.structured_outputs.outlines import ( + prepare_guided_output, + ) + + result = prepare_guided_output( + structured_output, "transformers", self._pipeline + ) + if schema := result.get("schema"): + self.structured_output["schema"] = schema + return result["processor"] diff --git a/src/distilabel/models/llms/utils.py b/src/distilabel/models/llms/utils.py index 9cf6590c78..afb09cab4d 100644 --- a/src/distilabel/models/llms/utils.py +++ b/src/distilabel/models/llms/utils.py @@ -50,6 +50,7 @@ def prepare_output( generations: The outputs from an LLM. input_tokens: The number of tokens of the inputs. Defaults to `None`. output_tokens: The number of tokens of the LLM response. Defaults to `None`. + logprobs: The logprobs of the LLM response. Defaults to `None`. Returns: Output generation from an LLM. diff --git a/tests/unit/models/llms/test_mlx.py b/tests/unit/models/llms/test_mlx.py new file mode 100644 index 0000000000..16cd44f6bc --- /dev/null +++ b/tests/unit/models/llms/test_mlx.py @@ -0,0 +1,124 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +from typing import Any, Dict, Generator + +import pytest + +from distilabel.models.llms.mlx import MlxLLM + +from .utils import DummyUserDetail + +RUNS_ON_APPLE_SILICON = platform.processor() == "arm" and platform.system() == "Darwin" + + +@pytest.mark.skipif( + not RUNS_ON_APPLE_SILICON, + reason="MLX only runs on Apple Silicon", +) +@pytest.fixture(scope="module") +def llm() -> Generator[MlxLLM, None, None]: + llm = MlxLLM(path_or_hf_repo="mlx-community/Qwen2.5-0.5B-4bit") + llm.load() + yield llm + + +@pytest.mark.skipif( + not RUNS_ON_APPLE_SILICON, + reason="MLX only runs on Apple Silicon", +) +class TestMlxLLM: + def test_model_name(self, llm: MlxLLM) -> None: + assert llm.path_or_hf_repo == "mlx-community/Qwen2.5-0.5B-4bit" + + def test_generate(self, llm: MlxLLM) -> None: + responses = llm.generate( + inputs=[ + [{"role": "user", "content": "Hello, how are you?"}], + [ + { + "role": "user", + "content": "You're GPT2, you're old now but you still serves a purpose which is being used in unit tests.", + } + ], + ], + num_generations=3, + ) + assert len(responses) == 2 + generations = responses[0]["generations"] + statistics = responses[0]["statistics"] + assert len(generations) == 3 + assert "input_tokens" in statistics + assert "output_tokens" in statistics + + @pytest.mark.parametrize( + "structured_output, dump", + [ + ( + None, + { + "path_or_hf_repo": "mlx-community/Qwen2.5-0.5B-4bit", + "generation_kwargs": {}, + "structured_output": None, + "adapter_path": None, + "jobs_ids": None, + "offline_batch_generation_block_until_done": None, + "use_offline_batch_generation": False, + "magpie_pre_query_template": None, + "tokenizer_config": {}, + "use_magpie_template": False, + "type_info": { + "module": "distilabel.models.llms.mlx", + "name": "MlxLLM", + }, + }, + ), + ( + { + "schema": DummyUserDetail.model_json_schema(), + "format": "json", + }, + { + "path_or_hf_repo": "mlx-community/Qwen2.5-0.5B-4bit", + "generation_kwargs": {}, + "magpie_pre_query_template": None, + "tokenizer_config": {}, + "use_magpie_template": False, + "structured_output": { + "schema": DummyUserDetail.model_json_schema(), + "format": "json", + }, + "adapter_path": None, + "jobs_ids": None, + "offline_batch_generation_block_until_done": None, + "use_offline_batch_generation": False, + "type_info": { + "module": "distilabel.models.llms.mlx", + "name": "MlxLLM", + }, + }, + ), + ], + ) + def test_serialization( + self, structured_output: Dict[str, Any], dump: Dict[str, Any] + ) -> None: + llm = MlxLLM( + path_or_hf_repo="mlx-community/Qwen2.5-0.5B-4bit", + structured_output=structured_output, + ) + + assert llm.dump() == dump + assert isinstance(MlxLLM.from_dict(dump), MlxLLM) From 680dd0946934a053991a7197b22e30caaf50cfa1 Mon Sep 17 00:00:00 2001 From: Zhangda Xu Date: Fri, 10 Jan 2025 22:13:16 +0800 Subject: [PATCH 17/30] Support custom default headers in `OpenAILLM` class. (#1088) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ikko Eltociear Ashimine Co-authored-by: Gabriel Martín Blázquez --- src/distilabel/models/llms/base.py | 2 +- src/distilabel/models/llms/openai.py | 7 +++++++ tests/unit/models/llms/test_openai.py | 25 +++++++++++++++++++++++-- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/distilabel/models/llms/base.py b/src/distilabel/models/llms/base.py index 785668cbee..df274df402 100644 --- a/src/distilabel/models/llms/base.py +++ b/src/distilabel/models/llms/base.py @@ -220,7 +220,7 @@ def _offline_batch_generate_polling( f" for {self.offline_batch_generation_block_until_done} seconds before" " trying to get the results again." ) - # When running a `Step` in a child process, SIGINT is overriden so the child + # When running a `Step` in a child process, SIGINT is overridden so the child # process doesn't stop when the parent process receives a SIGINT signal. # The new handler sets an environment variable that is checked here to stop # the polling. diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index 37bb5bb6be..a6dd9dbb5b 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -153,6 +153,10 @@ class User(BaseModel): default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME), description="The API key to authenticate the requests to the OpenAI API.", ) + default_headers: Optional[RuntimeParameter[Dict[str, str]]] = Field( + default=None, + description="The default headers to use for the OpenAI API requests.", + ) max_retries: RuntimeParameter[int] = Field( default=6, description="The maximum number of times to retry the request to the API before" @@ -196,6 +200,7 @@ def load(self) -> None: api_key=self.api_key.get_secret_value(), max_retries=self.max_retries, # type: ignore timeout=self.timeout, + default_headers=self.default_headers, ) self._aclient = AsyncOpenAI( @@ -203,6 +208,7 @@ def load(self) -> None: api_key=self.api_key.get_secret_value(), max_retries=self.max_retries, # type: ignore timeout=self.timeout, + default_headers=self.default_headers, ) if self.structured_output: @@ -221,6 +227,7 @@ def unload(self) -> None: self._client = None # type: ignore self._aclient = None # type: ignore + self.default_headers = None self.structured_output = None super().unload() diff --git a/tests/unit/models/llms/test_openai.py b/tests/unit/models/llms/test_openai.py index 1e2a2d31d2..c7cf7d4c45 100644 --- a/tests/unit/models/llms/test_openai.py +++ b/tests/unit/models/llms/test_openai.py @@ -569,9 +569,10 @@ def test_create_jsonl_row( } @pytest.mark.parametrize( - "structured_output, dump", + "default_headers, structured_output, dump", [ ( + None, None, { "model": "gpt-4", @@ -579,6 +580,7 @@ def test_create_jsonl_row( "max_retries": 6, "base_url": "https://api.openai.com/v1", "timeout": 120, + "default_headers": None, "structured_output": None, "jobs_ids": None, "offline_batch_generation_block_until_done": None, @@ -590,6 +592,7 @@ def test_create_jsonl_row( }, ), ( + {"X-Custom-Header": "test"}, { "schema": DummyUserDetail.model_json_schema(), "mode": "tool_call", @@ -601,6 +604,7 @@ def test_create_jsonl_row( "max_retries": 6, "base_url": "https://api.openai.com/v1", "timeout": 120, + "default_headers": {"X-Custom-Header": "test"}, "structured_output": { "schema": DummyUserDetail.model_json_schema(), "mode": "tool_call", @@ -621,10 +625,27 @@ def test_serialization( self, _async_openai_mock: MagicMock, _openai_mock: MagicMock, + default_headers: Dict[str, Any], structured_output: Dict[str, Any], dump: Dict[str, Any], ) -> None: - llm = OpenAILLM(model=self.model_id, structured_output=structured_output) + llm = OpenAILLM( + model=self.model_id, + default_headers=default_headers, + structured_output=structured_output, + ) assert llm.dump() == dump assert isinstance(OpenAILLM.from_dict(dump), OpenAILLM) + + def test_openai_llm_default_headers( + self, _async_openai_mock: MagicMock, _openai_mock: MagicMock + ) -> None: + custom_headers = {"X-Custom-Header": "test"} + llm = OpenAILLM( + model=self.model_id, api_key="api.key", default_headers=custom_headers + ) # type: ignore + + assert isinstance(llm, OpenAILLM) + assert llm.model_name == self.model_id + assert llm.default_headers == custom_headers From aaebaa5ff75b777ec3f880c8a23da664229d8c8c Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 10 Jan 2025 16:34:22 +0100 Subject: [PATCH 18/30] fix/pip install messages (#1095) --- src/distilabel/models/embeddings/llamacpp.py | 2 +- src/distilabel/models/embeddings/sentence_transformers.py | 2 +- src/distilabel/models/embeddings/vllm.py | 2 +- src/distilabel/models/llms/anthropic.py | 2 +- src/distilabel/models/llms/azure.py | 2 +- src/distilabel/models/llms/groq.py | 2 +- .../models/llms/huggingface/inference_endpoints.py | 4 ++-- src/distilabel/models/llms/huggingface/transformers.py | 2 +- src/distilabel/models/llms/litellm.py | 2 +- src/distilabel/models/llms/mistral.py | 2 +- src/distilabel/models/llms/ollama.py | 2 +- src/distilabel/models/llms/openai.py | 3 ++- src/distilabel/models/llms/vertexai.py | 2 +- src/distilabel/models/llms/vllm.py | 6 +++--- src/distilabel/pipeline/ray.py | 2 +- src/distilabel/steps/argilla/base.py | 3 +-- src/distilabel/steps/clustering/dbscan.py | 2 +- src/distilabel/steps/clustering/umap.py | 2 +- src/distilabel/steps/embeddings/nearest_neighbour.py | 2 +- src/distilabel/steps/filtering/_datasketch.py | 2 +- src/distilabel/steps/filtering/minhash.py | 4 ++-- src/distilabel/steps/reward_model.py | 2 +- src/distilabel/steps/tasks/structured_outputs/instructor.py | 2 +- src/distilabel/steps/tasks/structured_outputs/outlines.py | 2 +- src/distilabel/steps/truncate.py | 2 +- tests/unit/models/llms/test_anyscale.py | 1 + tests/unit/models/llms/test_azure.py | 2 ++ tests/unit/models/llms/test_together.py | 1 + 28 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/distilabel/models/embeddings/llamacpp.py b/src/distilabel/models/embeddings/llamacpp.py index 6596bb45ea..47f1f9720e 100644 --- a/src/distilabel/models/embeddings/llamacpp.py +++ b/src/distilabel/models/embeddings/llamacpp.py @@ -181,7 +181,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "`llama-cpp-python` package is not installed. Please install it using" - " `pip install llama-cpp-python`." + " `pip install 'distilabel[llama-cpp]'`." ) from ie if self.repo_id is not None: diff --git a/src/distilabel/models/embeddings/sentence_transformers.py b/src/distilabel/models/embeddings/sentence_transformers.py index 8c6e015027..a96b40a7ed 100644 --- a/src/distilabel/models/embeddings/sentence_transformers.py +++ b/src/distilabel/models/embeddings/sentence_transformers.py @@ -110,7 +110,7 @@ def load(self) -> None: except ImportError as e: raise ImportError( "`sentence-transformers` package is not installed. Please install it using" - " `pip install sentence-transformers`." + " `pip install 'distilabel[sentence-transformers]'`." ) from e self._model = SentenceTransformer( diff --git a/src/distilabel/models/embeddings/vllm.py b/src/distilabel/models/embeddings/vllm.py index 8ddaccd7bb..28ba10a12b 100644 --- a/src/distilabel/models/embeddings/vllm.py +++ b/src/distilabel/models/embeddings/vllm.py @@ -93,7 +93,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( - "vLLM is not installed. Please install it using `pip install vllm`." + "vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`." ) from ie self._model = _vLLM( diff --git a/src/distilabel/models/llms/anthropic.py b/src/distilabel/models/llms/anthropic.py index 0eefc092dc..ab364bad58 100644 --- a/src/distilabel/models/llms/anthropic.py +++ b/src/distilabel/models/llms/anthropic.py @@ -176,7 +176,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "Anthropic Python client is not installed. Please install it using" - " `pip install anthropic`." + " `pip install 'distilabel[anthropic]'`." ) from ie if self.api_key is None: diff --git a/src/distilabel/models/llms/azure.py b/src/distilabel/models/llms/azure.py index 964612f372..b9132991a2 100644 --- a/src/distilabel/models/llms/azure.py +++ b/src/distilabel/models/llms/azure.py @@ -131,7 +131,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "OpenAI Python client is not installed. Please install it using" - " `pip install openai`." + " `pip install 'distilabel[openai]'`." ) from ie if self.api_key is None: diff --git a/src/distilabel/models/llms/groq.py b/src/distilabel/models/llms/groq.py index 8000211936..fec511bbee 100644 --- a/src/distilabel/models/llms/groq.py +++ b/src/distilabel/models/llms/groq.py @@ -144,7 +144,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "Groq Python client is not installed. Please install it using" - ' `pip install groq` or from the extras as `pip install "distilabel[groq]"`.' + ' `pip install "distilabel[groq]"`.' ) from ie if self.api_key is None: diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index d4e53f1ed2..6f97c5814a 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -262,7 +262,7 @@ def load(self) -> None: # noqa: C901 except ImportError as ie: raise ImportError( "Hugging Face Hub Python client is not installed. Please install it using" - " `pip install huggingface-hub`." + " `pip install 'distilabel[hf-inference-endpoints]'`." ) from ie if self.api_key is None: @@ -311,7 +311,7 @@ def load(self) -> None: # noqa: C901 except ImportError as ie: raise ImportError( "Transformers Python client is not installed. Please install it using" - " `pip install transformers`." + " `pip install 'distilabel[hf-inference-endpoints]'`." ) from ie self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) diff --git a/src/distilabel/models/llms/huggingface/transformers.py b/src/distilabel/models/llms/huggingface/transformers.py index a4f9de95ab..69f3d02a2e 100644 --- a/src/distilabel/models/llms/huggingface/transformers.py +++ b/src/distilabel/models/llms/huggingface/transformers.py @@ -122,7 +122,7 @@ def load(self) -> None: from transformers import pipeline except ImportError as ie: raise ImportError( - "Transformers is not installed. Please install it using `pip install transformers`." + "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`." ) from ie token = self.token.get_secret_value() if self.token is not None else self.token diff --git a/src/distilabel/models/llms/litellm.py b/src/distilabel/models/llms/litellm.py index d2471f2991..9b52ad8c71 100644 --- a/src/distilabel/models/llms/litellm.py +++ b/src/distilabel/models/llms/litellm.py @@ -104,7 +104,7 @@ def load(self) -> None: except ImportError as e: raise ImportError( "LiteLLM Python client is not installed. Please install it using" - " `pip install litellm`." + " `pip install 'distilabel[litellm]'`." ) from e self._aclient = litellm.acompletion diff --git a/src/distilabel/models/llms/mistral.py b/src/distilabel/models/llms/mistral.py index 9fe9f357da..4147edaf03 100644 --- a/src/distilabel/models/llms/mistral.py +++ b/src/distilabel/models/llms/mistral.py @@ -140,7 +140,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "MistralAI Python client is not installed. Please install it using" - " `pip install mistralai`." + " `pip install 'distilabel[mistralai]'`." ) from ie if self.api_key is None: diff --git a/src/distilabel/models/llms/ollama.py b/src/distilabel/models/llms/ollama.py index ff0779d881..a930399114 100644 --- a/src/distilabel/models/llms/ollama.py +++ b/src/distilabel/models/llms/ollama.py @@ -163,7 +163,7 @@ def load(self) -> None: except ImportError as e: raise ImportError( "Ollama Python client is not installed. Please install it using" - " `pip install ollama`." + " `pip install 'distilabel[ollama]'`." ) from e if self.tokenizer_id: diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index a6dd9dbb5b..91f24a3336 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -54,6 +54,7 @@ class OpenAILLM(AsyncLLM): api_key: the API key to authenticate the requests to the OpenAI API. Defaults to `None` which means that the value set for the environment variable `OPENAI_API_KEY` will be used, or `None` if not set. + default_headers: the default headers to use for the OpenAI API requests. max_retries: the maximum number of times to retry the request to the API before failing. Defaults to `6`. timeout: the maximum time in seconds to wait for a response from the API. Defaults @@ -186,7 +187,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "OpenAI Python client is not installed. Please install it using" - " `pip install openai`." + " `pip install 'distilabel[openai]'`." ) from ie if self.api_key is None: diff --git a/src/distilabel/models/llms/vertexai.py b/src/distilabel/models/llms/vertexai.py index 62235dd321..7c1b3e6bb4 100644 --- a/src/distilabel/models/llms/vertexai.py +++ b/src/distilabel/models/llms/vertexai.py @@ -89,7 +89,7 @@ def load(self) -> None: except ImportError as e: raise ImportError( "vertexai is not installed. Please install it using" - " `pip install google-cloud-aiplatform`." + " `pip install 'distilabel[vertexai]'`." ) from e if _is_gemini_model(self.model): diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index 401bc66d09..ceab8e3e30 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -189,7 +189,7 @@ def load(self) -> None: from vllm import LLM as _vLLM except ImportError as ie: raise ImportError( - "vLLM is not installed. Please install it using `pip install vllm`." + "vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`." ) from ie self._model = _vLLM( @@ -585,7 +585,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "OpenAI Python client is not installed. Please install it using" - " `pip install openai`." + " `pip install 'distilabel[openai]'`." ) from ie self._client = OpenAI( @@ -602,7 +602,7 @@ def load(self) -> None: except ImportError as ie: raise ImportError( "To use `ClientvLLM` you need to install `transformers`." - "Please install it using `pip install transformers`." + "Please install it using `pip install 'distilabel[hf-transformers]'`." ) from ie self._tokenizer = AutoTokenizer.from_pretrained( diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py index 30d2e5a47e..c2e85afd86 100644 --- a/src/distilabel/pipeline/ray.py +++ b/src/distilabel/pipeline/ray.py @@ -204,7 +204,7 @@ def _init_ray(self) -> None: import ray except ImportError as ie: raise ImportError( - "ray is not installed. Please install it using `pip install ray[default]`." + "ray is not installed. Please install it using `pip install 'distilabel[ray]'`." ) from ie if self._ray_head_node_url: diff --git a/src/distilabel/steps/argilla/base.py b/src/distilabel/steps/argilla/base.py index ea491e07a5..06db05e05b 100644 --- a/src/distilabel/steps/argilla/base.py +++ b/src/distilabel/steps/argilla/base.py @@ -94,8 +94,7 @@ def model_post_init(self, __context: Any) -> None: if importlib.util.find_spec("argilla") is None: raise ImportError( - "Argilla is not installed. Please install it using `pip install argilla" - " --upgrade`." + "Argilla is not installed. Please install it using `pip install 'distilabel[argilla]'`." ) def _client_init(self) -> None: diff --git a/src/distilabel/steps/clustering/dbscan.py b/src/distilabel/steps/clustering/dbscan.py index 03ac5dcb3e..2124d787c1 100644 --- a/src/distilabel/steps/clustering/dbscan.py +++ b/src/distilabel/steps/clustering/dbscan.py @@ -124,7 +124,7 @@ def load(self) -> None: super().load() if importlib.util.find_spec("sklearn") is None: raise ImportError( - "`sklearn` package is not installed. Please install it using `pip install scikit-learn`." + "`sklearn` package is not installed. Please install it using `pip install 'distilabel[text-clustering]'`." ) from sklearn.cluster import DBSCAN as _DBSCAN diff --git a/src/distilabel/steps/clustering/umap.py b/src/distilabel/steps/clustering/umap.py index daeb37486d..9bf71c68e3 100644 --- a/src/distilabel/steps/clustering/umap.py +++ b/src/distilabel/steps/clustering/umap.py @@ -112,7 +112,7 @@ def load(self) -> None: super().load() if importlib.util.find_spec("umap") is None: raise ImportError( - "`umap` package is not installed. Please install it using `pip install umap-learn`." + "`umap` package is not installed. Please install it using `pip install 'distilabel[text-clustering]'`." ) from umap import UMAP as _UMAP diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py index df5f48f8fa..a962ca3b14 100644 --- a/src/distilabel/steps/embeddings/nearest_neighbour.py +++ b/src/distilabel/steps/embeddings/nearest_neighbour.py @@ -163,7 +163,7 @@ def load(self) -> None: if importlib.util.find_spec("faiss") is None: raise ImportError( "`faiss` package is not installed. Please install it using `pip install" - " faiss-cpu` or `pip install faiss-gpu`." + " 'distilabel[faiss-cpu]' or 'distilabel[faiss-gpu]'`." ) @property diff --git a/src/distilabel/steps/filtering/_datasketch.py b/src/distilabel/steps/filtering/_datasketch.py index 5e21940499..d3d0db74ef 100644 --- a/src/distilabel/steps/filtering/_datasketch.py +++ b/src/distilabel/steps/filtering/_datasketch.py @@ -43,7 +43,7 @@ def __init__(self, config, name) -> None: except ImportError as e: raise ImportError( "`diskcache` is required for disk storage using `MinHashDedup`. " - "Please install it using `pip install diskcache`." + "Please install it using `pip install 'distilabel[minhash]'`." ) from e # Start with a clean file on each pipeline diff --git a/src/distilabel/steps/filtering/minhash.py b/src/distilabel/steps/filtering/minhash.py index e6bb8038a3..7e86d30543 100644 --- a/src/distilabel/steps/filtering/minhash.py +++ b/src/distilabel/steps/filtering/minhash.py @@ -176,7 +176,7 @@ def load(self) -> None: if not importlib.import_module("datasketch"): raise ImportError( "`datasketch` is needed to deduplicate with MinHash, but is not installed. " - "Please install it using `pip install datasketch`." + "Please install it using `pip install 'distilabel[minhash]'`." ) from datasketch import MinHash @@ -193,7 +193,7 @@ def load(self) -> None: if not importlib.import_module("nltk"): raise ImportError( "`nltk` is needed to tokenize based on words, but is not installed. " - "Please install it using `pip install nltk`. Then run `nltk.download('punkt_tab')`." + "Please install it using `pip install 'distilabel[minhash]'`. Then run `nltk.download('punkt_tab')`." ) self._tokenizer = tokenized_on_words else: diff --git a/src/distilabel/steps/reward_model.py b/src/distilabel/steps/reward_model.py index fcb5b27371..0af5d5cfdd 100644 --- a/src/distilabel/steps/reward_model.py +++ b/src/distilabel/steps/reward_model.py @@ -156,7 +156,7 @@ def load(self) -> None: from transformers import AutoModelForSequenceClassification, AutoTokenizer except ImportError as e: raise ImportError( - "`transformers` is not installed. Please install it using `pip install transformers`." + "`transformers` is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`." ) from e token = self.token.get_secret_value() if self.token is not None else self.token diff --git a/src/distilabel/steps/tasks/structured_outputs/instructor.py b/src/distilabel/steps/tasks/structured_outputs/instructor.py index 93b90d9916..184c9be7b6 100644 --- a/src/distilabel/steps/tasks/structured_outputs/instructor.py +++ b/src/distilabel/steps/tasks/structured_outputs/instructor.py @@ -109,7 +109,7 @@ def prepare_instructor( """ if not importlib.util.find_spec("instructor"): raise ImportError( - "`instructor` is not installed. Please install it using `pip install instructor`." + "`instructor` is not installed. Please install it using `pip install 'distilabel[instructor]'`." ) import instructor diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index fe561d11af..b8ac03641a 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -99,7 +99,7 @@ def prepare_guided_output( """ if not importlib.util.find_spec("outlines"): raise ImportError( - "Outlines is not installed. Please install it using `pip install outlines`." + "Outlines is not installed. Please install it using `pip install 'distilabel[outlines]'`." ) json_processor, regex_processor = _get_logits_processor(framework) diff --git a/src/distilabel/steps/truncate.py b/src/distilabel/steps/truncate.py index 6e68af6630..a2240d716b 100644 --- a/src/distilabel/steps/truncate.py +++ b/src/distilabel/steps/truncate.py @@ -108,7 +108,7 @@ def load(self): if not importlib.util.find_spec("transformers"): raise ImportError( "`transformers` is needed to tokenize, but is not installed. " - "Please install it using `pip install transformers`." + "Please install it using `pip install 'distilabel[hf-transformers]'`." ) from transformers import AutoTokenizer diff --git a/tests/unit/models/llms/test_anyscale.py b/tests/unit/models/llms/test_anyscale.py index d12dbebd02..6a31d60809 100644 --- a/tests/unit/models/llms/test_anyscale.py +++ b/tests/unit/models/llms/test_anyscale.py @@ -46,6 +46,7 @@ def test_serialization(self) -> None: "model": self.model_id, "generation_kwargs": {}, "max_retries": 6, + "default_headers": None, "base_url": "https://api.endpoints.anyscale.com/v1", "timeout": 120, "structured_output": None, diff --git a/tests/unit/models/llms/test_azure.py b/tests/unit/models/llms/test_azure.py index a2122b611f..1e874c5f9b 100644 --- a/tests/unit/models/llms/test_azure.py +++ b/tests/unit/models/llms/test_azure.py @@ -71,6 +71,7 @@ def test_azure_openai_llm_env_vars(self) -> None: "api_version": "preview", "generation_kwargs": {}, "max_retries": 6, + "default_headers": None, "base_url": "https://example-resource.azure.openai.com/", "timeout": 120, "structured_output": None, @@ -95,6 +96,7 @@ def test_azure_openai_llm_env_vars(self) -> None: "generation_kwargs": {}, "max_retries": 6, "base_url": "https://example-resource.azure.openai.com/", + "default_headers": None, "timeout": 120, "structured_output": { "schema": DummyUserDetail.model_json_schema(), diff --git a/tests/unit/models/llms/test_together.py b/tests/unit/models/llms/test_together.py index 88208bf6c6..b7a045fbbb 100644 --- a/tests/unit/models/llms/test_together.py +++ b/tests/unit/models/llms/test_together.py @@ -46,6 +46,7 @@ def test_serialization(self) -> None: "model": self.model_id, "generation_kwargs": {}, "max_retries": 6, + "default_headers": None, "base_url": "https://api.together.xyz/v1", "timeout": 120, "structured_output": None, From d9fd15c30c790f38f0c2c51c8178cacc6fa1a6b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 10 Jan 2025 16:34:43 +0100 Subject: [PATCH 19/30] Fix handling empty list statistics (#1094) --- src/distilabel/models/llms/utils.py | 12 ++++++++---- src/distilabel/steps/tasks/base.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/distilabel/models/llms/utils.py b/src/distilabel/models/llms/utils.py index afb09cab4d..ef97e53e1f 100644 --- a/src/distilabel/models/llms/utils.py +++ b/src/distilabel/models/llms/utils.py @@ -57,11 +57,15 @@ def prepare_output( """ output: "GenerateOutput" = { "generations": generations, - "statistics": { - "input_tokens": input_tokens or [], - "output_tokens": output_tokens or [], - }, + "statistics": {}, } + + if input_tokens: + output["statistics"]["input_tokens"] = input_tokens + + if output_tokens: + output["statistics"]["output_tokens"] = output_tokens + if logprobs: output["logprobs"] = logprobs return output diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index dba92588cd..ae19a1038f 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -521,15 +521,15 @@ def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": gen_length = len(output["generations"]) for stat_key, stat_values in output["statistics"].items(): - current_length = len(stat_values) + current_length = len(stat_values) # type: ignore - if current_length < gen_length: + if current_length > 0 and current_length < gen_length: # Calculate how many times to repeat the tokens repeats = gen_length // current_length remainder = gen_length % current_length # Create new list with repeated values - new_values = stat_values * repeats + stat_values[:remainder] + new_values = stat_values * repeats + stat_values[:remainder] # type: ignore output["statistics"][stat_key] = new_values return output @@ -552,7 +552,11 @@ def iterate_generations_with_stats( ] for i, generation in enumerate(outputs["generations"]): # Create a new dictionary with the statistics for this index - stats = {key: values[i] for key, values in outputs["statistics"].items()} # type: ignore + stats = { + key: values[i] # type: ignore + for key, values in outputs["statistics"].items() + if values + } # Extra keys returned by the `LLM` extra = {key: outputs[key][i] for key in extra_keys} yield generation, stats, extra From 95069301140c4604680e22a886f257b7bc5015db Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 10 Jan 2025 17:59:24 +0100 Subject: [PATCH 20/30] update to outlines010 (#1092) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .gitignore | 1 + .../models/llms/huggingface/transformers.py | 17 ++- src/distilabel/models/llms/llamacpp.py | 11 +- src/distilabel/models/llms/mlx.py | 52 +------- .../tasks/structured_outputs/outlines.py | 119 +++++++++++++----- .../tasks/structured_outputs/test_outlines.py | 14 ++- 6 files changed, 123 insertions(+), 91 deletions(-) diff --git a/.gitignore b/.gitignore index d8337200af..1aab313fb9 100644 --- a/.gitignore +++ b/.gitignore @@ -77,3 +77,4 @@ venv.bak/ # Other *.log *.swp +.DS_Store diff --git a/src/distilabel/models/llms/huggingface/transformers.py b/src/distilabel/models/llms/huggingface/transformers.py index 69f3d02a2e..19dc32dd2d 100644 --- a/src/distilabel/models/llms/huggingface/transformers.py +++ b/src/distilabel/models/llms/huggingface/transformers.py @@ -23,6 +23,9 @@ from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin +from distilabel.steps.tasks.structured_outputs.outlines import ( + _is_outlines_version_below_0_1_0, +) from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR @@ -111,6 +114,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): _pipeline: Optional["Pipeline"] = PrivateAttr(...) _prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None) + _logits_processor: Union[Callable, None] = PrivateAttr(default=None) def load(self) -> None: """Loads the model and tokenizer and creates the text generation pipeline. In addition, @@ -149,9 +153,11 @@ def load(self) -> None: self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token # type: ignore if self.structured_output: - self._prefix_allowed_tokens_fn = self._prepare_structured_output( - self.structured_output - ) + processor = self._prepare_structured_output(self.structured_output) + if _is_outlines_version_below_0_1_0(): + self._prefix_allowed_tokens_fn = processor + else: + self._logits_processor = [processor] super().load() @@ -232,7 +238,8 @@ def generate( # type: ignore do_sample=do_sample, num_return_sequences=num_generations, prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, - pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore + pad_token_id=self._pipeline.tokenizer.eos_token_id, + logits_processor=self._logits_processor, ) llm_output = [ [generation["generated_text"] for generation in output] @@ -292,7 +299,7 @@ def get_last_hidden_states( def _prepare_structured_output( self, structured_output: Optional[OutlinesStructuredOutputType] = None - ) -> Union[Callable, None]: + ) -> Union[Callable, List[Callable]]: """Creates the appropriate function to filter tokens to generate structured outputs. Args: diff --git a/src/distilabel/models/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py index 822e5cea77..a754f6b84f 100644 --- a/src/distilabel/models/llms/llamacpp.py +++ b/src/distilabel/models/llms/llamacpp.py @@ -24,7 +24,12 @@ from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType if TYPE_CHECKING: - from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList + from llama_cpp import ( + CreateChatCompletionResponse, + Llama, + LogitsProcessor, + LogitsProcessorList, + ) from distilabel.steps.tasks.typing import FormattedInput, StandardInput @@ -383,7 +388,7 @@ def generate( # type: ignore def _prepare_structured_output( self, structured_output: Optional[OutlinesStructuredOutputType] = None - ) -> Union["LogitsProcessorList", None]: + ) -> Union["LogitsProcessorList", "LogitsProcessor"]: """Creates the appropriate function to filter tokens to generate structured outputs. Args: @@ -399,4 +404,4 @@ def _prepare_structured_output( result = prepare_guided_output(structured_output, "llamacpp", self._model) if (schema := result.get("schema")) and self.structured_output: self.structured_output["schema"] = schema - return result["processor"] + return [result["processor"]] diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py index 4ffcceddab..1f8c9b8c65 100644 --- a/src/distilabel/models/llms/mlx.py +++ b/src/distilabel/models/llms/mlx.py @@ -19,22 +19,18 @@ Dict, List, Optional, - Union, ) from pydantic import ( - Field, PrivateAttr, validate_call, ) -from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import ( - OutlinesStructuredOutputType, StandardInput, ) @@ -51,8 +47,6 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): tokenizer_config: the tokenizer configuration. model_config: the model configuration. adapter_path: the path to the adapter. - structured_output: a dictionary containing the structured output configuration or if more - fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. use_magpie_template: a flag used to enable/disable applying the Magpie pre-query template. Defaults to `False`. magpie_pre_query_template: the pre-query template to be applied to the prompt or @@ -82,17 +76,10 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): tokenizer_config: Dict[str, Any] = {} model_config: Dict[str, Any] = {} adapter_path: Optional[str] = None - structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( - default=None, - description="The structured output format to use across all the generations.", - ) _mlx_generate: Optional[Callable] = PrivateAttr(default=None) _model: Optional["nn.Module"] = PrivateAttr(...) _tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...) - _structured_output_logits_processor: Union[Callable, None] = PrivateAttr( - default=None - ) def load(self) -> None: """Loads the model and tokenizer and creates the text generation pipeline. In addition, @@ -112,11 +99,6 @@ def load(self) -> None: adapter_path=self.adapter_path, ) - if self.structured_output: - self._structured_output_logits_processor = self._prepare_structured_output( - self.structured_output - ) - if self._tokenizer.pad_token is None: self._tokenizer.pad_token = self._tokenizer.eos_token @@ -207,10 +189,6 @@ def generate( Returns: A list of lists of strings containing the generated responses for each input. """ - logits_processors = [] - if self._structured_output_logits_processor: - logits_processors.append(self._structured_output_logits_processor) - structured_output = None result = [] for input in inputs: @@ -219,13 +197,9 @@ def generate( output: List[str] = [] for _ in range(num_generations): - if structured_output: - additional_logits_processors = self._prepare_structured_output( - structured_output - ) - logits_processors.append(additional_logits_processors) + if structured_output: # will raise a NotImplementedError + self._prepare_structured_output(structured_output) prompt = self.prepare_input(input) - generation = self._mlx_generate( prompt=prompt, model=self._model, @@ -264,25 +238,3 @@ def generate( ) ) return result - - def _prepare_structured_output( - self, structured_output: Optional[OutlinesStructuredOutputType] = None - ) -> Union[Callable, None]: - """Creates the appropriate function to filter tokens to generate structured outputs. - - Args: - structured_output: the configuration dict to prepare the structured output. - - Returns: - The callable that will be used to guide the generation of the model. - """ - from distilabel.steps.tasks.structured_outputs.outlines import ( - prepare_guided_output, - ) - - result = prepare_guided_output( - structured_output, "transformers", self._pipeline - ) - if schema := result.get("schema"): - self.structured_output["schema"] = schema - return result["processor"] diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index b8ac03641a..a5aceacb3b 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -24,19 +24,38 @@ Literal, Tuple, Type, + Union, get_args, ) +import pkg_resources from pydantic import BaseModel from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict -if TYPE_CHECKING: - from distilabel.steps.tasks.typing import OutlinesStructuredOutputType +if TYPE_CHECKING: # noqa + from llama_cpp import Llama # noqa + from transformers import Pipeline # noqa + from vllm import LLM as _vLLM # noqa + + from distilabel.steps.tasks.typing import OutlinesStructuredOutputType # noqa Frameworks = Literal["transformers", "llamacpp", "vllm"] -"""Available frameworks for the structured output configuration. """ + + +def _is_outlines_version_below_0_1_0() -> bool: + """Helper function to check outlines availability and version. + + Returns: + bool: True if outlines is not installed or version is below 0.1.0 + """ + if not importlib.util.find_spec("outlines"): + raise ImportError( + "Outlines is not installed. Please install it using `pip install outlines`." + ) + version = pkg_resources.get_distribution("outlines").version + return pkg_resources.parse_version(version) < pkg_resources.parse_version("0.1.0") def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]: @@ -45,38 +64,77 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]: def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]: - """Helper function to return the appropriate logits processor for the given framework.""" - if framework == "transformers": - from outlines.integrations.transformers import ( - JSONPrefixAllowedTokens, - RegexPrefixAllowedTokens, + """Helper function to return the appropriate logits processors for the given framework.""" + if _is_outlines_version_below_0_1_0(): + processors = { + "transformers": ( + "outlines.integrations.transformers", + "JSONPrefixAllowedTokens", + "RegexPrefixAllowedTokens", + ), + "llamacpp": ( + "outlines.integrations.llamacpp", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + "vllm": ( + "outlines.integrations.vllm", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + } + else: + processors = { + "transformers": ( + "outlines.processors", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + "llamacpp": ( + "outlines.processors", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + "vllm": ( + "outlines.processors", + "JSONLogitsProcessor", + "RegexLogitsProcessor", + ), + } + + if framework not in processors: + raise DistilabelUserError( + f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}", + page="sections/how_to_guides/advanced/structured_generation/", ) - return JSONPrefixAllowedTokens, RegexPrefixAllowedTokens + module_path, json_cls, regex_cls = processors[framework] + module = importlib.import_module(module_path) + return getattr(module, json_cls), getattr(module, regex_cls) + +def _get_tokenizer_from_model( + llm: Union["_vLLM", "Pipeline", "Llama"], + framework: Frameworks, +) -> Callable: if framework == "llamacpp": - from outlines.integrations.llamacpp import ( - JSONLogitsProcessor, - RegexLogitsProcessor, - ) + from outlines.models.llamacpp import LlamaCppTokenizer - return JSONLogitsProcessor, RegexLogitsProcessor + return LlamaCppTokenizer(llm) + if framework == "transformers": + from outlines.models.transformers import TransformerTokenizer + return TransformerTokenizer(llm.tokenizer) if framework == "vllm": - from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor + from outlines.models.vllm import adapt_tokenizer - return JSONLogitsProcessor, RegexLogitsProcessor - - raise DistilabelUserError( - f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}", - page="sections/how_to_guides/advanced/structured_generation/", - ) + return adapt_tokenizer(llm.get_tokenizer()) def prepare_guided_output( structured_output: "OutlinesStructuredOutputType", framework: Frameworks, - llm: Any, + llm: Union["_vLLM", "Pipeline", "Llama"], ) -> Dict[str, Any]: """Prepares the `LLM` to generate guided output using `outlines`. @@ -97,10 +155,6 @@ def prepare_guided_output( case of "json" will also include the schema as a dict, to simplify serialization and deserialization. """ - if not importlib.util.find_spec("outlines"): - raise ImportError( - "Outlines is not installed. Please install it using `pip install 'distilabel[outlines]'`." - ) json_processor, regex_processor = _get_logits_processor(framework) @@ -116,18 +170,27 @@ def prepare_guided_output( elif isinstance(schema, str): format = "regex" + if _is_outlines_version_below_0_1_0(): + # use the llm for processor initialization + model = llm + tokenizer = None + else: + # use the tokenizer for processor initialization + model = None + tokenizer = _get_tokenizer_from_model(llm, framework) + if format == "json": return { "processor": json_processor( schema, - llm, + model or tokenizer, whitespace_pattern=structured_output.get("whitespace_pattern"), ), "schema": schema_as_dict(schema), } if format == "regex": - return {"processor": regex_processor(schema, llm)} + return {"processor": regex_processor(schema, model or tokenizer)} raise DistilabelUserError( f"Invalid format '{format}'. Must be either 'json' or 'regex'.", diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index e4eb2025c8..2812c2e48b 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -19,6 +19,7 @@ from distilabel.models.llms.huggingface.transformers import TransformersLLM from distilabel.steps.tasks.structured_outputs.outlines import ( + _is_outlines_version_below_0_1_0, model_to_schema, ) from distilabel.steps.tasks.typing import OutlinesStructuredOutputType @@ -100,9 +101,6 @@ class DummyUserTest(BaseModel): } -@pytest.mark.skip( - reason="won't work until we update our code to work with `outlines>0.1.0`" -) class TestOutlinesIntegration: @pytest.mark.parametrize( "format, schema, prompt", @@ -138,7 +136,7 @@ def test_generation( prompt = [ [{"role": "system", "content": ""}, {"role": "user", "content": prompt}] ] - result = llm.generate(prompt, max_new_tokens=30) + result = llm.generate(prompt, max_new_tokens=30, temperature=0.7) assert isinstance(result, list) assert isinstance(result[0], dict) assert "generations" in result[0] and "statistics" in result[0] @@ -174,6 +172,7 @@ def test_serialization( structured_output=OutlinesStructuredOutputType( format=format, schema=schema ), + token=None, ) llm.load() assert llm.dump() == dump @@ -182,4 +181,9 @@ def test_load_from_dict(self) -> None: llm = TransformersLLM.from_dict(DUMP_JSON) assert isinstance(llm, TransformersLLM) llm.load() - assert llm._prefix_allowed_tokens_fn is not None + if _is_outlines_version_below_0_1_0(): + assert llm._prefix_allowed_tokens_fn is not None + assert llm._logits_processor is None + else: + assert llm._prefix_allowed_tokens_fn is None + assert llm._logits_processor is not None From e866345a449a4febf5fc7c59dc6402095c05a11f Mon Sep 17 00:00:00 2001 From: Sara Han <127759186+sdiazlor@users.noreply.github.com> Date: Tue, 14 Jan 2025 16:11:50 +0100 Subject: [PATCH 21/30] update: search by match (#1096) --- src/distilabel/steps/tasks/complexity_scorer.py | 2 +- src/distilabel/steps/tasks/sentence_transformers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distilabel/steps/tasks/complexity_scorer.py b/src/distilabel/steps/tasks/complexity_scorer.py index 7578ecf187..bd8a99c6b0 100644 --- a/src/distilabel/steps/tasks/complexity_scorer.py +++ b/src/distilabel/steps/tasks/complexity_scorer.py @@ -185,7 +185,7 @@ def format_output( scores = [] score_lines = output.split("\n") for i, line in enumerate(score_lines): - match = _PARSE_SCORE_LINE_REGEX.match(line) + match = _PARSE_SCORE_LINE_REGEX.search(line) score = float(match.group(1)) if match else None scores.append(score) if i == len(input["instructions"]) - 1: diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py index fa29bbe367..350849e3d0 100644 --- a/src/distilabel/steps/tasks/sentence_transformers.py +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -346,7 +346,7 @@ def format_output( if self.use_default_structured_output: return self._format_structured_output(output) - match = POSITIVE_NEGATIVE_PAIR_REGEX.match(output) + match = POSITIVE_NEGATIVE_PAIR_REGEX.search(output) if match is None: formatted_output = {"positive": None} if self.triplet: From 27b5db21b46f620161600da917e2d45667041894 Mon Sep 17 00:00:00 2001 From: Parag Ekbote Date: Tue, 14 Jan 2025 20:42:47 +0530 Subject: [PATCH 22/30] Add Legend to Component Gallery Icons (#1090) --- src/distilabel/utils/mkdocs/components_gallery.py | 2 +- .../utils/mkdocs/templates/components-gallery/index.md | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index 621f4b61dc..022c6cf1fd 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -454,7 +454,7 @@ def _generate_embeddings_pages(self, src_dir: Path, embeddings: list) -> List[st paths.append(llm_path) - # Create the `components-gallery/llms/index.md` file + # Create the `components-gallery/embeddings/index.md` file content = _COMPONENTS_LIST_TEMPLATE.render( title="Embeddings Gallery", description="", diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md index cc3e44aecf..7d13969da3 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md @@ -5,6 +5,14 @@ hide: --- # Components Gallery +??? info "Category Overview" + | Icon | Category | Description | + |----------------------------|------------|-------------------------------------------------------------------| + | :material-step-forward: | Steps | Steps are used for data manipulation. | + | :material-check-outline: | Tasks | Tasks allow performing data generation, annotation, and more. | + | :material-brain: | LLMs | Explore all available Large Language Models integrated with distilabel. | + | :material-vector-line: | Embeddings | Explore all available Embeddings Models integrated with distilabel. | +

- :material-step-forward:{ .lg .middle } __Steps__ From 5257600f1d3f03bdcb9045e0bbcb56e3aaec1b2d Mon Sep 17 00:00:00 2001 From: Agus Date: Wed, 15 Jan 2025 12:28:19 +0100 Subject: [PATCH 23/30] Image Language Models and `ImageGeneration` task (#1060) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../image_generation_gallery.md | 10 + docs/api/models/image_generation/index.md | 7 + docs/api/pipeline/typing.md | 3 - docs/api/step/typing.md | 3 - docs/api/task/image_task.md | 7 + docs/api/task/task_gallery.md | 1 + docs/api/task/typing.md | 3 - docs/api/typing.md | 8 + .../how_to_guides/advanced/distiset.md | 27 ++ .../advanced/pipeline_requirements.md | 2 +- .../advanced/structured_generation.md | 4 +- .../basic/step/generator_step.md | 8 +- .../how_to_guides/basic/step/global_step.md | 6 +- .../how_to_guides/basic/step/index.md | 8 +- .../basic/task/generator_task.md | 5 +- .../how_to_guides/basic/task/image_task.md | 104 ++++++++ .../how_to_guides/basic/task/index.md | 5 +- .../examples/image_generation.md | 108 ++++++++ docs/sections/pipeline_samples/index.md | 8 + .../tutorials/clean_existing_dataset.ipynb | 2 +- examples/image_generation.py | 42 +++ mkdocs.yml | 9 +- pyproject.toml | 1 + src/distilabel/distiset.py | 54 +++- src/distilabel/llms.py | 2 +- src/distilabel/mixins/runtime_parameters.py | 87 ++++++ src/distilabel/models/__init__.py | 14 +- .../models/base_clients/__init__.py | 20 ++ .../base_clients/inference_endpoints.py | 154 +++++++++++ src/distilabel/models/base_clients/openai.py | 122 +++++++++ src/distilabel/models/embeddings/base.py | 4 +- .../models/image_generation/__init__.py | 29 ++ .../models/image_generation/base.py | 247 ++++++++++++++++++ .../image_generation/huggingface/__init__.py | 14 + .../huggingface/inference_endpoints.py | 106 ++++++++ .../models/image_generation/openai.py | 129 +++++++++ .../models/image_generation/utils.py | 31 +++ src/distilabel/models/llms/__init__.py | 2 +- src/distilabel/models/llms/anthropic.py | 9 +- src/distilabel/models/llms/base.py | 90 +------ src/distilabel/models/llms/cohere.py | 6 +- src/distilabel/models/llms/groq.py | 6 +- .../llms/huggingface/inference_endpoints.py | 152 ++--------- .../models/llms/huggingface/transformers.py | 9 +- src/distilabel/models/llms/litellm.py | 7 +- src/distilabel/models/llms/llamacpp.py | 9 +- src/distilabel/models/llms/mistral.py | 6 +- src/distilabel/models/llms/mlx.py | 5 +- src/distilabel/models/llms/moa.py | 5 +- src/distilabel/models/llms/ollama.py | 10 +- src/distilabel/models/llms/openai.py | 108 +------- src/distilabel/models/llms/typing.py | 62 ----- src/distilabel/models/llms/utils.py | 4 +- src/distilabel/models/llms/vertexai.py | 5 +- src/distilabel/models/llms/vllm.py | 19 +- src/distilabel/models/mixins/magpie.py | 2 +- src/distilabel/pipeline/base.py | 4 +- src/distilabel/pipeline/local.py | 2 +- src/distilabel/pipeline/ray.py | 2 +- .../pipeline/routing_batch_function.py | 2 +- src/distilabel/pipeline/step_wrapper.py | 2 +- src/distilabel/steps/__init__.py | 2 +- src/distilabel/steps/argilla/base.py | 2 +- src/distilabel/steps/argilla/preference.py | 2 +- .../steps/argilla/text_generation.py | 2 +- src/distilabel/steps/base.py | 6 +- src/distilabel/steps/clustering/dbscan.py | 2 +- .../steps/clustering/text_clustering.py | 2 +- src/distilabel/steps/clustering/umap.py | 2 +- src/distilabel/steps/columns/combine.py | 2 +- src/distilabel/steps/columns/expand.py | 2 +- src/distilabel/steps/columns/group.py | 2 +- src/distilabel/steps/columns/keep.py | 2 +- src/distilabel/steps/columns/merge.py | 2 +- src/distilabel/steps/decorator.py | 2 +- .../steps/embeddings/embedding_generation.py | 2 +- .../steps/embeddings/nearest_neighbour.py | 2 +- src/distilabel/steps/filtering/embedding.py | 2 +- src/distilabel/steps/filtering/minhash.py | 2 +- .../steps/formatting/conversation.py | 2 +- src/distilabel/steps/formatting/dpo.py | 2 +- src/distilabel/steps/formatting/sft.py | 2 +- src/distilabel/steps/generators/data.py | 2 +- .../steps/generators/huggingface.py | 2 +- src/distilabel/steps/globals/huggingface.py | 2 +- src/distilabel/steps/reward_model.py | 3 +- src/distilabel/steps/tasks/__init__.py | 10 +- .../steps/tasks/apigen/execution_checker.py | 2 +- .../steps/tasks/apigen/generator.py | 3 +- .../steps/tasks/apigen/semantic_checker.py | 3 +- src/distilabel/steps/tasks/apigen/utils.py | 2 +- .../steps/tasks/argilla_labeller.py | 3 +- src/distilabel/steps/tasks/base.py | 107 +++++++- src/distilabel/steps/tasks/clair.py | 3 +- .../steps/tasks/complexity_scorer.py | 2 +- src/distilabel/steps/tasks/decorator.py | 4 +- .../steps/tasks/evol_instruct/base.py | 5 +- .../steps/tasks/evol_instruct/generator.py | 4 +- .../steps/tasks/evol_quality/base.py | 4 +- .../steps/tasks/generate_embeddings.py | 3 +- src/distilabel/steps/tasks/genstruct.py | 2 +- .../steps/tasks/image_generation.py | 188 +++++++++++++ .../steps/tasks/improving_text_embeddings.py | 3 +- .../tasks/instruction_backtranslation.py | 2 +- src/distilabel/steps/tasks/magpie/base.py | 4 +- .../steps/tasks/magpie/generator.py | 3 +- .../steps/tasks/math_shepherd/completer.py | 4 +- .../steps/tasks/math_shepherd/generator.py | 3 +- .../steps/tasks/math_shepherd/utils.py | 2 +- src/distilabel/steps/tasks/pair_rm.py | 2 +- src/distilabel/steps/tasks/prometheus_eval.py | 2 +- src/distilabel/steps/tasks/quality_scorer.py | 2 +- src/distilabel/steps/tasks/self_instruct.py | 2 +- .../steps/tasks/sentence_transformers.py | 2 +- .../steps/tasks/structured_generation.py | 2 +- .../tasks/structured_outputs/outlines.py | 2 +- .../steps/tasks/text_classification.py | 2 +- src/distilabel/steps/tasks/text_generation.py | 3 +- .../steps/tasks/text_generation_with_image.py | 3 +- src/distilabel/steps/tasks/ultrafeedback.py | 2 +- src/distilabel/steps/tasks/urial.py | 3 +- .../{typing.py => typing/__init__.py} | 44 +++- src/distilabel/typing/base.py | 46 ++++ .../tasks/typing.py => typing/models.py} | 72 +++-- .../typing.py => typing/pipeline.py} | 3 +- .../{steps/typing.py => typing/steps.py} | 0 .../utils/export_components_info.py | 22 ++ .../utils/mkdocs/components_gallery.py | 68 ++++- .../templates/components-gallery/index.md | 8 + .../integration/test_dataset_without_step.py | 2 +- tests/integration/test_embedding_dedup.py | 2 +- tests/integration/test_load_stages.py | 2 +- tests/integration/test_multiple_replicas.py | 2 +- .../test_offline_batch_generation.py | 3 +- tests/integration/test_pipe_llms.py | 2 +- tests/integration/test_pipe_simple.py | 2 +- tests/integration/test_ray_pipeline.py | 2 +- .../test_routing_batch_function.py | 2 +- tests/unit/conftest.py | 27 +- .../unit/models/image_generation/__init__.py | 14 + .../image_generation/huggingface/__init__.py | 14 + .../huggingface/test_inference_endpoints.py | 59 +++++ .../models/image_generation/test_openai.py | 105 ++++++++ .../mixins/test_cuda_device_placement.py | 2 +- tests/unit/pipeline/test_base.py | 2 +- tests/unit/pipeline/test_dag.py | 2 +- tests/unit/pipeline/utils.py | 2 +- tests/unit/steps/argilla/test_base.py | 2 +- .../steps/clustering/test_text_clustering.py | 3 +- .../unit/steps/tasks/apigen/test_generator.py | 6 +- .../tasks/math_shepherd/test_completer.py | 2 +- .../tasks/math_shepherd/test_generator.py | 2 +- .../tasks/structured_outputs/test_outlines.py | 2 +- .../unit/steps/tasks/test_argilla_labeller.py | 2 +- .../unit/steps/tasks/test_image_generation.py | 55 ++++ .../tasks/test_improving_text_embeddings.py | 3 +- .../tasks/test_instruction_backtranslation.py | 3 +- .../steps/tasks/test_structured_generation.py | 3 +- .../steps/tasks/test_text_classification.py | 3 +- tests/unit/steps/tasks/test_ultrafeedback.py | 3 +- tests/unit/steps/test_base.py | 2 +- tests/unit/steps/test_decorator.py | 2 +- tests/unit/test_distiset.py | 56 ++++ tests/unit/utils/test_requirements.py | 2 +- 164 files changed, 2398 insertions(+), 632 deletions(-) create mode 100644 docs/api/models/image_generation/image_generation_gallery.md create mode 100644 docs/api/models/image_generation/index.md delete mode 100644 docs/api/pipeline/typing.md delete mode 100644 docs/api/step/typing.md create mode 100644 docs/api/task/image_task.md delete mode 100644 docs/api/task/typing.md create mode 100644 docs/api/typing.md create mode 100644 docs/sections/how_to_guides/basic/task/image_task.md create mode 100644 docs/sections/pipeline_samples/examples/image_generation.md create mode 100644 examples/image_generation.py create mode 100644 src/distilabel/models/base_clients/__init__.py create mode 100644 src/distilabel/models/base_clients/inference_endpoints.py create mode 100644 src/distilabel/models/base_clients/openai.py create mode 100644 src/distilabel/models/image_generation/__init__.py create mode 100644 src/distilabel/models/image_generation/base.py create mode 100644 src/distilabel/models/image_generation/huggingface/__init__.py create mode 100644 src/distilabel/models/image_generation/huggingface/inference_endpoints.py create mode 100644 src/distilabel/models/image_generation/openai.py create mode 100644 src/distilabel/models/image_generation/utils.py delete mode 100644 src/distilabel/models/llms/typing.py create mode 100644 src/distilabel/steps/tasks/image_generation.py rename src/distilabel/{typing.py => typing/__init__.py} (72%) create mode 100644 src/distilabel/typing/base.py rename src/distilabel/{steps/tasks/typing.py => typing/models.py} (66%) rename src/distilabel/{pipeline/typing.py => typing/pipeline.py} (98%) rename src/distilabel/{steps/typing.py => typing/steps.py} (100%) create mode 100644 tests/unit/models/image_generation/__init__.py create mode 100644 tests/unit/models/image_generation/huggingface/__init__.py create mode 100644 tests/unit/models/image_generation/huggingface/test_inference_endpoints.py create mode 100644 tests/unit/models/image_generation/test_openai.py create mode 100644 tests/unit/steps/tasks/test_image_generation.py diff --git a/docs/api/models/image_generation/image_generation_gallery.md b/docs/api/models/image_generation/image_generation_gallery.md new file mode 100644 index 0000000000..2baab4baee --- /dev/null +++ b/docs/api/models/image_generation/image_generation_gallery.md @@ -0,0 +1,10 @@ +# ImageGenerationModel Gallery + +This section contains the existing [`ImageGenerationModel`][distilabel.models.image_generation] subclasses implemented in `distilabel`. + +::: distilabel.models.image_generation + options: + filters: + - "!^ImageGenerationModel$" + - "!^AsyngImageGenerationModel$" + - "!typing" \ No newline at end of file diff --git a/docs/api/models/image_generation/index.md b/docs/api/models/image_generation/index.md new file mode 100644 index 0000000000..f8d326236b --- /dev/null +++ b/docs/api/models/image_generation/index.md @@ -0,0 +1,7 @@ +# ImageGenerationModel + +This section contains the API reference for the `distilabel` image generation models, both for the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] synchronous implementation, and for the [`AsyncImageGenerationModel`][distilabel.models.image_generation.AsyncImageGenerationModel] asynchronous one. + +For more information and examples on how to use existing LLMs or create custom ones, please refer to [Tutorial - ImageGenerationModel](../../../sections/how_to_guides/basic/task/image_task.md). + +::: distilabel.models.image_generation.base diff --git a/docs/api/pipeline/typing.md b/docs/api/pipeline/typing.md deleted file mode 100644 index e4455ece00..0000000000 --- a/docs/api/pipeline/typing.md +++ /dev/null @@ -1,3 +0,0 @@ -# Pipeline Typing - -::: distilabel.pipeline.typing diff --git a/docs/api/step/typing.md b/docs/api/step/typing.md deleted file mode 100644 index 1a86e7dac1..0000000000 --- a/docs/api/step/typing.md +++ /dev/null @@ -1,3 +0,0 @@ -# Step Typing - -::: distilabel.steps.typing \ No newline at end of file diff --git a/docs/api/task/image_task.md b/docs/api/task/image_task.md new file mode 100644 index 0000000000..5cb698d548 --- /dev/null +++ b/docs/api/task/image_task.md @@ -0,0 +1,7 @@ +# ImageTask + +This section contains the API reference for the `distilabel` image generation tasks. + +For more information on how the [`ImageTask`][distilabel.steps.tasks.ImageTask] works and see some examples, check the [Tutorial - Task - ImageTask](../../sections/how_to_guides/basic/task/generator_task.md) page. + +::: distilabel.steps.tasks.base.ImageTask diff --git a/docs/api/task/task_gallery.md b/docs/api/task/task_gallery.md index 4cf90c479d..aa2f3ecf2d 100644 --- a/docs/api/task/task_gallery.md +++ b/docs/api/task/task_gallery.md @@ -8,5 +8,6 @@ This section contains the existing [`Task`][distilabel.steps.tasks.Task] subclas - "!Task" - "!_Task" - "!GeneratorTask" + - "!ImageTask" - "!ChatType" - "!typing" \ No newline at end of file diff --git a/docs/api/task/typing.md b/docs/api/task/typing.md deleted file mode 100644 index 818ad070b6..0000000000 --- a/docs/api/task/typing.md +++ /dev/null @@ -1,3 +0,0 @@ -# Task Typing - -::: distilabel.steps.tasks.typing \ No newline at end of file diff --git a/docs/api/typing.md b/docs/api/typing.md new file mode 100644 index 0000000000..53d33868d8 --- /dev/null +++ b/docs/api/typing.md @@ -0,0 +1,8 @@ +# Types + +This section contains the different types used accross the distilabel codebase. + +::: distilabel.typing.base +::: distilabel.typing.steps +::: distilabel.typing.models +::: distilabel.typing.pipeline diff --git a/docs/sections/how_to_guides/advanced/distiset.md b/docs/sections/how_to_guides/advanced/distiset.md index 1c00554e28..001ec827ed 100644 --- a/docs/sections/how_to_guides/advanced/distiset.md +++ b/docs/sections/how_to_guides/advanced/distiset.md @@ -119,6 +119,33 @@ class MagpieGenerator(GeneratorTask, MagpieBase): The `Citations` section can include any number of bibtex references. To define them, you can add as much elements as needed just like in the example: each citation will be a block of the form: ` ```@misc{...}``` `. This information will be automatically used in the README of your `Distiset` if you decide to call `distiset.push_to_hub`. Alternatively, if the `Citations` is not found, but in the `References` there are found any urls pointing to `https://arxiv.org/`, we will try to obtain the `Bibtex` equivalent automatically. This way, Hugging Face can automatically track the paper for you and it's easier to find other datasets citing the same paper, or directly visiting the paper page. +#### Image Datasets + +!!! info "Keep reading if you are interested in Image datasets" + + The `Distiset` object has a new method `transform_columns_to_image` specifically to transform the images to `PIL.Image.Image` before pushing the dataset to the hugging face hub. + +Since version `1.5.0` we have the [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/task/imagegeneration/) task that is able to generate images from text. By default, all the process will work internally with a string representation for the images. This is done for simplicity while processing. But to take advantage of the Hugging Face Hub functionalities if the dataset generated is going to be stored there, a proper Image object may be preferable, so we can see the images in the dataset viewer for example. Let's take a look at the following pipeline extracted from "examples/image_generation.py" at the root of the repository to see how we can do it: + +```diff +# Assume all the imports are already done, we are only interested +with Pipeline(name="image_generation_pipeline") as pipeline: + img_generation = ImageGeneration( + name="flux_schnell", + llm=igm, + InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell") + ) + ... + +if __name__ == "__main__": + distiset = pipeline.run(use_cache=False, dataset=ds) + # Save the images as `PIL.Image.Image` ++ distiset = distiset.transform_columns_to_image("image") + distiset.push_to_hub(...) +``` + +After calling [`transform_columns_to_image`][distilabel.distiset.Distiset.transform_columns_to_image] on the image columns we may have generated (in this case we only want to transform the `image` column, but a list can be passed). This will apply to any leaf nodes we have in the pipeline, meaning if we have different subsets, the "image" column will be found in all of them, or we can pass a list of columns. + ### Save and load from disk Take into account that these methods work as `datasets.load_from_disk` and `datasets.Dataset.save_to_disk` so the arguments are directly passed to those methods. This means you can also make use of `storage_options` argument to save your [`Distiset`][distilabel.distiset.Distiset] in your cloud provider, including the distilabel artifacts (`pipeline.yaml`, `pipeline.log` and the `README.md` with the dataset card). You can read more in `datasets` documentation [here](https://huggingface.co/docs/datasets/filesystems#saving-serialized-datasets). diff --git a/docs/sections/how_to_guides/advanced/pipeline_requirements.md b/docs/sections/how_to_guides/advanced/pipeline_requirements.md index 66a2594bd2..3f739cbf2a 100644 --- a/docs/sections/how_to_guides/advanced/pipeline_requirements.md +++ b/docs/sections/how_to_guides/advanced/pipeline_requirements.md @@ -9,7 +9,7 @@ from typing import List from distilabel.steps import Step from distilabel.steps.base import StepInput -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput from distilabel.steps import LoadDataFromDicts from distilabel.utils.requirements import requirements from distilabel.pipeline import Pipeline diff --git a/docs/sections/how_to_guides/advanced/structured_generation.md b/docs/sections/how_to_guides/advanced/structured_generation.md index 3eb1da99af..1675d369a6 100644 --- a/docs/sections/how_to_guides/advanced/structured_generation.md +++ b/docs/sections/how_to_guides/advanced/structured_generation.md @@ -21,7 +21,7 @@ The [`LLM`][distilabel.models.llms.LLM] has an argument named `structured_output We will start with a JSON example, where we initially define a `pydantic.BaseModel` schema to guide the generation of the structured output. !!! NOTE - Take a look at [`StructuredOutputType`][distilabel.steps.tasks.typing.StructuredOutputType] to see the expected format + Take a look at [`StructuredOutputType`][distilabel.typing.models.StructuredOutputType] to see the expected format of the `structured_output` dict variable. ```python @@ -139,7 +139,7 @@ For other LLM providers behind APIs, there's no direct way of accessing the inte ``` !!! Note - Take a look at [`InstructorStructuredOutputType`][distilabel.steps.tasks.typing.InstructorStructuredOutputType] to see the expected format + Take a look at [`InstructorStructuredOutputType`][distilabel.typing.models.InstructorStructuredOutputType] to see the expected format of the `structured_output` dict variable. The following is the same example you can see with `outlines`'s `JSON` section for comparison purposes. diff --git a/docs/sections/how_to_guides/basic/step/generator_step.md b/docs/sections/how_to_guides/basic/step/generator_step.md index 0422644c36..50ca5e52d7 100644 --- a/docs/sections/how_to_guides/basic/step/generator_step.md +++ b/docs/sections/how_to_guides/basic/step/generator_step.md @@ -9,7 +9,7 @@ from typing_extensions import override from distilabel.steps import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, GeneratorStepOutput + from distilabel.typing import StepColumns, GeneratorStepOutput class MyGeneratorStep(GeneratorStep): instructions: List[str] @@ -67,7 +67,7 @@ We can define a custom generator step by creating a new subclass of the [`Genera The default signature for the `process` method is `process(self, offset: int = 0) -> GeneratorStepOutput`. The argument `offset` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. !!! WARNING - For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. === "Inherit from `GeneratorStep`" @@ -81,7 +81,7 @@ We can define a custom generator step by creating a new subclass of the [`Genera from distilabel.steps import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, GeneratorStepOutput + from distilabel.typing import StepColumns, GeneratorStepOutput class MyGeneratorStep(GeneratorStep): instructions: List[str] @@ -104,7 +104,7 @@ We can define a custom generator step by creating a new subclass of the [`Genera from distilabel.steps import step if TYPE_CHECKING: - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import GeneratorStepOutput @step(outputs=[...], step_type="generator") def CustomGeneratorStep(offset: int = 0) -> "GeneratorStepOutput": diff --git a/docs/sections/how_to_guides/basic/step/global_step.md b/docs/sections/how_to_guides/basic/step/global_step.md index 814f01a0fb..db050a6dc7 100644 --- a/docs/sections/how_to_guides/basic/step/global_step.md +++ b/docs/sections/how_to_guides/basic/step/global_step.md @@ -16,7 +16,7 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`. The argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. !!! WARNING - For the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + For the custom [`GlobalStep`][distilabel.steps.GlobalStep] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. === "Inherit from `GlobalStep`" @@ -27,7 +27,7 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis from distilabel.steps import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class CustomStep(Step): @property @@ -61,7 +61,7 @@ We can define a custom step by creating a new subclass of the [`GlobalStep`][dis from distilabel.steps import StepInput, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @step(inputs=[...], outputs=[...], step_type="global") def CustomStep(inputs: StepInput) -> "StepOutput": diff --git a/docs/sections/how_to_guides/basic/step/index.md b/docs/sections/how_to_guides/basic/step/index.md index d03a6b2149..76cae37075 100644 --- a/docs/sections/how_to_guides/basic/step/index.md +++ b/docs/sections/how_to_guides/basic/step/index.md @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING from distilabel.steps import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class MyStep(Step): @property @@ -87,7 +87,7 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe The default signature for the `process` method is `process(self, *inputs: StepInput) -> StepOutput`. The argument `inputs` should be respected, no more arguments can be provided, and the type-hints and return type-hints should be respected too because it should be able to receive any number of inputs by default i.e. more than one [`Step`][distilabel.steps.Step] at a time could be connected to the current one. !!! WARNING - For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.steps.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. + For the custom [`Step`][distilabel.steps.Step] subclasses to work properly with `distilabel` and with the validation and serialization performed by default over each [`Step`][distilabel.steps.Step] in the [`Pipeline`][distilabel.pipeline.Pipeline], the type-hint for both [`StepInput`][distilabel.steps.StepInput] and [`StepOutput`][distilabel.typing.StepOutput] should be used and not surrounded with double-quotes or imported under `typing.TYPE_CHECKING`, otherwise, the validation and/or serialization will fail. === "Inherit from `Step`" @@ -98,7 +98,7 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe from distilabel.steps import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class CustomStep(Step): @property @@ -132,7 +132,7 @@ We can define a custom step by creating a new subclass of the [`Step`][distilabe from distilabel.steps import StepInput, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @step(inputs=[...], outputs=[...]) def CustomStep(inputs: StepInput) -> "StepOutput": diff --git a/docs/sections/how_to_guides/basic/task/generator_task.md b/docs/sections/how_to_guides/basic/task/generator_task.md index 6fbb3d742e..bb86c28e31 100644 --- a/docs/sections/how_to_guides/basic/task/generator_task.md +++ b/docs/sections/how_to_guides/basic/task/generator_task.md @@ -12,8 +12,7 @@ from typing import Any, Dict, List, Union from typing_extensions import override from distilabel.steps.tasks.base import GeneratorTask -from distilabel.steps.tasks.typing import ChatType -from distilabel.steps.typing import GeneratorOutput +from distilabel.typing import ChatType, GeneratorOutput class MyCustomTask(GeneratorTask): @@ -78,7 +77,7 @@ We can define a custom generator task by creating a new subclass of the [`Genera from typing import Any, Dict, List, Union from distilabel.steps.tasks.base import GeneratorTask -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType class MyCustomTask(GeneratorTask): diff --git a/docs/sections/how_to_guides/basic/task/image_task.md b/docs/sections/how_to_guides/basic/task/image_task.md new file mode 100644 index 0000000000..ecdee6c66f --- /dev/null +++ b/docs/sections/how_to_guides/basic/task/image_task.md @@ -0,0 +1,104 @@ +# ImageTask to work with Image Generation Models + +## Working with ImageTasks + +The [`ImageTask`][distilabel.steps.tasks.ImageTask] is a custom implementation of a [`Task`][distilabel.steps.tasks.Task] special to deal images. These tasks behave exactly as any other [`Task`][distilabel.steps.tasks.Task], but instead of relying on an [`LLM`][distilabel.models.llms.LLM], they work with a [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel]. + +!!! info "New in version 1.5.0" + This task is new and is expected to work with Image Generation Models. + +These tasks take as attribute an `image_generation_model` instead of `llm` as we would have with the standard `Task`, but everything else remains the same. Let's see an example with [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/imagegeneration/): + +```python +from distilabel.steps.tasks import ImageGeneration +from distilabel.models.image_generation import InferenceEndpointsImageGeneration + +task = ImageGeneration( + name="image-generation", + image_generation_model=InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell"), +) +task.load() + +next(task.process([{"prompt": "a white siamese cat"}])) +# [{'image": "iVBORw0KGgoAAAANSUhEUgA...", "model_name": "black-forest-labs/FLUX.1-schnell"}] +``` + +!!! info "Visualize the image in a notebook" + If you are testing the `ImageGeneration` task in a notebook, you can do the following + to see the rendered image: + + ```python + from distilabel.models.image_generation.utils import image_from_str + + result = next(task.process([{"prompt": "a white siamese cat"}])) + image_from_str(result[0]["image"]) # Returns a `PIL.Image.Image` that renders directly + ``` + +!!! tip "Running ImageGeneration in a Pipeline" + This transformation between image as string and as PIL object can be done for the whole dataset if running a pipeline, by calling the method `transform_columns_to_image` on the final distiset and passing the name (or list of names) of the column image. + +## Defining custom ImageTasks + +We can define a custom generator task by creating a new subclass of the [`ImageTask`][distilabel.steps.tasks.ImageTask] and defining the following: + +- `process`: is a method that generates the data based on the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] and the `prompt` provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. + +- `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. + +- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always include `model_name` as one of the outputs since that's automatically injected from the LLM. + +- `format_input`: is a method that receives a dictionary with the input data and returns a *prompt* to be passed to the model. + +- `format_output`: is a method that receives the output from the [`ImageGenerationModel`][distilabel.models.image_generation.ImageGenerationModel] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. + +```python +from typing import TYPE_CHECKING + +from distilabel.models.image_generation.utils import image_from_str, image_to_str +from distilabel.steps.base import StepInput +from distilabel.steps.tasks.base import ImageTask + +if TYPE_CHECKING: + from distilabel.typing import StepColumns, StepOutput + + +class MyCustomImageTask(ImageTask): + @override + def process(self, offset: int = 0) -> GeneratorOutput: + formatted_inputs = self._format_inputs(inputs) + + outputs = self.llm.generate_outputs( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.get_generation_kwargs(), + ) + + task_outputs = [] + for input, input_outputs in zip(inputs, outputs): + formatted_outputs = self._format_outputs(input_outputs, input) + for formatted_output in formatted_outputs: + task_outputs.append( + {**input, **formatted_output, "model_name": self.llm.model_name} + ) + yield task_outputs + + @property + def inputs(self) -> "StepColumns": + return ["prompt"] + + @property + def outputs(self) -> "StepColumns": + return ["image", "model_name"] + + def format_input(self, input: dict[str, any]) -> str: + return input["prompt"] + + def format_output( + self, output: Union[str, None], input: dict[str, any] + ) -> Dict[str, Any]: + # Extract/generate/modify the image from the output + return {"image": ..., "model_name": self.llm.model_name} +``` + +!!! Warning + Note the fact that in the `process` method we are not dealing with the `image_generation` attribute but with the `llm`. This is not a bug, but intended, as internally we rename the `image_generation` to `llm` to reuse the code. diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md index dd5de6f837..c2291c8769 100644 --- a/docs/sections/how_to_guides/basic/task/index.md +++ b/docs/sections/how_to_guides/basic/task/index.md @@ -217,7 +217,7 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe - `inputs`: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. -- `format_input`: is a method that receives a dictionary with the input data and returns a [`ChatType`][distilabel.steps.tasks.ChatType] following [the chat-completion OpenAI message formatting](https://platform.openai.com/docs/guides/text-generation). +- `format_input`: is a method that receives a dictionary with the input data and returns a [`ChatType`][distilabel.typing.models.ChatType] following [the chat-completion OpenAI message formatting](https://platform.openai.com/docs/guides/text-generation). - `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always include `model_name` as one of the outputs since that's automatically injected from the LLM. @@ -233,8 +233,7 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe from distilabel.steps.tasks import Task if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import StepColumns, ChatType class MyCustomTask(Task): diff --git a/docs/sections/pipeline_samples/examples/image_generation.md b/docs/sections/pipeline_samples/examples/image_generation.md new file mode 100644 index 0000000000..39f8daba3b --- /dev/null +++ b/docs/sections/pipeline_samples/examples/image_generation.md @@ -0,0 +1,108 @@ +--- +hide: toc +--- + +# Image generation with `distilabel` + +Create synthetic images using `distilabel`. + +This example shows how distilabel can be used to generate image data, either using [`InferenceEndpointsImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/image_generation/inferenceendpointsimagegeneration/) or [`OpenAIImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/image_generation/openaiimagegeneration/), thanks to the [`ImageGeneration`](https://distilabel.argilla.io/dev/components-gallery/task/imagegeneration/) task. + + +=== "Inference Endpoints - black-forest-labs/FLUX.1-schnell" + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import KeepColumns + from distilabel.models.image_generation import InferenceEndpointsImageGeneration + from distilabel.steps.tasks import ImageGeneration + + from datasets import load_dataset + + ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3)) + + with Pipeline(name="image_generation_pipeline") as pipeline: + ilm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell" + ) + + img_generation = ImageGeneration( + name="flux_schnell", + llm=ilm, + input_mappings={"prompt": "persona"} + ) + + keep_columns = KeepColumns(columns=["persona", "model_name", "image"]) + + img_generation >> keep_columns + ``` + + Sample image for the prompt: + + > A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati. + + ![image_ie](https://huggingface.co/datasets/plaguss/test-finepersonas-v0.1-tiny-flux-schnell/resolve/main/artifacts/flux_schnell/images/3333f9870feda32a449994017eb72675.jpeg) + +=== "OpenAI - dall-e-3" + + ```python + from distilabel.pipeline import Pipeline + from distilabel.steps import KeepColumns + from distilabel.models.image_generation import OpenAIImageGeneration + from distilabel.steps.tasks import ImageGeneration + + from datasets import load_dataset + + ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3)) + + with Pipeline(name="image_generation_pipeline") as pipeline: + ilm = OpenAIImageGeneration( + model="dall-e-3", + generation_kwargs={ + "size": "1024x1024", + "quality": "standard", + "style": "natural" + } + ) + + img_generation = ImageGeneration( + name="dalle-3" + llm=ilm, + input_mappings={"prompt": "persona"} + ) + + keep_columns = KeepColumns(columns=["persona", "model_name", "image"]) + + img_generation >> keep_columns + ``` + + Sample image for the prompt: + + > A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati. + + ![image_oai](https://huggingface.co/datasets/plaguss/test-finepersonas-v0.1-tiny-dall-e-3/resolve/main/artifacts/dalle-3/images/3333f9870feda32a449994017eb72675.jpeg) + +!!! success "Save the Distiset as an Image Dataset" + + Note the call to `Distiset.transform_columns_to_image`, to have the images uploaded directly as an [`Image dataset`](https://huggingface.co/docs/hub/en/datasets-image): + + ```python + if __name__ == "__main__": + distiset = pipeline.run(use_cache=False, dataset=ds) + # Save the images as `PIL.Image.Image` + distiset = distiset.transform_columns_to_image("image") + distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell") + + ``` + +The full pipeline can be run at the following example. Keep in mind, you need to install `pillow` first: `pip install distilabel[vision]`. + +??? Run + + ```python + python examples/image_generation.py + ``` + +```python title="image_generation.py" +--8<-- "examples/image_generation.py" +``` diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md index 1c95b60b18..956af3b518 100644 --- a/docs/sections/pipeline_samples/index.md +++ b/docs/sections/pipeline_samples/index.md @@ -161,6 +161,14 @@ hide: toc [:octicons-arrow-right-24: Example](examples/exam_questions.md) +- __Image generation with distilabel__ + + --- + + Generate synthetic images using distilabel. + + [:octicons-arrow-right-24: Example](examples/image_generation.md) + - __Text generation with images in distilabel__ --- diff --git a/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb index 7b75f7fcaa..6730a80892 100644 --- a/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb +++ b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb @@ -196,7 +196,7 @@ " from distilabel.steps import GlobalStep, StepInput\n", "\n", " if TYPE_CHECKING:\n", - " from distilabel.steps.typing import StepOutput\n", + " from distilabel.typing import StepOutput\n", " \n", " import random\n", "\n", diff --git a/examples/image_generation.py b/examples/image_generation.py new file mode 100644 index 0000000000..dbee42ebe4 --- /dev/null +++ b/examples/image_generation.py @@ -0,0 +1,42 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset + +from distilabel.models.image_generation import InferenceEndpointsImageGeneration +from distilabel.pipeline import Pipeline +from distilabel.steps import KeepColumns +from distilabel.steps.tasks import ImageGeneration + +ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3)) + +with Pipeline(name="image_generation_pipeline") as pipeline: + igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell") + + img_generation = ImageGeneration( + name="flux_schnell", + image_generation_model=igm, + input_mappings={"prompt": "persona"}, + ) + + keep_columns = KeepColumns(columns=["persona", "model_name", "image"]) + + img_generation >> keep_columns + + +if __name__ == "__main__": + distiset = pipeline.run(use_cache=False, dataset=ds) + # Save the images as `PIL.Image.Image` + distiset = distiset.transform_columns_to_image("image") + distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell") diff --git a/mkdocs.yml b/mkdocs.yml index f5a98be65d..24e5ca9b74 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -184,6 +184,7 @@ nav: - Tasks for generating and judging with LLMs: - "sections/how_to_guides/basic/task/index.md" - GeneratorTask: "sections/how_to_guides/basic/task/generator_task.md" + - ImageTask: "sections/how_to_guides/basic/task/image_task.md" - Executing Tasks with LLMs: "sections/how_to_guides/basic/llm/index.md" - Execute Steps and Tasks in a Pipeline: "sections/how_to_guides/basic/pipeline/index.md" - Advanced: @@ -222,6 +223,7 @@ nav: - Structured generation with instructor: "sections/pipeline_samples/examples/mistralai_with_instructor.md" - Create a social network with FinePersonas: "sections/pipeline_samples/examples/fine_personas_social_network.md" - Create questions and answers for a exam: "sections/pipeline_samples/examples/exam_questions.md" + - Image generation with distilabel: "sections/pipeline_samples/examples/image_generation.md" - Text generation with images in distilabel: "sections/pipeline_samples/examples/text_generation_with_image.md" - API Reference: - Step: @@ -235,22 +237,22 @@ nav: - Hugging Face: "api/step_gallery/hugging_face.md" - Columns: "api/step_gallery/columns.md" - Extra: "api/step_gallery/extra.md" - - Typing: "api/step/typing.md" - Task: - "api/task/index.md" - GeneratorTask: "api/task/generator_task.md" - Task Gallery: "api/task/task_gallery.md" - - Typing: "api/task/typing.md" - LLM: - "api/models/llm/index.md" - LLM Gallery: "api/models/llm/llm_gallery.md" - Embedding: - "api/models/embedding/index.md" - Embedding Gallery: "api/models/embedding/embedding_gallery.md" + - ImageGenerationModels: + - "api/models/image_generation/index.md" + - Image Generation Gallery: "api/models/image_generation/image_generation_gallery.md" - Pipeline: - "api/pipeline/index.md" - Routing Batch Function: "api/pipeline/routing_batch_function.md" - - Typing: "api/pipeline/typing.md" - Step Wrapper: "api/pipeline/step_wrapper.md" - Mixins: - RuntimeParametersMixin: "api/mixins/runtime_parameters.md" @@ -259,6 +261,7 @@ nav: - Errors: "api/errors.md" - Distiset: "api/distiset.md" - CLI: "api/cli.md" + - Types: "api/typing.md" - Community: - sections/community/index.md - How to contribute?: sections/community/contributor.md diff --git a/pyproject.toml b/pyproject.toml index 3123d56b55..1c55ebb1c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ text-clustering = [ "matplotlib >= 3.8.3", # For the figure (even though it's optional) ] mlx = ["mlx >= 0.21.0", "mlx-lm"] +vision = ["Pillow >= 10.3.0"] # To work with images. # minhash minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"] diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index e934f9d340..ce4e855858 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -24,7 +24,7 @@ import fsspec import yaml -from datasets import Dataset, load_dataset, load_from_disk +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from datasets.filesystems import is_remote_filesystem from huggingface_hub import DatasetCardData, HfApi, upload_file, upload_folder from huggingface_hub.file_download import hf_hub_download @@ -187,9 +187,14 @@ def _get_card( record = ( dataset[0] if not isinstance(dataset, dict) else dataset["train"][0] ) + from PIL import ImageFile + for key, value in record.items(): + # If the value is an image, we set it to an empty string to avoid the `README.md` to huge + if isinstance(value, ImageFile.ImageFile): + value = "" # If list is too big, the `README.md` generated will be huge so we truncate it - if isinstance(value, list): + elif isinstance(value, list): length = len(value) if length < 10: continue @@ -585,6 +590,51 @@ def __repr__(self): repr = re.sub(r"^", " " * 4, repr, count=0, flags=re.M) return f"Distiset({{\n{repr}\n}})" + def transform_columns_to_image(self, columns: Union[str, list[str]]) -> Self: + """Transforms the columns of the dataset to `PIL.Image` objects. + + Args: + columns: Column or list of columns to transform. + + Returns: + Transforms the columns of the dataset to `PIL.Image` objects before pushing, + so the Hub treats them as Image objects and can be rendered in the dataset + viewer, and cast them to be automatically transformed when downloading + the dataset back. + """ + from datasets import Image + + from distilabel.models.image_generation.utils import image_from_str + + columns = [columns] if isinstance(columns, str) else columns + + def cast_to_image(row: dict) -> dict: + for column in columns: + row[column] = image_from_str(row[column]) + return row + + for name, dataset in self.items(): + # In case train_test_split was called + if isinstance(dataset, DatasetDict): + for split, dataset_split in dataset.items(): + dataset_split = dataset_split.map(cast_to_image) + for column in columns: + if column in dataset_split.column_names: + dataset_split = dataset_split.cast_column( + column, Image(decode=True) + ) + self[name][split] = dataset_split + else: + dataset = dataset.map(cast_to_image) + + for column in columns: + if column in dataset.column_names: + dataset = dataset.cast_column(column, Image(decode=True)) + + self[name] = dataset + + return self + def create_distiset( # noqa: C901 data_dir: Path, diff --git a/src/distilabel/llms.py b/src/distilabel/llms.py index b00d891407..8d579048df 100644 --- a/src/distilabel/llms.py +++ b/src/distilabel/llms.py @@ -37,10 +37,10 @@ from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM from distilabel.models.llms.together import TogetherLLM -from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.typing import GenerateOutput, HiddenState __all__ = [ "LLM", diff --git a/src/distilabel/mixins/runtime_parameters.py b/src/distilabel/mixins/runtime_parameters.py index f8371e30ab..73b5845e8f 100644 --- a/src/distilabel/mixins/runtime_parameters.py +++ b/src/distilabel/mixins/runtime_parameters.py @@ -13,11 +13,14 @@ # limitations under the License. import difflib +import inspect +from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Tuple, TypeVar, Union from pydantic import BaseModel, Field, PrivateAttr from typing_extensions import Annotated, get_args, get_origin +from distilabel.utils.docstring import parse_google_docstring from distilabel.utils.typing_ import ( extract_annotation_inner_type, is_type_pydantic_secret_field, @@ -26,6 +29,9 @@ if TYPE_CHECKING: from pydantic.fields import FieldInfo + from distilabel.utils.docstring import Docstring + + _T = TypeVar("_T") _RUNTIME_PARAMETER_ANNOTATION = "distilabel_step_runtime_parameter" RuntimeParameter = Annotated[ @@ -218,3 +224,84 @@ def _is_runtime_parameter(field: "FieldInfo") -> Tuple[bool, bool]: return True, is_optional return False, False + + +class RuntimeParametersModelMixin(RuntimeParametersMixin): + """Specific mixin for RuntimeParameters that affect the model classes, LLM, + ImageGenerationModel, etc. + """ + + @property + def generate_parameters(self) -> list["inspect.Parameter"]: + """Returns the parameters of the `generate` method. + + Returns: + A list containing the parameters of the `generate` method. + """ + return list(inspect.signature(self.generate).parameters.values()) + + @property + def runtime_parameters_names(self) -> "RuntimeParametersNames": + """Returns the runtime parameters of the `ImageGenerationModel`, which are combination of the + attributes of the `ImageGenerationModel` type hinted with `RuntimeParameter` and the parameters + of the `generate` method that are not `input` and `num_generations`. + + Returns: + A dictionary with the name of the runtime parameters as keys and a boolean + indicating if the parameter is optional or not. + """ + runtime_parameters = super().runtime_parameters_names + runtime_parameters["generation_kwargs"] = {} + + # runtime parameters from the `generate` method + for param in self.generate_parameters: + if param.name in ["input", "inputs", "num_generations"]: + continue + is_optional = param.default != inspect.Parameter.empty + runtime_parameters["generation_kwargs"][param.name] = is_optional + + return runtime_parameters + + def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]: + """Gets the information of the runtime parameters of the `LLM` such as the name + and the description. This function is meant to include the information of the runtime + parameters in the serialized data of the `LLM`. + + Returns: + A list containing the information for each runtime parameter of the `LLM`. + """ + runtime_parameters_info = super().get_runtime_parameters_info() + + generation_kwargs_info = next( + ( + runtime_parameter_info + for runtime_parameter_info in runtime_parameters_info + if runtime_parameter_info["name"] == "generation_kwargs" + ), + None, + ) + + # If `generation_kwargs` attribute is present, we need to include the `generate` + # method arguments as the information for this attribute. + if generation_kwargs_info: + generate_docstring_args = self.generate_parsed_docstring["args"] + generation_kwargs_info["keys"] = [] + + for key, value in generation_kwargs_info["optional"].items(): + info = {"name": key, "optional": value} + if description := generate_docstring_args.get(key): + info["description"] = description + generation_kwargs_info["keys"].append(info) + + generation_kwargs_info.pop("optional") + + return runtime_parameters_info + + @cached_property + def generate_parsed_docstring(self) -> "Docstring": + """Returns the parsed docstring of the `generate` method. + + Returns: + The parsed docstring of the `generate` method. + """ + return parse_google_docstring(self.generate) diff --git a/src/distilabel/models/__init__.py b/src/distilabel/models/__init__.py index 86ea2023e4..1c96f5ab0b 100644 --- a/src/distilabel/models/__init__.py +++ b/src/distilabel/models/__init__.py @@ -19,6 +19,14 @@ SentenceTransformerEmbeddings, ) from distilabel.models.embeddings.vllm import vLLMEmbeddings +from distilabel.models.image_generation.base import ( + AsyncImageGenerationModel, + ImageGenerationModel, +) +from distilabel.models.image_generation.huggingface.inference_endpoints import ( + InferenceEndpointsImageGeneration, +) +from distilabel.models.image_generation.openai import OpenAIImageGeneration from distilabel.models.llms.anthropic import AnthropicLLM from distilabel.models.llms.anyscale import AnyscaleLLM from distilabel.models.llms.azure import AzureOpenAILLM @@ -34,15 +42,16 @@ from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM from distilabel.models.llms.together import TogetherLLM -from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.typing import GenerateOutput, HiddenState __all__ = [ "LLM", "AnthropicLLM", "AnyscaleLLM", + "AsyncImageGenerationModel", "AsyncLLM", "AzureOpenAILLM", "ClientvLLM", @@ -52,6 +61,8 @@ "GenerateOutput", "GroqLLM", "HiddenState", + "ImageGenerationModel", + "InferenceEndpointsImageGeneration", "InferenceEndpointsLLM", "LiteLLM", "LlamaCppEmbeddings", @@ -60,6 +71,7 @@ "MixtureOfAgentsLLM", "MlxLLM", "OllamaLLM", + "OpenAIImageGeneration", "OpenAILLM", "SentenceTransformerEmbeddings", "TogetherLLM", diff --git a/src/distilabel/models/base_clients/__init__.py b/src/distilabel/models/base_clients/__init__.py new file mode 100644 index 0000000000..07e329b4d9 --- /dev/null +++ b/src/distilabel/models/base_clients/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.models.base_clients.inference_endpoints import ( + InferenceEndpointsBaseClient, +) +from distilabel.models.base_clients.openai import OpenAIBaseClient + +__all__ = ["InferenceEndpointsBaseClient", "OpenAIBaseClient"] diff --git a/src/distilabel/models/base_clients/inference_endpoints.py b/src/distilabel/models/base_clients/inference_endpoints.py new file mode 100644 index 0000000000..ebcc84e344 --- /dev/null +++ b/src/distilabel/models/base_clients/inference_endpoints.py @@ -0,0 +1,154 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import ( + TYPE_CHECKING, + Optional, + Union, +) + +from pydantic import ( + BaseModel, + Field, + PrivateAttr, + SecretStr, +) + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.typing import StructuredOutputType +from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR, get_hf_token + +if TYPE_CHECKING: + from huggingface_hub import AsyncInferenceClient + from transformers import PreTrainedTokenizer + + +class InferenceEndpointsBaseClient(BaseModel): + model_id: Optional[str] = None + + endpoint_name: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The name of the Inference Endpoint to use for the LLM.", + ) + endpoint_namespace: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The namespace of the Inference Endpoint to use for the LLM.", + ) + base_url: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The base URL to use for the Inference Endpoints API requests.", + ) + api_key: Optional[RuntimeParameter[SecretStr]] = Field( + default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR), + description="The API key to authenticate the requests to the Inference Endpoints API.", + ) + + tokenizer_id: Optional[str] = None + model_display_name: Optional[str] = None + + structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) + + _num_generations_param_supported = False + + _model_name: Optional[str] = PrivateAttr(default=None) + _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) + _api_key_env_var: str = PrivateAttr(HF_TOKEN_ENV_VAR) + _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) + + def load(self) -> None: # noqa: C901 + """Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference + Endpoint. + + Raises: + ImportError: if the `huggingface-hub` Python client is not installed. + ValueError: if the model is not currently deployed or is not running the TGI framework. + ImportError: if the `transformers` Python client is not installed. + """ + + try: + from huggingface_hub import ( + AsyncInferenceClient, + InferenceClient, + get_inference_endpoint, + ) + except ImportError as ie: + raise ImportError( + "Hugging Face Hub Python client is not installed. Please install it using" + " `pip install 'distilabel[hf-inference-endpoints]'`." + ) from ie + + if self.api_key is None: + self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key")) + + if self.model_id is not None: + client = InferenceClient( + model=self.model_id, token=self.api_key.get_secret_value() + ) + status = client.get_model_status() + + if ( + status.state not in {"Loadable", "Loaded"} + and status.framework != "text-generation-inference" + ): + raise ValueError( + f"Model {self.model_id} is not currently deployed or is not running the TGI framework" + ) + + self.base_url = client._resolve_url( + model=self.model_id, task="text-generation" + ) + + if self.endpoint_name is not None: + client = get_inference_endpoint( + name=self.endpoint_name, + namespace=self.endpoint_namespace, + token=self.api_key.get_secret_value(), + ) + if client.status in ["paused", "scaledToZero"]: + client.resume().wait(timeout=300) + elif client.status == "initializing": + client.wait(timeout=300) + + self.base_url = client.url + self._model_name = client.repository + + self._aclient = AsyncInferenceClient( + base_url=self.base_url, + token=self.api_key.get_secret_value(), + ) + + if self.tokenizer_id: + try: + from transformers import AutoTokenizer + except ImportError as ie: + raise ImportError( + "Transformers Python client is not installed. Please install it using" + " `pip install 'distilabel[hf-inference-endpoints]'`." + ) from ie + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) + + @property + def model_name(self) -> Union[str, None]: # type: ignore + """Returns the model name used for the model.""" + return ( + self.model_display_name + or self._model_name + or self.model_id + or self.endpoint_name + or self.base_url + ) diff --git a/src/distilabel/models/base_clients/openai.py b/src/distilabel/models/base_clients/openai.py new file mode 100644 index 0000000000..ada4d0b4d7 --- /dev/null +++ b/src/distilabel/models/base_clients/openai.py @@ -0,0 +1,122 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING, Dict, Optional + +from pydantic import BaseModel, Field, PrivateAttr, SecretStr + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.typing import InstructorStructuredOutputType + +if TYPE_CHECKING: + from openai import AsyncOpenAI, OpenAI + + +_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" + + +class OpenAIBaseClient(BaseModel): + model: str + base_url: Optional[RuntimeParameter[str]] = Field( + default_factory=lambda: os.getenv( + "OPENAI_BASE_URL", "https://api.openai.com/v1" + ), + description="The base URL to use for the OpenAI API requests.", + ) + api_key: Optional[RuntimeParameter[SecretStr]] = Field( + default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME), + description="The API key to authenticate the requests to the OpenAI API.", + ) # type: ignore + default_headers: Optional[RuntimeParameter[Dict[str, str]]] = Field( + default=None, + description="The default headers to use for the OpenAI API requests.", + ) + max_retries: RuntimeParameter[int] = Field( + default=6, + description="The maximum number of times to retry the request to the API before" + " failing.", + ) + timeout: RuntimeParameter[int] = Field( + default=120, + description="The maximum time in seconds to wait for a response from the API.", + ) + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( + Field( + default=None, + description="The structured output format to use across all the generations.", + ) + ) + + _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) + _client: "OpenAI" = PrivateAttr(None) # type: ignore + _aclient: "AsyncOpenAI" = PrivateAttr(None) # type: ignore + + def load(self) -> None: + """Loads the `AsyncOpenAI` client to benefit from async requests.""" + + try: + from openai import AsyncOpenAI, OpenAI + except ImportError as ie: + raise ImportError( + "OpenAI Python client is not installed. Please install it using" + " `pip install 'distilabel[openai]'`." + ) from ie + + if self.api_key is None: + raise ValueError( + f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" + f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." + ) + + self._client = OpenAI( + base_url=self.base_url, + api_key=self.api_key.get_secret_value(), + max_retries=self.max_retries, # type: ignore + timeout=self.timeout, + default_headers=self.default_headers, + ) + + self._aclient = AsyncOpenAI( + base_url=self.base_url, + api_key=self.api_key.get_secret_value(), + max_retries=self.max_retries, # type: ignore + timeout=self.timeout, + default_headers=self.default_headers, + ) + + if self.structured_output: + # This applies only to the LLMs. + result = self._prepare_structured_output( + structured_output=self.structured_output, + client=self._aclient, + framework="openai", + ) + self._aclient = result.get("client") # type: ignore + if structured_output := result.get("structured_output"): + self.structured_output = structured_output + + def unload(self) -> None: + """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled + in case an exception is raised and has to be handled in the main process""" + + self._client = None # type: ignore + self._aclient = None # type: ignore + self.default_headers = None + self.structured_output = None + + @property + def model_name(self) -> str: + """Returns the model name used for the LLM.""" + return self.model diff --git a/src/distilabel/models/embeddings/base.py b/src/distilabel/models/embeddings/base.py index e2ee4af3f1..ad46345d54 100644 --- a/src/distilabel/models/embeddings/base.py +++ b/src/distilabel/models/embeddings/base.py @@ -50,7 +50,9 @@ class Embeddings(RuntimeParametersMixin, BaseModel, _Serializable, ABC): def load(self) -> None: """Method to be called to initialize the `Embeddings`""" - self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}") + self._logger = logging.getLogger( + f"distilabel.models.embeddings.{self.model_name}" + ) def unload(self) -> None: """Method to be called to unload the `Embeddings` and release any resources.""" diff --git a/src/distilabel/models/image_generation/__init__.py b/src/distilabel/models/image_generation/__init__.py new file mode 100644 index 0000000000..42a4f5a3db --- /dev/null +++ b/src/distilabel/models/image_generation/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.models.image_generation.base import ( + AsyncImageGenerationModel, + ImageGenerationModel, +) +from distilabel.models.image_generation.huggingface.inference_endpoints import ( + InferenceEndpointsImageGeneration, +) +from distilabel.models.image_generation.openai import OpenAIImageGeneration + +__all__ = [ + "AsyncImageGenerationModel", + "ImageGenerationModel", + "InferenceEndpointsImageGeneration", + "OpenAIImageGeneration", +] diff --git a/src/distilabel/models/image_generation/base.py b/src/distilabel/models/image_generation/base.py new file mode 100644 index 0000000000..bdce07d504 --- /dev/null +++ b/src/distilabel/models/image_generation/base.py @@ -0,0 +1,247 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect +import logging +import sys +from abc import ABC, abstractmethod +from functools import cached_property +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr + +from distilabel.mixins.runtime_parameters import ( + RuntimeParameter, + RuntimeParametersModelMixin, +) +from distilabel.utils.docstring import parse_google_docstring +from distilabel.utils.itertools import grouper +from distilabel.utils.serialization import _Serializable + +if TYPE_CHECKING: + from logging import Logger + + from distilabel.utils.docstring import Docstring + + +class ImageGenerationModel(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC): + """Base class for `ImageGeneration` models. + + To implement an `ImageGeneration` subclass, you need to subclass this class and implement: + - `load` method to load the `ImageGeneration` model if needed. Don't forget to call `super().load()`, + so the `_logger` attribute is initialized. + - `model_name` property to return the model name used for the LLM. + - `generate` method to generate `num_generations` per input in `inputs`. + + Attributes: + generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate` + methods within each `ImageGenerationModel`. + _logger: the logger to be used for the `ImageGenerationModel`. It will be initialized + when the `load` method is called. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + protected_namespaces=(), + validate_default=True, + validate_assignment=True, + extra="forbid", + ) + + generation_kwargs: Optional[RuntimeParameter[dict[str, Any]]] = Field( + default_factory=dict, + description="The kwargs to be propagated to either `generate` or `agenerate`" + " methods within each `ImageGenerationModel`.", + ) + _logger: "Logger" = PrivateAttr(None) + + def load(self) -> None: + """Method to be called to initialize the `ImageGenerationModel`, and its logger.""" + self._logger = logging.getLogger( + f"distilabel.models.image_generation.{self.model_name}" + ) + + def unload(self) -> None: + """Method to be called to unload the `ImageGenerationModel` and release any resources.""" + pass + + @property + @abstractmethod + def model_name(self) -> str: + """Returns the model name used for the `ImageGenerationModel`.""" + pass + + def get_generation_kwargs(self) -> dict[str, Any]: + """Returns the generation kwargs to be used for the generation. This method can + be overridden to provide a more complex logic for the generation kwargs. + + Returns: + The kwargs to be used for the generation. + """ + return self.generation_kwargs # type: ignore + + @abstractmethod + def generate( + self, inputs: list[str], num_generations: int = 1, **kwargs: Any + ) -> list[list[dict[str, Any]]]: + """Generates images from the provided input. + + Args: + inputs: the prompt text to generate the image from. + num_generations: the number of images to generate. Defaults to `1`. + + Returns: + A list with a dictionary with the list of images generated. + """ + pass + + def generate_outputs( + self, + inputs: list[str], + num_generations: int = 1, + **kwargs: Any, + ) -> list[list[dict[str, Any]]]: + """This method is defined for compatibility with the `LLMs`. It calls the `generate` + method. + """ + return self.generate(inputs=inputs, num_generations=num_generations, **kwargs) + + +class AsyncImageGenerationModel(ImageGenerationModel): + """Abstract class for asynchronous `ImageGenerationModels`, to benefit from the async capabilities + of each LLM implementation. This class is meant to be subclassed by each `ImageGenerationModel`, and the + method `agenerate` needs to be implemented to provide the asynchronous generation of + responses. + + Attributes: + _event_loop: the event loop to be used for the asynchronous generation of responses. + """ + + _num_generations_param_supported = True + _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None) + _new_event_loop: bool = PrivateAttr(default=False) + + @property + def generate_parameters(self) -> list[inspect.Parameter]: + """Returns the parameters of the `agenerate` method. + + Returns: + A list containing the parameters of the `agenerate` method. + """ + return list(inspect.signature(self.agenerate).parameters.values()) + + @cached_property + def generate_parsed_docstring(self) -> "Docstring": + """Returns the parsed docstring of the `agenerate` method. + + Returns: + The parsed docstring of the `agenerate` method. + """ + return parse_google_docstring(self.agenerate) + + @property + def event_loop(self) -> "asyncio.AbstractEventLoop": + if self._event_loop is None: + try: + self._event_loop = asyncio.get_running_loop() + if self._event_loop.is_closed(): + self._event_loop = asyncio.new_event_loop() # type: ignore + self._new_event_loop = True + except RuntimeError: + self._event_loop = asyncio.new_event_loop() + self._new_event_loop = True + asyncio.set_event_loop(self._event_loop) + return self._event_loop + + @abstractmethod + async def agenerate( + self, input: str, num_generations: int = 1, **kwargs: Any + ) -> list[dict[str, Any]]: + """Generates images from the provided input. + + Args: + input: the input text to generate the image from. + num_generations: the number of images to generate. Defaults to `1`. + + Returns: + A list with a dictionary with the list of images generated. + """ + pass + + async def _agenerate( + self, inputs: list[str], num_generations: int = 1, **kwargs: Any + ) -> list[list[dict[str, Any]]]: + """Internal function to concurrently generate images for a list of inputs. + + Args: + inputs: the list of inputs to generate images for. + num_generations: the number of generations to generate per input. + **kwargs: the additional kwargs to be used for the generation. + + Returns: + A list containing the generations for each input. + """ + if self._num_generations_param_supported: + tasks = [ + asyncio.create_task( + self.agenerate( + input=input, num_generations=num_generations, **kwargs + ) + ) + for input in inputs + ] + return await asyncio.gather(*tasks) + + tasks = [ + asyncio.create_task(self.agenerate(input=input, **kwargs)) + for input in inputs + for _ in range(num_generations) + ] + outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)] + return [ + list(group) + for group in grouper(outputs, n=num_generations, incomplete="ignore") + ] + + def generate( + self, + inputs: list[str], + num_generations: int = 1, + **kwargs: Any, + ) -> list[list[dict[str, Any]]]: + """Method to generate a list of images asynchronously, returning the output + synchronously awaiting for the image of each input sent to `agenerate`. + + Args: + inputs: the list of inputs to generate images for. + num_generations: the number of generations to generate per input. + **kwargs: the additional kwargs to be used for the generation. + + Returns: + A list containing the images for each input. + """ + return self.event_loop.run_until_complete( + self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs) + ) + + def __del__(self) -> None: + """Closes the event loop when the object is deleted.""" + if sys.meta_path is None: + return + + if self._new_event_loop: + if self._event_loop.is_running(): + self._event_loop.stop() + self._event_loop.close() diff --git a/src/distilabel/models/image_generation/huggingface/__init__.py b/src/distilabel/models/image_generation/huggingface/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/models/image_generation/huggingface/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/models/image_generation/huggingface/inference_endpoints.py b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py new file mode 100644 index 0000000000..2403fbf018 --- /dev/null +++ b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py @@ -0,0 +1,106 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import validate_call + +from distilabel.models.base_clients.inference_endpoints import ( + InferenceEndpointsBaseClient, +) +from distilabel.models.image_generation.base import AsyncImageGenerationModel +from distilabel.models.image_generation.utils import image_to_str + +if TYPE_CHECKING: + from PIL.Image import Image + + +class InferenceEndpointsImageGeneration( # type: ignore + InferenceEndpointsBaseClient, AsyncImageGenerationModel +): + """Inference Endpoint image generation implementation running the async API client. + + Attributes: + model_id: the model ID to use for the ImageGenerationModel as available in the Hugging Face Hub, which + will be used to resolve the base URL for the serverless Inference Endpoints API requests. + Defaults to `None`. + endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`. + endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`. + base_url: the base URL to use for the Inference Endpoints API requests. + api_key: the API key to authenticate the requests to the Inference Endpoints API. + + Icon: + `:hugging:` + + Examples: + Generate images from text prompts: + + ```python + from distilabel.models.image_generation import InferenceEndpointsImageGeneration + + igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell", api_key="api.key") + igm.load() + + output = igm.generate_outputs( + inputs=["a white siamese cat"], + ) + # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}] + ``` + """ + + def load(self) -> None: + # Sets the logger and calls the load method of the BaseClient + AsyncImageGenerationModel.load(self) + InferenceEndpointsBaseClient.load(self) + + @validate_call + async def agenerate( # type: ignore + self, + input: str, + negative_prompt: Optional[str] = None, + height: Optional[float] = None, + width: Optional[float] = None, + num_inference_steps: Optional[float] = None, + guidance_scale: Optional[float] = None, + num_generations: int = 1, + ) -> list[dict[str, Any]]: + """Generates images from text prompts using `huggingface_hub.AsyncInferenceClient.text_to_image`. + + Args: + input: Prompt to generate an image from. + negative_prompt: An optional negative prompt for the image generation. Defaults to None. + height: The height in pixels of the image to generate. + width: The width in pixels of the image to generate. + num_inference_steps: The number of denoising steps. More denoising steps usually lead + to a higher quality image at the expense of slower inference. + guidance_scale: Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + num_generations: The number of images to generate. Defaults to `1`. + It's here to ensure the validation succeeds, but it won't have effect. + + Returns: + A list with a dictionary containing a list with the image as a base64 string. + """ + + image: "Image" = await self._aclient.text_to_image( # type: ignore + input, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ) + img_str = image_to_str(image, image_format="JPEG") + + return [{"images": [img_str]}] diff --git a/src/distilabel/models/image_generation/openai.py b/src/distilabel/models/image_generation/openai.py new file mode 100644 index 0000000000..6315eb8046 --- /dev/null +++ b/src/distilabel/models/image_generation/openai.py @@ -0,0 +1,129 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from typing import TYPE_CHECKING, Any, Literal, Optional + +import requests +from pydantic import validate_call + +from distilabel.models.base_clients.openai import OpenAIBaseClient +from distilabel.models.image_generation.base import AsyncImageGenerationModel + +if TYPE_CHECKING: + from openai.types import ImagesResponse + + +class OpenAIImageGeneration(OpenAIBaseClient, AsyncImageGenerationModel): + """OpenAI image generation implementation running the async API client. + + Attributes: + model: the model name to use for the ImageGenerationModel e.g. "dall-e-3", etc. + Supported models can be found [here](https://platform.openai.com/docs/guides/images). + base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which + means that the value set for the environment variable `OPENAI_BASE_URL` will + be used, or "https://api.openai.com/v1" if not set. + api_key: the API key to authenticate the requests to the OpenAI API. Defaults to + `None` which means that the value set for the environment variable `OPENAI_API_KEY` + will be used, or `None` if not set. + max_retries: the maximum number of times to retry the request to the API before + failing. Defaults to `6`. + timeout: the maximum time in seconds to wait for a response from the API. Defaults + to `120`. + + Icon: + `:simple-openai:` + + Examples: + Generate images from text prompts: + + ```python + from distilabel.models.image_generation import OpenAIImageGeneration + + igm = OpenAIImageGeneration(model="dall-e-3", api_key="api.key") + + igm.load() + + output = igm.generate_outputs( + inputs=["a white siamese cat"], + size="1024x1024", + quality="standard", + style="natural", + ) + # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}] + ``` + """ + + def load(self) -> None: + # Sets the logger and calls the load method of the BaseClient + AsyncImageGenerationModel.load(self) + OpenAIBaseClient.load(self) + + @validate_call + async def agenerate( # type: ignore + self, + input: str, + num_generations: int = 1, + quality: Optional[Literal["standard", "hd"]] = "standard", + response_format: Optional[Literal["url", "b64_json"]] = "url", + size: Optional[ + Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + ] = None, + style: Optional[Literal["vivid", "natural"]] = None, + ) -> list[dict[str, Any]]: + """Generates `num_generations` images for the given input using the OpenAI async + client. The images are base64 string representations. + + Args: + input: A text description of the desired image(s). The maximum length is 1000 + characters for `dall-e-2` and 4000 characters for `dall-e-3`. + num_generations: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only + `n=1` is supported. + quality: The quality of the image that will be generated. `hd` creates images with finer + details and greater consistency across the image. This param is only supported + for `dall-e-3`. + response_format: The format in which the generated images are returned. Must be one of `url` or + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. + size: The size of the generated images. Must be one of `256x256`, `512x512`, or + `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or + `1024x1792` for `dall-e-3` models. + style: The style of the generated images. Must be one of `vivid` or `natural`. Vivid + causes the model to lean towards generating hyper-real and dramatic images. + Natural causes the model to produce more natural, less hyper-real looking + images. This param is only supported for `dall-e-3`. + + Returns: + A list with a dictionary with the list of images generated. + """ + images_response: "ImagesResponse" = await self._aclient.images.generate( + model=self.model_name, + prompt=input, + n=num_generations, + quality=quality, + response_format=response_format, + size=size, + style=style, + ) + images = [] + for image in images_response.data: + if response_format == "url": + image_data = requests.get( + image.url + ).content # TODO: Keep a requests/httpx session instead + image_str = base64.b64encode(image_data).decode() + images.append(image_str) + elif response_format == "b64_json": + images.append(image.b64_json) + return [{"images": images}] diff --git a/src/distilabel/models/image_generation/utils.py b/src/distilabel/models/image_generation/utils.py new file mode 100644 index 0000000000..e5f08ca343 --- /dev/null +++ b/src/distilabel/models/image_generation/utils.py @@ -0,0 +1,31 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import io + +from PIL import Image + + +def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str: + """Converts a PIL Image to a base64 encoded string.""" + buffered = io.BytesIO() + image.save(buffered, format=image_format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +def image_from_str(image_str: str) -> Image.Image: + """Converts a base64 encoded string to a PIL Image.""" + image_bytes = base64.b64decode(image_str) + return Image.open(io.BytesIO(image_bytes)) diff --git a/src/distilabel/models/llms/__init__.py b/src/distilabel/models/llms/__init__.py index 0b0f3a7a9c..3469c1e2bc 100644 --- a/src/distilabel/models/llms/__init__.py +++ b/src/distilabel/models/llms/__init__.py @@ -27,10 +27,10 @@ from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM from distilabel.models.llms.together import TogetherLLM -from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.typing import GenerateOutput, HiddenState __all__ = [ "LLM", diff --git a/src/distilabel/models/llms/anthropic.py b/src/distilabel/models/llms/anthropic.py index ab364bad58..3650671118 100644 --- a/src/distilabel/models/llms/anthropic.py +++ b/src/distilabel/models/llms/anthropic.py @@ -29,20 +29,19 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) if TYPE_CHECKING: - from typing import BaseModel - from anthropic import AsyncAnthropic from anthropic.types import Message + from pydantic import BaseModel - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" diff --git a/src/distilabel/models/llms/base.py b/src/distilabel/models/llms/base.py index df274df402..912839b27b 100644 --- a/src/distilabel/models/llms/base.py +++ b/src/distilabel/models/llms/base.py @@ -31,7 +31,7 @@ from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.mixins.runtime_parameters import ( RuntimeParameter, - RuntimeParametersMixin, + RuntimeParametersModelMixin, ) from distilabel.utils.docstring import parse_google_docstring from distilabel.utils.notebook import in_notebook @@ -40,16 +40,13 @@ if TYPE_CHECKING: from logging import Logger - from distilabel.mixins.runtime_parameters import ( - RuntimeParameterInfo, - RuntimeParametersNames, - ) - from distilabel.models.llms.typing import GenerateOutput, HiddenState - from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType - from distilabel.steps.tasks.typing import ( + from distilabel.typing import ( FormattedInput, + GenerateOutput, + HiddenState, InstructorStructuredOutputType, StandardInput, + StructuredOutputType, ) from distilabel.utils.docstring import Docstring @@ -59,7 +56,7 @@ nest_asyncio.apply() -class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC): +class LLM(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC): """Base class for `LLM`s to be used in `distilabel` framework. To implement an `LLM` subclass, you need to subclass this class and implement: @@ -241,81 +238,6 @@ def _offline_batch_generate_polling( jobs_ids=self.jobs_ids # type: ignore ) from e - @property - def generate_parameters(self) -> List["inspect.Parameter"]: - """Returns the parameters of the `generate` method. - - Returns: - A list containing the parameters of the `generate` method. - """ - return list(inspect.signature(self.generate).parameters.values()) - - @property - def runtime_parameters_names(self) -> "RuntimeParametersNames": - """Returns the runtime parameters of the `LLM`, which are combination of the - attributes of the `LLM` type hinted with `RuntimeParameter` and the parameters - of the `generate` method that are not `input` and `num_generations`. - - Returns: - A dictionary with the name of the runtime parameters as keys and a boolean - indicating if the parameter is optional or not. - """ - runtime_parameters = super().runtime_parameters_names - runtime_parameters["generation_kwargs"] = {} - - # runtime parameters from the `generate` method - for param in self.generate_parameters: - if param.name in ["input", "inputs", "num_generations"]: - continue - is_optional = param.default != inspect.Parameter.empty - runtime_parameters["generation_kwargs"][param.name] = is_optional - - return runtime_parameters - - def get_runtime_parameters_info(self) -> List["RuntimeParameterInfo"]: - """Gets the information of the runtime parameters of the `LLM` such as the name - and the description. This function is meant to include the information of the runtime - parameters in the serialized data of the `LLM`. - - Returns: - A list containing the information for each runtime parameter of the `LLM`. - """ - runtime_parameters_info = super().get_runtime_parameters_info() - - generation_kwargs_info = next( - ( - runtime_parameter_info - for runtime_parameter_info in runtime_parameters_info - if runtime_parameter_info["name"] == "generation_kwargs" - ), - None, - ) - - # If `generation_kwargs` attribute is present, we need to include the `generate` - # method arguments as the information for this attribute. - if generation_kwargs_info: - generate_docstring_args = self.generate_parsed_docstring["args"] - - generation_kwargs_info["keys"] = [] - for key, value in generation_kwargs_info["optional"].items(): - info = {"name": key, "optional": value} - if description := generate_docstring_args.get(key): - info["description"] = description - generation_kwargs_info["keys"].append(info) - - generation_kwargs_info.pop("optional") - - return runtime_parameters_info - - @cached_property - def generate_parsed_docstring(self) -> "Docstring": - """Returns the parsed docstring of the `generate` method. - - Returns: - The parsed docstring of the `generate` method. - """ - return parse_google_docstring(self.generate) - def get_last_hidden_states( self, inputs: List["StandardInput"] ) -> List["HiddenState"]: diff --git a/src/distilabel/models/llms/cohere.py b/src/distilabel/models/llms/cohere.py index 8b081a762e..0c9a342aea 100644 --- a/src/distilabel/models/llms/cohere.py +++ b/src/distilabel/models/llms/cohere.py @@ -28,10 +28,10 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) @@ -40,7 +40,7 @@ from pydantic import BaseModel from tokenizers import Tokenizer - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" diff --git a/src/distilabel/models/llms/groq.py b/src/distilabel/models/llms/groq.py index fec511bbee..4334d72bdd 100644 --- a/src/distilabel/models/llms/groq.py +++ b/src/distilabel/models/llms/groq.py @@ -18,11 +18,11 @@ from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output from distilabel.steps.base import RuntimeParameter -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) @@ -30,7 +30,7 @@ from groq import AsyncGroq from groq.types.chat.chat_completion import ChatCompletion - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL" diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index 6f97c5814a..8956529999 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import random import sys import warnings @@ -30,29 +29,22 @@ from pydantic import ( Field, PositiveInt, - PrivateAttr, - SecretStr, ValidationError, model_validator, validate_call, ) from pydantic._internal._model_construction import ModelMetaclass -from typing_extensions import Annotated, override +from typing_extensions import Annotated -from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.base_clients.inference_endpoints import ( + InferenceEndpointsBaseClient, +) from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput, Logprob from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import ( - FormattedInput, - StandardInput, - StructuredOutputType, -) -from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR, get_hf_token +from distilabel.typing import FormattedInput, GenerateOutput, Logprob, StandardInput if TYPE_CHECKING: - from huggingface_hub import AsyncInferenceClient from huggingface_hub.inference._generated.types.chat_completion import ( ChatCompletionOutput, ChatCompletionOutputComplete, @@ -60,12 +52,13 @@ from huggingface_hub.inference._generated.types.text_generation import ( TextGenerationOutput, ) - from transformers import PreTrainedTokenizer - from distilabel.models.llms.typing import Logprob + from distilabel.typing import Logprob -class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin): +class InferenceEndpointsLLM( + InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin +): """InferenceEndpoints LLM implementation running the async API client. This LLM will internally use `huggingface_hub.AsyncInferenceClient`. @@ -164,39 +157,11 @@ class User(BaseModel): ``` """ - model_id: Optional[str] = None - - endpoint_name: Optional[RuntimeParameter[str]] = Field( - default=None, - description="The name of the Inference Endpoint to use for the LLM.", - ) - endpoint_namespace: Optional[RuntimeParameter[str]] = Field( - default=None, - description="The namespace of the Inference Endpoint to use for the LLM.", - ) - base_url: Optional[RuntimeParameter[str]] = Field( - default=None, - description="The base URL to use for the Inference Endpoints API requests.", - ) - api_key: Optional[RuntimeParameter[SecretStr]] = Field( - default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR), - description="The API key to authenticate the requests to the Inference Endpoints API.", - ) - - tokenizer_id: Optional[str] = None - model_display_name: Optional[str] = None - - structured_output: Optional[RuntimeParameter[StructuredOutputType]] = Field( - default=None, - description="The structured output format to use across all the generations.", - ) - - _num_generations_param_supported = False - - _model_name: Optional[str] = PrivateAttr(default=None) - _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None) - _api_key_env_var: str = PrivateAttr(HF_TOKEN_ENV_VAR) - _aclient: Optional["AsyncInferenceClient"] = PrivateAttr(...) + def load(self) -> None: + # Sets the logger and calls the load method of the BaseClient + self._num_generations_param_supported = False + AsyncLLM.load(self) + InferenceEndpointsBaseClient.load(self) @model_validator(mode="after") # type: ignore def only_one_of_model_id_endpoint_name_or_base_url_provided( @@ -242,92 +207,6 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided( f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}." ) - def load(self) -> None: # noqa: C901 - """Loads the `AsyncInferenceClient` client to connect to the Hugging Face Inference - Endpoint. - - Raises: - ImportError: if the `huggingface-hub` Python client is not installed. - ValueError: if the model is not currently deployed or is not running the TGI framework. - ImportError: if the `transformers` Python client is not installed. - """ - super().load() - - try: - from huggingface_hub import ( - AsyncInferenceClient, - InferenceClient, - get_inference_endpoint, - ) - except ImportError as ie: - raise ImportError( - "Hugging Face Hub Python client is not installed. Please install it using" - " `pip install 'distilabel[hf-inference-endpoints]'`." - ) from ie - - if self.api_key is None: - self.api_key = SecretStr(get_hf_token(self.__class__.__name__, "api_key")) - - if self.model_id is not None: - client = InferenceClient( - model=self.model_id, token=self.api_key.get_secret_value() - ) - status = client.get_model_status() - - if ( - status.state not in {"Loadable", "Loaded"} - and status.framework != "text-generation-inference" - ): - raise ValueError( - f"Model {self.model_id} is not currently deployed or is not running the TGI framework" - ) - - self.base_url = client._resolve_url( - model=self.model_id, task="text-generation" - ) - - if self.endpoint_name is not None: - client = get_inference_endpoint( - name=self.endpoint_name, - namespace=self.endpoint_namespace, - token=self.api_key.get_secret_value(), - ) - if client.status in ["paused", "scaledToZero"]: - client.resume().wait(timeout=300) - elif client.status == "initializing": - client.wait(timeout=300) - - self.base_url = client.url - self._model_name = client.repository - - self._aclient = AsyncInferenceClient( - base_url=self.base_url, - token=self.api_key.get_secret_value(), - ) - - if self.tokenizer_id: - try: - from transformers import AutoTokenizer - except ImportError as ie: - raise ImportError( - "Transformers Python client is not installed. Please install it using" - " `pip install 'distilabel[hf-inference-endpoints]'`." - ) from ie - - self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) - - @property - @override - def model_name(self) -> Union[str, None]: # type: ignore - """Returns the model name used for the LLM.""" - return ( - self.model_display_name - or self._model_name - or self.model_id - or self.endpoint_name - or self.base_url - ) - def prepare_input(self, input: "StandardInput") -> str: """Prepares the input (applying the chat template and tokenization) for the provided input. @@ -588,6 +467,7 @@ async def agenerate( # type: ignore top_k: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + num_generations: int = 1, ) -> GenerateOutput: """Generates completions for the given input using the async client. This method uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`. @@ -656,6 +536,8 @@ async def agenerate( # type: ignore watermark: whether to add the watermark to the generated text. This argument is exclusive of the `text_generation` method and will be only used if `tokenizer_id` is not `None`. Defaults to `None`. + num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure + the validation succeds. Returns: A list of lists of strings containing the generated responses for each input. diff --git a/src/distilabel/models/llms/huggingface/transformers.py b/src/distilabel/models/llms/huggingface/transformers.py index 19dc32dd2d..aef8c40e16 100644 --- a/src/distilabel/models/llms/huggingface/transformers.py +++ b/src/distilabel/models/llms/huggingface/transformers.py @@ -19,14 +19,17 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.structured_outputs.outlines import ( _is_outlines_version_below_0_1_0, ) -from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput +from distilabel.typing import ( + GenerateOutput, + OutlinesStructuredOutputType, + StandardInput, +) from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR if TYPE_CHECKING: @@ -34,7 +37,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer - from distilabel.models.llms.typing import HiddenState + from distilabel.typing import HiddenState class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): diff --git a/src/distilabel/models/llms/litellm.py b/src/distilabel/models/llms/litellm.py index 9b52ad8c71..29c910622b 100644 --- a/src/distilabel/models/llms/litellm.py +++ b/src/distilabel/models/llms/litellm.py @@ -20,9 +20,12 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType +from distilabel.typing import ( + FormattedInput, + GenerateOutput, + InstructorStructuredOutputType, +) if TYPE_CHECKING: from litellm import Choices diff --git a/src/distilabel/models/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py index a754f6b84f..87f5eb358f 100644 --- a/src/distilabel/models/llms/llamacpp.py +++ b/src/distilabel/models/llms/llamacpp.py @@ -18,10 +18,13 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType +from distilabel.typing import ( + FormattedInput, + GenerateOutput, + OutlinesStructuredOutputType, +) if TYPE_CHECKING: from llama_cpp import ( @@ -31,7 +34,7 @@ LogitsProcessorList, ) - from distilabel.steps.tasks.typing import FormattedInput, StandardInput + from distilabel.typing import FormattedInput, StandardInput class LlamaCppLLM(LLM, MagpieChatTemplateMixin): diff --git a/src/distilabel/models/llms/mistral.py b/src/distilabel/models/llms/mistral.py index 4147edaf03..e6047d4be5 100644 --- a/src/distilabel/models/llms/mistral.py +++ b/src/distilabel/models/llms/mistral.py @@ -19,10 +19,10 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import ( +from distilabel.typing import ( FormattedInput, + GenerateOutput, InstructorStructuredOutputType, ) @@ -30,7 +30,7 @@ from mistralai import Mistral from mistralai.models.chatcompletionresponse import ChatCompletionResponse - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics _MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY" diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py index 1f8c9b8c65..ffdcf37526 100644 --- a/src/distilabel/models/llms/mlx.py +++ b/src/distilabel/models/llms/mlx.py @@ -27,12 +27,9 @@ ) from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import ( - StandardInput, -) +from distilabel.typing import GenerateOutput, StandardInput if TYPE_CHECKING: import mlx.nn as nn diff --git a/src/distilabel/models/llms/moa.py b/src/distilabel/models/llms/moa.py index 11af619ad4..ea859e95da 100644 --- a/src/distilabel/models/llms/moa.py +++ b/src/distilabel/models/llms/moa.py @@ -19,12 +19,11 @@ from pydantic import Field from distilabel.models.llms.base import LLM, AsyncLLM -from distilabel.steps.tasks.typing import StandardInput +from distilabel.typing import StandardInput if TYPE_CHECKING: from distilabel.mixins.runtime_parameters import RuntimeParametersNames - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput # Mixture-of-Agents system prompt from the paper with the addition instructing the LLM # to not mention that it used responses from previous models to avoid having texts like diff --git a/src/distilabel/models/llms/ollama.py b/src/distilabel/models/llms/ollama.py index a930399114..4cb5aa0428 100644 --- a/src/distilabel/models/llms/ollama.py +++ b/src/distilabel/models/llms/ollama.py @@ -19,17 +19,19 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput +from distilabel.typing import ( + GenerateOutput, + InstructorStructuredOutputType, + StandardInput, +) if TYPE_CHECKING: from ollama import AsyncClient from ollama._types import ChatResponse, GenerateResponse - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import StandardInput + from distilabel.typing import LLMStatistics, StandardInput # Copied from `ollama._types.Options` diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index 91f24a3336..66dbfcff17 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -13,36 +13,32 @@ # limitations under the License. import io -import os from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import orjson -from pydantic import Field, PositiveInt, PrivateAttr, SecretStr, validate_call +from pydantic import PositiveInt, validate_call from distilabel import envs from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException -from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.base_clients.openai import OpenAIBaseClient from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType +from distilabel.typing import FormattedInput, GenerateOutput if TYPE_CHECKING: - from openai import AsyncOpenAI, OpenAI from openai.types import Batch as OpenAIBatch from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion from openai.types.chat.chat_completion import Choice as OpenAIChoice from openai.types.completion import Completion as OpenAICompletion - from distilabel.models.llms.typing import LLMStatistics, Logprob + from distilabel.typing import LLMStatistics, Logprob -_OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" _OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB -class OpenAILLM(AsyncLLM): +class OpenAILLM(OpenAIBaseClient, AsyncLLM): """OpenAI LLM implementation running the async API client. Attributes: @@ -143,99 +139,9 @@ class User(BaseModel): ``` """ - model: str - base_url: Optional[RuntimeParameter[str]] = Field( - default_factory=lambda: os.getenv( - "OPENAI_BASE_URL", "https://api.openai.com/v1" - ), - description="The base URL to use for the OpenAI API requests.", - ) - api_key: Optional[RuntimeParameter[SecretStr]] = Field( - default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME), - description="The API key to authenticate the requests to the OpenAI API.", - ) - default_headers: Optional[RuntimeParameter[Dict[str, str]]] = Field( - default=None, - description="The default headers to use for the OpenAI API requests.", - ) - max_retries: RuntimeParameter[int] = Field( - default=6, - description="The maximum number of times to retry the request to the API before" - " failing.", - ) - timeout: RuntimeParameter[int] = Field( - default=120, - description="The maximum time in seconds to wait for a response from the API.", - ) - structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( - Field( - default=None, - description="The structured output format to use across all the generations.", - ) - ) - - _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) - _client: "OpenAI" = PrivateAttr(None) - _aclient: "AsyncOpenAI" = PrivateAttr(None) - def load(self) -> None: - """Loads the `AsyncOpenAI` client to benefit from async requests.""" - super().load() - - try: - from openai import AsyncOpenAI, OpenAI - except ImportError as ie: - raise ImportError( - "OpenAI Python client is not installed. Please install it using" - " `pip install 'distilabel[openai]'`." - ) from ie - - if self.api_key is None: - raise ValueError( - f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" - f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." - ) - - self._client = OpenAI( - base_url=self.base_url, - api_key=self.api_key.get_secret_value(), - max_retries=self.max_retries, # type: ignore - timeout=self.timeout, - default_headers=self.default_headers, - ) - - self._aclient = AsyncOpenAI( - base_url=self.base_url, - api_key=self.api_key.get_secret_value(), - max_retries=self.max_retries, # type: ignore - timeout=self.timeout, - default_headers=self.default_headers, - ) - - if self.structured_output: - result = self._prepare_structured_output( - structured_output=self.structured_output, - client=self._aclient, - framework="openai", - ) - self._aclient = result.get("client") # type: ignore - if structured_output := result.get("structured_output"): - self.structured_output = structured_output - - def unload(self) -> None: - """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled - in case an exception is raised and has to be handled in the main process""" - - self._client = None # type: ignore - self._aclient = None # type: ignore - self.default_headers = None - self.structured_output = None - super().unload() - - @property - def model_name(self) -> str: - """Returns the model name used for the LLM.""" - return self.model + AsyncLLM.load(self) + OpenAIBaseClient.load(self) @validate_call async def agenerate( # type: ignore diff --git a/src/distilabel/models/llms/typing.py b/src/distilabel/models/llms/typing.py deleted file mode 100644 index cfa4ec382f..0000000000 --- a/src/distilabel/models/llms/typing.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, TypeVar, Union - -from typing_extensions import NotRequired - -LLMOutput = List[Union[str, None]] - - -class Logprob(TypedDict): - token: str - logprob: float - - -LLMLogprobs = List[List[List[Logprob]]] -"""A type alias representing the probability distributions output by an `LLM`. - -Structure: - - Outermost list: contains multiple generation choices when sampling (`n` sequences) - - Middle list: represents each position in the generated sequence - - Innermost list: contains the log probabilities for each token in the vocabulary at that position -""" - - -class TokenCount(TypedDict): - input_tokens: List[int] - output_tokens: List[int] - - -LLMStatistics = Union[TokenCount, Dict[str, Any]] -"""Initially the LLMStatistics will contain the token count, but can have more variables. -They can be added once we have them defined for every LLM. -""" - - -class GenerateOutput(TypedDict): - generations: LLMOutput - statistics: LLMStatistics - logprobs: NotRequired[LLMLogprobs] - - -if TYPE_CHECKING: - from numpy import floating - from numpy.typing import NDArray - - GenericFloat = TypeVar("GenericFloat", bound=floating[Any]) - - HiddenState = NDArray[GenericFloat] -else: - HiddenState = Any diff --git a/src/distilabel/models/llms/utils.py b/src/distilabel/models/llms/utils.py index ef97e53e1f..45f9088ca5 100644 --- a/src/distilabel/models/llms/utils.py +++ b/src/distilabel/models/llms/utils.py @@ -14,10 +14,10 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput, LLMLogprobs, LLMOutput + from distilabel.typing import GenerateOutput, LLMLogprobs, LLMOutput def compute_tokens( diff --git a/src/distilabel/models/llms/vertexai.py b/src/distilabel/models/llms/vertexai.py index 7c1b3e6bb4..b241e4d8d8 100644 --- a/src/distilabel/models/llms/vertexai.py +++ b/src/distilabel/models/llms/vertexai.py @@ -18,14 +18,13 @@ from typing_extensions import TypedDict from distilabel.models.llms.base import AsyncLLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.models.llms.utils import prepare_output -from distilabel.steps.tasks.typing import StandardInput +from distilabel.typing import GenerateOutput, StandardInput if TYPE_CHECKING: from vertexai.generative_models import Content, GenerationResponse, GenerativeModel - from distilabel.models.llms.typing import LLMStatistics + from distilabel.typing import LLMStatistics class VertexChatItem(TypedDict): diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index ceab8e3e30..6075c4f54e 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -33,11 +33,15 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import LLM from distilabel.models.llms.openai import OpenAILLM -from distilabel.models.llms.typing import GenerateOutput, Logprob from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType +from distilabel.typing import ( + FormattedInput, + GenerateOutput, + Logprob, + OutlinesStructuredOutputType, +) if TYPE_CHECKING: from openai import OpenAI # noqa @@ -45,10 +49,13 @@ from vllm import LLM as _vLLM from vllm.outputs import RequestOutput, CompletionOutput - from distilabel.steps.tasks.typing import StandardInput - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import StructuredInput - from distilabel.models.llms.typing import LLMLogprobs, LLMOutput + from distilabel.typing import ( + StandardInput, + StructuredInput, + LLMStatistics, + LLMLogprobs, + LLMOutput, + ) LogitsProcessorFn = Union[ diff --git a/src/distilabel/models/mixins/magpie.py b/src/distilabel/models/mixins/magpie.py index 8efa3add58..8edc1d92e8 100644 --- a/src/distilabel/models/mixins/magpie.py +++ b/src/distilabel/models/mixins/magpie.py @@ -18,7 +18,7 @@ from typing_extensions import Self if TYPE_CHECKING: - from distilabel.steps.tasks.typing import StandardInput + from distilabel.typing import StandardInput MagpieAvailablePreQueryTemplates = Literal["llama3", "qwen2"] """The available predefined pre-query templates.""" diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 168599f782..2a0c89abd5 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -68,13 +68,13 @@ from distilabel.distiset import Distiset from distilabel.pipeline.routing_batch_function import RoutingBatchFunction - from distilabel.pipeline.typing import ( + from distilabel.steps.base import Step + from distilabel.typing import ( InputDataset, LoadGroups, PipelineRuntimeParametersInfo, StepLoadStatus, ) - from distilabel.steps.base import Step class _CacheLocation(TypedDict): """Dictionary to store the filenames and directories of a cached pipeline. diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index e8716f1ade..29ab8131cb 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -44,8 +44,8 @@ from queue import Queue from distilabel.distiset import Distiset - from distilabel.pipeline.typing import InputDataset, LoadGroups from distilabel.steps.base import _Step + from distilabel.typing import InputDataset, LoadGroups _SUBPROCESS_EXCEPTION: Union[Exception, None] = None diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py index c2e85afd86..2d1158aedd 100644 --- a/src/distilabel/pipeline/ray.py +++ b/src/distilabel/pipeline/ray.py @@ -32,8 +32,8 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from distilabel.distiset import Distiset - from distilabel.pipeline.typing import InputDataset, LoadGroups from distilabel.steps.base import _Step + from distilabel.typing import InputDataset, LoadGroups class RayPipeline(BasePipeline): diff --git a/src/distilabel/pipeline/routing_batch_function.py b/src/distilabel/pipeline/routing_batch_function.py index 3f0aaf9ff4..31889acc90 100644 --- a/src/distilabel/pipeline/routing_batch_function.py +++ b/src/distilabel/pipeline/routing_batch_function.py @@ -28,8 +28,8 @@ if TYPE_CHECKING: from distilabel.pipeline.batch import _Batch - from distilabel.pipeline.typing import DownstreamConnectableSteps from distilabel.steps.base import _Step + from distilabel.typing import DownstreamConnectableSteps RoutingBatchFunc = Callable[[List[str]], List[str]] """Type alias for a routing batch function. It takes a list of all the downstream steps and diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py index 1caa3a3e38..52937107f3 100644 --- a/src/distilabel/pipeline/step_wrapper.py +++ b/src/distilabel/pipeline/step_wrapper.py @@ -21,8 +21,8 @@ from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.pipeline.batch import _Batch -from distilabel.pipeline.typing import StepLoadStatus from distilabel.steps.base import GeneratorStep, Step, _Step +from distilabel.typing import StepLoadStatus class _StepWrapper: diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 58875bbec3..19d90f9a33 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -55,7 +55,7 @@ from distilabel.steps.globals.huggingface import PushToHub from distilabel.steps.reward_model import RewardModelScore from distilabel.steps.truncate import TruncateTextColumn -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput __all__ = [ "DBSCAN", diff --git a/src/distilabel/steps/argilla/base.py b/src/distilabel/steps/argilla/base.py index 06db05e05b..1742ac675d 100644 --- a/src/distilabel/steps/argilla/base.py +++ b/src/distilabel/steps/argilla/base.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from argilla import Argilla, Dataset - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput _ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL" diff --git a/src/distilabel/steps/argilla/preference.py b/src/distilabel/steps/argilla/preference.py index 210cca208f..22cb6d02da 100644 --- a/src/distilabel/steps/argilla/preference.py +++ b/src/distilabel/steps/argilla/preference.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from argilla import RatingQuestion, Suggestion, TextField, TextQuestion - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class PreferenceToArgilla(ArgillaBase): diff --git a/src/distilabel/steps/argilla/text_generation.py b/src/distilabel/steps/argilla/text_generation.py index ad5323b0bc..ed590dec57 100644 --- a/src/distilabel/steps/argilla/text_generation.py +++ b/src/distilabel/steps/argilla/text_generation.py @@ -28,7 +28,7 @@ from distilabel.steps.base import StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class TextGenerationToArgilla(ArgillaBase): diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index 128ccefc75..88bed374bc 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -48,12 +48,14 @@ from distilabel.pipeline.base import BasePipeline from distilabel.pipeline.routing_batch_function import RoutingBatchFunction - from distilabel.pipeline.typing import ( + from distilabel.typing import ( DownstreamConnectable, DownstreamConnectableSteps, + GeneratorStepOutput, + StepColumns, + StepOutput, UpstreamConnectableSteps, ) - from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput DEFAULT_INPUT_BATCH_SIZE = 50 diff --git a/src/distilabel/steps/clustering/dbscan.py b/src/distilabel/steps/clustering/dbscan.py index 2124d787c1..238d9338ed 100644 --- a/src/distilabel/steps/clustering/dbscan.py +++ b/src/distilabel/steps/clustering/dbscan.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from sklearn.cluster import DBSCAN as _DBSCAN - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class DBSCAN(GlobalStep): diff --git a/src/distilabel/steps/clustering/text_clustering.py b/src/distilabel/steps/clustering/text_clustering.py index 925ffab229..06358a6189 100644 --- a/src/distilabel/steps/clustering/text_clustering.py +++ b/src/distilabel/steps/clustering/text_clustering.py @@ -28,7 +28,7 @@ from distilabel.utils.itertools import batched if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class TextClustering(TextClassification, GlobalTask): diff --git a/src/distilabel/steps/clustering/umap.py b/src/distilabel/steps/clustering/umap.py index 9bf71c68e3..2688088c6f 100644 --- a/src/distilabel/steps/clustering/umap.py +++ b/src/distilabel/steps/clustering/umap.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from umap import UMAP as _UMAP - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class UMAP(GlobalStep): diff --git a/src/distilabel/steps/columns/combine.py b/src/distilabel/steps/columns/combine.py index 784beffe47..cd08303ac8 100644 --- a/src/distilabel/steps/columns/combine.py +++ b/src/distilabel/steps/columns/combine.py @@ -19,7 +19,7 @@ from distilabel.steps.columns.utils import merge_distilabel_metadata if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class CombineOutputs(Step): diff --git a/src/distilabel/steps/columns/expand.py b/src/distilabel/steps/columns/expand.py index 989924cf8a..aae1c336e7 100644 --- a/src/distilabel/steps/columns/expand.py +++ b/src/distilabel/steps/columns/expand.py @@ -22,7 +22,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class ExpandColumns(Step): diff --git a/src/distilabel/steps/columns/group.py b/src/distilabel/steps/columns/group.py index 4cc77b50f0..ed9ee7a2df 100644 --- a/src/distilabel/steps/columns/group.py +++ b/src/distilabel/steps/columns/group.py @@ -21,7 +21,7 @@ from distilabel.steps.columns.utils import group_columns if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class GroupColumns(Step): diff --git a/src/distilabel/steps/columns/keep.py b/src/distilabel/steps/columns/keep.py index c12dfdd61d..0835cd834c 100644 --- a/src/distilabel/steps/columns/keep.py +++ b/src/distilabel/steps/columns/keep.py @@ -19,7 +19,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class KeepColumns(Step): diff --git a/src/distilabel/steps/columns/merge.py b/src/distilabel/steps/columns/merge.py index 54ab3e3c75..0f2bb66f2f 100644 --- a/src/distilabel/steps/columns/merge.py +++ b/src/distilabel/steps/columns/merge.py @@ -20,7 +20,7 @@ from distilabel.steps.columns.utils import merge_columns if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class MergeColumns(Step): diff --git a/src/distilabel/steps/decorator.py b/src/distilabel/steps/decorator.py index 3e84df66f2..9bcc6f2dcb 100644 --- a/src/distilabel/steps/decorator.py +++ b/src/distilabel/steps/decorator.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from distilabel.steps.base import _Step - from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput + from distilabel.typing import GeneratorStepOutput, StepColumns, StepOutput _STEP_MAPPING = { "normal": Step, diff --git a/src/distilabel/steps/embeddings/embedding_generation.py b/src/distilabel/steps/embeddings/embedding_generation.py index 0aeed03102..5e2a839f69 100644 --- a/src/distilabel/steps/embeddings/embedding_generation.py +++ b/src/distilabel/steps/embeddings/embedding_generation.py @@ -18,7 +18,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class EmbeddingGeneration(Step): diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py index a962ca3b14..ab33be2a4d 100644 --- a/src/distilabel/steps/embeddings/nearest_neighbour.py +++ b/src/distilabel/steps/embeddings/nearest_neighbour.py @@ -23,7 +23,7 @@ from distilabel.steps import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class FaissNearestNeighbour(GlobalStep): diff --git a/src/distilabel/steps/filtering/embedding.py b/src/distilabel/steps/filtering/embedding.py index cb1e710374..4572bca5cf 100644 --- a/src/distilabel/steps/filtering/embedding.py +++ b/src/distilabel/steps/filtering/embedding.py @@ -23,7 +23,7 @@ from distilabel.steps.base import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class EmbeddingDedup(GlobalStep): diff --git a/src/distilabel/steps/filtering/minhash.py b/src/distilabel/steps/filtering/minhash.py index 7e86d30543..3d89b11e50 100644 --- a/src/distilabel/steps/filtering/minhash.py +++ b/src/distilabel/steps/filtering/minhash.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from datasketch import MinHash, MinHashLSH - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput # Copied from: https://github.com/huggingface/datatrove/blob/main/src/datatrove/utils/text.py#L89C1-L95C65 diff --git a/src/distilabel/steps/formatting/conversation.py b/src/distilabel/steps/formatting/conversation.py index 29381521bd..0101aec196 100644 --- a/src/distilabel/steps/formatting/conversation.py +++ b/src/distilabel/steps/formatting/conversation.py @@ -17,7 +17,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class ConversationTemplate(Step): diff --git a/src/distilabel/steps/formatting/dpo.py b/src/distilabel/steps/formatting/dpo.py index 72253eb194..528abbbb87 100644 --- a/src/distilabel/steps/formatting/dpo.py +++ b/src/distilabel/steps/formatting/dpo.py @@ -18,7 +18,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class FormatTextGenerationDPO(Step): diff --git a/src/distilabel/steps/formatting/sft.py b/src/distilabel/steps/formatting/sft.py index 2793b212e6..6122ead0d1 100644 --- a/src/distilabel/steps/formatting/sft.py +++ b/src/distilabel/steps/formatting/sft.py @@ -18,7 +18,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class FormatTextGenerationSFT(Step): diff --git a/src/distilabel/steps/generators/data.py b/src/distilabel/steps/generators/data.py index 803ee35eac..3b43e97c8f 100644 --- a/src/distilabel/steps/generators/data.py +++ b/src/distilabel/steps/generators/data.py @@ -20,7 +20,7 @@ from distilabel.steps.base import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import GeneratorStepOutput class LoadDataFromDicts(GeneratorStep): diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py index 721b3d4081..6c3b821a33 100644 --- a/src/distilabel/steps/generators/huggingface.py +++ b/src/distilabel/steps/generators/huggingface.py @@ -47,7 +47,7 @@ from distilabel.steps.base import GeneratorStep if TYPE_CHECKING: - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import GeneratorStepOutput T = TypeVar("T") diff --git a/src/distilabel/steps/globals/huggingface.py b/src/distilabel/steps/globals/huggingface.py index 82e7f35ab6..e9723f520d 100644 --- a/src/distilabel/steps/globals/huggingface.py +++ b/src/distilabel/steps/globals/huggingface.py @@ -23,7 +23,7 @@ from distilabel.steps.base import GlobalStep, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class PushToHub(GlobalStep): diff --git a/src/distilabel/steps/reward_model.py b/src/distilabel/steps/reward_model.py index 0af5d5cfdd..87fef02264 100644 --- a/src/distilabel/steps/reward_model.py +++ b/src/distilabel/steps/reward_model.py @@ -25,8 +25,7 @@ import torch from transformers import PreTrainedModel, PreTrainedTokenizer - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, StepColumns, StepOutput class RewardModelScore(Step, CudaDevicePlacementMixin): diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index f542aea232..977e663992 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -16,7 +16,7 @@ from distilabel.steps.tasks.apigen.generator import APIGenGenerator from distilabel.steps.tasks.apigen.semantic_checker import APIGenSemanticChecker from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller -from distilabel.steps.tasks.base import GeneratorTask, Task +from distilabel.steps.tasks.base import GeneratorTask, ImageTask, Task from distilabel.steps.tasks.clair import CLAIR from distilabel.steps.tasks.complexity_scorer import ComplexityScorer from distilabel.steps.tasks.decorator import task @@ -29,6 +29,7 @@ from distilabel.steps.tasks.evol_quality.base import EvolQuality from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings from distilabel.steps.tasks.genstruct import Genstruct +from distilabel.steps.tasks.image_generation import ImageGeneration from distilabel.steps.tasks.improving_text_embeddings import ( BitextRetrievalGenerator, EmbeddingTaskGenerator, @@ -55,9 +56,9 @@ from distilabel.steps.tasks.text_classification import TextClassification from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.text_generation_with_image import TextGenerationWithImage -from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback from distilabel.steps.tasks.urial import URIAL +from distilabel.typing import ChatItem, ChatType __all__ = [ "CLAIR", @@ -66,6 +67,7 @@ "APIGenGenerator", "APIGenSemanticChecker", "ArgillaLabeller", + "ArgillaLabeller", "BitextRetrievalGenerator", "ChatGeneration", "ChatItem", @@ -86,18 +88,22 @@ "GenerateTextRetrievalData", "GeneratorTask", "Genstruct", + "ImageGeneration", + "ImageTask", "InstructionBacktranslation", "Magpie", "MagpieGenerator", "MathShepherdCompleter", "MathShepherdGenerator", "MonolingualTripletGenerator", + "MonolingualTripletGenerator", "PairRM", "PrometheusEval", "QualityScorer", "SelfInstruct", "StructuredGeneration", "Task", + "Task", "TextClassification", "TextGeneration", "TextGenerationWithImage", diff --git a/src/distilabel/steps/tasks/apigen/execution_checker.py b/src/distilabel/steps/tasks/apigen/execution_checker.py index 7d30dd1f75..7cd597e88e 100644 --- a/src/distilabel/steps/tasks/apigen/execution_checker.py +++ b/src/distilabel/steps/tasks/apigen/execution_checker.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from types import ModuleType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class APIGenExecutionChecker(Step): diff --git a/src/distilabel/steps/tasks/apigen/generator.py b/src/distilabel/steps/tasks/apigen/generator.py index 39f202d065..941c7b3ea4 100644 --- a/src/distilabel/steps/tasks/apigen/generator.py +++ b/src/distilabel/steps/tasks/apigen/generator.py @@ -26,8 +26,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT_API_GEN: Final[str] = """\ diff --git a/src/distilabel/steps/tasks/apigen/semantic_checker.py b/src/distilabel/steps/tasks/apigen/semantic_checker.py index c5cf0b183b..c5e7582313 100644 --- a/src/distilabel/steps/tasks/apigen/semantic_checker.py +++ b/src/distilabel/steps/tasks/apigen/semantic_checker.py @@ -24,8 +24,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT_SEMANTIC_CHECKER: Final[str] = """\ diff --git a/src/distilabel/steps/tasks/apigen/utils.py b/src/distilabel/steps/tasks/apigen/utils.py index 85ff0b764c..7e07997b06 100644 --- a/src/distilabel/steps/tasks/apigen/utils.py +++ b/src/distilabel/steps/tasks/apigen/utils.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from types import ModuleType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class PrepareExamples(Step): diff --git a/src/distilabel/steps/tasks/argilla_labeller.py b/src/distilabel/steps/tasks/argilla_labeller.py index 1888087e8d..c3fae412c0 100644 --- a/src/distilabel/steps/tasks/argilla_labeller.py +++ b/src/distilabel/steps/tasks/argilla_labeller.py @@ -40,8 +40,7 @@ TextQuestion, ) - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepOutput + from distilabel.typing import ChatType, StepOutput class ArgillaLabeller(Task): diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index ae19a1038f..3a575545d1 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -22,6 +22,7 @@ from distilabel.constants import DISTILABEL_METADATA_KEY from distilabel.errors import DistilabelUserError from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.image_generation.base import ImageGenerationModel from distilabel.models.llms.base import LLM from distilabel.steps.base import ( GeneratorStep, @@ -33,9 +34,13 @@ from distilabel.utils.dicts import group_dicts if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput, LLMStatistics - from distilabel.steps.tasks.typing import ChatType, FormattedInput - from distilabel.steps.typing import StepOutput + from distilabel.typing import ( + ChatType, + FormattedInput, + GenerateOutput, + LLMStatistics, + StepOutput, + ) class _Task(_Step, ABC): @@ -491,6 +496,102 @@ class GlobalTask(_Task, GlobalStep): pass +class ImageTask(_Task, Step): + """`ImageTask` is a class that implements the `_Task` abstract class and adds the `Step` + interface to be used as a step in the pipeline. It differs from the `Task` in that it's + expected to work with `ImageGenerationModel`s instead of `LLM`s. + + Attributes: + image_generation_model: the `ImageGenerationModel` to be used to generate the outputs. + llm: This attribute is here to respect the `_Task` interface, but it's used internally only. + group_generations: whether to group the `num_generations` generated per input in + a list or create a row per generation. Defaults to `False`. + num_generations: The number of generations to be produced per input. + """ + + llm: Union[LLM, ImageGenerationModel, None] = None + image_generation_model: ImageGenerationModel + + def model_post_init(self, __context: Any) -> None: + assert self.llm is None, ( + "`ImageTask` cannot use an `LLM` attribute given by the user, pass " + "the `image_generation_model` attribute instead." + ) + self.llm = self.image_generation_model + # Call the post init from the Step, as we don't want to call specific behaviour + # from the task, that may need to deal with specific attributes from the LLM + # not in the ImageGenerationModel + super(Step, self).model_post_init(__context) + + @abstractmethod + def format_input(self, input: dict[str, any]) -> str: + """Abstract method to format the inputs of the task. It needs to receive an input + as a Python dictionary, and generates a string to be used as the prompt for the model.""" + pass + + def _format_inputs(self, inputs: list[dict[str, any]]) -> List["FormattedInput"]: + """Formats the inputs of the task using the `format_input` method. + + Args: + inputs: A list of Python dictionaries with the inputs of the task. + + Returns: + A list containing the formatted inputs, which are `ChatType`-like following + the OpenAI formatting. + """ + return [self.format_input(input) for input in inputs] + + def _format_outputs( + self, + outputs: list[Union[str, None]], + input: Union[Dict[str, Any], None] = None, + ) -> List[Dict[str, Any]]: + """Formats the outputs of the task using the `format_output` method. If the output + is `None` (i.e. the LLM failed to generate a response), then the outputs will be + set to `None` as well. + + Args: + outputs: The outputs (`n` generations) for the provided `input`. + input: The input used to generate the output. + + Returns: + A list containing a dictionary with the outputs of the task for each input. + """ + inputs = [None] if input is None else [input] + formatted_outputs = [] + + for output, input in zip(outputs, inputs): # type: ignore + try: + formatted_output = self.format_output(output, input) + formatted_output = self._create_metadata( + formatted_output, + output, + input, + add_raw_output=self.add_raw_output, # type: ignore + add_raw_input=self.add_raw_input, # type: ignore + statistics=None, + ) + formatted_outputs.append(formatted_output) + except Exception as e: + self._logger.warning( # type: ignore + f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore + ) + formatted_outputs.append(self._output_on_failure(output, input)) + return formatted_outputs + + @abstractmethod + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + """Processes the inputs of the task and generates the outputs using the `ImageGenerationModel`. + + Args: + inputs: A list of Python dictionaries with the inputs of the task. + + Yields: + A list of Python dictionaries with the outputs of the task. + """ + pass + + def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": """Transforms the GenerateOutput statistics to have the same length as the generations. diff --git a/src/distilabel/steps/tasks/clair.py b/src/distilabel/steps/tasks/clair.py index 524a1d76c9..b619ef9dbb 100644 --- a/src/distilabel/steps/tasks/clair.py +++ b/src/distilabel/steps/tasks/clair.py @@ -21,8 +21,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT: Final[str] = ( diff --git a/src/distilabel/steps/tasks/complexity_scorer.py b/src/distilabel/steps/tasks/complexity_scorer.py index bd8a99c6b0..d36c7f1d07 100644 --- a/src/distilabel/steps/tasks/complexity_scorer.py +++ b/src/distilabel/steps/tasks/complexity_scorer.py @@ -30,7 +30,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType _PARSE_SCORE_LINE_REGEX = re.compile(r"\[\d+\] score: (\d+)", re.IGNORECASE) diff --git a/src/distilabel/steps/tasks/decorator.py b/src/distilabel/steps/tasks/decorator.py index 8862734f8c..c9752f247c 100644 --- a/src/distilabel/steps/tasks/decorator.py +++ b/src/distilabel/steps/tasks/decorator.py @@ -20,10 +20,10 @@ from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import FormattedInput +from distilabel.typing import FormattedInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns + from distilabel.typing import StepColumns TaskFormattingOutputFunc = Callable[..., Dict[str, Any]] diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py index f1a44d6a84..eae066c690 100644 --- a/src/distilabel/steps/tasks/evol_instruct/base.py +++ b/src/distilabel/steps/tasks/evol_instruct/base.py @@ -23,12 +23,11 @@ from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task from distilabel.steps.tasks.evol_instruct.utils import MUTATION_TEMPLATES -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.typing import StepOutput + from distilabel.typing import LLMStatistics, StepOutput class EvolInstruct(Task): diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index 6f985464eb..415654ba12 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -33,9 +33,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import GeneratorStepOutput + from distilabel.typing import ChatType, GeneratorStepOutput, LLMStatistics class EvolInstructGenerator(GeneratorTask): diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py index 8ea7061105..41441381df 100644 --- a/src/distilabel/steps/tasks/evol_quality/base.py +++ b/src/distilabel/steps/tasks/evol_quality/base.py @@ -23,10 +23,10 @@ from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task from distilabel.steps.tasks.evol_quality.utils import MUTATION_TEMPLATES -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class EvolQuality(Task): diff --git a/src/distilabel/steps/tasks/generate_embeddings.py b/src/distilabel/steps/tasks/generate_embeddings.py index f73ee1b2b3..bedc8c5419 100644 --- a/src/distilabel/steps/tasks/generate_embeddings.py +++ b/src/distilabel/steps/tasks/generate_embeddings.py @@ -20,8 +20,7 @@ from distilabel.utils.chat import is_openai_format if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, StepColumns, StepOutput class GenerateEmbeddings(Step): diff --git a/src/distilabel/steps/tasks/genstruct.py b/src/distilabel/steps/tasks/genstruct.py index 2b9c307d5b..e63a75f704 100644 --- a/src/distilabel/steps/tasks/genstruct.py +++ b/src/distilabel/steps/tasks/genstruct.py @@ -28,7 +28,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType _PARSE_GENSTRUCT_OUTPUT_REGEX = r"(.+?)\[\[\[Assistant\]\]\](.+)$" diff --git a/src/distilabel/steps/tasks/image_generation.py b/src/distilabel/steps/tasks/image_generation.py new file mode 100644 index 0000000000..3484b90058 --- /dev/null +++ b/src/distilabel/steps/tasks/image_generation.py @@ -0,0 +1,188 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from typing import TYPE_CHECKING + +from distilabel.models.image_generation.utils import image_from_str +from distilabel.steps.base import StepInput +from distilabel.steps.tasks.base import ImageTask + +if TYPE_CHECKING: + from distilabel.typing import StepColumns, StepOutput + + +class ImageGeneration(ImageTask): + """Image generation with an image to text model given a prompt. + + `ImageGeneration` is a pre-defined task that allows generating images from a prompt. + It works with any of the `image_generation` defined under `distilabel.models.image_generation`, + the models implemented models that allow image generation. + By default, the images are generated as a base64 string format, and after the dataset + has been generated, the images can be automatically transformed to `PIL.Image.Image` using + `Distiset.transform_columns_to_image`. Take a look at the `Image Generation with distilabel` + example in the documentation for more information. + Using the `save_artifacts` attribute, the images can be saved on the artifacts folder in the + hugging face hub repository. + + Attributes: + save_artifacts: Bool value to save the image artifacts on its folder. + Otherwise, the base64 representation of the image will be saved as + a string. Defaults to False. + image_format: Any of the formats supported by PIL. Defaults to `JPEG`. + + Input columns: + - prompt (str): A column named prompt with the prompts to generate the images. + + Output columns: + - image (`str`): The generated image. Initially is a base64 string, for simplicity + during the pipeline run, but this can be transformed to an Image object after + distiset is returned at the end of a pipeline by calling + `distiset.transform_columns_to_image()`. + - image_path (`str`): The path where the image is saved. Only available if `save_artifacts` + is True. + - model_name (`str`): The name of the model used to generate the image. + + Categories: + - image-generation + + Examples: + Generate an image from a prompt: + + ```python + from distilabel.steps.tasks import ImageGeneration + from distilabel.models.image_generation import InferenceEndpointsImageGeneration + + igm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell" + ) + + # save_artifacts=True by default in JPEG format, if set to False, the image will be saved as a string. + image_gen = ImageGeneration(image_generation_model=igm) + + image_gen.load() + + result = next( + image_gen.process( + [{"prompt": "a white siamese cat"}] + ) + ) + ``` + + Generate an image and save them as artifacts in a Hugging Face Hub repository: + + ```python + from distilabel.steps.tasks import ImageGeneration + # Select the Image Generation model to use + from distilabel.models.image_generation import OpenAIImageGeneration + + igm = OpenAIImageGeneration( + model="dall-e-3", + api_key="api.key", + generation_kwargs={ + "size": "1024x1024", + "quality": "standard", + "style": "natural" + } + ) + + # save_artifacts=True by default in JPEG format, if set to False, the image will be saved as a string. + image_gen = ImageGeneration( + image_generation_model=igm, + save_artifacts=True, + image_format="JPEG" # By default will use JPEG, the options available can be seen in PIL documentation. + ) + + image_gen.load() + + result = next( + image_gen.process( + [{"prompt": "a white siamese cat"}] + ) + ) + ``` + """ + + save_artifacts: bool = False + image_format: str = "JPEG" + + @property + def inputs(self) -> "StepColumns": + return ["prompt"] + + @property + def outputs(self) -> "StepColumns": + return { + "image": True, + "image_path": False, + "model_name": True, + } + + def format_input(self, input: dict[str, any]) -> str: + return input["prompt"] + + def format_output( + self, output: dict[str, any], input: dict[str, any] + ) -> dict[str, any]: + image = None + if img_str := output.get("images"): + image = img_str[0] # Grab only the first image + + return {"image": image, "model_name": self.llm.model_name} + + def save(self, **kwargs): + if not self.save_artifacts: + from distilabel.utils.serialization import _Serializable + + super(_Serializable).save(**kwargs) + + def process(self, inputs: StepInput) -> "StepOutput": + formatted_inputs = self._format_inputs(inputs) + + outputs = self.llm.generate_outputs( + inputs=formatted_inputs, + num_generations=self.num_generations, + **self.llm.get_generation_kwargs(), + ) + + task_outputs = [] + for input, input_outputs in zip(inputs, outputs): + formatted_outputs = self._format_outputs(input_outputs, input) + for formatted_output in formatted_outputs: + if self.save_artifacts and ( + image := formatted_output.get("image", None) + ): + # use prompt as filename + prompt_hash = hashlib.md5(input["prompt"].encode()).hexdigest() + # Build PIL image to save it + image = image_from_str(image) + + self.save_artifact( + name="images", + write_function=lambda path, + prompt_hash=prompt_hash, + img=image: img.save( + path / f"{prompt_hash}.{self.image_format.lower()}", + format=self.image_format, + ), + metadata={"type": "image"}, + ) + formatted_output["image_path"] = ( + f"artifacts/{self.name}/images/{prompt_hash}.{self.image_format.lower()}" + ) + + task_outputs.append( + {**input, **formatted_output, "model_name": self.llm.model_name} + ) + yield task_outputs diff --git a/src/distilabel/steps/tasks/improving_text_embeddings.py b/src/distilabel/steps/tasks/improving_text_embeddings.py index d806e3aded..8569c12810 100644 --- a/src/distilabel/steps/tasks/improving_text_embeddings.py +++ b/src/distilabel/steps/tasks/improving_text_embeddings.py @@ -23,8 +23,7 @@ from typing_extensions import override from distilabel.steps.tasks.base import GeneratorTask, Task -from distilabel.steps.tasks.typing import ChatType -from distilabel.steps.typing import GeneratorStepOutput +from distilabel.typing import ChatType, GeneratorStepOutput # BASE CLASSES diff --git a/src/distilabel/steps/tasks/instruction_backtranslation.py b/src/distilabel/steps/tasks/instruction_backtranslation.py index a0420ef8f3..52405406ed 100644 --- a/src/distilabel/steps/tasks/instruction_backtranslation.py +++ b/src/distilabel/steps/tasks/instruction_backtranslation.py @@ -26,7 +26,7 @@ from pydantic import PrivateAttr from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType class InstructionBacktranslation(Task): diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index 265497409c..13c4f9f0be 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -30,9 +30,7 @@ from distilabel.utils.dicts import merge_dicts if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, LLMStatistics, StepColumns, StepOutput MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py index c9d18d9fca..c1fcd14828 100644 --- a/src/distilabel/steps/tasks/magpie/generator.py +++ b/src/distilabel/steps/tasks/magpie/generator.py @@ -24,8 +24,7 @@ from distilabel.steps.tasks.magpie.base import MagpieBase if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import GeneratorStepOutput, StepColumns + from distilabel.typing import ChatType, GeneratorStepOutput, StepColumns class MagpieGenerator(GeneratorTask, MagpieBase): diff --git a/src/distilabel/steps/tasks/math_shepherd/completer.py b/src/distilabel/steps/tasks/math_shepherd/completer.py index 5d3fdd7e15..05ff410ac5 100644 --- a/src/distilabel/steps/tasks/math_shepherd/completer.py +++ b/src/distilabel/steps/tasks/math_shepherd/completer.py @@ -28,9 +28,7 @@ from distilabel.utils.itertools import batched if TYPE_CHECKING: - from distilabel.models.llms.typing import LLMStatistics - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import ChatType, LLMStatistics, StepColumns, StepOutput SYSTEM_PROMPT = """\ diff --git a/src/distilabel/steps/tasks/math_shepherd/generator.py b/src/distilabel/steps/tasks/math_shepherd/generator.py index d9ab565e54..efcd986549 100644 --- a/src/distilabel/steps/tasks/math_shepherd/generator.py +++ b/src/distilabel/steps/tasks/math_shepherd/generator.py @@ -25,8 +25,7 @@ ) if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns SYSTEM_PROMPT = """\ diff --git a/src/distilabel/steps/tasks/math_shepherd/utils.py b/src/distilabel/steps/tasks/math_shepherd/utils.py index 978496996f..8a04f325b5 100644 --- a/src/distilabel/steps/tasks/math_shepherd/utils.py +++ b/src/distilabel/steps/tasks/math_shepherd/utils.py @@ -20,7 +20,7 @@ from distilabel.steps.base import Step, StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput def split_solution_steps(text: str) -> list[str]: diff --git a/src/distilabel/steps/tasks/pair_rm.py b/src/distilabel/steps/tasks/pair_rm.py index 23262a533f..4def62615f 100644 --- a/src/distilabel/steps/tasks/pair_rm.py +++ b/src/distilabel/steps/tasks/pair_rm.py @@ -20,7 +20,7 @@ from distilabel.steps.tasks.base import Step if TYPE_CHECKING: - from distilabel.steps.typing import StepColumns, StepOutput + from distilabel.typing import StepColumns, StepOutput class PairRM(Step): diff --git a/src/distilabel/steps/tasks/prometheus_eval.py b/src/distilabel/steps/tasks/prometheus_eval.py index 4c61c416be..36e8e6ac9f 100644 --- a/src/distilabel/steps/tasks/prometheus_eval.py +++ b/src/distilabel/steps/tasks/prometheus_eval.py @@ -30,7 +30,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType _DEFAULT_RUBRICS = { diff --git a/src/distilabel/steps/tasks/quality_scorer.py b/src/distilabel/steps/tasks/quality_scorer.py index efafda2b7a..81dc0c1632 100644 --- a/src/distilabel/steps/tasks/quality_scorer.py +++ b/src/distilabel/steps/tasks/quality_scorer.py @@ -28,7 +28,7 @@ from typing_extensions import override from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType _PARSE_SCORE_LINE_REGEX = re.compile(r"\[\d+\] score: (\d+)", re.IGNORECASE) diff --git a/src/distilabel/steps/tasks/self_instruct.py b/src/distilabel/steps/tasks/self_instruct.py index dcca46ee67..2b36740c7a 100644 --- a/src/distilabel/steps/tasks/self_instruct.py +++ b/src/distilabel/steps/tasks/self_instruct.py @@ -27,7 +27,7 @@ from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType class SelfInstruct(Task): diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py index 350849e3d0..a4c2feb20b 100644 --- a/src/distilabel/steps/tasks/sentence_transformers.py +++ b/src/distilabel/steps/tasks/sentence_transformers.py @@ -28,7 +28,7 @@ import importlib.resources as importlib_resources if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType GenerationAction = Literal["paraphrase", "semantically-similar", "query", "answer"] diff --git a/src/distilabel/steps/tasks/structured_generation.py b/src/distilabel/steps/tasks/structured_generation.py index 905a6672d0..92eb3fd9e1 100644 --- a/src/distilabel/steps/tasks/structured_generation.py +++ b/src/distilabel/steps/tasks/structured_generation.py @@ -17,7 +17,7 @@ from distilabel.errors import DistilabelUserError from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import StructuredInput +from distilabel.typing import StructuredInput class StructuredGeneration(Task): diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index a5aceacb3b..45b5fe7494 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -39,7 +39,7 @@ from transformers import Pipeline # noqa from vllm import LLM as _vLLM # noqa - from distilabel.steps.tasks.typing import OutlinesStructuredOutputType # noqa + from distilabel.typing import OutlinesStructuredOutputType # noqa Frameworks = Literal["transformers", "llamacpp", "vllm"] diff --git a/src/distilabel/steps/tasks/text_classification.py b/src/distilabel/steps/tasks/text_classification.py index 19df530fb6..ec032241f5 100644 --- a/src/distilabel/steps/tasks/text_classification.py +++ b/src/distilabel/steps/tasks/text_classification.py @@ -23,7 +23,7 @@ from distilabel.steps.tasks import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType TEXT_CLASSIFICATION_TEMPLATE: str = """\ diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py index b6620430cc..59cf932423 100644 --- a/src/distilabel/steps/tasks/text_generation.py +++ b/src/distilabel/steps/tasks/text_generation.py @@ -23,8 +23,7 @@ from distilabel.utils.template import check_column_in_template if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns class TextGeneration(Task): diff --git a/src/distilabel/steps/tasks/text_generation_with_image.py b/src/distilabel/steps/tasks/text_generation_with_image.py index 8494afc9db..8aee386f80 100644 --- a/src/distilabel/steps/tasks/text_generation_with_image.py +++ b/src/distilabel/steps/tasks/text_generation_with_image.py @@ -28,8 +28,7 @@ if TYPE_CHECKING: from PIL.Image import Image - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns class TextGenerationWithImage(TextGeneration): diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py index bac144f54d..1139254abb 100644 --- a/src/distilabel/steps/tasks/ultrafeedback.py +++ b/src/distilabel/steps/tasks/ultrafeedback.py @@ -22,7 +22,7 @@ from typing_extensions import override from distilabel.steps.tasks.base import Task -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType from distilabel.utils.dicts import group_dicts diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py index 24b643ada6..b49c5d9f36 100644 --- a/src/distilabel/steps/tasks/urial.py +++ b/src/distilabel/steps/tasks/urial.py @@ -20,8 +20,7 @@ from distilabel.steps.tasks import Task if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType - from distilabel.steps.typing import StepColumns + from distilabel.typing import ChatType, StepColumns class URIAL(Task): diff --git a/src/distilabel/typing.py b/src/distilabel/typing/__init__.py similarity index 72% rename from src/distilabel/typing.py rename to src/distilabel/typing/__init__.py index a3d65d5d75..ec65f26878 100644 --- a/src/distilabel/typing.py +++ b/src/distilabel/typing/__init__.py @@ -12,26 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from distilabel.models.llms.typing import GenerateOutput -from distilabel.pipeline.typing import ( - DownstreamConnectable, - DownstreamConnectableSteps, - InputDataset, - PipelineRuntimeParametersInfo, - StepLoadStatus, - UpstreamConnectableSteps, -) -from distilabel.steps.tasks.typing import ( +from distilabel.typing.base import ( ChatItem, ChatType, + ImageContent, + ImageUrl, + TextContent, +) +from distilabel.typing.models import ( FormattedInput, + GenerateOutput, + HiddenState, InstructorStructuredOutputType, + LLMLogprobs, + LLMOutput, + LLMStatistics, + Logprob, OutlinesStructuredOutputType, StandardInput, StructuredInput, StructuredOutputType, + TokenCount, +) +from distilabel.typing.pipeline import ( + DownstreamConnectable, + DownstreamConnectableSteps, + InputDataset, + LoadGroups, + PipelineRuntimeParametersInfo, + StepLoadStatus, + UpstreamConnectableSteps, ) -from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput +from distilabel.typing.steps import GeneratorStepOutput, StepColumns, StepOutput __all__ = [ "ChatItem", @@ -41,8 +53,16 @@ "FormattedInput", "GenerateOutput", "GeneratorStepOutput", + "HiddenState", + "ImageContent", + "ImageUrl", "InputDataset", "InstructorStructuredOutputType", + "LLMLogprobs", + "LLMOutput", + "LLMStatistics", + "LoadGroups", + "Logprob", "OutlinesStructuredOutputType", "PipelineRuntimeParametersInfo", "StandardInput", @@ -51,5 +71,7 @@ "StepOutput", "StructuredInput", "StructuredOutputType", + "TextContent", + "TokenCount", "UpstreamConnectableSteps", ] diff --git a/src/distilabel/typing/base.py b/src/distilabel/typing/base.py new file mode 100644 index 0000000000..16645c0957 --- /dev/null +++ b/src/distilabel/typing/base.py @@ -0,0 +1,46 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Literal, Union + +from typing_extensions import Required, TypedDict + + +class TextContent(TypedDict, total=False): + type: Required[Literal["text"]] + text: Required[str] + + +class ImageUrl(TypedDict): + url: Required[str] + """Either a URL of the image or the base64 encoded image data.""" + + +class ImageContent(TypedDict, total=False): + """Type alias for the user's message in a conversation that can include text or an image. + It's the standard type for vision language models: + https://platform.openai.com/docs/guides/vision + """ + + type: Required[Literal["image_url"]] + image_url: Required[ImageUrl] + + +class ChatItem(TypedDict): + role: Literal["system", "user", "assistant"] + content: Union[str, list[Union[TextContent, ImageContent]]] + + +ChatType = List[ChatItem] +"""ChatType is a type alias for a `list` of `dict`s following the OpenAI conversational format.""" diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/typing/models.py similarity index 66% rename from src/distilabel/steps/tasks/typing.py rename to src/distilabel/typing/models.py index d0d22a6811..aa11305421 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/typing/models.py @@ -12,39 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from pydantic import BaseModel -from typing_extensions import Required, TypedDict +from typing_extensions import NotRequired, TypedDict +from distilabel.typing.base import ChatType -class TextContent(TypedDict, total=False): - type: Required[Literal["text"]] - text: Required[str] +LLMOutput = List[Union[str, None]] -class ImageUrl(TypedDict): - url: Required[str] - """Either a URL of the image or the base64 encoded image data.""" +class Logprob(TypedDict): + token: str + logprob: float -class ImageContent(TypedDict, total=False): - """Type alias for the user's message in a conversation that can include text or an image. - It's the standard type for vision language models: - https://platform.openai.com/docs/guides/vision - """ +LLMLogprobs = List[List[List[Logprob]]] +"""A type alias representing the probability distributions output by an `LLM`. + +Structure: + - Outermost list: contains multiple generation choices when sampling (`n` sequences) + - Middle list: represents each position in the generated sequence + - Innermost list: contains the log probabilities for each token in the vocabulary at that position +""" - type: Required[Literal["image_url"]] - image_url: Required[ImageUrl] +class TokenCount(TypedDict): + input_tokens: List[int] + output_tokens: List[int] -class ChatItem(TypedDict): - role: Literal["system", "user", "assistant"] - content: Union[str, list[Union[TextContent, ImageContent]]] +LLMStatistics = Union[TokenCount, Dict[str, Any]] +"""Initially the LLMStatistics will contain the token count, but can have more variables. +They can be added once we have them defined for every LLM. +""" -ChatType = List[ChatItem] -"""ChatType is a type alias for a `list` of `dict`s following the OpenAI conversational format.""" + +class GenerateOutput(TypedDict): + generations: LLMOutput + statistics: LLMStatistics + logprobs: NotRequired[LLMLogprobs] class OutlinesStructuredOutputType(TypedDict, total=False): @@ -84,11 +102,21 @@ class InstructorStructuredOutputType(TypedDict, total=False): OutlinesStructuredOutputType, InstructorStructuredOutputType ] """StructuredOutputType is an alias for the union of `OutlinesStructuredOutputType` and `InstructorStructuredOutputType`.""" - StandardInput = ChatType """StandardInput is an alias for ChatType that defines the default / standard input produced by `format_input`.""" StructuredInput = Tuple[StandardInput, Union[StructuredOutputType, None]] """StructuredInput defines a type produced by `format_input` when using either `StructuredGeneration` or a subclass of it.""" -FormattedInput = Union[StandardInput, StructuredInput, ChatType] +FormattedInput = Union[StandardInput, StructuredInput] """FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated by `format_input` and expected by the `LLM`s, as well as `ConversationType` for the vision language models.""" + + +if TYPE_CHECKING: + from numpy import floating + from numpy.typing import NDArray + + GenericFloat = TypeVar("GenericFloat", bound=floating[Any]) + + HiddenState = NDArray[GenericFloat] +else: + HiddenState = Any diff --git a/src/distilabel/pipeline/typing.py b/src/distilabel/typing/pipeline.py similarity index 98% rename from src/distilabel/pipeline/typing.py rename to src/distilabel/typing/pipeline.py index 3e796948aa..3824cbf116 100644 --- a/src/distilabel/pipeline/typing.py +++ b/src/distilabel/typing/pipeline.py @@ -18,11 +18,12 @@ Dict, List, Literal, - TypedDict, TypeVar, Union, ) +from typing_extensions import TypedDict + if TYPE_CHECKING: import pandas as pd from datasets import Dataset diff --git a/src/distilabel/steps/typing.py b/src/distilabel/typing/steps.py similarity index 100% rename from src/distilabel/steps/typing.py rename to src/distilabel/typing/steps.py diff --git a/src/distilabel/utils/export_components_info.py b/src/distilabel/utils/export_components_info.py index 00144fd041..dcf9e0ecd7 100644 --- a/src/distilabel/utils/export_components_info.py +++ b/src/distilabel/utils/export_components_info.py @@ -16,6 +16,7 @@ from typing import Generator, List, Type, TypedDict, TypeVar from distilabel.models.embeddings.base import Embeddings +from distilabel.models.image_generation.base import ImageGenerationModel from distilabel.models.llms.base import LLM from distilabel.steps.base import _Step from distilabel.steps.tasks.base import _Task @@ -28,6 +29,7 @@ class ComponentsInfo(TypedDict): """A dictionary containing `distilabel` components information.""" llms: List + image_generation_models: List steps: List tasks: List embeddings: List @@ -55,6 +57,10 @@ def export_components_info() -> ComponentsInfo: {"name": llm_type.__name__, "docstring": parse_google_docstring(llm_type)} for llm_type in _get_llms() ], + "image_generation_models": [ + {"name": igm_type.__name__, "docstring": parse_google_docstring(igm_type)} + for igm_type in _get_image_generation_models() + ], "embeddings": [ { "name": embeddings_type.__name__, @@ -113,6 +119,22 @@ def _get_llms() -> List[Type["LLM"]]: ] +def _get_image_generation_models() -> List[Type["ImageGenerationModel"]]: + """Get all `ImageGenerationModel` subclasses, that are not abstract classes. + + Note: + This is a placeholder as we don't have `ImageGenerationModel` classes yet. + + Returns: + The list of all the classes under `distilabel.models.image_generation` that are not abstract classes. + """ + return [ + igm_type + for igm_type in _recursive_subclasses(ImageGenerationModel) + if not inspect.isabstract(igm_type) + ] + + def _get_embeddings() -> List[Type["Embeddings"]]: """Get all `Embeddings` subclasses, that are not abstract classes. diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index 005f74748e..7293d90e69 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -75,6 +75,7 @@ ).read() ) + _STEPS_CATEGORY_TO_ICON = { "text-generation": ":material-text-box-edit:", "chat-generation": ":material-chat:", @@ -92,6 +93,7 @@ "load": ":material-file-download:", "execution": ":octicons-code-16:", "save": ":material-content-save:", + "image-generation": ":material-image:", "labelling": ":label:", } @@ -112,6 +114,7 @@ "load": "Load steps are used to load the data.", "execution": "Executes python functions.", "save": "Save steps are used to save the data.", + "image-generation": "Image generation steps are used to generate images based on a given prompt.", "labelling": "Labelling steps are used to label the data.", } @@ -199,6 +202,12 @@ def on_files( self.file_paths["llms"] = self._generate_llms_pages( src_dir=src_dir, llms=components_info["llms"] ) + self.file_paths["image_generation_models"] = ( + self._generate_image_generation_pages( + src_dir=src_dir, + image_generation_models=components_info["image_generation_models"], + ) + ) self.file_paths["embeddings"] = self._generate_embeddings_pages( src_dir=src_dir, embeddings=components_info["embeddings"] ) @@ -209,6 +218,7 @@ def on_files( *self.file_paths["steps"], *self.file_paths["tasks"], *self.file_paths["llms"], + *self.file_paths["image_generation_models"], *self.file_paths["embeddings"], ]: file = File( @@ -429,6 +439,48 @@ def _generate_llms_pages(self, src_dir: Path, llms: list) -> List[str]: return paths + def _generate_image_generation_pages( + self, src_dir: Path, image_generation_models: list + ) -> List[str]: + """Generates the files for the `ILMs` subsection of the components gallery. + + Args: + src_dir: The path to the source directory. + image_generation_models: The list of `ImageGenerationModel` components. + + Returns: + The relative paths to the generated files. + """ + + paths = ["components-gallery/image_generation/index.md"] + steps_gallery_page_path = src_dir / paths[0] + steps_gallery_page_path.parent.mkdir(parents=True, exist_ok=True) + + # Create detail page for each `ImageGenerationModel` + for igm in image_generation_models: + content = _LLM_DETAIL_TEMPLATE.render(llm=igm) + + ilm_path = f"components-gallery/image_generation/{igm['name'].lower()}.md" + path = src_dir / ilm_path + with open(path, "w") as f: + f.write(content) + + paths.append(ilm_path) + + # Create the `components-gallery/ilms/index.md` file + content = _COMPONENTS_LIST_TEMPLATE.render( + title="Image Generation Gallery", + description="", + components=image_generation_models, + component_group="image_generation_models", + default_icon=":material-image:", + ) + + with open(steps_gallery_page_path, "w") as f: + f.write(content) + + return paths + def _generate_embeddings_pages(self, src_dir: Path, embeddings: list) -> List[str]: """Generates the files for the `Embeddings` subsection of the components gallery. @@ -491,6 +543,10 @@ def on_nav( steps_file = files.get_file_from_path(self.file_paths["steps"][0]) tasks_file = files.get_file_from_path(self.file_paths["tasks"][0]) llms_file = files.get_file_from_path(self.file_paths["llms"][0]) + image_generation_file = files.get_file_from_path( + self.file_paths["image_generation_models"][0] + ) + steps_files = [ files.get_file_from_path(path) for path in self.file_paths["steps"][0:] ] @@ -500,6 +556,10 @@ def on_nav( llms_files = [ files.get_file_from_path(path) for path in self.file_paths["llms"][0:] ] + image_generation_files = [ + files.get_file_from_path(path) + for path in self.file_paths["image_generation_models"][0:] + ] # Create subsections steps_page = SectionPage( @@ -511,13 +571,19 @@ def on_nav( llms_page = SectionPage( "LLMs", file=llms_file, config=config, children=llms_files ) # type: ignore + igms_page = SectionPage( + "ImageGenerationModels", + file=image_generation_file, + config=config, + children=image_generation_files, + ) # type: ignore # Create the gallery section page = SectionPage( title=self.config.page_title, file=components_gallery_file, config=config, - children=[steps_page, tasks_page, llms_page], + children=[steps_page, tasks_page, llms_page, igms_page], ) # Add the page diff --git a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md index cc3e44aecf..96f09dcc32 100644 --- a/src/distilabel/utils/mkdocs/templates/components-gallery/index.md +++ b/src/distilabel/utils/mkdocs/templates/components-gallery/index.md @@ -31,6 +31,14 @@ hide: [:octicons-arrow-right-24: LLMs](llms/index.md){ .bottom } +- :material-image:{ .lg .middle } __ImageGenerationModels__ + + --- + + Explore all the available `ImageGenerationModels`s integrated with `distilabel`. + + [:octicons-arrow-right-24: ImageGenerationModels](image_generation/index.md){ .bottom } + - :material-vector-line:{ .lg .middle } __Embeddings__ --- diff --git a/tests/integration/test_dataset_without_step.py b/tests/integration/test_dataset_without_step.py index b71631c27e..820793355f 100644 --- a/tests/integration/test_dataset_without_step.py +++ b/tests/integration/test_dataset_without_step.py @@ -21,7 +21,7 @@ from distilabel.pipeline import Pipeline from distilabel.steps import make_generator_step from distilabel.steps.base import Step, StepInput -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput if TYPE_CHECKING: pass diff --git a/tests/integration/test_embedding_dedup.py b/tests/integration/test_embedding_dedup.py index 7806cf6761..7ff02f3d70 100644 --- a/tests/integration/test_embedding_dedup.py +++ b/tests/integration/test_embedding_dedup.py @@ -22,7 +22,7 @@ from distilabel.steps.filtering.embedding import EmbeddingDedup if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput SAMPLE_DATA = [ diff --git a/tests/integration/test_load_stages.py b/tests/integration/test_load_stages.py index 9faa771d77..fa7806a6eb 100644 --- a/tests/integration/test_load_stages.py +++ b/tests/integration/test_load_stages.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from distilabel.pipeline.batch import _Batch - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput routing_batch_function = sample_n_steps(2) diff --git a/tests/integration/test_multiple_replicas.py b/tests/integration/test_multiple_replicas.py index 26d0f19b57..210a338e35 100644 --- a/tests/integration/test_multiple_replicas.py +++ b/tests/integration/test_multiple_replicas.py @@ -22,7 +22,7 @@ from distilabel.steps import LoadDataFromDicts, StepInput, StepResources, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @step(outputs=["generation"]) diff --git a/tests/integration/test_offline_batch_generation.py b/tests/integration/test_offline_batch_generation.py index ae34d04159..7c81e94663 100644 --- a/tests/integration/test_offline_batch_generation.py +++ b/tests/integration/test_offline_batch_generation.py @@ -22,8 +22,7 @@ from distilabel.steps.tasks import TextGeneration if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput class DummyOfflineBatchGenerateLLM(LLM): diff --git a/tests/integration/test_pipe_llms.py b/tests/integration/test_pipe_llms.py index c95af1ac3f..b148d00b79 100644 --- a/tests/integration/test_pipe_llms.py +++ b/tests/integration/test_pipe_llms.py @@ -24,7 +24,7 @@ from distilabel.steps.tasks.text_generation import TextGeneration if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class RenameColumns(Step): diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py index eee334677e..fd4548700b 100644 --- a/tests/integration/test_pipe_simple.py +++ b/tests/integration/test_pipe_simple.py @@ -21,7 +21,7 @@ from distilabel.steps.generators.data import LoadDataFromDicts if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput DATA = [ {"prompt": "Tell me a joke"}, diff --git a/tests/integration/test_ray_pipeline.py b/tests/integration/test_ray_pipeline.py index 241232b0cf..b29c7f454b 100644 --- a/tests/integration/test_ray_pipeline.py +++ b/tests/integration/test_ray_pipeline.py @@ -22,7 +22,7 @@ from distilabel.steps.generators.data import LoadDataFromDicts if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput DATA = [ {"prompt": "Tell me a joke"}, diff --git a/tests/integration/test_routing_batch_function.py b/tests/integration/test_routing_batch_function.py index 3a48543a8d..60e951f556 100644 --- a/tests/integration/test_routing_batch_function.py +++ b/tests/integration/test_routing_batch_function.py @@ -22,7 +22,7 @@ from distilabel.steps import LoadDataFromDicts, StepInput, step if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput @routing_batch_function() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 32f70133a2..86c8e6c33c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -20,13 +20,13 @@ import pytest from pydantic import PrivateAttr +from distilabel.models.image_generation.base import AsyncImageGenerationModel from distilabel.models.llms.base import LLM, AsyncLLM from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.base import Task if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType, FormattedInput + from distilabel.typing import ChatType, FormattedInput, GenerateOutput # Defined here too, so that the serde still works @@ -101,6 +101,29 @@ def generate( ] +class DummyAsyncImageGenerationModel(AsyncImageGenerationModel): + def load(self) -> None: + pass + + @property + def model_name(self) -> str: + return "test" + + async def agenerate( # type: ignore + self, input: str, num_generations: int = 1 + ) -> list[dict[str, Any]]: + import numpy as np + from PIL import Image + + np.random.seed(42) + arr = np.random.randint(0, 255, (100, 100, 3)) + random_image = Image.fromarray(arr, "RGB") + from distilabel.models.image_generation.utils import image_to_str + + img_str = image_to_str(random_image) + return [{"images": [img_str]} for _ in range(num_generations)] + + class DummyTask(Task): @property def inputs(self) -> List[str]: diff --git a/tests/unit/models/image_generation/__init__.py b/tests/unit/models/image_generation/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/models/image_generation/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/models/image_generation/huggingface/__init__.py b/tests/unit/models/image_generation/huggingface/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/models/image_generation/huggingface/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/models/image_generation/huggingface/test_inference_endpoints.py b/tests/unit/models/image_generation/huggingface/test_inference_endpoints.py new file mode 100644 index 0000000000..2ca5eeab0d --- /dev/null +++ b/tests/unit/models/image_generation/huggingface/test_inference_endpoints.py @@ -0,0 +1,59 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import AsyncMock, MagicMock, patch + +import nest_asyncio +import numpy as np +import pytest +from PIL import Image + +from distilabel.models.image_generation.huggingface.inference_endpoints import ( + InferenceEndpointsImageGeneration, +) + + +@patch("huggingface_hub.AsyncInferenceClient") +class TestInferenceEndpointsImageGeneration: + @pytest.mark.asyncio + async def test_agenerate(self, mock_inference_client: MagicMock) -> None: + igm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell", + api_key="api.key", + ) + igm.load() + + arr = np.random.randint(0, 255, (100, 100, 3)) + random_image = Image.fromarray(arr, "RGB") + igm._aclient.text_to_image = AsyncMock(return_value=random_image) + + assert await igm.agenerate("Aenean hend") + + @pytest.mark.asyncio + async def test_generate(self, mock_inference_client: MagicMock) -> None: + igm = InferenceEndpointsImageGeneration( + model_id="black-forest-labs/FLUX.1-schnell", + api_key="api.key", + ) + igm.load() + + arr = np.random.randint(0, 255, (100, 100, 3)) + random_image = Image.fromarray(arr, "RGB") + igm._aclient.text_to_image = AsyncMock(return_value=random_image) + + nest_asyncio.apply() + + images = igm.generate(inputs=["Aenean hendrerit aliquam velit. ..."]) + assert images[0][0]["images"][0].startswith("/9j/4AAQSkZJRgABAQAAAQABAAD/2w") diff --git a/tests/unit/models/image_generation/test_openai.py b/tests/unit/models/image_generation/test_openai.py new file mode 100644 index 0000000000..d057d6c85f --- /dev/null +++ b/tests/unit/models/image_generation/test_openai.py @@ -0,0 +1,105 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import nest_asyncio +import pytest + +from distilabel.models.image_generation.openai import OpenAIImageGeneration + + +@patch("openai.OpenAI") +@patch("openai.AsyncOpenAI") +class TestOpenAIImageGeneration: + model_id: str = "dall-e-3" + + def test_openai_image_generation( + self, _async_openai_mock: MagicMock, _openai_mock: MagicMock + ): + igm = OpenAIImageGeneration( + model="dall-e-3", + api_key="api.key", + generation_kwargs={ + "size": "1024x1024", + "quality": "standard", + "style": "natural", + }, + ) + + assert isinstance(igm, OpenAIImageGeneration) + assert igm.model_name == self.model_id + + @pytest.mark.parametrize("response_format", ["url", "b64_json"]) + @pytest.mark.asyncio + async def test_agenerate( + self, + async_openai_mock: MagicMock, + _openai_mock: MagicMock, + response_format: str, + ) -> None: + igm = OpenAIImageGeneration(model=self.model_id, api_key="api.key") # type: ignore + igm._aclient = async_openai_mock + + with patch("requests.get") as mock_get: + # Mock the download of the image + mock_get.return_value = Mock(content=b"iVBORw0KGgoAAAANSUhEUgA...") + if response_format == "url": + mocked_response = Mock(b64_json=None, url="https://example.com") + else: + mocked_response = Mock(b64_json="iVBORw0KGgoAAAANSUhEUgA...", url=None) + + mocked_generation = Mock(data=[mocked_response]) + igm._aclient.images.generate = AsyncMock(return_value=mocked_generation) + + await igm.agenerate( + input="a white siamese cat", response_format=response_format + ) + + @pytest.mark.parametrize("response_format", ["url", "b64_json"]) + @pytest.mark.asyncio + async def test_generate( + self, + async_openai_mock: MagicMock, + _openai_mock: MagicMock, + response_format: str, + ) -> None: + igm = OpenAIImageGeneration(model=self.model_id, api_key="api.key") # type: ignore + igm._aclient = async_openai_mock + + with patch("requests.get") as mock_get: + # Mock the download of the image + mock_get.return_value = Mock(content=b"iVBORw0KGgoAAAANSUhEUgA...") + + if response_format == "url": + mocked_response = Mock(b64_json=None, url="https://example.com") + else: + mocked_response = Mock(b64_json="iVBORw0KGgoAAAANSUhEUgA...", url=None) + + mocked_generation = Mock(data=[mocked_response]) + igm._aclient.images.generate = AsyncMock(return_value=mocked_generation) + + nest_asyncio.apply() + + igm.generate( + inputs=["a white siamese cat"], response_format=response_format + ) + + with pytest.raises(ValueError): + igm.generate( + inputs=[ + "a white siamese cat", + ], + response_format="unkown_format", + ) diff --git a/tests/unit/models/mixins/test_cuda_device_placement.py b/tests/unit/models/mixins/test_cuda_device_placement.py index bdddabf83e..a20bc2098c 100644 --- a/tests/unit/models/mixins/test_cuda_device_placement.py +++ b/tests/unit/models/mixins/test_cuda_device_placement.py @@ -23,7 +23,7 @@ from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin if TYPE_CHECKING: - from distilabel.steps.tasks.typing import ChatType + from distilabel.typing import ChatType @pytest.fixture diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 3cb680eb06..aa4da987fa 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -47,7 +47,7 @@ ) from distilabel.pipeline.write_buffer import _WriteBuffer from distilabel.steps.base import Step, StepInput, StepResources, _Step -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput from distilabel.utils.requirements import requirements from distilabel.utils.serialization import TYPE_INFO_KEY diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py index 6a6163b75e..1874a14986 100644 --- a/tests/unit/pipeline/test_dag.py +++ b/tests/unit/pipeline/test_dag.py @@ -28,7 +28,7 @@ from .utils import DummyGeneratorStep, DummyGlobalStep, DummyStep1, DummyStep2 if TYPE_CHECKING: - from distilabel.steps.typing import ( + from distilabel.typing import ( GeneratorStepOutput, StepOutput, ) diff --git a/tests/unit/pipeline/utils.py b/tests/unit/pipeline/utils.py index cb223755aa..bc3a618adb 100644 --- a/tests/unit/pipeline/utils.py +++ b/tests/unit/pipeline/utils.py @@ -16,7 +16,7 @@ from distilabel.pipeline.batch import _Batch from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput class DummyGeneratorStep(GeneratorStep): diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py index c0a452e72b..da7b971250 100644 --- a/tests/unit/steps/argilla/test_base.py +++ b/tests/unit/steps/argilla/test_base.py @@ -23,7 +23,7 @@ from distilabel.steps.base import StepInput if TYPE_CHECKING: - from distilabel.steps.typing import StepOutput + from distilabel.typing import StepOutput class CustomArgilla(ArgillaBase): diff --git a/tests/unit/steps/clustering/test_text_clustering.py b/tests/unit/steps/clustering/test_text_clustering.py index ddd473bb76..b5eb7a29dd 100644 --- a/tests/unit/steps/clustering/test_text_clustering.py +++ b/tests/unit/steps/clustering/test_text_clustering.py @@ -21,8 +21,7 @@ from tests.unit.conftest import DummyAsyncLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput class ClusteringLLM(DummyAsyncLLM): diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py index 38580c2c42..a73ed68dbf 100644 --- a/tests/unit/steps/tasks/apigen/test_generator.py +++ b/tests/unit/steps/tasks/apigen/test_generator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import random from typing import TYPE_CHECKING, List, Union @@ -21,10 +22,7 @@ from tests.unit.conftest import DummyLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput - -import json + from distilabel.typing import FormattedInput, GenerateOutput class DummyAPIGenLLM(DummyLLM): diff --git a/tests/unit/steps/tasks/math_shepherd/test_completer.py b/tests/unit/steps/tasks/math_shepherd/test_completer.py index c5e8092cd3..5283bb79c3 100644 --- a/tests/unit/steps/tasks/math_shepherd/test_completer.py +++ b/tests/unit/steps/tasks/math_shepherd/test_completer.py @@ -21,7 +21,7 @@ from tests.unit.conftest import DummyLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput + from distilabel.typing import GenerateOutput class MathShepherdCompleterLLM(DummyLLM): diff --git a/tests/unit/steps/tasks/math_shepherd/test_generator.py b/tests/unit/steps/tasks/math_shepherd/test_generator.py index 14ccc87533..6be30405f3 100644 --- a/tests/unit/steps/tasks/math_shepherd/test_generator.py +++ b/tests/unit/steps/tasks/math_shepherd/test_generator.py @@ -26,7 +26,7 @@ from tests.unit.conftest import DummyLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput + from distilabel.typing import GenerateOutput class MathShepherdGeneratorLLM(DummyLLM): diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index 2812c2e48b..d6a7c11126 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -22,7 +22,7 @@ _is_outlines_version_below_0_1_0, model_to_schema, ) -from distilabel.steps.tasks.typing import OutlinesStructuredOutputType +from distilabel.typing import OutlinesStructuredOutputType class DummyUserTest(BaseModel): diff --git a/tests/unit/steps/tasks/test_argilla_labeller.py b/tests/unit/steps/tasks/test_argilla_labeller.py index 9418e899a5..b883b39197 100644 --- a/tests/unit/steps/tasks/test_argilla_labeller.py +++ b/tests/unit/steps/tasks/test_argilla_labeller.py @@ -19,7 +19,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.argilla_labeller import ArgillaLabeller -from distilabel.steps.tasks.typing import ChatItem +from distilabel.typing import ChatItem from tests.unit.conftest import DummyAsyncLLM diff --git a/tests/unit/steps/tasks/test_image_generation.py b/tests/unit/steps/tasks/test_image_generation.py new file mode 100644 index 0000000000..4c588419e0 --- /dev/null +++ b/tests/unit/steps/tasks/test_image_generation.py @@ -0,0 +1,55 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from distilabel.steps.tasks.image_generation import ImageGeneration +from tests.unit.conftest import DummyAsyncImageGenerationModel + + +class TestImageGeneration: + def test_format_input(self) -> None: + igm = DummyAsyncImageGenerationModel() + task = ImageGeneration(image_generation_model=igm) + task.load() + + assert ( + task.format_input({"prompt": "a white siamese cat"}) + == "a white siamese cat" + ) + + @pytest.mark.parametrize("save_artifacts", [False]) + def test_process(self, save_artifacts: bool) -> None: + igm = DummyAsyncImageGenerationModel() + task = ImageGeneration( + image_generation_model=igm, save_artifacts=save_artifacts + ) + task.load() + img_str = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABkAGQDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDw9whjujGGK7EOS3fv2HfJxz0/ixuDrgqv2jciofJjUKiZG7A7jAxgE55z1+b74jkfzBcMWZfkVRsQYbHZsdM4JzzkjJz94OMg23hIALxIACevKnPBGemed3rz98EU1Z+n/toSVtwupVZ7krEQsipyeMcA/rjPJPqdx+anTiZVuMNhfJi38bdwIBHpnse+cbvmxupJ3mfz2YhGaKMsB8u5cA9Mc9j7/e5+9SzFSt0QikGNCGckEZ5yPc+nPBz82N4UI2S+X/to7p6jZB5guGwqkRIdu7bxgdBgbucHuep55YOdVjS9VlCsYkOHbnJIOVPGQevfg5wcbwXEnNyvmAkxRqSp4bgE5wBnnnvkjPzffBJuj+2fMwV4EHQrnJVgCMjPTP8AFnrz98NO6VvL/wBsJd0guFmVrkSGNXMUZI4XKkAjA/i/hOec/e5+8ImQQpOrFWLImDg55w2ePYd8g57/AHg0fvBc7AmwIDk4U4BGMDPJ9ue57bhPdSNFJOiKcSQxAnGM/KrZ4AzkjPcd8scPRH7Kt2/9tDrYZcghrk4VwVX5mzkEnOQc8/rnJPON1LO/k/aEZXBkjRQTxkcNk465wD3Hfk4YJNcEtdBGwHVVbDY3Ac8468gHqeRnk/NS3BZmuHkVlLQpgMNpOcEHqOo57k5zz96iG135f+2lT313FddqXXlFoovLTcrH72ecc9s8gc9AecbhGw2LchDLGrRoGCtuDngkE8cZBYdfujr96pJyE+1hGbY6ISS2ck84JPqecc9P4sbgXAAM5VQo8tBwSwyQCRnj39emfm+/RFp2v5f+2hJakWprtvTwfmVW5HJyAc/jnPfPq33iUmpGM3f7oKEEaYCjA+6PYf1+rfeJQvhXovyFr1HSqI3mV42jYxhlXHY4Pr0IOQefx+9Trpjvm+980UYJVQA3yg88DrjOeckZ+b71E5K+cjRlWaNMBlwcYznj1GD75zz96iSIJHcAExnyo229mzg45wSOc8Z6DqPmD/lfp/7aLrqx7xLEt4AQFEaMu3ockEDk579t3TPI+cMnLYnADIAiBjlQG/Lrn73Gc4zz96lmMkbXQlRgXRcZXkg8g9ehHPfPB5+8JJpDKL0kBT5UY5KksQQCQRjOeT/ET1O4guFFtJddv/bP6/4cp7tlZyCbk9cjjAyMk5xnPpn16d/vCaYQr9pGN37mMRsq9+Cc4xg4B5+b/gX3ws6uFuAsiriGLftYKGGBx0G7nB4znG75vv0XOGa4fzMbo4yFVcbs4POcfU9ckZ+b79EW218v/bRO0nd7iTOyPdqJAQ8S5IGNwyDg88+vfJGefv0l1E/mXG/ch2I5BGd2Rnr6EHPfPB5HzUt15ckkxMQVvJjKg8Y+UcgYGc/jwSfm+/THLSJcuVVcovYjvkd/T6568/eDgtE/T/20E73aZNKFCXuPLKmKMAoNoHIwByMn1+9nBPzffEM2VWdVLKdqbg7glvUg45BOG4Pp97G4SSOVF2GwzPEgyhO0ZIYjtnp1OQcZ5++GGQf6YTnEiDBOSSSwPPP167v/AGYKC27af+2jva7X9LXoPv40SSUNlSsUW0CIfMSo74GARk5GcnHLffpJPMk+1tIqqxjVum3IyMdTk5BB756nP3gtzJGrXScx7o4wqgdeh7Y4PXvnj733w102R3IYKxMMbDdlWGQGyMgZ689c5zzjeFCXw38v/bRN293+v61ItRwbrIXb8i9gM8Dn8evvnq33iVHdtun6AYUDAxjge3+T6nqSn0XovyC1ieUxgzqkLhWRdu49OhyPr178ev3qU7hHcfvEBEKIVjOAw44wMA8gHvkjPP3gtwrJ9o8xOqpgsuDzyD+I56nOc8/eEcsiuZmlTLmNVUgZweOeMdgeTnPuTuFQtZfL/wBtCUetgl8orOYgEXahCk5Oe+D6Z7c9vvY3VJcqm6cLJjbFHjhRu4A9vrxnnn5vv0+7jiWW4DZV/JjaMYPOQCeuOxzn5v8AgWd9RvJs+1AzmTzEAyu7nJDYPPbHOcgkcZ4YTDo15f8AtoPVXW6/IddkLNO2XHmQocKOCSFODnHuc4OcdW+/TDII1ulVsCWFAR8wzyre2enfP44DB8zf8fO503NEnCdDyDj3x685Izz98I4DLdvGoCKijBI457c8+uOT1PONwIpWSfl/7aLlbGkGGO5T513RrkjO05IbB9u46jjv94OuJHL3DvECZI0BIUgDIBz2zwOpznk8n5qW4WWRrmQblXy037zgsDgg++SN2OT35wWpSSsd4QkiGSFAd7HnJDe2c4yM545wcbwR6S9P/bRsjuVkBkEiEErGRiMLkbflJwO45z368/eoeWKQXDPFtcxIqYXhSMemOoB5Oe+ck7wk5Iln3xuHaNcbhjIIBz75HOefXn71EiCMzq2Y90alVC43A4Izz0xg988dfvBws0reX/tvYTa+4SVFiMyyqDKUTZgcDIBz27d+c9ec7hPO7RC5HQyQxA4yAQQrdMDPQHnOevzffEckZ2XAE0bBUTJTjd7e5B64zkjPI+YNmj8nzkEuRsXJTo2ecH+fGRxkZHzUoxvbXt/7b9w7EF0rLOQxJOAcnvkZz+v/ANc9aKffBVnXZ90xocemVBPYf57t94lGtlfsvyC99SxIUl+2Nt4WNACVUEsMDPBHUZPG4nqc8uC4VnFw8igNsQrmPaSD0P4rz3z15+8FkQbbvzV2usUZH3eTx9M5BzxnPXn74Jnmf7W7ps3xoW+XZkHBX3ORg9843HP3hNO1l8v/AG0aa6fd9/4ELSMEuQCRvRc5G0kZBHGec8Hv68/eDn3wi6KHfHJGoZiWX7xDDr1PHQ56ZGcBqddkrJOWiYEoi5kPOSAdwIwDuxkZzwc8n5qUMXhvSZAT5a5OfvHcCe4z69+mcHG8ONnZry/9tB/3thbgSMblxLuxFGJGBChgccYwNxyAe+SCfm5an3XzLdMgXBiiLEnBPAPoMknnHPr82N4jcu8dyVYQr5KExqMbxwQOcEjv3JIB5wWEc6+Z58iMGUBGYkgnJHOCR6knHJ7/ADY3URitL+X/ALaEbD3XfHcsFgZRFHkj5dpwOnAyeCCOc8nnG8SOyyR3zFSpMaYBI9R05Gc9f4j3wfvhk4ljW4wzorQxeYrHBfIDDsMgnDY5zwfmxuolCzfa5FbywiICqsMMeMjPfkZ7njPPLgglovT/ANtEr8um3/DiHe6Xsmcfu1Dcj5vmHvz0z3PGcHG4LLIifahCWMbxKhGWOTwx6YGMqeDn8cb6hYvtnwDgqFJDcYznHHXJGe/rz1Fi4heL7UqoI08qMlSexwRjpkHqBzkc/NjeHHRr5f8Ato2rt3RFOhLT+ZF5TiNHClgMggcjuc5B4zkc8/eC+ZF5N0Akg3RKoJbcNwIJ5BHXBI6/Qn5wtxIy/aSCCskaKdoKDBwwGO54HXOeTz96mu8aJPsLfPEinDZGeCQencZ79O/3gR2Sfl/7aS09mRXylbgZUqTGhORjOVBz0HXrnvnqepKbeYFwQIzGAB8pIPbqD6HqOvXqepKFsvRfkNK2jJ59xM7AkAxoOm3cMA8gYz0B7+vP3qdOjkzGRgHEEbjK7SwIHY4ycHORnPXn71SXkSiS4LblxDFs+XAOVB54HXk55z1y2d9Muv8AWXB3lB5SDCLgNwCAfyznnJGct96lTa0a8v8A20Vno0EzjfeFVkTeiqfmyG5BOeeQcbh16d/vBJSMTmf7xiQoDEQSTgg+3y5Oec5zz98LKix/ahHuAESLkEbT0yO2c4yOvTPP3wyRpnS5Z5OSqq2xR8+PUjjtnvkgHn7wdPZW8v8A23+mU022xHIk89mIjxEoRUUAEccY47DPcnqc5LCSVN4uS8TRlYUYByM545B4yCCWHXjnnlxG7F47hn2SMQvzkYOfUe/r1zknB+8HXChXmSUMsgiQrkg54HPQcHOcjOffO4OO6Xp/7b+AmreQyVWQzKyr/q1IyoU44wR+H1znPP3qklkj3XSgAb4xxncdwIJII7dfXt1++Gyq7NOcGMCFTjaE3LxtyO+Rhu5OM88tT5MTx3MnlgERxk7mGc9yDxnPXHJwcnOC4ILZvy/9tEno1f7iM7IFuYzuO6JVDZOM5DdiM5x7j68MFaI+XctISHCq43Dlt3156NnjOcZwR8wGuiY7hUVB5kaodvyAKCOw6nheue5OT8wdNNHIbpiisXRNrHsRjJ4xyffPcnJ+cKPMmvl/7aNe7ewsgaL7ZkH95EuSSe7K3qM9M/xevP3wSSlVuwn3ZI0XhSvHDe3pnnOcZ5OGBcwFWuMHGI42fLZyxAJwSBkZ57+vzAb6JYoVjuticCOMpkngnBPp78c8f3vviY2aT9P/AG0N3fuV74g3TEDAIB785Gf89fqepKZdFjMN6hTtXAC44xx+nfv1yc5JVdF6L8gvfUtMUiW8WN1KsiqAhbGCQxHvgj3HGRn7wbMXj+0Isi7SiK21Qu8cEA+vY98kZ5+9T5lIa7KloV8lAVBHzn5ep4yDjcOp4B55emyuyfagNzCWNdxyW5JDHnI44J5yPrgNUxTaXfT/ANtDvpqOnhRGuYyCNsaMmV5JODnORgEEnjdn3++ImfCTKcfMibcrg4xnsP8A9fXn7wmbYsd55bAhok7EdSGx29Pf15xvC3K83J3YYwxsRnGQQDjkDPOD39fm++Kg3dX8v/bQvqRkmNbxUKlWjUMVfjqDjnG7ntz0zzjcCUtH542OokjTrxkY3Z6d8A859efvBd8ckV2zMGby12HHJOefx656/Q/fV1wgie4XlB5EYUEY3AhTnAwOevf1+b79ELJq/l/7aJ6PQSZuLqR0kRnjQDd3zg5PTrjcM5P1+8HTRqgu8jIEUeM+pIPByPc/xZ68/fEMyhDNhtxZFJJ3fxDceo5/H8M/eqbywkF6EkkVfKjJHA8zJBwc44/iwM/dHUDeEla1n2/9tKdnqNuUSJ7hQxBMaFFUcMCAec9u+eeg+998RSW7qs7OHBUIx3HltwznJHOev055HzCQEvHeuspQNGpYZyZDuHBJI4745PAODgsGjYYbx4htXaoO5iOCc/jyBxk/jjcCN1a77f8Ato1u7f1uFwFd7iRF3DC/MT0J6/U9fXv1+9Sygj7Qdu3EaBsEYPT884z36Z5+8GuBG10sqksYwIzs6HIIPBxyuTn5s5/4EJphJGbxRKCjQpkjjIJVgOoz6/xZIzzjeHDpby/9tFJ6u6Kt+E+1EoSVZVbJzkkgE5z7/X6t94lO1IMLw7sZKIeFwMFQfx69ec9ctncSkvhXovyEWLlFSGViNzFIBlh03Rlyfz4/HJyearGdtkxCgb1VMAkAD73rz0HXPr15ooqruz+X/tgb0035fqKHzZzuVXJ8uPgYwME547/KP59eaex+0RzzygGT5FBAxj5Sc8dT8vU9cknJ5oooiv3n3f8AtpSXu/15iXyLBOUQYV4o5MHnBZAxAPpkn9Op5p8qho5myRlY+B05Qvj8wP65PNFFFLVxv5f+2lLr/XRi3LmBrgLyJ4oi2WPG5Q5788jvn16gEJeILe5eNCxWW3jc5Y8FkWQ/UZ9c/nzRRWNFtyin/XwmM3rL1H3Ci3inCE4kjhzkn+JPMP6jofr1ANMv/luinUPBE5OBnJjDfzP49Tk80UVvT+Nei/KA2yO7fbKQFX5oY+gxj5VPb+vXqcnmpLqT7O8saKu2aCInPUZVX4x7+ufU5IBooqdvuX/tpD0Wncr3pzc7j1ZEY/UqD/X6+uTRRRSWy9Eay3Z//9k=" + + assert next(task.process([{"prompt": "a white siamese cat"}])) == [ + { + "prompt": "a white siamese cat", + "image": img_str, + "model_name": "test", + "distilabel_metadata": { + "raw_input_image_generation_0": "a white siamese cat", + "raw_output_image_generation_0": { + "images": [ + "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCABkAGQDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDw9whjujGGK7EOS3fv2HfJxz0/ixuDrgqv2jciofJjUKiZG7A7jAxgE55z1+b74jkfzBcMWZfkVRsQYbHZsdM4JzzkjJz94OMg23hIALxIACevKnPBGemed3rz98EU1Z+n/toSVtwupVZ7krEQsipyeMcA/rjPJPqdx+anTiZVuMNhfJi38bdwIBHpnse+cbvmxupJ3mfz2YhGaKMsB8u5cA9Mc9j7/e5+9SzFSt0QikGNCGckEZ5yPc+nPBz82N4UI2S+X/to7p6jZB5guGwqkRIdu7bxgdBgbucHuep55YOdVjS9VlCsYkOHbnJIOVPGQevfg5wcbwXEnNyvmAkxRqSp4bgE5wBnnnvkjPzffBJuj+2fMwV4EHQrnJVgCMjPTP8AFnrz98NO6VvL/wBsJd0guFmVrkSGNXMUZI4XKkAjA/i/hOec/e5+8ImQQpOrFWLImDg55w2ePYd8g57/AHg0fvBc7AmwIDk4U4BGMDPJ9ue57bhPdSNFJOiKcSQxAnGM/KrZ4AzkjPcd8scPRH7Kt2/9tDrYZcghrk4VwVX5mzkEnOQc8/rnJPON1LO/k/aEZXBkjRQTxkcNk465wD3Hfk4YJNcEtdBGwHVVbDY3Ac8468gHqeRnk/NS3BZmuHkVlLQpgMNpOcEHqOo57k5zz96iG135f+2lT313FddqXXlFoovLTcrH72ecc9s8gc9AecbhGw2LchDLGrRoGCtuDngkE8cZBYdfujr96pJyE+1hGbY6ISS2ck84JPqecc9P4sbgXAAM5VQo8tBwSwyQCRnj39emfm+/RFp2v5f+2hJakWprtvTwfmVW5HJyAc/jnPfPq33iUmpGM3f7oKEEaYCjA+6PYf1+rfeJQvhXovyFr1HSqI3mV42jYxhlXHY4Pr0IOQefx+9Trpjvm+980UYJVQA3yg88DrjOeckZ+b71E5K+cjRlWaNMBlwcYznj1GD75zz96iSIJHcAExnyo229mzg45wSOc8Z6DqPmD/lfp/7aLrqx7xLEt4AQFEaMu3ockEDk579t3TPI+cMnLYnADIAiBjlQG/Lrn73Gc4zz96lmMkbXQlRgXRcZXkg8g9ehHPfPB5+8JJpDKL0kBT5UY5KksQQCQRjOeT/ET1O4guFFtJddv/bP6/4cp7tlZyCbk9cjjAyMk5xnPpn16d/vCaYQr9pGN37mMRsq9+Cc4xg4B5+b/gX3ws6uFuAsiriGLftYKGGBx0G7nB4znG75vv0XOGa4fzMbo4yFVcbs4POcfU9ckZ+b79EW218v/bRO0nd7iTOyPdqJAQ8S5IGNwyDg88+vfJGefv0l1E/mXG/ch2I5BGd2Rnr6EHPfPB5HzUt15ckkxMQVvJjKg8Y+UcgYGc/jwSfm+/THLSJcuVVcovYjvkd/T6568/eDgtE/T/20E73aZNKFCXuPLKmKMAoNoHIwByMn1+9nBPzffEM2VWdVLKdqbg7glvUg45BOG4Pp97G4SSOVF2GwzPEgyhO0ZIYjtnp1OQcZ5++GGQf6YTnEiDBOSSSwPPP167v/AGYKC27af+2jva7X9LXoPv40SSUNlSsUW0CIfMSo74GARk5GcnHLffpJPMk+1tIqqxjVum3IyMdTk5BB756nP3gtzJGrXScx7o4wqgdeh7Y4PXvnj733w102R3IYKxMMbDdlWGQGyMgZ689c5zzjeFCXw38v/bRN293+v61ItRwbrIXb8i9gM8Dn8evvnq33iVHdtun6AYUDAxjge3+T6nqSn0XovyC1ieUxgzqkLhWRdu49OhyPr178ev3qU7hHcfvEBEKIVjOAw44wMA8gHvkjPP3gtwrJ9o8xOqpgsuDzyD+I56nOc8/eEcsiuZmlTLmNVUgZweOeMdgeTnPuTuFQtZfL/wBtCUetgl8orOYgEXahCk5Oe+D6Z7c9vvY3VJcqm6cLJjbFHjhRu4A9vrxnnn5vv0+7jiWW4DZV/JjaMYPOQCeuOxzn5v8AgWd9RvJs+1AzmTzEAyu7nJDYPPbHOcgkcZ4YTDo15f8AtoPVXW6/IddkLNO2XHmQocKOCSFODnHuc4OcdW+/TDII1ulVsCWFAR8wzyre2enfP44DB8zf8fO503NEnCdDyDj3x685Izz98I4DLdvGoCKijBI457c8+uOT1PONwIpWSfl/7aLlbGkGGO5T513RrkjO05IbB9u46jjv94OuJHL3DvECZI0BIUgDIBz2zwOpznk8n5qW4WWRrmQblXy037zgsDgg++SN2OT35wWpSSsd4QkiGSFAd7HnJDe2c4yM545wcbwR6S9P/bRsjuVkBkEiEErGRiMLkbflJwO45z368/eoeWKQXDPFtcxIqYXhSMemOoB5Oe+ck7wk5Iln3xuHaNcbhjIIBz75HOefXn71EiCMzq2Y90alVC43A4Izz0xg988dfvBws0reX/tvYTa+4SVFiMyyqDKUTZgcDIBz27d+c9ec7hPO7RC5HQyQxA4yAQQrdMDPQHnOevzffEckZ2XAE0bBUTJTjd7e5B64zkjPI+YNmj8nzkEuRsXJTo2ecH+fGRxkZHzUoxvbXt/7b9w7EF0rLOQxJOAcnvkZz+v/ANc9aKffBVnXZ90xocemVBPYf57t94lGtlfsvyC99SxIUl+2Nt4WNACVUEsMDPBHUZPG4nqc8uC4VnFw8igNsQrmPaSD0P4rz3z15+8FkQbbvzV2usUZH3eTx9M5BzxnPXn74Jnmf7W7ps3xoW+XZkHBX3ORg9843HP3hNO1l8v/AG0aa6fd9/4ELSMEuQCRvRc5G0kZBHGec8Hv68/eDn3wi6KHfHJGoZiWX7xDDr1PHQ56ZGcBqddkrJOWiYEoi5kPOSAdwIwDuxkZzwc8n5qUMXhvSZAT5a5OfvHcCe4z69+mcHG8ONnZry/9tB/3thbgSMblxLuxFGJGBChgccYwNxyAe+SCfm5an3XzLdMgXBiiLEnBPAPoMknnHPr82N4jcu8dyVYQr5KExqMbxwQOcEjv3JIB5wWEc6+Z58iMGUBGYkgnJHOCR6knHJ7/ADY3URitL+X/ALaEbD3XfHcsFgZRFHkj5dpwOnAyeCCOc8nnG8SOyyR3zFSpMaYBI9R05Gc9f4j3wfvhk4ljW4wzorQxeYrHBfIDDsMgnDY5zwfmxuolCzfa5FbywiICqsMMeMjPfkZ7njPPLgglovT/ANtEr8um3/DiHe6Xsmcfu1Dcj5vmHvz0z3PGcHG4LLIifahCWMbxKhGWOTwx6YGMqeDn8cb6hYvtnwDgqFJDcYznHHXJGe/rz1Fi4heL7UqoI08qMlSexwRjpkHqBzkc/NjeHHRr5f8Ato2rt3RFOhLT+ZF5TiNHClgMggcjuc5B4zkc8/eC+ZF5N0Akg3RKoJbcNwIJ5BHXBI6/Qn5wtxIy/aSCCskaKdoKDBwwGO54HXOeTz96mu8aJPsLfPEinDZGeCQencZ79O/3gR2Sfl/7aS09mRXylbgZUqTGhORjOVBz0HXrnvnqepKbeYFwQIzGAB8pIPbqD6HqOvXqepKFsvRfkNK2jJ59xM7AkAxoOm3cMA8gYz0B7+vP3qdOjkzGRgHEEbjK7SwIHY4ycHORnPXn71SXkSiS4LblxDFs+XAOVB54HXk55z1y2d9Muv8AWXB3lB5SDCLgNwCAfyznnJGct96lTa0a8v8A20Vno0EzjfeFVkTeiqfmyG5BOeeQcbh16d/vBJSMTmf7xiQoDEQSTgg+3y5Oec5zz98LKix/ahHuAESLkEbT0yO2c4yOvTPP3wyRpnS5Z5OSqq2xR8+PUjjtnvkgHn7wdPZW8v8A23+mU022xHIk89mIjxEoRUUAEccY47DPcnqc5LCSVN4uS8TRlYUYByM545B4yCCWHXjnnlxG7F47hn2SMQvzkYOfUe/r1zknB+8HXChXmSUMsgiQrkg54HPQcHOcjOffO4OO6Xp/7b+AmreQyVWQzKyr/q1IyoU44wR+H1znPP3qklkj3XSgAb4xxncdwIJII7dfXt1++Gyq7NOcGMCFTjaE3LxtyO+Rhu5OM88tT5MTx3MnlgERxk7mGc9yDxnPXHJwcnOC4ILZvy/9tEno1f7iM7IFuYzuO6JVDZOM5DdiM5x7j68MFaI+XctISHCq43Dlt3156NnjOcZwR8wGuiY7hUVB5kaodvyAKCOw6nheue5OT8wdNNHIbpiisXRNrHsRjJ4xyffPcnJ+cKPMmvl/7aNe7ewsgaL7ZkH95EuSSe7K3qM9M/xevP3wSSlVuwn3ZI0XhSvHDe3pnnOcZ5OGBcwFWuMHGI42fLZyxAJwSBkZ57+vzAb6JYoVjuticCOMpkngnBPp78c8f3vviY2aT9P/AG0N3fuV74g3TEDAIB785Gf89fqepKZdFjMN6hTtXAC44xx+nfv1yc5JVdF6L8gvfUtMUiW8WN1KsiqAhbGCQxHvgj3HGRn7wbMXj+0Isi7SiK21Qu8cEA+vY98kZ5+9T5lIa7KloV8lAVBHzn5ep4yDjcOp4B55emyuyfagNzCWNdxyW5JDHnI44J5yPrgNUxTaXfT/ANtDvpqOnhRGuYyCNsaMmV5JODnORgEEnjdn3++ImfCTKcfMibcrg4xnsP8A9fXn7wmbYsd55bAhok7EdSGx29Pf15xvC3K83J3YYwxsRnGQQDjkDPOD39fm++Kg3dX8v/bQvqRkmNbxUKlWjUMVfjqDjnG7ntz0zzjcCUtH542OokjTrxkY3Z6d8A859efvBd8ckV2zMGby12HHJOefx656/Q/fV1wgie4XlB5EYUEY3AhTnAwOevf1+b79ELJq/l/7aJ6PQSZuLqR0kRnjQDd3zg5PTrjcM5P1+8HTRqgu8jIEUeM+pIPByPc/xZ68/fEMyhDNhtxZFJJ3fxDceo5/H8M/eqbywkF6EkkVfKjJHA8zJBwc44/iwM/dHUDeEla1n2/9tKdnqNuUSJ7hQxBMaFFUcMCAec9u+eeg+998RSW7qs7OHBUIx3HltwznJHOev055HzCQEvHeuspQNGpYZyZDuHBJI4745PAODgsGjYYbx4htXaoO5iOCc/jyBxk/jjcCN1a77f8Ato1u7f1uFwFd7iRF3DC/MT0J6/U9fXv1+9Sygj7Qdu3EaBsEYPT884z36Z5+8GuBG10sqksYwIzs6HIIPBxyuTn5s5/4EJphJGbxRKCjQpkjjIJVgOoz6/xZIzzjeHDpby/9tFJ6u6Kt+E+1EoSVZVbJzkkgE5z7/X6t94lO1IMLw7sZKIeFwMFQfx69ec9ctncSkvhXovyEWLlFSGViNzFIBlh03Rlyfz4/HJyearGdtkxCgb1VMAkAD73rz0HXPr15ooqruz+X/tgb0035fqKHzZzuVXJ8uPgYwME547/KP59eaex+0RzzygGT5FBAxj5Sc8dT8vU9cknJ5oooiv3n3f8AtpSXu/15iXyLBOUQYV4o5MHnBZAxAPpkn9Op5p8qho5myRlY+B05Qvj8wP65PNFFFLVxv5f+2lLr/XRi3LmBrgLyJ4oi2WPG5Q5788jvn16gEJeILe5eNCxWW3jc5Y8FkWQ/UZ9c/nzRRWNFtyin/XwmM3rL1H3Ci3inCE4kjhzkn+JPMP6jofr1ANMv/luinUPBE5OBnJjDfzP49Tk80UVvT+Nei/KA2yO7fbKQFX5oY+gxj5VPb+vXqcnmpLqT7O8saKu2aCInPUZVX4x7+ufU5IBooqdvuX/tpD0Wncr3pzc7j1ZEY/UqD/X6+uTRRRSWy9Eay3Z//9k=" + ] + }, + }, + } + ] diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py index 1bc4128c7c..25e60cab0b 100644 --- a/tests/unit/steps/tasks/test_improving_text_embeddings.py +++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py @@ -18,7 +18,6 @@ import pytest from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.improving_text_embeddings import ( BitextRetrievalGenerator, @@ -29,7 +28,7 @@ GenerateTextRetrievalData, MonolingualTripletGenerator, ) -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType, GenerateOutput class MockLLM(LLM): diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py index 5e54d94658..4139eef525 100644 --- a/tests/unit/steps/tasks/test_instruction_backtranslation.py +++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py @@ -15,12 +15,11 @@ from typing import Any, List from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.instruction_backtranslation import ( InstructionBacktranslation, ) -from distilabel.steps.tasks.typing import ChatType +from distilabel.typing import ChatType, GenerateOutput class InstructionBacktranslationLLM(LLM): diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py index 125b26ed37..689f18e007 100644 --- a/tests/unit/steps/tasks/test_structured_generation.py +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -18,10 +18,9 @@ from typing_extensions import override from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.structured_generation import StructuredGeneration -from distilabel.steps.tasks.typing import StructuredInput +from distilabel.typing import GenerateOutput, StructuredInput class DummyStructuredLLM(LLM): diff --git a/tests/unit/steps/tasks/test_text_classification.py b/tests/unit/steps/tasks/test_text_classification.py index c1bcf47e24..a3d3b0518b 100644 --- a/tests/unit/steps/tasks/test_text_classification.py +++ b/tests/unit/steps/tasks/test_text_classification.py @@ -21,8 +21,7 @@ from tests.unit.conftest import DummyAsyncLLM if TYPE_CHECKING: - from distilabel.models.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import FormattedInput + from distilabel.typing import FormattedInput, GenerateOutput class TextClassificationLLM(DummyAsyncLLM): diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 3754c8803d..d94a4d8721 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -17,9 +17,8 @@ import pytest from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput -from distilabel.steps.tasks.typing import ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback +from distilabel.typing import ChatType, GenerateOutput class UltraFeedbackLLM(LLM): diff --git a/tests/unit/steps/test_base.py b/tests/unit/steps/test_base.py index 6e8297bb06..e3e479e5c1 100644 --- a/tests/unit/steps/test_base.py +++ b/tests/unit/steps/test_base.py @@ -24,7 +24,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput from distilabel.steps.decorator import step -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput from distilabel.utils.serialization import TYPE_INFO_KEY diff --git a/tests/unit/steps/test_decorator.py b/tests/unit/steps/test_decorator.py index 0071507cf0..d4b2aeb839 100644 --- a/tests/unit/steps/test_decorator.py +++ b/tests/unit/steps/test_decorator.py @@ -25,7 +25,7 @@ StepInput, ) from distilabel.steps.decorator import step -from distilabel.steps.typing import GeneratorStepOutput, StepOutput +from distilabel.typing import GeneratorStepOutput, StepOutput class TestStepDecorator: diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py index 1649a2ff18..1eb47a5e96 100644 --- a/tests/unit/test_distiset.py +++ b/tests/unit/test_distiset.py @@ -236,3 +236,59 @@ def test_dataset_card(self, distiset: Distiset) -> None: "size_categories": "n<1K", "tags": ["synthetic", "distilabel", "rlaif"], } + + def test_transform_columns_to_image(self): + import numpy as np + from PIL import Image + + arr = np.random.randint(0, 255, (100, 100, 3)) + image = Image.fromarray(arr, "RGB") + from distilabel.models.image_generation.utils import image_to_str + + img_str = image_to_str(image) + + distiset_with_images = Distiset( + { + "leaf_step_1": Dataset.from_dict({"image": [img_str] * 3}), + "leaf_step_2": Dataset.from_dict( + {"image": [img_str] * 4, "column": [5, 6, 7, 8]} + ), + } + ) + distiset_with_images.transform_columns_to_image("image") + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_1"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_2"]["image"] + ) + + distiset_with_images = Distiset( + { + "leaf_step_1": Dataset.from_dict({"image": [img_str] * 3}), + "leaf_step_2": Dataset.from_dict( + {"image": [img_str] * 4, "column": [5, 6, 7, 8]} + ), + } + ) + distiset_with_images = distiset_with_images.train_test_split(0.8) + print(distiset_with_images) + distiset_with_images.transform_columns_to_image("image") + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_1"]["train"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_1"]["test"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_2"]["train"]["image"] + ) + assert all( + isinstance(img, Image.Image) + for img in distiset_with_images["leaf_step_2"]["test"]["image"] + ) diff --git a/tests/unit/utils/test_requirements.py b/tests/unit/utils/test_requirements.py index 32ed762560..04125242ee 100644 --- a/tests/unit/utils/test_requirements.py +++ b/tests/unit/utils/test_requirements.py @@ -19,7 +19,7 @@ from distilabel.pipeline import Pipeline from distilabel.steps import Step from distilabel.steps.base import StepInput -from distilabel.steps.typing import StepOutput +from distilabel.typing import StepOutput from distilabel.utils.requirements import requirements from ..pipeline.utils import DummyGeneratorStep From 74cc09e256f4bf92b9b7564e8607bbcac299de11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 17 Jan 2025 08:38:21 +0100 Subject: [PATCH 24/30] Update `LLM`s to support prompt logprobs use-case (#1099) --- .../base_clients/inference_endpoints.py | 5 +- .../llms/huggingface/inference_endpoints.py | 63 ++++---- src/distilabel/models/llms/openai.py | 143 ++++++++++++++++-- src/distilabel/models/llms/vllm.py | 79 +++++++--- src/distilabel/typing/models.py | 2 +- tests/unit/models/llms/test_openai.py | 61 ++++++++ tests/unit/models/llms/test_vllm.py | 5 +- 7 files changed, 288 insertions(+), 70 deletions(-) diff --git a/src/distilabel/models/base_clients/inference_endpoints.py b/src/distilabel/models/base_clients/inference_endpoints.py index ebcc84e344..7b7ed39337 100644 --- a/src/distilabel/models/base_clients/inference_endpoints.py +++ b/src/distilabel/models/base_clients/inference_endpoints.py @@ -16,7 +16,6 @@ from typing import ( TYPE_CHECKING, Optional, - Union, ) from pydantic import ( @@ -143,9 +142,9 @@ def load(self) -> None: # noqa: C901 self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id) @property - def model_name(self) -> Union[str, None]: # type: ignore + def model_name(self) -> str: """Returns the model name used for the model.""" - return ( + return ( # type: ignore self.model_display_name or self._model_name or self.model_id diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index 8956529999..8b31cbd471 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -273,7 +273,7 @@ def _get_structured_output( async def _generate_with_text_generation( self, - input: FormattedInput, + input: str, max_new_tokens: int = 128, repetition_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, @@ -287,13 +287,12 @@ async def _generate_with_text_generation( return_full_text: bool = False, seed: Optional[int] = None, watermark: bool = False, + structured_output: Union[Dict[str, Any], None] = None, ) -> GenerateOutput: - input, structured_output = self._get_structured_output(input) - prompt = self.prepare_input(input) generation: Union["TextGenerationOutput", None] = None try: generation = await self._aclient.text_generation( # type: ignore - prompt=prompt, + prompt=input, max_new_tokens=max_new_tokens, do_sample=do_sample, typical_p=typical_p, @@ -319,7 +318,9 @@ async def _generate_with_text_generation( ) return prepare_output( generations=[generation.generated_text] if generation else [None], - input_tokens=[compute_tokens(prompt, self._tokenizer.encode)], # type: ignore + input_tokens=[ + compute_tokens(input, self._tokenizer.encode) if self._tokenizer else -1 + ], output_tokens=[ generation.details.generated_tokens if generation and generation.details @@ -544,37 +545,43 @@ async def agenerate( # type: ignore """ stop_sequences = self._check_stop_sequences(stop_sequences) - if self.tokenizer_id is None: - return await self._generate_with_chat_completion( - input=input, # type: ignore + if isinstance(input, str) or self.tokenizer_id is not None: + structured_output = None + if not isinstance(input, str): + input, structured_output = self._get_structured_output(input) + input = self.prepare_input(input) + + return await self._generate_with_text_generation( + input=input, max_new_tokens=max_new_tokens, + do_sample=do_sample, + typical_p=typical_p, + repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - presence_penalty=presence_penalty, - seed=seed, - stop_sequences=stop_sequences, temperature=temperature, - tool_choice=tool_choice, - tool_prompt=tool_prompt, - tools=tools, - top_logprobs=top_logprobs, + top_n_tokens=top_n_tokens, top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, + return_full_text=return_full_text, + seed=seed, + watermark=watermark, + structured_output=structured_output, ) - return await self._generate_with_text_generation( - input=input, + return await self._generate_with_chat_completion( + input=input, # type: ignore max_new_tokens=max_new_tokens, - do_sample=do_sample, - typical_p=typical_p, - repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + presence_penalty=presence_penalty, + seed=seed, + stop_sequences=stop_sequences, temperature=temperature, - top_n_tokens=top_n_tokens, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, + top_logprobs=top_logprobs, top_p=top_p, - top_k=top_k, - stop_sequences=stop_sequences, - return_full_text=return_full_text, - seed=seed, - watermark=watermark, ) diff --git a/src/distilabel/models/llms/openai.py b/src/distilabel/models/llms/openai.py index 66dbfcff17..3cdd17fc96 100644 --- a/src/distilabel/models/llms/openai.py +++ b/src/distilabel/models/llms/openai.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union import orjson -from pydantic import PositiveInt, validate_call +from pydantic import NonNegativeInt, PositiveInt, validate_call from distilabel import envs from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException @@ -29,10 +29,18 @@ from openai.types import Batch as OpenAIBatch from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion - from openai.types.chat.chat_completion import Choice as OpenAIChoice + from openai.types.chat.chat_completion import Choice as OpenAIChatCompletionChoice from openai.types.completion import Completion as OpenAICompletion + from openai.types.completion_choice import ( + CompletionChoice as OpenAICompletionChoice, + ) - from distilabel.typing import LLMStatistics, Logprob + from distilabel.typing.models import ( + LLMStatistics, + Logprob, + StandardInput, + StructuredInput, + ) _OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB @@ -148,15 +156,17 @@ async def agenerate( # type: ignore self, input: FormattedInput, num_generations: int = 1, - max_new_tokens: int = 128, + max_new_tokens: NonNegativeInt = 128, logprobs: bool = False, top_logprobs: Optional[PositiveInt] = None, + echo: bool = False, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, stop: Optional[Union[str, List[str]]] = None, response_format: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> GenerateOutput: """Generates `num_generations` responses for the given input using the OpenAI async client. @@ -170,6 +180,8 @@ async def agenerate( # type: ignore logprobs: whether to return the log probabilities or not. Defaults to `False`. top_logprobs: the number of top log probabilities to return per output token generated. Defaults to `None`. + echo: whether to echo the input in the response or not. It's only used if the + `input` argument is an `str`. Defaults to `False`. frequency_penalty: the repetition penalty to use for the generation. Defaults to `0.0`. presence_penalty: the presence penalty to use for the generation. Defaults to @@ -182,14 +194,115 @@ async def agenerate( # type: ignore "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode) for more information on how to use the JSON model from OpenAI. Defaults to None which returns text. To return JSON, use {"type": "json_object"}. - - Note: - If response_format + extra_body: an optional dictionary containing extra body parameters that will + be sent to the OpenAI API endpoint. Defaults to `None`. Returns: A list of lists of strings containing the generated responses for each input. """ + if isinstance(input, str): + return await self._generate_completion( + input=input, + num_generations=num_generations, + max_new_tokens=max_new_tokens, + echo=echo, + top_logprobs=top_logprobs, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + extra_body=extra_body, + ) + + return await self._generate_chat_completion( + input=input, + num_generations=num_generations, + max_new_tokens=max_new_tokens, + logprobs=logprobs, + top_logprobs=top_logprobs, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + stop=stop, + response_format=response_format, + extra_body=extra_body, + ) + + async def _generate_completion( + self, + input: str, + num_generations: int = 1, + max_new_tokens: int = 128, + echo: bool = False, + top_logprobs: Optional[PositiveInt] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + extra_body: Optional[Dict[str, Any]] = None, + ) -> GenerateOutput: + completion = await self._aclient.completions.create( + prompt=input, + echo=echo, + model=self.model, + n=num_generations, + max_tokens=max_new_tokens, + logprobs=top_logprobs, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + extra_body=extra_body, + ) + + generations = [] + logprobs = [] + for choice in completion.choices: + generations.append(choice.text) + if choice_logprobs := self._get_logprobs_from_completion_choice(choice): + logprobs.append(choice_logprobs) + + statistics = self._get_llm_statistics(completion) + return prepare_output( + generations=generations, + input_tokens=statistics["input_tokens"], + output_tokens=statistics["output_tokens"], + logprobs=logprobs, + ) + + def _get_logprobs_from_completion_choice( + self, choice: "OpenAICompletionChoice" + ) -> Union[List[Union[List["Logprob"], None]], None]: + if choice.logprobs is None or choice.logprobs.top_logprobs is None: + return None + + return [ + [ + {"token": token, "logprob": token_logprob} + for token, token_logprob in logprobs.items() + ] + if logprobs is not None + else None + for logprobs in choice.logprobs.top_logprobs + ] + + async def _generate_chat_completion( + self, + input: Union["StandardInput", "StructuredInput"], + num_generations: int = 1, + max_new_tokens: int = 128, + logprobs: bool = False, + top_logprobs: Optional[PositiveInt] = None, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + stop: Optional[Union[str, List[str]]] = None, + response_format: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, Any]] = None, + ) -> GenerateOutput: structured_output = None if isinstance(input, tuple): input, structured_output = input @@ -215,9 +328,11 @@ async def agenerate( # type: ignore "temperature": temperature, "top_p": top_p, "stop": stop, + "extra_body": extra_body, } - # Check if it's a vision generation task, in that case "stop" cannot be used or raises - # an error in the API. + + # Checks if any message contains an image, in that case "stop" cannot be used or + # raises an error in the API. if isinstance( [row for row in input if row["role"] == "user"][0]["content"], list ): @@ -235,7 +350,7 @@ async def agenerate( # type: ignore # NOTE: `instructor` doesn't work with `n` parameter, so it will always return # only 1 choice. statistics = self._get_llm_statistics(completion._raw_response) - if choice_logprobs := self._get_logprobs_from_choice( + if choice_logprobs := self._get_logprobs_from_chat_completion_choice( completion._raw_response.choices[0] ): output_logprobs = [choice_logprobs] @@ -270,7 +385,9 @@ def _generations_from_openai_completion( f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - if choice_logprobs := self._get_logprobs_from_choice(choice): + if choice_logprobs := self._get_logprobs_from_chat_completion_choice( + choice + ): logprobs.append(choice_logprobs) statistics = self._get_llm_statistics(completion) @@ -281,8 +398,8 @@ def _generations_from_openai_completion( logprobs=logprobs, ) - def _get_logprobs_from_choice( - self, choice: "OpenAIChoice" + def _get_logprobs_from_chat_completion_choice( + self, choice: "OpenAIChatCompletionChoice" ) -> Union[List[List["Logprob"]], None]: if choice.logprobs is None or choice.logprobs.content is None: return None diff --git a/src/distilabel/models/llms/vllm.py b/src/distilabel/models/llms/vllm.py index 6075c4f54e..4082f978e8 100644 --- a/src/distilabel/models/llms/vllm.py +++ b/src/distilabel/models/llms/vllm.py @@ -47,7 +47,8 @@ from openai import OpenAI # noqa from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM - from vllm.outputs import RequestOutput, CompletionOutput + from vllm.outputs import RequestOutput + from vllm.sequence import SampleLogprobs, PromptLogprobs from distilabel.typing import ( StandardInput, @@ -256,7 +257,7 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model - def prepare_input(self, input: "StandardInput") -> str: + def prepare_input(self, input: Union["StandardInput", str]) -> str: """Prepares the input (applying the chat template and tokenization) for the provided input. @@ -266,8 +267,8 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ - if self._tokenizer.chat_template is None: - return [item["content"] for item in input if item["role"] == "user"][0] + if isinstance(input, str): + return input prompt: str = ( self._tokenizer.apply_chat_template( @@ -342,8 +343,10 @@ def generate( # noqa: C901 # type: ignore stop: Optional[List[str]] = None, stop_token_ids: Optional[List[int]] = None, include_stop_str_in_output: bool = False, + skip_special_tokens: bool = True, logits_processors: Optional[LogitsProcessors] = None, extra_sampling_params: Optional[Dict[str, Any]] = None, + echo: bool = False, ) -> List[GenerateOutput]: """Generates `num_generations` responses for each input. @@ -371,10 +374,14 @@ def generate( # noqa: C901 # type: ignore when found. Defaults to `None`. include_stop_str_in_output: whether to include the stop string in the output. Defaults to `False`. + skip_special_tokens: whether to exclude special tokens from the output. Defaults + to `False`. logits_processors: a list of functions to process the logits before sampling. Defaults to `None`. extra_sampling_params: dictionary with additional arguments to be passed to the `SamplingParams` class from `vllm`. + echo: whether to echo the include the prompt in the response or not. Defaults + to `False`. Returns: A list of lists of strings containing the generated responses for each input. @@ -406,8 +413,11 @@ def generate( # noqa: C901 # type: ignore for prepared_inputs, structured_output in prepared_batches: if self.structured_output is not None and structured_output is not None: - # TODO: warning - pass + self._logger.warning( + "An `structured_output` was provided in the model configuration, but" + " one was also provided in the input. The input structured output will" + " be used." + ) if structured_output is not None: logits_processors.append( @@ -424,10 +434,12 @@ def generate( # noqa: C901 # type: ignore top_k=top_k, min_p=min_p, max_tokens=max_new_tokens, + prompt_logprobs=logprobs if echo else None, logprobs=logprobs, stop=stop, stop_token_ids=stop_token_ids, include_stop_str_in_output=include_stop_str_in_output, + skip_special_tokens=skip_special_tokens, logits_processors=logits_processors, **extra_sampling_params, ) @@ -444,19 +456,27 @@ def generate( # noqa: C901 # type: ignore logits_processors.pop(-1) for input, outputs in zip(prepared_inputs, batch_outputs): + processed_prompt_logprobs = [] + if outputs.prompt_logprobs is not None: + processed_prompt_logprobs = self._get_llm_logprobs( + outputs.prompt_logprobs + ) texts, statistics, outputs_logprobs = self._process_outputs( - input, outputs + input=input, + outputs=outputs, + echo=echo, + prompt_logprobs=processed_prompt_logprobs, ) batched_outputs.append(texts) - generations.append( - prepare_output( - generations=texts, - input_tokens=statistics["input_tokens"], - output_tokens=statistics["output_tokens"], - logprobs=outputs_logprobs, - ) + generation = prepare_output( + generations=texts, + input_tokens=statistics["input_tokens"], + output_tokens=statistics["output_tokens"], + logprobs=outputs_logprobs, ) + generations.append(generation) + if sorted_indices is not None: pairs = list(enumerate(sorted_indices)) pairs.sort(key=lambda x: x[1]) @@ -465,7 +485,11 @@ def generate( # noqa: C901 # type: ignore return generations def _process_outputs( - self, input: str, outputs: "RequestOutput" + self, + input: str, + outputs: "RequestOutput", + prompt_logprobs: List[List["Logprob"]], + echo: bool = False, ) -> Tuple["LLMOutput", "LLMStatistics", "LLMLogprobs"]: texts = [] outputs_logprobs = [] @@ -475,13 +499,17 @@ def _process_outputs( "output_tokens": [], } for output in outputs.outputs: - texts.append(output.text) + text = output.text + if echo: + text = input + text + texts.append(text) statistics["output_tokens"].append(len(output.token_ids)) if output.logprobs is not None: - outputs_logprobs.append(self._get_llm_logprobs(output)) + processed_output_logprobs = self._get_llm_logprobs(output.logprobs) + outputs_logprobs.append(prompt_logprobs + processed_output_logprobs) return texts, statistics, outputs_logprobs - def _prepare_structured_output( + def _prepare_structured_output( # type: ignore self, structured_output: "OutlinesStructuredOutputType" ) -> Union[Callable, None]: """Creates the appropriate function to filter tokens to generate structured outputs. @@ -503,16 +531,21 @@ def _prepare_structured_output( self.structured_output["schema"] = schema return result["processor"] - def _get_llm_logprobs(self, output: "CompletionOutput") -> List[List["Logprob"]]: - logprobs = [] - for token_logprob in output.logprobs: # type: ignore + def _get_llm_logprobs( + self, logprobs: Union["PromptLogprobs", "SampleLogprobs"] + ) -> List[List["Logprob"]]: + processed_logprobs = [] + for token_logprob in logprobs: # type: ignore token_logprobs = [] + if token_logprob is None: + processed_logprobs.append(None) + continue for logprob in token_logprob.values(): token_logprobs.append( {"token": logprob.decoded_token, "logprob": logprob.logprob} ) - logprobs.append(token_logprobs) - return logprobs + processed_logprobs.append(token_logprobs) + return processed_logprobs class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin): diff --git a/src/distilabel/typing/models.py b/src/distilabel/typing/models.py index aa11305421..0b6d1715d6 100644 --- a/src/distilabel/typing/models.py +++ b/src/distilabel/typing/models.py @@ -106,7 +106,7 @@ class InstructorStructuredOutputType(TypedDict, total=False): """StandardInput is an alias for ChatType that defines the default / standard input produced by `format_input`.""" StructuredInput = Tuple[StandardInput, Union[StructuredOutputType, None]] """StructuredInput defines a type produced by `format_input` when using either `StructuredGeneration` or a subclass of it.""" -FormattedInput = Union[StandardInput, StructuredInput] +FormattedInput = Union[StandardInput, StructuredInput, str] """FormattedInput is an alias for the union of `StandardInput` and `StructuredInput` as generated by `format_input` and expected by the `LLM`s, as well as `ConversationType` for the vision language models.""" diff --git a/tests/unit/models/llms/test_openai.py b/tests/unit/models/llms/test_openai.py index c7cf7d4c45..ac27160a42 100644 --- a/tests/unit/models/llms/test_openai.py +++ b/tests/unit/models/llms/test_openai.py @@ -97,6 +97,33 @@ async def test_agenerate( ], } + @pytest.mark.asyncio + async def test_agenerate_with_string_input( + self, async_openai_mock: MagicMock, _openai_mock: MagicMock + ) -> None: + llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore + llm._aclient = async_openai_mock + + mocked_completion = Mock( + choices=[ + Mock( + text=" Aenean hendrerit aliquam velit. ...", + logprobs=Mock(top_logprobs=[{" ": -1}, {"Aenean": -2}]), + ) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), + ) + llm._aclient.completions.create = AsyncMock(return_value=mocked_completion) + + result = await llm.agenerate(input="string input") + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + "logprobs": [ + [[{"token": " ", "logprob": -1}], [{"token": "Aenean", "logprob": -2}]] + ], + } + @pytest.mark.asyncio async def test_agenerate_structured( self, async_openai_mock: MagicMock, _openai_mock: MagicMock @@ -217,6 +244,40 @@ async def test_generate( ) assert result == expected_result + @pytest.mark.asyncio + async def test_generate_with_string_input( + self, async_openai_mock: MagicMock, _openai_mock: MagicMock + ) -> None: + llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore + llm._aclient = async_openai_mock + + mocked_completion = Mock( + choices=[ + Mock( + text=" Aenean hendrerit aliquam velit. ...", + logprobs=Mock(top_logprobs=[{" ": -1}, {"Aenean": -2}]), + ) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), + ) + llm._aclient.completions.create = AsyncMock(return_value=mocked_completion) + + nest_asyncio.apply() + + result = llm.generate(inputs=["input string"]) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + "logprobs": [ + [ + [{"token": " ", "logprob": -1}], + [{"token": "Aenean", "logprob": -2}], + ] + ], + } + ] + @pytest.mark.asyncio async def test_generate_raises_value_error_if_unknown_response_format( self, async_openai_mock: MagicMock, _: MagicMock diff --git a/tests/unit/models/llms/test_vllm.py b/tests/unit/models/llms/test_vllm.py index 2230186bf3..f21c0a9691 100644 --- a/tests/unit/models/llms/test_vllm.py +++ b/tests/unit/models/llms/test_vllm.py @@ -106,8 +106,8 @@ class TestvLLM: @pytest.mark.parametrize( "multi_structured_output", # TODO: uncomment once with update our code to work with `outlines>0.1.0` - # (True, False), - (False,), + (True, False), + # (False,), ) @pytest.mark.parametrize( "num_generations, expected_result", @@ -183,6 +183,7 @@ def test_generate( mocked_requests_output = [ mock.Mock( # RequestOutput + prompt_logprobs=[], outputs=[ mock.Mock( # CompletionOutput text="I'm fine thank you", From 2c85dcc5672abf4dcb5c64c2d594bde23a333d8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 17 Jan 2025 09:08:39 +0100 Subject: [PATCH 25/30] Bump version to `1.6.0` --- src/distilabel/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py index 47628af331..41bc97e783 100644 --- a/src/distilabel/__init__.py +++ b/src/distilabel/__init__.py @@ -14,6 +14,6 @@ from rich import traceback as rich_traceback -__version__ = "1.5.0" +__version__ = "1.6.0" rich_traceback.install(show_locals=True) From d04f069ad05117aba2ee3c780e1f20e9ea861e0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 17 Jan 2025 11:05:33 +0100 Subject: [PATCH 26/30] Remove deprecated `CombineColumns` step (#1101) --- src/distilabel/steps/__init__.py | 3 +-- src/distilabel/steps/columns/group.py | 15 +-------------- tests/unit/steps/columns/test_group.py | 17 +---------------- tests/unit/test_imports.py | 1 - 4 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 19d90f9a33..661704c7d8 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -26,7 +26,7 @@ from distilabel.steps.clustering.umap import UMAP from distilabel.steps.columns.combine import CombineOutputs from distilabel.steps.columns.expand import ExpandColumns -from distilabel.steps.columns.group import CombineColumns, GroupColumns +from distilabel.steps.columns.group import GroupColumns from distilabel.steps.columns.keep import KeepColumns from distilabel.steps.columns.merge import MergeColumns from distilabel.steps.decorator import step @@ -60,7 +60,6 @@ __all__ = [ "DBSCAN", "UMAP", - "CombineColumns", "CombineOutputs", "ConversationTemplate", "DataSampler", diff --git a/src/distilabel/steps/columns/group.py b/src/distilabel/steps/columns/group.py index ed9ee7a2df..aaea6a3fee 100644 --- a/src/distilabel/steps/columns/group.py +++ b/src/distilabel/steps/columns/group.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, List, Optional from typing_extensions import override @@ -125,15 +124,3 @@ def process(self, *inputs: StepInput) -> "StepOutput": group_columns=self.inputs, output_group_columns=self.outputs, ) - - -class CombineColumns(GroupColumns): - """`CombineColumns` is deprecated and will be removed in version 1.5.0, use `GroupColumns` instead.""" - - def __init__(self, **data: Any) -> None: - warnings.warn( - "`CombineColumns` is deprecated and will be removed in version 1.5.0, use `GroupColumns` instead.", - DeprecationWarning, - stacklevel=2, - ) - return super().__init__(**data) diff --git a/tests/unit/steps/columns/test_group.py b/tests/unit/steps/columns/test_group.py index 57f9f114de..c929cda927 100644 --- a/tests/unit/steps/columns/test_group.py +++ b/tests/unit/steps/columns/test_group.py @@ -13,11 +13,9 @@ # limitations under the License. -import pytest - from distilabel.constants import DISTILABEL_METADATA_KEY from distilabel.pipeline.local import Pipeline -from distilabel.steps.columns.group import CombineColumns, GroupColumns +from distilabel.steps.columns.group import GroupColumns class TestGroupColumns: @@ -58,16 +56,3 @@ def test_process(self) -> None: DISTILABEL_METADATA_KEY: {"model": ["model-1", "model-2"]}, } ] - - -def test_CombineColumns_deprecation_warning(): - with pytest.deprecated_call(): - CombineColumns( - name="combine_columns", - columns=["generation", "model_name"], - ) - from packaging.version import Version - - import distilabel - - assert Version(distilabel.__version__) <= Version("1.5.0") diff --git a/tests/unit/test_imports.py b/tests/unit/test_imports.py index a836cceb15..1309232a27 100644 --- a/tests/unit/test_imports.py +++ b/tests/unit/test_imports.py @@ -40,7 +40,6 @@ def test_imports() -> None: from distilabel.steps import ( StepResources, - CombineColumns, GroupColumns, MergeColumns, ConversationTemplate, From e6c9d9e2cd9f2a853fbb2d8759a3fb86f4931656 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 17 Jan 2025 13:13:48 +0100 Subject: [PATCH 27/30] Fix `Image` import handling and update `MlxLLM` initialisation (#1102) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez --- pyproject.toml | 5 ++--- .../huggingface/inference_endpoints.py | 7 +++++-- src/distilabel/models/image_generation/utils.py | 4 ++-- src/distilabel/models/llms/mlx.py | 2 +- src/distilabel/steps/tasks/image_generation.py | 10 ++++++++-- .../steps/tasks/structured_outputs/outlines.py | 3 ++- .../steps/tasks/text_generation_with_image.py | 1 - src/distilabel/utils/image.py | 7 ++++--- 8 files changed, 24 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c55ebb1c7..a413f05e74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "orjson >= 3.10.0", "universal_pathlib >= 0.2.2", "portalocker >= 2.8.2", + "setuptools", ] dynamic = ["version"] @@ -90,9 +91,7 @@ ray = ["ray[default] >= 2.31.0"] vertexai = ["google-cloud-aiplatform >= 1.38.0"] vllm = [ "vllm >= 0.5.3", - "filelock >= 3.13.4", - # `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]` - "setuptools", + "filelock >= 3.13.4" ] sentence-transformers = ["sentence-transformers >= 3.0.0"] faiss-cpu = ["faiss-cpu >= 1.8.0"] diff --git a/src/distilabel/models/image_generation/huggingface/inference_endpoints.py b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py index 2403fbf018..a5225815ef 100644 --- a/src/distilabel/models/image_generation/huggingface/inference_endpoints.py +++ b/src/distilabel/models/image_generation/huggingface/inference_endpoints.py @@ -20,7 +20,6 @@ InferenceEndpointsBaseClient, ) from distilabel.models.image_generation.base import AsyncImageGenerationModel -from distilabel.models.image_generation.utils import image_to_str if TYPE_CHECKING: from PIL.Image import Image @@ -60,10 +59,14 @@ class InferenceEndpointsImageGeneration( # type: ignore """ def load(self) -> None: + from distilabel.models.image_generation.utils import image_to_str + # Sets the logger and calls the load method of the BaseClient AsyncImageGenerationModel.load(self) InferenceEndpointsBaseClient.load(self) + self._image_to_str = image_to_str + @validate_call async def agenerate( # type: ignore self, @@ -101,6 +104,6 @@ async def agenerate( # type: ignore num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) - img_str = image_to_str(image, image_format="JPEG") + img_str = self._image_to_str(image, image_format="JPEG") return [{"images": [img_str]}] diff --git a/src/distilabel/models/image_generation/utils.py b/src/distilabel/models/image_generation/utils.py index e5f08ca343..fe7e4e5d7d 100644 --- a/src/distilabel/models/image_generation/utils.py +++ b/src/distilabel/models/image_generation/utils.py @@ -18,14 +18,14 @@ from PIL import Image -def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str: +def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str: """Converts a PIL Image to a base64 encoded string.""" buffered = io.BytesIO() image.save(buffered, format=image_format) return base64.b64encode(buffered.getvalue()).decode("utf-8") -def image_from_str(image_str: str) -> Image.Image: +def image_from_str(image_str: str) -> "Image.Image": """Converts a base64 encoded string to a PIL Image.""" image_bytes = base64.b64decode(image_str) return Image.open(io.BytesIO(image_bytes)) diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py index ffdcf37526..e23401b07e 100644 --- a/src/distilabel/models/llms/mlx.py +++ b/src/distilabel/models/llms/mlx.py @@ -60,7 +60,7 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): ```python from distilabel.models.llms import MlxLLM - llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit") + llm = MlxLLM(path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit") llm.load() diff --git a/src/distilabel/steps/tasks/image_generation.py b/src/distilabel/steps/tasks/image_generation.py index 3484b90058..ebc411b8c7 100644 --- a/src/distilabel/steps/tasks/image_generation.py +++ b/src/distilabel/steps/tasks/image_generation.py @@ -15,7 +15,6 @@ import hashlib from typing import TYPE_CHECKING -from distilabel.models.image_generation.utils import image_from_str from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import ImageTask @@ -117,6 +116,13 @@ class ImageGeneration(ImageTask): save_artifacts: bool = False image_format: str = "JPEG" + def load(self) -> None: + from distilabel.models.image_generation.utils import image_from_str + + super().load() + + self._image_from_str = image_from_str + @property def inputs(self) -> "StepColumns": return ["prompt"] @@ -166,7 +172,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # use prompt as filename prompt_hash = hashlib.md5(input["prompt"].encode()).hexdigest() # Build PIL image to save it - image = image_from_str(image) + image = self._image_from_str(image) self.save_artifact( name="images", diff --git a/src/distilabel/steps/tasks/structured_outputs/outlines.py b/src/distilabel/steps/tasks/structured_outputs/outlines.py index 45b5fe7494..a0b4ced551 100644 --- a/src/distilabel/steps/tasks/structured_outputs/outlines.py +++ b/src/distilabel/steps/tasks/structured_outputs/outlines.py @@ -28,7 +28,6 @@ get_args, ) -import pkg_resources from pydantic import BaseModel from distilabel.errors import DistilabelUserError @@ -50,6 +49,8 @@ def _is_outlines_version_below_0_1_0() -> bool: Returns: bool: True if outlines is not installed or version is below 0.1.0 """ + import pkg_resources + if not importlib.util.find_spec("outlines"): raise ImportError( "Outlines is not installed. Please install it using `pip install outlines`." diff --git a/src/distilabel/steps/tasks/text_generation_with_image.py b/src/distilabel/steps/tasks/text_generation_with_image.py index 8aee386f80..3e0bef56e6 100644 --- a/src/distilabel/steps/tasks/text_generation_with_image.py +++ b/src/distilabel/steps/tasks/text_generation_with_image.py @@ -15,7 +15,6 @@ from typing import TYPE_CHECKING, Any, Literal, Union from jinja2 import Template -from PIL import Image from pydantic import Field from distilabel.steps.tasks.base import Task diff --git a/src/distilabel/utils/image.py b/src/distilabel/utils/image.py index aa9d09089c..060eb71e08 100644 --- a/src/distilabel/utils/image.py +++ b/src/distilabel/utils/image.py @@ -14,12 +14,13 @@ import base64 import io +from typing import TYPE_CHECKING -from PIL import Image +if TYPE_CHECKING: + from PIL import Image -# TODO: Once we merge the image generation, this function can be reused -def image_to_str(image: Image.Image, image_format: str = "JPEG") -> str: +def image_to_str(image: "Image.Image", image_format: str = "JPEG") -> str: """Converts a PIL Image to a base64 encoded string.""" buffered = io.BytesIO() image.save(buffered, format=image_format) From 34e84e3a86b8ca40f2606074788c67c2402074f9 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Fri, 17 Jan 2025 14:58:11 +0100 Subject: [PATCH 28/30] Fix `MlxLLM` by aligning it with `mlx-lm>=0.21` (#1103) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez --- pyproject.toml | 2 +- src/distilabel/llms.py | 2 + src/distilabel/models/llms/mlx.py | 83 ++++++++++++++++--------------- 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a413f05e74..30bd06262d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ text-clustering = [ "scikit-learn >= 1.4.1", "matplotlib >= 3.8.3", # For the figure (even though it's optional) ] -mlx = ["mlx >= 0.21.0", "mlx-lm"] +mlx = ["mlx >= 0.21.0", "mlx-lm >= 0.21.0, < 0.22.0"] vision = ["Pillow >= 10.3.0"] # To work with images. # minhash diff --git a/src/distilabel/llms.py b/src/distilabel/llms.py index 8d579048df..730950a109 100644 --- a/src/distilabel/llms.py +++ b/src/distilabel/llms.py @@ -33,6 +33,7 @@ from distilabel.models.llms.litellm import LiteLLM from distilabel.models.llms.llamacpp import LlamaCppLLM from distilabel.models.llms.mistral import MistralLLM +from distilabel.models.llms.mlx import MlxLLM from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM @@ -59,6 +60,7 @@ "LlamaCppLLM", "MistralLLM", "MixtureOfAgentsLLM", + "MlxLLM", "OllamaLLM", "OpenAILLM", "TogetherLLM", diff --git a/src/distilabel/models/llms/mlx.py b/src/distilabel/models/llms/mlx.py index e23401b07e..8418fde4f6 100644 --- a/src/distilabel/models/llms/mlx.py +++ b/src/distilabel/models/llms/mlx.py @@ -19,9 +19,11 @@ Dict, List, Optional, + Union, ) from pydantic import ( + Field, PrivateAttr, validate_call, ) @@ -42,7 +44,7 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): Attributes: path_or_hf_repo: the path to the model or the Hugging Face Hub repo id. tokenizer_config: the tokenizer configuration. - model_config: the model configuration. + mlx_model_config: the MLX model configuration. adapter_path: the path to the adapter. use_magpie_template: a flag used to enable/disable applying the Magpie pre-query template. Defaults to `False`. @@ -70,20 +72,22 @@ class MlxLLM(LLM, MagpieChatTemplateMixin): """ path_or_hf_repo: str - tokenizer_config: Dict[str, Any] = {} - model_config: Dict[str, Any] = {} + tokenizer_config: Dict[str, Any] = Field(default_factory=dict) + mlx_model_config: Dict[str, Any] = Field(default_factory=dict) adapter_path: Optional[str] = None - _mlx_generate: Optional[Callable] = PrivateAttr(default=None) - _model: Optional["nn.Module"] = PrivateAttr(...) - _tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...) + _model: Optional["nn.Module"] = PrivateAttr(None) + _tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(None) + _mlx_generate: Optional[Callable] = PrivateAttr(None) + _make_sampler: Optional[Callable] = PrivateAttr(None) def load(self) -> None: """Loads the model and tokenizer and creates the text generation pipeline. In addition, it will configure the tokenizer chat template.""" try: import mlx # noqa - from mlx_lm import generate, load + from mlx_lm.utils import generate, load + from mlx_lm.sample_utils import make_sampler except ImportError as ie: raise ImportError( "MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`." @@ -92,7 +96,7 @@ def load(self) -> None: self._model, self._tokenizer = load( self.path_or_hf_repo, tokenizer_config=self.tokenizer_config, - model_config=self.model_config, + model_config=self.mlx_model_config, adapter_path=self.adapter_path, ) @@ -100,7 +104,7 @@ def load(self) -> None: self._tokenizer.pad_token = self._tokenizer.eos_token self._mlx_generate = generate - + self._make_sampler = make_sampler super().load() @property @@ -108,7 +112,7 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.path_or_hf_repo - def prepare_input(self, input: "StandardInput") -> str: + def prepare_input(self, input: Union["StandardInput", str]) -> str: """Prepares the input (applying the chat template and tokenization) for the provided input. @@ -118,11 +122,11 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ - if self._tokenizer.chat_template is None: - return input[0]["content"] + if isinstance(input, str): + return input prompt: str = ( - self._tokenizer.apply_chat_template( + self._tokenizer.apply_chat_template( # type: ignore input, tokenize=False, add_generation_prompt=True, @@ -133,12 +137,11 @@ def prepare_input(self, input: "StandardInput") -> str: return super().apply_magpie_pre_query_template(prompt, input) @validate_call - def generate( + def generate( # type: ignore self, - inputs: List[StandardInput], + inputs: List[Union[StandardInput, str]], num_generations: int = 1, max_tokens: int = 256, - sampler: Optional[Callable] = None, logits_processors: Optional[List[Callable]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, @@ -147,12 +150,11 @@ def generate( kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[[int, int], None]] = None, - temp: Optional[float] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = None, - top_p: Optional[float] = None, - min_p: Optional[float] = None, - min_tokens_to_keep: Optional[int] = None, + temp: float = 0.0, + top_p: float = 0.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + top_k: int = -1, ) -> List[GenerateOutput]: """Generates `num_generations` responses for each input using the text generation pipeline. @@ -163,7 +165,6 @@ def generate( `1`. max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`. - sampler: the sampler to use for the generation. Defaults to `None`. logits_processors: the logits processors to use for the generation. Defaults to `None`. max_kv_size: the maximum size of the key-value cache. Defaults to `None`. @@ -174,18 +175,24 @@ def generate( quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`. prompt_progress_callback: the callback to use for the generation. Defaults to `None`. - temp: the temperature to use for the generation. Defaults to `None`. - repetition_penalty: the repetition penalty to use for the generation. Defaults to - `None`. - repetition_context_size: the context size for the repetition penalty. Defaults to - `None`. - top_p: the top-p value to use for the generation. Defaults to `None`. - min_p: the minimum p value to use for the generation. Defaults to `None`. - min_tokens_to_keep: the minimum number of tokens to keep. Defaults to `None`. + temp: The temperature for text generation. Defaults to `0.0`. + top_p: The top-p value used for the generation. Defaults to `0.0`. + min_p: The min-p value used for the generation. Defaults to `0.0`. + min_tokens_to_keep: Minimum number of tokens to keep for sampling after + filtering. Must be at least 1. Defaults to `1`. + top_k: The top-k value used for the generation. Defaults to `-1`. Returns: A list of lists of strings containing the generated responses for each input. """ + + sampler = self._make_sampler( # type: ignore + temp=temp, + top_p=top_p, + min_p=min_p, + min_tokens_to_keep=min_tokens_to_keep, + top_k=top_k, + ) structured_output = None result = [] for input in inputs: @@ -197,7 +204,7 @@ def generate( if structured_output: # will raise a NotImplementedError self._prepare_structured_output(structured_output) prompt = self.prepare_input(input) - generation = self._mlx_generate( + generation = self._mlx_generate( # type: ignore prompt=prompt, model=self._model, tokenizer=self._tokenizer, @@ -211,24 +218,18 @@ def generate( kv_group_size=kv_group_size, quantized_kv_start=quantized_kv_start, prompt_progress_callback=prompt_progress_callback, - temp=temp, - repetition_penalty=repetition_penalty, - repetition_context_size=repetition_context_size, - top_p=top_p, - min_p=min_p, - min_tokens_to_keep=min_tokens_to_keep, ) output.append(generation) result.append( prepare_output( - output, - input_tokens=[compute_tokens(input, self._tokenizer.encode)], + generations=output, + input_tokens=[compute_tokens(input, self._tokenizer.encode)], # type: ignore output_tokens=[ compute_tokens( text_or_messages=generation, - tokenizer=self._tokenizer.encode, + tokenizer=self._tokenizer.encode, # type: ignore ) for generation in output ], From 69bbe3d080e85af93361403a497dbe34fa00bd22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 17 Jan 2025 15:23:45 +0100 Subject: [PATCH 29/30] Update version to `1.5.1` --- src/distilabel/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py index 41bc97e783..c1336509f5 100644 --- a/src/distilabel/__init__.py +++ b/src/distilabel/__init__.py @@ -14,6 +14,6 @@ from rich import traceback as rich_traceback -__version__ = "1.6.0" +__version__ = "1.5.1" rich_traceback.install(show_locals=True) From a3320ed145ec989e26ac4c26e7defe9e49e49698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 17 Jan 2025 15:24:28 +0100 Subject: [PATCH 30/30] Bump version to `1.6.0` --- src/distilabel/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py index c1336509f5..41bc97e783 100644 --- a/src/distilabel/__init__.py +++ b/src/distilabel/__init__.py @@ -14,6 +14,6 @@ from rich import traceback as rich_traceback -__version__ = "1.5.1" +__version__ = "1.6.0" rich_traceback.install(show_locals=True)