From faa9940f2c4733083980ba5043d66bd2637f0883 Mon Sep 17 00:00:00 2001 From: Yakov Date: Mon, 11 May 2020 17:44:54 +0300 Subject: [PATCH] Add import encoding parametrization. Issue #1 --- import_me/parsers/csv.py | 20 +++++++++++++++++--- tests/conftest.py | 10 +++++++--- tests/test_parsers/test_csv.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/import_me/parsers/csv.py b/import_me/parsers/csv.py index 6de87a8..6ea23af 100644 --- a/import_me/parsers/csv.py +++ b/import_me/parsers/csv.py @@ -2,7 +2,7 @@ import io from contextlib import contextmanager -from typing import Optional, Iterator, Tuple, List, Any +from typing import Optional, Iterator, Tuple, List, Any, Dict from import_me.exceptions import StopParsing from import_me.parsers.base import BaseParser @@ -24,11 +24,25 @@ def header_row_offset(self) -> Optional[int]: raise StopParsing('Invalid row index.') return index + @property + def _open_file_params(self) -> Dict[str, Any]: + return { + key: self._params[key] + for key in ['encoding', 'buffering', 'newline', 'errors'] + if key in self._params + } + + @property + def _reader_params(self) -> Dict[str, Any]: + reader_params = [i for i in dir(csv.Dialect) if not i.startswith('_')] + reader_params.append('dialect') + return {key: self._params[key] for key in reader_params if key in self._params} + @contextmanager def open_file(self) -> Iterator: if self.file_path: try: - file_obj = open(self.file_path, 'r') + file_obj = open(self.file_path, 'r', **self._open_file_params) yield file_obj finally: file_obj.close() @@ -41,7 +55,7 @@ def open_file(self) -> Iterator: def iterate_file_rows(self) -> Iterator[Tuple[int, List[Any]]]: with self.open_file() as csv_file: - reader = csv.reader(csv_file, **self._params) + reader = csv.reader(csv_file, **self._reader_params) self.validate_headers(reader) csv_file.seek(0) diff --git a/tests/conftest.py b/tests/conftest.py index 9db2934..150d203 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,10 +96,14 @@ def _xlsx_file_factory(header=None, data=None, header_row_index=0, data_row_inde @pytest.fixture def csv_file_factory(): - def _csv_file_factory(header=None, data=None, header_row_index=0, data_row_index=1): + def _csv_file_factory( + header=None, data=None, header_row_index=0, data_row_index=1, file_kwargs=None, writer_kwargs=None, + ): + file_kwargs = file_kwargs or {} + writer_kwargs = writer_kwargs or {} csv_file = tempfile.NamedTemporaryFile(suffix='.csv') - with open(csv_file.name, 'w') as file: - writer = csv.writer(file) + with open(csv_file.name, 'w', **file_kwargs) as file: + writer = csv.writer(file, **writer_kwargs) if header is not None: for _row_index in range(header_row_index): diff --git a/tests/test_parsers/test_csv.py b/tests/test_parsers/test_csv.py index 7cbef38..3277754 100644 --- a/tests/test_parsers/test_csv.py +++ b/tests/test_parsers/test_csv.py @@ -31,3 +31,36 @@ class CSVParser(BaseCSVParser): 'row_index': 2, }, ] + + +def test_base_csv_parser_additional_params(csv_file_factory): + class CSVParser(BaseCSVParser): + columns = [ + Column('first_name', index=0, header='First Name'), + Column('last_name', index=1, header='Last Name'), + ] + + csv_file = csv_file_factory( + header=['First Name', 'Last Name'], + data=[ + ['Ivan', 'Ivanov'], + ['Petr', 'Petrov'], + ], + file_kwargs={'encoding': 'cp1251'}, + writer_kwargs={'delimiter': ';'}, + ) + parser = CSVParser(file_path=csv_file.name, encoding='cp1251', delimiter=';') + parser() + + assert parser.cleaned_data == [ + { + 'first_name': 'Ivan', + 'last_name': 'Ivanov', + 'row_index': 1, + }, + { + 'first_name': 'Petr', + 'last_name': 'Petrov', + 'row_index': 2, + }, + ]