Skip to content

Commit

Permalink
Update PPOCRLabel v2.0
Browse files Browse the repository at this point in the history
Update PPOCRLabel v2.0
  • Loading branch information
Evezerest committed May 31, 2022
1 parent 115c140 commit 7f7c17c
Show file tree
Hide file tree
Showing 23 changed files with 2,065 additions and 763 deletions.
1,238 changes: 865 additions & 373 deletions PPOCRLabel.py

Large diffs are not rendered by default.

204 changes: 139 additions & 65 deletions README.md

Large diffs are not rendered by default.

228 changes: 147 additions & 81 deletions README_ch.md

Large diffs are not rendered by default.

Binary file added data/gif/kie.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/gif/multi-point.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/gif/table.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
151 changes: 151 additions & 0 deletions gen_ocr_train_val_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# coding:utf8
import os
import shutil
import random
import argparse


# 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹
def isCreateOrDeleteFolder(path, flag):
flagPath = os.path.join(path, flag)

if os.path.exists(flagPath):
shutil.rmtree(flagPath)

os.makedirs(flagPath)
flagAbsPath = os.path.abspath(flagPath)
return flagAbsPath


def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
# 按照指定的比例划分训练集、验证集、测试集
dataAbsPath = os.path.abspath(root)

if flag == "det":
labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
elif flag == "rec":
labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)

labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
labelFileContent = labelFileRead.readlines()
random.shuffle(labelFileContent)
labelRecordLen = len(labelFileContent)

for index, labelRecordInfo in enumerate(labelFileContent):
imageRelativePath = labelRecordInfo.split('\t')[0]
imageLabel = labelRecordInfo.split('\t')[1]
imageName = os.path.basename(imageRelativePath)

if flag == "det":
imagePath = os.path.join(dataAbsPath, imageName)
elif flag == "rec":
imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))

# 按预设的比例划分训练集、验证集、测试集
trainValTestRatio = args.trainValTestRatio.split(":")
trainRatio = eval(trainValTestRatio[0]) / 10
valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
curRatio = index / labelRecordLen

if curRatio < trainRatio:
imageCopyPath = os.path.join(absTrainRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
elif curRatio >= trainRatio and curRatio < valRatio:
imageCopyPath = os.path.join(absValRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
else:
imageCopyPath = os.path.join(absTestRootPath, imageName)
shutil.copy(imagePath, imageCopyPath)
testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))


# 删掉存在的文件
def removeFile(path):
if os.path.exists(path):
os.remove(path)


def genDetRecTrainVal(args):
detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test")
recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test")

removeFile(os.path.join(args.detRootPath, "train.txt"))
removeFile(os.path.join(args.detRootPath, "val.txt"))
removeFile(os.path.join(args.detRootPath, "test.txt"))
removeFile(os.path.join(args.recRootPath, "train.txt"))
removeFile(os.path.join(args.recRootPath, "val.txt"))
removeFile(os.path.join(args.recRootPath, "test.txt"))

detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")

splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")

for root, dirs, files in os.walk(args.datasetRootPath):
for dir in dirs:
if dir == 'crop_img':
splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
recTestTxt, "rec")
else:
continue
break



if __name__ == "__main__":
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
# 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求
parser = argparse.ArgumentParser()
parser.add_argument(
"--trainValTestRatio",
type=str,
default="6:2:2",
help="ratio of trainset:valset:testset")
parser.add_argument(
"--datasetRootPath",
type=str,
default="../train_data/",
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
)
parser.add_argument(
"--detRootPath",
type=str,
default="../train_data/det",
help="the path where the divided detection dataset is placed")
parser.add_argument(
"--recRootPath",
type=str,
default="../train_data/rec",
help="the path where the divided recognition dataset is placed"
)
parser.add_argument(
"--detLabelFileName",
type=str,
default="Label.txt",
help="the name of the detection annotation file")
parser.add_argument(
"--recLabelFileName",
type=str,
default="rec_gt.txt",
help="the name of the recognition annotation file"
)
parser.add_argument(
"--recImageDirName",
type=str,
default="crop_img",
help="the name of the folder where the cropped recognition dataset is located"
)
args = parser.parse_args()
genDetRecTrainVal(args)
11 changes: 10 additions & 1 deletion libs/autoDialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from PyQt4.QtGui import *
from PyQt4.QtCore import *

import time
import datetime
import json
import cv2
import numpy as np
Expand Down Expand Up @@ -80,8 +82,9 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No
self.parent = parent
self.ocr = ocr
self.mImgList = mImgList
self.lender = lenbar
self.pb = QProgressBar()
self.pb.setRange(0, lenbar)
self.pb.setRange(0, self.lender)
self.pb.setValue(0)

layout = QVBoxLayout()
Expand All @@ -108,10 +111,16 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No
self.thread_1.progressBarValue.connect(self.handleProgressBarSingal)
self.thread_1.listValue.connect(self.handleListWidgetSingal)
self.thread_1.endsignal.connect(self.handleEndsignalSignal)
self.time_start = time.time() # save start time

def handleProgressBarSingal(self, i):
self.pb.setValue(i)

# calculate time left of auto labeling
avg_time = (time.time() - self.time_start) / i # Use average time to prevent time fluctuations
time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(".")[0] # Remove microseconds
self.setWindowTitle("PPOCRLabel -- " + f"Time Left: {time_left}") # show

def handleListWidgetSingal(self, i):
self.listWidget.addItem(i)
titem = self.listWidget.item(self.listWidget.count() - 1)
Expand Down
76 changes: 45 additions & 31 deletions libs/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,20 @@
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

try:
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *

#from PyQt4.QtOpenGL import *
import copy

from PyQt5.QtCore import Qt, pyqtSignal, QPointF, QPoint
from PyQt5.QtGui import QPainter, QBrush, QColor, QPixmap
from PyQt5.QtWidgets import QWidget, QMenu, QApplication
from libs.shape import Shape
from libs.utils import distance
import copy

CURSOR_DEFAULT = Qt.ArrowCursor
CURSOR_POINT = Qt.PointingHandCursor
CURSOR_DRAW = Qt.CrossCursor
CURSOR_MOVE = Qt.ClosedHandCursor
CURSOR_GRAB = Qt.OpenHandCursor

# class Canvas(QGLWidget):


class Canvas(QWidget):
zoomRequest = pyqtSignal(int)
Expand Down Expand Up @@ -87,6 +79,10 @@ def __init__(self, *args, **kwargs):
#initialisation for panning
self.pan_initial_pos = QPoint()

#lockedshapes related
self.lockedShapes = []
self.isInTheSameImage = False

def setDrawingColor(self, qColor):
self.drawingLineColor = qColor
self.drawingRectColor = qColor
Expand Down Expand Up @@ -125,7 +121,6 @@ def unHighlight(self):
def selectedVertex(self):
return self.hVertex is not None


def mouseMoveEvent(self, ev):
"""Update line with last point and current coordinates."""
pos = self.transformPos(ev.pos())
Expand Down Expand Up @@ -233,7 +228,7 @@ def mouseMoveEvent(self, ev):
self.hVertex, self.hShape = index, shape
shape.highlightVertex(index, shape.MOVE_VERTEX)
self.overrideCursor(CURSOR_POINT)
self.setToolTip("Click & drag to move point") #move point
self.setToolTip("Click & drag to move point")
self.setStatusTip(self.toolTip())
self.update()
break
Expand Down Expand Up @@ -268,18 +263,10 @@ def mousePressEvent(self, ev):
if self.current.isClosed():
# print('1111')
self.finalise()
elif self.drawSquare: # 增加
elif self.drawSquare:
assert len(self.current.points) == 1
self.current.points = self.line.points
self.finalise()

if self.canCloseShape() and len(self.current) > 3 and self.current[0].x() + 2 >= pos.x() >= \
self.current[0].x() - 2 and self.current[0].y() + 2 >= pos.y() >= self.current[0].y() - 2:
print('鼠标单击事件')
if len(self.current) > 4:
self.current.popPoint() # Eliminate the extra point from the last click.
self.finalise()

elif not self.outOfPixmap(pos):
# Create new shape.
self.current = Shape()
Expand Down Expand Up @@ -337,7 +324,6 @@ def mouseReleaseEvent(self, ev):

self.movingShape = False


def endMove(self, copy=False):
assert self.selectedShapes and self.selectedShapesCopy
assert len(self.selectedShapesCopy) == len(self.selectedShapes)
Expand Down Expand Up @@ -414,7 +400,6 @@ def selectShapes(self, shapes):
self.selectionChanged.emit(shapes)
self.update()


def selectShapePoint(self, point, multiple_selection_mode):
"""Select the first shape created which contains this point."""
if self.selectedVertex(): # A vertex is marked for selection.
Expand Down Expand Up @@ -496,10 +481,8 @@ def boundedMoveVertex(self, pos):
shape.moveVertexBy(lindex, lshift)

else:
#move point
shape.moveVertexBy(index, shiftPos)


def boundedMoveShape(self, shapes, pos):
if type(shapes).__name__ != 'list': shapes = [shapes]
if self.outOfPixmap(pos):
Expand All @@ -520,6 +503,7 @@ def boundedMoveShape(self, shapes, pos):
if dp:
for shape in shapes:
shape.moveBy(dp)
shape.close()
self.prevPoint = pos
return True
return False
Expand Down Expand Up @@ -562,7 +546,7 @@ def boundedShiftShapes(self, shapes):
# Give up if both fail.
for shape in shapes:
point = shape[0]
offset = QPointF(2.0, 2.0)
offset = QPointF(5.0, 5.0)
self.calculateOffsets(shape, point)
self.prevPoint = point
if not self.boundedMoveShape(shape, point - offset):
Expand Down Expand Up @@ -659,7 +643,7 @@ def outOfPixmap(self, p):

def finalise(self):
assert self.current
if len(self.current) < 4 and self.current.points[0] == self.current.points[-1]:
if self.current.points[0] == self.current.points[-1]:
# print('finalse')
self.current = None
self.drawingPolygon.emit(False)
Expand Down Expand Up @@ -713,8 +697,9 @@ def wheelEvent(self, ev):

def keyPressEvent(self, ev):
key = ev.key()
shapesBackup = []
shapesBackup = copy.deepcopy(self.shapes)
if len(shapesBackup) == 0:
return
self.shapesBackups.pop()
self.shapesBackups.append(shapesBackup)
if key == Qt.Key_Escape and self.current:
Expand All @@ -732,6 +717,31 @@ def keyPressEvent(self, ev):
self.moveOnePixel('Up')
elif key == Qt.Key_Down and self.selectedShapes:
self.moveOnePixel('Down')
elif key == Qt.Key_X and self.selectedShapes:
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
if self.rotateOutOfBound(0.01):
continue
self.selectedShape.rotate(0.01)
self.shapeMoved.emit()
self.update()

elif key == Qt.Key_C and self.selectedShapes:
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
if self.rotateOutOfBound(-0.01):
continue
self.selectedShape.rotate(-0.01)
self.shapeMoved.emit()
self.update()

def rotateOutOfBound(self, angle):
for shape in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[shape]
for i, p in enumerate(self.selectedShape.points):
if self.outOfPixmap(self.selectedShape.rotatePoint(p, angle)):
return True
return False

def moveOnePixel(self, direction):
# print(self.selectedShape.points)
Expand Down Expand Up @@ -773,14 +783,18 @@ def moveOutOfBound(self, step):
points = [p1+p2 for p1, p2 in zip(self.selectedShape.points, [step]*4)]
return True in map(self.outOfPixmap, points)

def setLastLabel(self, text, line_color=None, fill_color=None):
def setLastLabel(self, text, line_color=None, fill_color=None, key_cls=None):
assert text
self.shapes[-1].label = text
if line_color:
self.shapes[-1].line_color = line_color

if fill_color:
self.shapes[-1].fill_color = fill_color

if key_cls:
self.shapes[-1].key_cls = key_cls

self.storeShapes()

return self.shapes[-1]
Expand Down
Loading

0 comments on commit 7f7c17c

Please sign in to comment.