Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add XFUND dataset and project LayoutLMv3 #1809

Open
wants to merge 58 commits into
base: dev-1.x
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
24119ec
阶段性提交
KevinNuNu Mar 23, 2023
38fbba5
Merge branch 'dev-1.x' into layoutlm
KevinNuNu Mar 23, 2023
3ae3f84
重构xfund数据集config文件结构
KevinNuNu Mar 25, 2023
35d0dd9
新增xfund zh数据集
KevinNuNu Mar 25, 2023
125cce2
[Fix] 解决jsondumper生成的文件无法正确显示中文的问题
KevinNuNu Mar 25, 2023
f4f1dac
[Fix] 解决路径拼接异常Bug
KevinNuNu Mar 25, 2023
ffe8909
Merge branch 'dev-1.x' into layoutlm
KevinNuNu Mar 25, 2023
5b203ad
新增另外6个数据集config文件
KevinNuNu Mar 25, 2023
2016921
增加xfund RE任务
KevinNuNu Mar 25, 2023
717ac03
pre-commit fix
KevinNuNu Mar 25, 2023
078cc83
[Fix] 简化XFUND parser,优化最终数据集的目录结构
KevinNuNu Mar 27, 2023
d3e16ad
[Fix] 回退删除huggingface dataset形式,没意义。修改ser/re packer的metainfo信息,阶段性添加S…
KevinNuNu Mar 28, 2023
1d0c5e3
阶段性完成SERDataset数据集加载
KevinNuNu Mar 28, 2023
deb96cc
优化ser/re packer,根据words关键字是否存在觉得是否加入
KevinNuNu Mar 30, 2023
443e979
优化xfund数据集的config_generator命名,使config_generator目录结构更清晰
KevinNuNu Mar 30, 2023
a88a129
修改SERDataset为XFUNDSERDataset
KevinNuNu Mar 30, 2023
25f084a
ser/re packer docstring fix
KevinNuNu Mar 30, 2023
f8f2614
add SERDataSample structure and PackSERInputs transforms
KevinNuNu Mar 30, 2023
c8a7b68
初步构建SER部分model文件结构,LayoutLMv3DataPreprocessor参数已与HuggingFace的LayoutLM…
KevinNuNu Mar 30, 2023
81a4527
Merge branch 'dev-1.x' into layoutlm
KevinNuNu Apr 10, 2023
e22e466
packer metainfo删除id2label信息
KevinNuNu Apr 11, 2023
ceb66dc
优化xfund_dataset
KevinNuNu Apr 11, 2023
2eb79c3
明确添加的metainfo类型
KevinNuNu Apr 11, 2023
a6bbe12
简化版layoutlmv3代码
KevinNuNu Apr 17, 2023
7951200
优化layoutlmv3预处理代码,整合到datasets/transforms里,更明确
KevinNuNu Apr 17, 2023
3ddf780
添加测试脚本
KevinNuNu Apr 17, 2023
60b2a52
Merge branch 'dev-1.x' into layoutlm
KevinNuNu Apr 17, 2023
84be264
重构xfund数据集mmocr格式
KevinNuNu Apr 18, 2023
2767fcc
简化XFUNDDataset,不再按ser/re任务区分
KevinNuNu Apr 19, 2023
4b4b343
将原本在XFUNDDataset内做的预处理全部移到pipeline中,重构预处理代码为LoadProcessorFromPretrain…
KevinNuNu Apr 19, 2023
8399f94
更新项目测试脚本
KevinNuNu Apr 19, 2023
bda6742
跑通train.py训练流程
KevinNuNu Apr 19, 2023
44c68b1
修改SERDataSample形式
KevinNuNu Apr 19, 2023
023b0cf
修改SERPostprocessor一个命名错误
KevinNuNu Apr 19, 2023
de98eb1
Merge branch 'dev-1.x' into layoutlm
KevinNuNu Apr 28, 2023
a05a2e1
整理config目录
KevinNuNu Apr 28, 2023
3664773
添加SER任务的评估模块
KevinNuNu Apr 28, 2023
6c1f5be
优化PackSERInputs
KevinNuNu Apr 29, 2023
d21a181
将数据处理部分代码移动到project中
KevinNuNu May 1, 2023
40cfe65
fix an error
KevinNuNu May 1, 2023
d1f43e7
将ser_data_sample移到projects里
KevinNuNu May 1, 2023
e102ef2
Merge branch 'dev-1.x' into layoutlm
KevinNuNu May 8, 2023
50fa7f9
规范xfund数据集准备脚本文件
KevinNuNu May 8, 2023
a04cd51
[Fix]解决推理时存在的一个bug
KevinNuNu May 8, 2023
81b8f86
使用custom_imports优化自定义模块的导入
KevinNuNu May 8, 2023
059e203
优化SER任务结果可视化效果
KevinNuNu May 8, 2023
f0a03ac
规范配置文件命名
KevinNuNu May 25, 2023
d9a3a5e
化繁为简,优化之前基于default_collate的long_text_data_collate为更明确易理解的ser_collate
KevinNuNu May 25, 2023
b04e126
针对inference阶段没有gt_label的情况针对性修复ser_postprocessor以及ser_visualizer中存在的bug.
KevinNuNu May 25, 2023
b6f55f8
优化ser_postprocessor
KevinNuNu May 29, 2023
edf7fe8
[Fix] 修复一个因为分词结果恰好510*n个,剔除收尾None标识后没有结束标志,导致最后一个label无法加入结果的Bug
KevinNuNu Jun 12, 2023
8a1e37b
[Fix] 重置word_biolabels防止重复添加
KevinNuNu Jun 12, 2023
dbe9145
Merge branch 'dev-1.x' into layoutlm
KevinNuNu Jun 26, 2023
0f0f8ca
删除项目中所有的绝对路径,补充README.md
KevinNuNu Jun 27, 2023
ae8c426
fix lint
gaotongxiao Oct 18, 2023
ab14bb0
Merge branch 'ci' into layoutlm
gaotongxiao Oct 20, 2023
db5673f
fix ci
gaotongxiao Oct 20, 2023
c7a3895
ci
gaotongxiao Oct 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
针对inference阶段没有gt_label的情况针对性修复ser_postprocessor以及ser_visualizer中存在的bug.
KevinNuNu committed May 25, 2023
commit b04e126acdaa96beb8ac3afec7e3bf16fb9d6e15
13 changes: 7 additions & 6 deletions projects/LayoutLMv3/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
@@ -98,22 +98,23 @@ def transform(self, results: dict) -> dict:
for key in self.ser_keys:
if key not in results:
continue
value = to_tensor(results[key])
inputs[key] = value
inputs[key] = to_tensor(results[key])
packed_results['inputs'] = inputs

# pack `data_samples`
data_samples = []
for truncation_idx in range(truncation_number):
data_sample = SERDataSample()
gt_label = LabelData()
assert 'labels' in results, 'key `labels` not in results.'
value = to_tensor(results['labels'][truncation_idx])
gt_label.item = value
if results.get('labels', None):
gt_label.item = to_tensor(results['labels'][truncation_idx])
data_sample.gt_label = gt_label
meta = {}
for key in self.meta_keys:
meta[key] = results[key]
if key == 'truncation_word_ids':
meta[key] = results[key][truncation_idx]
else:
meta[key] = results[key]
data_sample.set_metainfo(meta)
data_samples.append(data_sample)
packed_results['data_samples'] = data_samples
65 changes: 41 additions & 24 deletions projects/LayoutLMv3/models/ser_postprocessor.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,10 @@
class SERPostprocessor(nn.Module):
"""PostProcessor for SER."""

def __init__(self,
classes: Union[tuple, list],
ignore_index: int = -100) -> None:
def __init__(self, classes: Union[tuple, list]) -> None:
super().__init__()
self.other_label_name = find_other_label_name_of_biolabel(classes)
self.id2biolabel = self._generate_id2biolabel_map(classes)
self.ignore_index = ignore_index
self.softmax = nn.Softmax(dim=-1)

def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
@@ -43,42 +40,62 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
def __call__(self, outputs: torch.Tensor,
data_samples: Sequence[SERDataSample]
) -> Sequence[SERDataSample]:
# merge several truncation data_sample to one data_sample
assert all('truncation_word_ids' in d for d in data_samples), \
'The key `truncation_word_ids` should be specified' \
'in PackSERInputs.'
truncation_word_ids = []
for data_sample in data_samples:
truncation_word_ids.append(data_sample.pop('truncation_word_ids'))
merged_data_sample = copy.deepcopy(data_samples[0])
merged_data_sample.set_metainfo(
dict(truncation_word_ids=truncation_word_ids))
flattened_word_ids = [
word_id for word_ids in truncation_word_ids for word_id in word_ids
]

# convert outputs dim from (truncation_num, max_length, label_num)
# to (truncation_num * max_length, label_num)
outputs = outputs.cpu().detach()
truncation_num = outputs.size(0)
outputs = torch.reshape(outputs, (-1, outputs.size(-1)))
# merge gt label ids from data_samples
gt_label_ids = [
data_samples[truncation_idx].gt_label.item
for truncation_idx in range(truncation_num)
]
gt_label_ids = torch.cat(gt_label_ids, dim=0).cpu().detach().numpy()
# get pred label ids/scores from outputs
probs = self.softmax(outputs)
max_value, max_idx = torch.max(probs, -1)
pred_label_ids = max_idx.numpy()
pred_label_scores = max_value.numpy()
# select valid token and convert iid to biolabel
gt_biolabels = [
self.id2biolabel[g] for (g, p) in zip(gt_label_ids, pred_label_ids)
if g != self.ignore_index
]

# determine whether it is an inference process
if 'item' in data_samples[0].gt_label:
# merge gt label ids from data_samples
gt_label_ids = [
data_sample.gt_label.item for data_sample in data_samples
]
gt_label_ids = torch.cat(
gt_label_ids, dim=0).cpu().detach().numpy()
gt_biolabels = [
self.id2biolabel[g]
for (w, g) in zip(flattened_word_ids, gt_label_ids)
if w is not None
]
# update merged gt_label
merged_data_sample.gt_label.item = gt_biolabels

# inference process do not have item in gt_label,
# so select valid token with flattened_word_ids
# rather than with gt_label_ids like official code.
pred_biolabels = [
self.id2biolabel[p] for (g, p) in zip(gt_label_ids, pred_label_ids)
if g != self.ignore_index
self.id2biolabel[p]
for (w, p) in zip(flattened_word_ids, pred_label_ids)
if w is not None
]
pred_biolabel_scores = [
s for (g, s) in zip(gt_label_ids, pred_label_scores)
if g != self.ignore_index
s for (w, s) in zip(flattened_word_ids, pred_label_scores)
if w is not None
]
# record pred_label
pred_label = LabelData()
pred_label.item = pred_biolabels
pred_label.score = pred_biolabel_scores
# merge several truncation data_sample to one data_sample
merged_data_sample = copy.deepcopy(data_samples[0])
merged_data_sample.pred_label = pred_label
# update merged gt_label
merged_data_sample.gt_label.item = gt_biolabels

return [merged_data_sample]
70 changes: 44 additions & 26 deletions projects/LayoutLMv3/visualization/ser_visualizer.py
Original file line number Diff line number Diff line change
@@ -91,19 +91,13 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
line_width=self.line_width,
alpha=self.alpha)

# draw gt/pred labels
if gt_labels is not None and pred_labels is not None:
areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
scales = _get_adaptive_scales(areas)
positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2

if gt_labels is not None:
gt_tokens_biolabel = gt_labels.item
gt_words_label = []
pred_tokens_biolabel = pred_labels.item
pred_words_label = []

if 'score' in pred_labels:
pred_tokens_biolabel_score = pred_labels.score
pred_words_label_score = []
else:
pred_tokens_biolabel_score = None
pred_words_label_score = None

pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
@@ -112,36 +106,60 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
gt_words_label_name = gt_tokens_biolabel[idx][2:] \
if gt_tokens_biolabel[idx] != 'O' else 'other'
gt_words_label.append(gt_words_label_name)
pre_word_id = cur_word_id
assert len(gt_words_label) == len(bboxes)
if pred_labels is not None:
pred_tokens_biolabel = pred_labels.item
pred_words_label = []
pred_tokens_biolabel_score = pred_labels.score
pred_words_label_score = []

pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
if cur_word_id is not None:
if cur_word_id != pre_word_id:
pred_words_label_name = pred_tokens_biolabel[idx][2:] \
if pred_tokens_biolabel[idx] != 'O' else 'other'
pred_words_label.append(pred_words_label_name)
if pred_tokens_biolabel_score is not None:
pred_words_label_score.append(
pred_tokens_biolabel_score[idx])
pred_words_label_score.append(
pred_tokens_biolabel_score[idx])
pre_word_id = cur_word_id
assert len(gt_words_label) == len(bboxes)
assert len(pred_words_label) == len(bboxes)

areas = (bboxes[:, 3] - bboxes[:, 1]) * (
bboxes[:, 2] - bboxes[:, 0])
scales = _get_adaptive_scales(areas)
positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2

# draw gt or pred labels
if gt_labels is not None and pred_labels is not None:
for i, (pos, gt, pred) in enumerate(
zip(positions, gt_words_label, pred_words_label)):
if pred_words_label_score is not None:
score = round(float(pred_words_label_score[i]) * 100, 1)
label_text = f'{gt} | {pred}({score})'
else:
label_text = f'{gt} | {pred}'

score = round(float(pred_words_label_score[i]) * 100, 1)
label_text = f'{gt} | {pred}({score})'
self.draw_texts(
label_text,
pos,
colors=self.label_color if gt == pred else 'r',
font_sizes=int(13 * scales[i]),
vertical_alignments='center',
horizontal_alignments='center')
elif pred_labels is not None:
for i, (pos, pred) in enumerate(zip(positions, pred_words_label)):
score = round(float(pred_words_label_score[i]) * 100, 1)
label_text = f'Pred: {pred}({score})'
self.draw_texts(
label_text,
pos,
colors=self.label_color,
font_sizes=int(13 * scales[i]),
vertical_alignments='center',
horizontal_alignments='center')
elif gt_labels is not None:
for i, (pos, gt) in enumerate(zip(positions, gt_words_label)):
label_text = f'GT: {gt}'
self.draw_texts(
label_text,
pos,
colors=self.label_color,
font_sizes=int(13 * scales[i]),
vertical_alignments='center',
horizontal_alignments='center')

return self.get_image()