Skip to content

Commit

Permalink
add 繁简转换
Browse files Browse the repository at this point in the history
  • Loading branch information
CjangCjengh committed Aug 22, 2023
1 parent 53eddad commit da72623
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 15 deletions.
7 changes: 7 additions & 0 deletions cleaner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import re
import opencc

def basic_cleaner(text):
text = re.sub(r'\n\s+', '\n', text)
text = re.sub(r'^\s+', '', text)
text = re.sub(r'\s+$', '', text)
return text

def s2t(text):
return opencc.OpenCC('s2tw').convert(text)

def t2s(text):
return opencc.OpenCC('t2s').convert(text)
46 changes: 40 additions & 6 deletions ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from PyQt6.QtCore import QTranslator, Qt, QThread, QMetaObject, QGenericArgument, Q_ARG, pyqtSlot
from PyQt6.QtGui import QAction
from utils import Translator
import cleaner

class UISettingsDialog(QDialog):
def __init__(self, parent):
Expand Down Expand Up @@ -129,6 +130,18 @@ def __init__(self, parent):
parent.translate_func.append([beam_size_label.setText, self, "Beam Size"])
layout.addRow(beam_size_label, self.beam_size_spinbox)

self.input_cleaner_combo = QComboBox()
self.fill_cleaner_combo(self.input_cleaner_combo)
input_cleaner_label = QLabel(self.tr("输入转换"))
parent.translate_func.append([input_cleaner_label.setText, self, "输入转换"])
layout.addRow(input_cleaner_label, self.input_cleaner_combo)

self.output_cleaner_combo = QComboBox()
self.fill_cleaner_combo(self.output_cleaner_combo)
output_cleaner_label = QLabel(self.tr("输出转换"))
parent.translate_func.append([output_cleaner_label.setText, self, "输出转换"])
layout.addRow(output_cleaner_label, self.output_cleaner_combo)

self.save_button = QPushButton(self.tr("保存"))
parent.translate_func.append([self.save_button.setText, self, "保存"])
self.save_button.clicked.connect(self.save_translate_settings)
Expand All @@ -143,8 +156,20 @@ def save_translate_settings(self):
settings.setValue('device', parent.device)
parent.beam_size = self.beam_size_spinbox.value()
settings.setValue('beam_size', parent.beam_size)
parent.input_cleaner = self.input_cleaner_combo.currentData()
parent.output_cleaner = self.output_cleaner_combo.currentData()
self.accept()

@staticmethod
def fill_cleaner_combo(combo: QComboBox):
combo.addItem("", None)
combo.addItem("简 → 繁", "s2t")
combo.addItem("繁 → 简", "t2s")

def init_cleaners(self):
self.input_cleaner_combo.setCurrentIndex(-1)
self.output_cleaner_combo.setCurrentIndex(-1)

class BatchTranslateDialog(QDialog):
def __init__(self, parent):
super().__init__(parent)
Expand Down Expand Up @@ -248,9 +273,11 @@ def _batch_translate():
while os.path.exists(output_file):
output_file = f'{self.output_folder}/new_{os.path.basename(output_file)}'
if file.endswith('.epub'):
parent.translator.translate_epub(file, output_file, parent.beam_size, parent.device)
parent.translator.translate_epub(file, output_file, parent.beam_size, parent.device,
parent.input_cleaner, parent.output_cleaner)
else:
parent.translator.translate_txt(file, output_file, parent.beam_size, parent.device)
parent.translator.translate_txt(file, output_file, parent.beam_size, parent.device,
parent.input_cleaner, parent.output_cleaner)
if parent.translator.is_terminated():
break
QMetaObject.invokeMethod(self, "add_translated_file", Qt.ConnectionType.QueuedConnection,
Expand Down Expand Up @@ -284,6 +311,8 @@ def __init__(self, settings):
self.device = settings.value('device') if settings.contains('device') else 'cpu'
self.settings = settings
self.translator = None
self.input_cleaner = None
self.output_cleaner = None
self.init_ui()
self.init_settings()

Expand Down Expand Up @@ -419,12 +448,16 @@ def load_model(self, index):
self.tr("当前版本的翻译姬中不含有{},请更新至最新版本")
.format(self.translator.config['tokenizer']))
return
if self.translator.cleaner is None:
if None in self.translator.input_cleaners or None in self.translator.output_cleaners:
QMessageBox.critical(self, self.tr("错误"),
self.tr("当前版本的翻译姬中不含有{},请更新至最新版本")
.format(self.translator.config['cleaner']))
self.tr("当前版本的翻译姬中不含有所需的cleaner,请更新至最新版本"))
return
self.batch_translate_dialog.finished.connect(self.translator.terminate)

self.input_cleaner = None
self.output_cleaner = None
self.translate_settings_dialog.init_cleaners()

self.max_text_length = self.translator.config['max_len'][0]
self.text_count_label.setText(f"{len(self.original_text_edit.toPlainText())}/{self.max_text_length}")
except:
Expand All @@ -451,7 +484,8 @@ def translate(self):

def _translate():
self.translator._is_terminated = False
translated_text = self.translator.translate(original_text, self.beam_size, self.device)
translated_text = self.translator.translate(original_text, self.beam_size, self.device,
self.input_cleaner, self.output_cleaner)
if translated_text is None:
return
self.translated_index_combo.clear()
Expand Down
33 changes: 24 additions & 9 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ def __init__(self, model_dir, device='cpu'):
self.model.load_state_dict(torch.load(f'{model_dir}/model.pth', map_location=device))
self.model.eval()
self.tokenizer = getattr(tokenizer, self.config['tokenizer'], None)
self.cleaner = getattr(cleaner, self.config['cleaner'], None)

ic_names = self.config.get('input_cleaners', None)
if ic_names is None:
ic_names = [self.config['cleaner']]
oc_names = self.config.get('output_cleaners', [])
self.input_cleaners = [getattr(cleaner, c, None) for c in ic_names]
self.output_cleaners = [getattr(cleaner, c, None) for c in oc_names]

if self.tokenizer is not None:
self.encode, _ = self.tokenizer(self.vocabs_source)
_, self.decode = self.tokenizer(self.vocabs_target)
Expand All @@ -34,12 +41,15 @@ def is_terminated(self):
def terminate(self):
self._is_terminated = True

def translate(self, text, beam_size=3, device='cpu'):
def translate(self, text, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None):
bos_idx = self.config['bos_idx']
eos_idx = self.config['eos_idx']
pad_idx = self.config['pad_idx']
if self.cleaner is not None:
text = self.cleaner(text)
if self.input_cleaners is not None:
for c in self.input_cleaners:
text = c(text)
if input_cleaner:
text = getattr(cleaner, input_cleaner)(text)
src_tokens = torch.LongTensor([[bos_idx] + self.encode(text) + [eos_idx]])
src_mask = (src_tokens != pad_idx).unsqueeze(-2)
results, _ = beam_search(self.model.to(device), src_tokens.to(device), src_mask.to(device), self.config['max_len'][1],
Expand All @@ -50,12 +60,17 @@ def translate(self, text, beam_size=3, device='cpu'):
for result in results[0]:
index_of_eos = result.index(2) if 2 in result else len(result)
result = result[:index_of_eos + 1]
texts.append(self.decode(result))
text = self.decode(result)
for c in self.output_cleaners:
text = c(text)
if output_cleaner:
text = getattr(cleaner, output_cleaner)(text)
texts.append(text)
return texts

def translate_txt(self, file, output, beam_size=3, device='cpu'):
def translate_txt(self, file, output, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None):
def translate_and_write(text):
text = self.translate(text, beam_size, device)
text = self.translate(text, beam_size, device, input_cleaner, output_cleaner)
if text is not None:
with open(output, 'a', encoding='utf-8') as f:
f.write(text[0] + '\n')
Expand All @@ -78,9 +93,9 @@ def translate_and_write(text):
except UnicodeDecodeError:
print(f"Error decoding file: {file}. Please ensure that the file is encoded in UTF-8.")

def translate_epub(self, file, output, beam_size=3, device='cpu'):
def translate_epub(self, file, output, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None):
def translate_and_replace(text, file_text, matches, pre_end):
text = self.translate(text, beam_size, device)
text = self.translate(text, beam_size, device, input_cleaner, output_cleaner)
new_file_text = ''
if text is not None:
text = text[0].split('\n')
Expand Down

0 comments on commit da72623

Please sign in to comment.