diff --git a/cleaner.py b/cleaner.py index ee04833..3b431e2 100644 --- a/cleaner.py +++ b/cleaner.py @@ -1,7 +1,14 @@ import re +import opencc def basic_cleaner(text): text = re.sub(r'\n\s+', '\n', text) text = re.sub(r'^\s+', '', text) text = re.sub(r'\s+$', '', text) return text + +def s2t(text): + return opencc.OpenCC('s2tw').convert(text) + +def t2s(text): + return opencc.OpenCC('t2s').convert(text) diff --git a/ui.py b/ui.py index 7602a9f..6cae261 100644 --- a/ui.py +++ b/ui.py @@ -5,6 +5,7 @@ from PyQt6.QtCore import QTranslator, Qt, QThread, QMetaObject, QGenericArgument, Q_ARG, pyqtSlot from PyQt6.QtGui import QAction from utils import Translator +import cleaner class UISettingsDialog(QDialog): def __init__(self, parent): @@ -129,6 +130,18 @@ def __init__(self, parent): parent.translate_func.append([beam_size_label.setText, self, "Beam Size"]) layout.addRow(beam_size_label, self.beam_size_spinbox) + self.input_cleaner_combo = QComboBox() + self.fill_cleaner_combo(self.input_cleaner_combo) + input_cleaner_label = QLabel(self.tr("输入转换")) + parent.translate_func.append([input_cleaner_label.setText, self, "输入转换"]) + layout.addRow(input_cleaner_label, self.input_cleaner_combo) + + self.output_cleaner_combo = QComboBox() + self.fill_cleaner_combo(self.output_cleaner_combo) + output_cleaner_label = QLabel(self.tr("输出转换")) + parent.translate_func.append([output_cleaner_label.setText, self, "输出转换"]) + layout.addRow(output_cleaner_label, self.output_cleaner_combo) + self.save_button = QPushButton(self.tr("保存")) parent.translate_func.append([self.save_button.setText, self, "保存"]) self.save_button.clicked.connect(self.save_translate_settings) @@ -143,8 +156,20 @@ def save_translate_settings(self): settings.setValue('device', parent.device) parent.beam_size = self.beam_size_spinbox.value() settings.setValue('beam_size', parent.beam_size) + parent.input_cleaner = self.input_cleaner_combo.currentData() + parent.output_cleaner = self.output_cleaner_combo.currentData() self.accept() + @staticmethod + def fill_cleaner_combo(combo: QComboBox): + combo.addItem("", None) + combo.addItem("简 → 繁", "s2t") + combo.addItem("繁 → 简", "t2s") + + def init_cleaners(self): + self.input_cleaner_combo.setCurrentIndex(-1) + self.output_cleaner_combo.setCurrentIndex(-1) + class BatchTranslateDialog(QDialog): def __init__(self, parent): super().__init__(parent) @@ -248,9 +273,11 @@ def _batch_translate(): while os.path.exists(output_file): output_file = f'{self.output_folder}/new_{os.path.basename(output_file)}' if file.endswith('.epub'): - parent.translator.translate_epub(file, output_file, parent.beam_size, parent.device) + parent.translator.translate_epub(file, output_file, parent.beam_size, parent.device, + parent.input_cleaner, parent.output_cleaner) else: - parent.translator.translate_txt(file, output_file, parent.beam_size, parent.device) + parent.translator.translate_txt(file, output_file, parent.beam_size, parent.device, + parent.input_cleaner, parent.output_cleaner) if parent.translator.is_terminated(): break QMetaObject.invokeMethod(self, "add_translated_file", Qt.ConnectionType.QueuedConnection, @@ -284,6 +311,8 @@ def __init__(self, settings): self.device = settings.value('device') if settings.contains('device') else 'cpu' self.settings = settings self.translator = None + self.input_cleaner = None + self.output_cleaner = None self.init_ui() self.init_settings() @@ -419,12 +448,16 @@ def load_model(self, index): self.tr("当前版本的翻译姬中不含有{},请更新至最新版本") .format(self.translator.config['tokenizer'])) return - if self.translator.cleaner is None: + if None in self.translator.input_cleaners or None in self.translator.output_cleaners: QMessageBox.critical(self, self.tr("错误"), - self.tr("当前版本的翻译姬中不含有{},请更新至最新版本") - .format(self.translator.config['cleaner'])) + self.tr("当前版本的翻译姬中不含有所需的cleaner,请更新至最新版本")) return self.batch_translate_dialog.finished.connect(self.translator.terminate) + + self.input_cleaner = None + self.output_cleaner = None + self.translate_settings_dialog.init_cleaners() + self.max_text_length = self.translator.config['max_len'][0] self.text_count_label.setText(f"{len(self.original_text_edit.toPlainText())}/{self.max_text_length}") except: @@ -451,7 +484,8 @@ def translate(self): def _translate(): self.translator._is_terminated = False - translated_text = self.translator.translate(original_text, self.beam_size, self.device) + translated_text = self.translator.translate(original_text, self.beam_size, self.device, + self.input_cleaner, self.output_cleaner) if translated_text is None: return self.translated_index_combo.clear() diff --git a/utils.py b/utils.py index 987587e..ca61549 100644 --- a/utils.py +++ b/utils.py @@ -23,7 +23,14 @@ def __init__(self, model_dir, device='cpu'): self.model.load_state_dict(torch.load(f'{model_dir}/model.pth', map_location=device)) self.model.eval() self.tokenizer = getattr(tokenizer, self.config['tokenizer'], None) - self.cleaner = getattr(cleaner, self.config['cleaner'], None) + + ic_names = self.config.get('input_cleaners', None) + if ic_names is None: + ic_names = [self.config['cleaner']] + oc_names = self.config.get('output_cleaners', []) + self.input_cleaners = [getattr(cleaner, c, None) for c in ic_names] + self.output_cleaners = [getattr(cleaner, c, None) for c in oc_names] + if self.tokenizer is not None: self.encode, _ = self.tokenizer(self.vocabs_source) _, self.decode = self.tokenizer(self.vocabs_target) @@ -34,12 +41,15 @@ def is_terminated(self): def terminate(self): self._is_terminated = True - def translate(self, text, beam_size=3, device='cpu'): + def translate(self, text, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None): bos_idx = self.config['bos_idx'] eos_idx = self.config['eos_idx'] pad_idx = self.config['pad_idx'] - if self.cleaner is not None: - text = self.cleaner(text) + if self.input_cleaners is not None: + for c in self.input_cleaners: + text = c(text) + if input_cleaner: + text = getattr(cleaner, input_cleaner)(text) src_tokens = torch.LongTensor([[bos_idx] + self.encode(text) + [eos_idx]]) src_mask = (src_tokens != pad_idx).unsqueeze(-2) results, _ = beam_search(self.model.to(device), src_tokens.to(device), src_mask.to(device), self.config['max_len'][1], @@ -50,12 +60,17 @@ def translate(self, text, beam_size=3, device='cpu'): for result in results[0]: index_of_eos = result.index(2) if 2 in result else len(result) result = result[:index_of_eos + 1] - texts.append(self.decode(result)) + text = self.decode(result) + for c in self.output_cleaners: + text = c(text) + if output_cleaner: + text = getattr(cleaner, output_cleaner)(text) + texts.append(text) return texts - def translate_txt(self, file, output, beam_size=3, device='cpu'): + def translate_txt(self, file, output, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None): def translate_and_write(text): - text = self.translate(text, beam_size, device) + text = self.translate(text, beam_size, device, input_cleaner, output_cleaner) if text is not None: with open(output, 'a', encoding='utf-8') as f: f.write(text[0] + '\n') @@ -78,9 +93,9 @@ def translate_and_write(text): except UnicodeDecodeError: print(f"Error decoding file: {file}. Please ensure that the file is encoded in UTF-8.") - def translate_epub(self, file, output, beam_size=3, device='cpu'): + def translate_epub(self, file, output, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None): def translate_and_replace(text, file_text, matches, pre_end): - text = self.translate(text, beam_size, device) + text = self.translate(text, beam_size, device, input_cleaner, output_cleaner) new_file_text = '' if text is not None: text = text[0].split('\n')