diff --git a/README.md b/README.md index 2794b13..18e21b1 100644 --- a/README.md +++ b/README.md @@ -70,9 +70,10 @@ Curious to learn more about GPT-4 Vision? [Check out our GPT-4V experiments ๐Ÿงช WARNING: DO NOT EDIT THIS TABLE MANUALLY. IT IS AUTOMATICALLY GENERATED. HEAD OVER TO CONTRIBUTING.MD FOR MORE DETAILS ON HOW TO MAKE CHANGES PROPERLY. --> -## ๐Ÿš€ model tutorials (41 notebooks) +## ๐Ÿš€ model tutorials (42 notebooks) | **notebook** | **open in colab / kaggle / sagemaker studio lab** | **complementary materials** | **repository / paper** | |:------------:|:-------------------------------------------------:|:---------------------------:|:----------------------:| +| [Fine-Tune GPT-4o](https://github.com/roboflow-ai/notebooks/blob/main/notebooks/openai-gpt-4o-fine-tuning.ipynb) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/openai-gpt-4o-fine-tuning.ipynb) [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/roboflow-ai/notebooks/blob/main/notebooks/openai-gpt-4o-fine-tuning.ipynb) | | | | [YOLO11 Object Detection](https://github.com/roboflow-ai/notebooks/blob/main/notebooks/train-yolo11-object-detection-on-custom-dataset.ipynb) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-yolo11-object-detection-on-custom-dataset.ipynb) [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/roboflow-ai/notebooks/blob/main/notebooks/train-yolo11-object-detection-on-custom-dataset.ipynb) | [![Roboflow](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg)](https://blog.roboflow.com/yolov11-how-to-train-custom-data/) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=jE_s4tVgPHA) | [![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/ultralytics/ultralytics) | | [YOLO11 Instance Segmentation](https://github.com/roboflow-ai/notebooks/blob/main/notebooks/train-yolo11-instance-segmentation-on-custom-dataset.ipynb) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-yolo11-instance-segmentation-on-custom-dataset.ipynb) [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/roboflow-ai/notebooks/blob/main/notebooks/train-yolo11-instance-segmentation-on-custom-dataset.ipynb) | [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=jE_s4tVgPHA) | [![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/ultralytics/ultralytics) | | [Segment Images with SAM2](https://github.com/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb) [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb) | [![Roboflow](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg)](https://blog.roboflow.com/what-is-segment-anything-2/) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/live/Dv003fTyO-Y) | [![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/facebookresearch/segment-anything-2) [![arXiv](https://img.shields.io/badge/arXiv-2408.00714-b31b1b.svg)](https://arxiv.org/abs/2408.00714)| diff --git a/automation/notebooks-table-data.csv b/automation/notebooks-table-data.csv index dfa5b3d..11eb528 100644 --- a/automation/notebooks-table-data.csv +++ b/automation/notebooks-table-data.csv @@ -1,4 +1,5 @@ display_name, notebook_name, roboflow_blogpost_path, youtube_video_path, github_repository_path, arxiv_index, should_open_in_sagemaker_labs, readme_section +Fine-Tune GPT-4o, openai-gpt-4o-fine-tuning.ipynb, , , , , False, models YOLO11 Object Detection, train-yolo11-object-detection-on-custom-dataset.ipynb, https://blog.roboflow.com/yolov11-how-to-train-custom-data/, https://www.youtube.com/watch?v=jE_s4tVgPHA, https://github.com/ultralytics/ultralytics, , False, models YOLO11 Instance Segmentation, train-yolo11-instance-segmentation-on-custom-dataset.ipynb, , https://www.youtube.com/watch?v=jE_s4tVgPHA, https://github.com/ultralytics/ultralytics, , False, models Football AI, football-ai.ipynb, https://blog.roboflow.com/camera-calibration-sports-computer-vision/, https://youtu.be/aBVGKoNZQUw, https://github.com/roboflow/sports, , False, skills diff --git a/notebooks/openai-gpt-4o-fine-tuning.ipynb b/notebooks/openai-gpt-4o-fine-tuning.ipynb new file mode 100644 index 0000000..03b9051 --- /dev/null +++ b/notebooks/openai-gpt-4o-fine-tuning.ipynb @@ -0,0 +1,552 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "[![Roboflow Notebooks](https://media.roboflow.com/notebooks/template/bannertest2-2.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672932710194)](https://github.com/roboflow/notebooks)\n", + "\n", + "# OpenAI GPT-4o fine-tuning\n", + "---" + ], + "metadata": { + "id": "qxRhs8smCtFO" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Setup" + ], + "metadata": { + "id": "qgabw2nPowqj" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Configure your API keys\n", + "\n", + "To fine-tune GPT-4o, you need to provide your OpenAI API key and Roboflow API key. Follow these steps:\n", + "\n", + "- Open your [`OpenAI Settings`](https://platform.openai.com/settings) page. Click `User API keys` then `Create new secret key` to generate new token.\n", + "- Go to your [`Roboflow Settings`](https://app.roboflow.com/settings/api) page. Click `Copy`. This will place your private key in the clipboard.\n", + "- In Colab, go to the left pane and click on `Secrets` (๐Ÿ”‘).\n", + " - Store OpenAI API key under the name `OPENAI_API_KEY`.\n", + " - Store Roboflow API Key under the name `ROBOFLOW_API_KEY`." + ], + "metadata": { + "id": "_U2G3NfcozRD" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Install dependencies" + ], + "metadata": { + "id": "yGb7ydD1pwC9" + } + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2xC6GaDomrrN", + "outputId": "cde6cdfe-99c6-4235-c03f-7ba1f67fe828" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90mโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”\u001b[0m \u001b[32m42.4/42.4 kB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”\u001b[0m \u001b[32m2.6/2.6 MB\u001b[0m \u001b[31m46.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90mโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”\u001b[0m \u001b[32m43.0/43.0 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”\u001b[0m \u001b[32m296.4/296.4 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”\u001b[0m \u001b[32m151.3/151.3 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m45.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90mโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m56.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for flash-attn (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install -q openai roboflow maestro==0.2.0rc5" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Download dataset" + ], + "metadata": { + "id": "QBZNvweDq_O4" + } + }, + { + "cell_type": "code", + "source": [ + "from roboflow import Roboflow\n", + "from google.colab import userdata\n", + "\n", + "ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')\n", + "rf = Roboflow(api_key=ROBOFLOW_API_KEY)\n", + "\n", + "workspace = rf.workspace(\"april-public-yibrz\")\n", + "project = workspace.project(\"focal-length\")\n", + "version = project.version(1)\n", + "dataset = version.download(\"openai\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uF3PKCH6p9Ca", + "outputId": "79a15645-1a86-4db3-a01e-c6c993416b43" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "loading Roboflow workspace...\n", + "loading Roboflow project...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading Dataset Version Zip in Focal-Length-1 to openai:: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 4/4 [00:00<00:00, 2926.94it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + "Extracting Dataset Version Zip to Focal-Length-1 in openai:: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5/5 [00:00<00:00, 1204.71it/s]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!head -n 5 {dataset.location}/_annotations.train.jsonl" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CKuZ5frBuu4C", + "outputId": "2600b533-abcc-49c9-e431-11929c585958" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},{\"role\":\"user\",\"content\":\"What focal length is this photo?\"},{\"role\":\"user\",\"content\":[{\"type\":\"image_url\",\"image_url\":{\"url\":\"https://transform.roboflow.com/SFgRaqEsIPfd7Vj37buG/018867065caf451d0098b749fae0f310/transformed.jpg\"}}]},{\"role\":\"assistant\",\"content\":\"55.0mm\"}]}\n", + "{\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},{\"role\":\"user\",\"content\":\"What focal length is this photo?\"},{\"role\":\"user\",\"content\":[{\"type\":\"image_url\",\"image_url\":{\"url\":\"https://transform.roboflow.com/SFgRaqEsIPfd7Vj37buG/44561655d9c73836724db71e9639dc63/transformed.jpg\"}}]},{\"role\":\"assistant\",\"content\":\"50.0mm\"}]}\n", + "{\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},{\"role\":\"user\",\"content\":\"What focal length is this photo?\"},{\"role\":\"user\",\"content\":[{\"type\":\"image_url\",\"image_url\":{\"url\":\"https://transform.roboflow.com/SFgRaqEsIPfd7Vj37buG/4b3887ae2840343d284e47a582a47f9d/transformed.jpg\"}}]},{\"role\":\"assistant\",\"content\":\"14.3mm\"}]}\n", + "{\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},{\"role\":\"user\",\"content\":\"What focal length is this photo?\"},{\"role\":\"user\",\"content\":[{\"type\":\"image_url\",\"image_url\":{\"url\":\"https://transform.roboflow.com/SFgRaqEsIPfd7Vj37buG/21db473de379c82b5e348c046fdb9633/transformed.jpg\"}}]},{\"role\":\"assistant\",\"content\":\"50.0mm\"}]}\n", + "{\"messages\":[{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},{\"role\":\"user\",\"content\":\"What focal length is this photo?\"},{\"role\":\"user\",\"content\":[{\"type\":\"image_url\",\"image_url\":{\"url\":\"https://transform.roboflow.com/SFgRaqEsIPfd7Vj37buG/6bec57abe2603f9b701b2370893f80ab/transformed.jpg\"}}]},{\"role\":\"assistant\",\"content\":\"59.0mm\"}]}\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Run GPT-4o fine-tuning\n", + "\n", + "**NOTE:** At the time of publishing this notebook, only the `gpt-4o-2024-08-06` model can be fine-tuned with vision datasets." + ], + "metadata": { + "id": "SFcFQvocuqn-" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Initiate OpenAI client\n", + "\n", + "from openai import OpenAI\n", + "from google.colab import userdata\n", + "\n", + "OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')\n", + "\n", + "client = OpenAI(api_key=OPENAI_API_KEY)" + ], + "metadata": { + "id": "uc4fjWVQxP_b" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Upload a training and validation file\n", + "\n", + "training_file_upload_response = client.files.create(\n", + " file=open(f\"{dataset.location}/_annotations.train.jsonl\", \"rb\"),\n", + " purpose=\"fine-tune\"\n", + ")\n", + "\n", + "validation_file_upload_response = client.files.create(\n", + " file=open(f\"{dataset.location}/_annotations.valid.jsonl\", \"rb\"),\n", + " purpose=\"fine-tune\"\n", + ")\n", + "\n", + "print(\"treaining file response:\", training_file_upload_response)\n", + "print(\"validation file response:\", validation_file_upload_response)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "grniVs2Xw8i8", + "outputId": "de402041-f921-409e-d7c6-4f612d186849" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "treaining file response: FileObject(id='file-ByuHoRS2fs7TM1TQEsQJft05', bytes=12146, created_at=1727882579, filename='_annotations.train.jsonl', object='file', purpose='fine-tune', status='processed', status_details=None)\n", + "validation file response: FileObject(id='file-n2qoSkM5FA1AvxYINDJnKXmx', bytes=3471, created_at=1727882579, filename='_annotations.valid.jsonl', object='file', purpose='fine-tune', status='processed', status_details=None)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Create a fine-tuned model\n", + "\n", + "import re\n", + "\n", + "def process_suffix(text: str) -> str:\n", + " \"\"\"\n", + " Converts a string into kebab-case, where spaces are replaced with hyphens\n", + " and all letters are lowercase.\n", + "\n", + " Args:\n", + " text (str): The input string to be converted. Typically, words are\n", + " separated by spaces.\n", + "\n", + " Returns:\n", + " str: The kebab-case version of the input string, where spaces are\n", + " replaced by hyphens and the text is lowercase.\n", + "\n", + " Example:\n", + " >>> process_suffix(\"Focal Length\")\n", + " 'focal-length'\n", + " \"\"\"\n", + " return re.sub(r'\\s+', '-', text.strip()).lower()\n", + "\n", + "\n", + "fine_tuning_response = client.fine_tuning.jobs.create(\n", + " training_file=training_file_upload_response.id,\n", + " validation_file=validation_file_upload_response.id,\n", + " suffix=process_suffix(dataset.name),\n", + " model=\"gpt-4o-2024-08-06\"\n", + ")\n", + "\n", + "fine_tuning_response" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "95eRs4VmrNJL", + "outputId": "d6aad0a1-fe81-45c1-d3c8-28184a1a3ef2" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "FineTuningJob(id='ftjob-JCGY6KZqPitJwXD5rPwsu7hb', created_at=1727882809, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='gpt-4o-2024-08-06', object='fine_tuning.job', organization_id='org-sLGE3gXNesVjtWzgho17NkRy', result_files=[], seed=1660100124, status='validating_files', trained_tokens=None, training_file='file-ByuHoRS2fs7TM1TQEsQJft05', validation_file='file-n2qoSkM5FA1AvxYINDJnKXmx', estimated_finish=None, integrations=[], user_provided_suffix='focal-length')" + ] + }, + "metadata": {}, + "execution_count": 16 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "โš ๏ธ After you've started a fine-tuning job, it may take some time to complete. Your job may be queued behind other jobs in our system, and training a model can take minutes or hours depending on the model and dataset size. After the model training is completed, the user who created the fine-tuning job will receive an email confirmation.\n", + "\n", + "In addition to creating a fine-tuning job, you can also list existing jobs, retrieve the status of a job, or cancel a job." + ], + "metadata": { + "id": "a1GmSlqtzlT5" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Check training job status\n", + "\n", + "status_response = client.fine_tuning.jobs.retrieve(fine_tuning_response.id)\n", + "\n", + "status_response" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_rGTqH0-vwAM", + "outputId": "8cf7efd4-fac2-40b0-9805-e9c97c26a739" + }, + "execution_count": 25, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "FineTuningJob(id='ftjob-JCGY6KZqPitJwXD5rPwsu7hb', created_at=1727882809, error=Error(code=None, message=None, param=None), fine_tuned_model='ft:gpt-4o-2024-08-06:personal:focal-length:ADvvXOAF', finished_at=1727884296, hyperparameters=Hyperparameters(n_epochs=4, batch_size=1, learning_rate_multiplier=2), model='gpt-4o-2024-08-06', object='fine_tuning.job', organization_id='org-sLGE3gXNesVjtWzgho17NkRy', result_files=['file-gKmIB1gmlxjQsDSHtGoeSkig'], seed=1660100124, status='succeeded', trained_tokens=101972, training_file='file-ByuHoRS2fs7TM1TQEsQJft05', validation_file='file-n2qoSkM5FA1AvxYINDJnKXmx', estimated_finish=None, integrations=[], user_provided_suffix='focal-length')" + ] + }, + "metadata": {}, + "execution_count": 25 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "**NOTE:** When the training status changes to `succeeded`, the model is ready to use." + ], + "metadata": { + "id": "zCqb-Gkx_RIj" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Use a fine-tuned model\n", + "\n", + "import random\n", + "from torch.utils.data import Dataset\n", + "from maestro.trainer.common.utils.file_system import read_jsonl\n", + "\n", + "class JSONLDataset(Dataset):\n", + " @classmethod\n", + " def from_jsonl_file(cls, path: str):\n", + " file_content = read_jsonl(path=path)\n", + " random.shuffle(file_content)\n", + " return cls(jsons=file_content)\n", + "\n", + " def __init__(self, jsons: list[dict]) -> None:\n", + " self.jsons = jsons\n", + "\n", + " def __getitem__(self, index):\n", + " return self.jsons[index]\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.jsons)\n", + "\n", + " def shuffle(self) -> None:\n", + " random.shuffle(self.jsons)\n", + "\n", + "\n", + "test_dataset = JSONLDataset.from_jsonl_file(f\"{dataset.location}/_annotations.test.jsonl\")" + ], + "metadata": { + "id": "WhmepSHgymL6" + }, + "execution_count": 26, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "test_dataset[0]['messages']" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qRI5rqLH_o7u", + "outputId": "95ee4359-9c9a-4ba8-8247-27f57ab4927b" + }, + "execution_count": 30, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[{'role': 'system', 'content': 'You are a helpful assistant.'},\n", + " {'role': 'user', 'content': 'What focal length is this photo?'},\n", + " {'role': 'user',\n", + " 'content': [{'type': 'image_url',\n", + " 'image_url': {'url': 'https://transform.roboflow.com/SFgRaqEsIPfd7Vj37buG/cb5a2dc8fe341a5360aea91ea00bdd15/transformed.jpg'}}]},\n", + " {'role': 'assistant', 'content': '135.0mm'}]" + ] + }, + "metadata": {}, + "execution_count": 30 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "**NOTE:** When querying the model, we need to remove the last element of the messages list, which contains the expected model response." + ], + "metadata": { + "id": "HAJLtZ2i-3xC" + } + }, + { + "cell_type": "code", + "source": [ + "completion = client.chat.completions.create(\n", + " model=status_response.fine_tuned_model,\n", + " messages=test_dataset[0]['messages'][:-1]\n", + ")\n", + "print(completion.choices[0].message)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "W8YspUe648Sz", + "outputId": "4146d849-2732-476e-ba00-2fca63f99938" + }, + "execution_count": 31, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "ChatCompletionMessage(content='35.0mm', refusal=None, role='assistant', function_call=None, tool_calls=None)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Evaluate fine-tuned model\n", + "\n", + "from maestro.trainer.common.utils.metrics import WordErrorRateMetric, CharacterErrorRateMetric\n", + "\n", + "targets = []\n", + "predistions = []\n", + "\n", + "for i in range(len(test_dataset)):\n", + " messages = test_dataset[i]['messages'][:-1]\n", + " target = test_dataset[i]['messages'][-1]['content']\n", + "\n", + " completion = client.chat.completions.create(\n", + " model=status_response.fine_tuned_model,\n", + " messages=messages\n", + " )\n", + " prediction = completion.choices[0].message.content\n", + "\n", + " targets.append(target)\n", + " predistions.append(prediction)\n", + "\n", + "wer = WordErrorRateMetric().compute(targets=targets, predictions=predistions)\n", + "cer = CharacterErrorRateMetric().compute(targets=targets, predictions=predistions)\n", + "\n", + "print(f\"WER: {wer}\")\n", + "print(f\"CER: {cer}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qtWhg5rM9d8i", + "outputId": "bf70df48-ef26-4d67-c8af-58809e023c32" + }, + "execution_count": 36, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "WER: {'wer': 1.0}\n", + "CER: {'cer': 0.319047619047619}\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "for target, prediction in zip(targets, predistions):\n", + " print(f\"Target: {target}\")\n", + " print(f\"Prediction: {prediction}\")\n", + " print()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Oq-7q2hgAEp9", + "outputId": "ca4f1a5e-d782-4a12-e9ed-307aa4ae369f" + }, + "execution_count": 37, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Target: 135.0mm\n", + "Prediction: 87.0mm\n", + "\n", + "Target: 45.0mm\n", + "Prediction: 50.0mm\n", + "\n", + "Target: 56.0mm\n", + "Prediction: 50.0mm\n", + "\n", + "Target: 85.0mm\n", + "Prediction: 66.0mm\n", + "\n", + "Target: 50.0mm\n", + "Prediction: 35.0mm\n", + "\n" + ] + } + ] + } + ] +} \ No newline at end of file