Repository : https://github.com/AntoineBendafiSchulmann/doc_detector_pl
Le projet Document Detector utilise PyTorch Lightning pour entraîner un modèle de détection basé sur l'architecture UNet qui est une structure de réseaux de neurones conçu pour la segmentation d'images, c'est à dire reconnaitre et localiser des éléments spécifiques dans une image. Il permet de détecter, segmenter et recadrer automatiquement des documents dans des images. Le projet est organisé pour un environnement d'entraînement et d'évaluation flexible, en exploitant TensorBoard pour visualiser les courbes d'entraînement et OpenCV pour le traitement des masques.
-
data/
: Données utilisées pour l'entraînement et les tests.images/
: Images d'entraînement.masks/
: Masques d'entraînement correspondants.debug/
: Contient les images de debug, où les masques générés sont superposés aux images originales. Ces fichiers permettent de vérifier visuellement la qualité et la pertinence des masques.test_images/
: Images de test pour vérifier les performances du modèle.
-
models/
: Fichiers des modèles sauvegardés :.pth
: Modèle PyTorch (format entraînable)..onnx
: Modèle exporté pour l'inférence.
-
outputs/
: Comparaisons générées entre les images originales, les masques attendus, les prédictions, et les recadrages. -
logs/
: Fichiers de logs générés par PyTorch Lightning pour le suivi via TensorBoard. -
Scripts principaux :
generate_masks.py
: Génère des masques à partir des images d'entraînement.train_model.py
: Lance l'entraînement du modèle.test_model.py
: Effectue une prédiction sur les images de test, génère les comparaisons, et recadre les images.
-
Cloner le dépôt :
git clone https://github.com/AntoineBendafiSchulmann/doc_detector_pl cd doc_detector_pl
-
Installer les dépendances :
pip install -r requirements.txt
j'ai généré un fichier
requirements.txt
contenant la liste des dépendances installées sur mon environnement virtuel python (il se peut qu'il y ait des trucs pertinent pour moi dedans genre certaines bibliothèques commepytorch-triton-rocm
oujax
pour les gpu amd désolé)Ceci est la commande pour générer le fichier directement à partir de son environnement virtuel :
pip freeze > requirements.txt
Pour préparer les données d'entraînement, des masques correspondant aux documents présents dans les images sont générés à l'aide du script generate_masks.py
. Les masques servent de cibles pour l'entraînement du modèle
python generate_masks.py
Le script utilise OpenCV pour détecter automatiquement les documents dans les images. Voici les étapes principales :
- Floutage des Images : Utilise un flou gaussien pour réduire les bruits.
- Seuil adaptatif : Applique une méthode de seuil adaptatif pour séparer le fond des zones importantes.
- Contours : Identifie les contours dans l'image.
- Filtrage : Retient uniquement les contours ayant :
- Quatre côtés (approximé comme un quadrilatère).
- Une aire minimale (par défaut, > 1000 pixels).
- Masque Final : Dessine le quadrilatère détecté.
En cas de besoin, un répertoire debug/
est utilisé pour sauvegarder des images où les masques générés sont superposés aux images originales. Cela permet de vérifier visuellement que les contours détectés et les masques générés sont conformes aux attentes.
- Les fichiers dans
debug/
aident à repérer rapidement les images problématiques.
Dans le terminal principal, exécutez la commande suivante pour lancer l'entraînement :
python train_model.py
Ouvrez un deuxième terminal, placez-vous dans le dossier doc_detector_pl
et exécutez:
tensorboard --logdir logs/ --bind_all
ouvrez l'url pour accéder à TensorBoard
- Les courbes d'entraînement sont visibles dans TensorBoard.
- Le modèle entraîné est sauvegardé dans le dossier
models/
.
Testez le modèle sur une image de test et générez des comparaisons visuelles :
python test_model.py
- Les comparaisons (image originale, masque attendu, masque prédit, image recadrée) sont sauvegardées dans
outputs/
.
Les courbes affichées dans TensorBoard sont des outils essentiels pour suivre et comprendre l'entraînement du modèle. Elles permettent d'identifier les progrès réalisés mais aussi de détecter des problèmes comme l'overfitting.
- Description : Cette courbe indique le nombre d'époques complétées au fur et à mesure de l'entraînement.
- Elle sert simplement à visualiser la progression. Une augmentation linéaire est attendue et normale.
- Description : Cette courbe montre comment la "perte", une mesure de l'erreur entre la prédiction du modèle et les résultats attendus, évolue à chaque époque (un passage complet de toutes les images du jeu d'entraînement dans le modèle).
- Une courbe qui diminue régulièrement indique que le modèle apprend bien.
- Si la perte stagne ou augmente après une phase de diminution, cela peut signaler que le modèle atteint sa capacité maximale ou commence à mémoriser les données (overfitting).
- Description : Cette courbe montre les variations de la perte après chaque groupe d'images (mini-lot ou "batch") traité par le modèle.
- De légères fluctuations sont normales et reflètent la diversité des données.
- Une tendance globale à la baisse est un bon signe. Cependant, si les variations sont trop importantes ou ne diminuent pas, cela peut indiquer que l'apprentissage est instable.
- Description : Cette courbe représente l'évolution de la perte (ou erreur) calculée sur les données de validation après chaque époque. Contrairement à la perte d'entraînement, la perte de validation mesure la capacité du modèle à généraliser sur des données qu'il n'a jamais vues auparavant.
- Tendance à la baisse : Si la courbe diminue régulièrement, cela signifie que le modèle apprend à bien généraliser sur de nouvelles données.
- Augmentation ou stagnation : Si la courbe stagne ou augmente alors que la perte d'entraînement continue de diminuer, cela peut indiquer un début d'overfitting.
- Oscillations : Une légère variation est normale. Cependant, des fluctuations importantes peuvent refléter une instabilité dans l'apprentissage ou un jeu de validation trop petit.(dans mon cas actuel il se pouurait qu'on soit pas bon xd)
L'overfitting se produit lorsque le modèle apprend trop bien sur les données d'entraînement, au point de mémoriser des détails spécifiques qui ne se généralisent pas à de nouvelles données. Cela se traduit par une bonne performance sur les données d'entraînement, mais de mauvaises performances sur des données jamais vues auparavant. Voici quelques techniques pour éviter ce problème
- Augmentation de la
val_loss
: Si la courbe de validation (val_loss
) commence à augmenter ou stagner alors que la perte d'entraînement (train_loss
) continue de diminuer, cela indique que le modèle mémorise les données d'entraînement sans généraliser correctement. - Écart significatif entre
train_loss
etval_loss
: Une différence importante entre la perte d'entraînement (très faible) et la perte de validation (plus élevée) reflète un sur-apprentissage sur les données d'entraînement. - Fluctuations importantes de
val_loss
: De grandes oscillations peuvent indiquer que le modèle est instable ou que les données de validation ne sont pas représentatives.
- Augmenter le volume des données : Ajouter davantage d'exemples, si possible, pour couvrir une plus large gamme de variations (angles, résolutions, types de documents).
- Augmentation des données (Data Augmentation) : Utiliser des transformations comme des rotations, des inversions, des changements de luminosité ou des zooms. Cela peut être implémenté via des outils comme torchvision.transforms.
- Diviser les données en plusieurs groupes et entraîner le modèle sur différents sous-ensembles de données. Cela aide à vérifier la capacité de généralisation du modèle et réduit la dépendance à une seule division "données d'entraînement / validation".
- Configurer un mécanisme dans PyTorch Lightning pour arrêter l'entraînement lorsque la perte de validation (
val_loss
) n'améliore plus après un certain nombre d'époch. Cela empêche le modèle d'apprendre des détails inutiles des données d'entraînement.
- Utiliser des couches de Dropout dans le modèle U-Net pour désactiver aléatoirement une fraction des neurones pendant l'entraînement. Cela force le modèle à s'appuyer sur différentes combinaisons de caractéristiques et améliore la généralisation.
- Ajouter une régularisation L2 (via l'option
weight_decay
dans l'optimiseur Adam). Cela limite la croissance excessive des poids du modèle et empêche la mémorisation.
- Si le modèle est trop complexe pour la taille des données disponibles (par exemple, trop de couches ou de paramètres), réduire sa taille peut améliorer la généralisation.
- S'assurer que les données de validation reflètent la diversité des données réelles. Cela permet d'avoir une meilleure idée des performances du modèle sur des cas non vus.
Le script test_model
permet d'effectuer une prédiction sur une image de test et de recadrer automatiquement le document détecté en fonction du masque prédit par le modèle UNet. Cette étape simule comment le modèle peut être utilisé dans un pipeline complet quand le modèle sera exporté sur tensorflow js pour être utilisable sur un navigateur, allant de la détection à l'extraction ciblée d'un document.
- Chargement de l'image et du masque attendu :
- L'image de test est lue et préparée pour être passée dans le modèle.
- Un masque attendu peut également être utilisé pour comparaison.
- Prédiction du masque :
- Le modèle génère un masque binaire prédisant les zones correspondant au document.
- Redimensionnement du masque :
- Le masque prédit est redimensionné pour correspondre aux dimensions originales de l'image.
- Détection et recadrage :
- À l'aide des coordonnées du masque prédit, la zone correspondant au document est recadrée automatiquement.
- Visualisation et sauvegarde :
- L'image originale
- Le masque attendu
- Le masque prédit
- L'image recadrée
Exemple de Commande : Pour tester cette fonctionnalité, exécutez simplement la commande suivante :
python test_model.py
Voici la sortie générée par ce script :
(je sais il reste du travail au vu du résultat mais vous avez l'idée)
- Image originale : Affiche l'image d'entrée brute.
- Masque attendu : Masque utilisé pour l'entraînement.
- Masque prédit : Généré automatiquement par le modèle.
- Image recadrée : Image finale extraite en fonction des prédictions
Le modèle entraîné peut être exporté dans différents formats comme ONNX ou TensorFlow.js, permettant son utilisation dans des systèmes backend ou frontend.
- ONNX : Permet l'intégration avec des frameworks comme ONNX Runtime pour des prédictions rapides.
- TensorFlow.js : Rendra le modèle directement utilisable dans les navigateurs.
Le modèle peut également être déployé via une API Flask pour permettre une interaction en temps réel avec des systèmes externes. Voici les étapes principales d'une intégration avec Flask :
Création de l'API : L'API reçoit une image en entrée (par exemple sous forme de fichier envoyé dans une requête POST), l'analyse avec le modèle U-Net, et retourne :
- Le masque prédit.
- L'image recadrée.
- Les métriques associées.