Skip to content

Commit

Permalink
Merge pull request #35 from pozitronik/output_path_refactored
Browse files Browse the repository at this point in the history
Output path refactored
  • Loading branch information
pozitronik authored Jul 25, 2023
2 parents f376d14 + 88fdabc commit 8595261
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 61 deletions.
20 changes: 17 additions & 3 deletions sinner/Core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sinner.processors.frame.BaseFrameProcessor import BaseFrameProcessor
from sinner.State import State
from sinner.typing import Frame
from sinner.utilities import is_image, is_video, delete_subdirectories, list_class_descendants, resolve_relative_path, get_app_dir, TEMP_DIRECTORY, suggest_output_path
from sinner.utilities import is_image, is_video, delete_subdirectories, list_class_descendants, resolve_relative_path, get_app_dir, TEMP_DIRECTORY
from sinner.validators.AttributeLoader import AttributeLoader, Rules

# single thread doubles cuda performance - needs to be set before torch import
Expand All @@ -33,6 +33,7 @@

class Core(AttributeLoader, Status):
target_path: str
output_path: str
frame_processor: List[str]
frame_handler: str
temp_dir: str
Expand All @@ -53,6 +54,11 @@ def rules(self) -> Rules:
'required': True,
'help': 'Select the target file or the directory'
},
{
'parameter': {'output', 'output-path'},
'attribute': 'output_path',
'default:': lambda: self.suggest_output_path(),
},
{
'parameter': 'frame-processor',
'default': ['FaceSwapper'],
Expand Down Expand Up @@ -114,7 +120,7 @@ def run(self, set_progress: Callable[[int], None] | None = None) -> None:
temp_resources.append(state.in_dir)

if temp_resources is not []:
output_filename = current_processor.output_path if current_processor is not None else suggest_output_path(self.target_path)
output_filename = current_processor.output_path if current_processor is not None else self.output_path
final_handler = BaseFrameHandler.create(handler_name=self.frame_handler, parameters=self.parameters, target_path=self.target_path)
if final_handler.result(from_dir=current_target_path, filename=output_filename, audio_target=self.target_path) is True:
if self.keep_frames is False:
Expand Down Expand Up @@ -157,10 +163,18 @@ def get_frame(self, frame_number: int = 0, extractor_handler: BaseFrameHandler |
self.preview_processors[processor_name].load(self.parameters)
frame = self.preview_processors[processor_name].process_frame(frame)
result.append((frame, processor_name))
except Exception as exception: # skip, if parameters is not enough for processors
except Exception as exception: # skip, if parameters is not enough for processor
self.update_status(message=str(exception), mood=Mood.BAD)
pass
return result

def stop(self) -> None:
self._stop_flag = True

def suggest_output_path(self) -> str:
target_name, target_extension = os.path.splitext(os.path.basename(self.target_path))
if self.output_path is None:
return os.path.join(os.path.dirname(self.target_path), 'result-' + target_name + target_extension)
if os.path.isdir(self.output_path):
return os.path.join(self.output_path, 'result-' + target_name + target_extension)
return self.output_path
4 changes: 2 additions & 2 deletions sinner/State.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Any, Dict, List

from sinner.Status import Status, Mood
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.typing import Frame
from sinner.utilities import write_image
from sinner.validators.AttributeLoader import AttributeLoader, Rules

OUT_DIR = 'OUT'
Expand Down Expand Up @@ -113,7 +113,7 @@ def state_path(self, dir_type: str) -> str:
return os.path.join(self.temp_dir, *sub_path)

def save_temp_frame(self, frame: Frame, index: int) -> None:
if not write_image(frame, self.get_frame_processed_name(index)):
if not CV2VideoHandler.write_image(frame, self.get_frame_processed_name(index)):
raise Exception(f"Error saving frame: {self.get_frame_processed_name(index)}")

# Checks if some frame already processed
Expand Down
31 changes: 26 additions & 5 deletions sinner/handlers/frame/CV2VideoHandler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import glob
import os.path
import platform
from pathlib import Path
from typing import List

import cv2
from cv2 import VideoCapture
from numpy import fromfile, uint8
from tqdm import tqdm

from sinner.Status import Mood
from sinner.handlers.frame.BaseFrameHandler import BaseFrameHandler
from sinner.typing import NumeratedFrame, NumeratedFramePath
from sinner.utilities import write_image, get_file_name
from sinner.typing import NumeratedFrame, NumeratedFramePath, Frame
from sinner.utilities import get_file_name
from sinner.validators.AttributeLoader import Rules


Expand Down Expand Up @@ -87,7 +89,7 @@ def get_frames_paths(self, path: str) -> List[NumeratedFramePath]:
ret, frame = capture.read()
if not ret:
break
write_image(frame, os.path.join(path, str(i + 1).zfill(filename_length) + ".png"))
self.write_image(frame, os.path.join(path, str(i + 1).zfill(filename_length) + ".png"))
progress.update()
i += 1
capture.release()
Expand All @@ -111,12 +113,12 @@ def result(self, from_dir: str, filename: str, audio_target: str | None = None)
try:
Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
frame_files = glob.glob(os.path.join(glob.escape(from_dir), '*.png'))
first_frame = cv2.imread(frame_files[0])
first_frame = self.read_image(frame_files[0])
height, width, channels = first_frame.shape
fourcc = self.suggest_codec()
video_writer = cv2.VideoWriter(filename, fourcc, self.output_fps, (width, height))
for frame_path in frame_files:
frame = cv2.imread(frame_path)
frame = self.read_image(frame_path)
video_writer.write(frame)
video_writer.release()
return True
Expand All @@ -132,3 +134,22 @@ def suggest_codec(self) -> int:
self.update_status(message=f"Suggested codec: {fourcc}", mood=Mood.NEUTRAL)
return fourcc
raise NotImplementedError('No supported codecs found')

@staticmethod
def read_image(path: str) -> Frame:
if platform.system().lower() == 'windows': # issue #511
image = cv2.imdecode(fromfile(path, dtype=uint8), cv2.IMREAD_UNCHANGED)
if image.shape[2] == 4: # fixes the alpha-channel issue
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
else:
return cv2.imread(path)

@staticmethod
def write_image(image: Frame, path: str) -> bool:
if platform.system().lower() == 'windows': # issue #511
is_success, im_buf_arr = cv2.imencode(".png", image)
im_buf_arr.tofile(path)
return is_success
else:
return cv2.imwrite(path, image)
5 changes: 3 additions & 2 deletions sinner/handlers/frame/DirectoryHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing import List

from sinner.handlers.frame.BaseFrameHandler import BaseFrameHandler
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.typing import NumeratedFrame, NumeratedFramePath
from sinner.utilities import read_image, is_image, get_file_name
from sinner.utilities import is_image, get_file_name


class DirectoryHandler(BaseFrameHandler):
Expand All @@ -27,7 +28,7 @@ def get_frames_paths(self, path: str) -> List[NumeratedFramePath]:
return [(int(get_file_name(file_path)), file_path) for file_path in frames_path if is_image(file_path)]

def extract_frame(self, frame_number: int) -> NumeratedFrame:
return frame_number, read_image(self.get_frames_paths(self._target_path)[frame_number - 1][1]) # zero-based sorted frames list
return frame_number, CV2VideoHandler.read_image(self.get_frames_paths(self._target_path)[frame_number - 1][1]) # zero-based sorted frames list

def result(self, from_dir: str, filename: str, audio_target: str | None = None) -> bool:
self.update_status(f"Copying results from {from_dir} to {filename}")
Expand Down
2 changes: 1 addition & 1 deletion sinner/handlers/frame/FFmpegVideoHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def detect_fps(self) -> float:
def detect_fc(self) -> int:
try:
command = ['ffprobe', '-v', 'error', '-count_frames', '-select_streams', 'v:0', '-show_entries', 'stream=nb_frames', '-of', 'default=nokey=1:noprint_wrappers=1', self._target_path]
output = subprocess.check_output(command, stderr=subprocess.STDOUT).decode('utf-8').strip()
output = subprocess.check_output(command, stderr=subprocess.STDOUT).decode('utf-8').strip() # can be very slow!
if 'N/A' == output:
return 1 # non-frame files, still processable
return int(output)
Expand Down
5 changes: 3 additions & 2 deletions sinner/handlers/frame/ImageHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from sinner.Status import Mood
from sinner.handlers.frame.BaseFrameHandler import BaseFrameHandler
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.typing import NumeratedFrame, NumeratedFramePath
from sinner.utilities import read_image, is_image
from sinner.utilities import is_image


class ImageHandler(BaseFrameHandler):
Expand All @@ -27,7 +28,7 @@ def get_frames_paths(self, path: str) -> List[NumeratedFramePath]:
return [(1, self._target_path)]

def extract_frame(self, frame_number: int) -> NumeratedFrame:
return frame_number, read_image(self._target_path)
return frame_number, CV2VideoHandler.read_image(self._target_path)

def result(self, from_dir: str, filename: str, audio_target: str | None = None) -> bool:
try:
Expand Down
5 changes: 3 additions & 2 deletions sinner/processors/frame/BaseFrameProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from argparse import Namespace

from sinner.handlers.frame.BaseFrameHandler import BaseFrameHandler
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.validators.AttributeLoader import AttributeLoader, Rules
from sinner.State import State
from sinner.typing import Frame, FramesDataType, FrameDataType, NumeratedFrame
from sinner.utilities import load_class, get_mem_usage, read_image, suggest_execution_threads, suggest_execution_providers, decode_execution_providers, suggest_max_memory
from sinner.utilities import load_class, get_mem_usage, suggest_execution_threads, suggest_execution_providers, decode_execution_providers, suggest_max_memory


class BaseFrameProcessor(ABC, AttributeLoader):
Expand Down Expand Up @@ -96,7 +97,7 @@ def process_frames(self, frame_data: FrameDataType, state: State) -> None: # ty
if isinstance(frame_data, int):
frame_num, frame = self.extract_frame_method(frame_data)
else:
frame = read_image(frame_data[1])
frame = CV2VideoHandler.read_image(frame_data[1])
frame_num = frame_data[0]
state.save_temp_frame(self.process_frame(frame), frame_num)
except Exception as exception:
Expand Down
5 changes: 3 additions & 2 deletions sinner/processors/frame/FaceSwapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import insightface

from sinner.FaceAnalyser import FaceAnalyser
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.validators.AttributeLoader import Rules
from sinner.processors.frame.BaseFrameProcessor import BaseFrameProcessor
from sinner.typing import Face, Frame, FaceSwapperType
from sinner.utilities import conditional_download, read_image, get_app_dir, is_image, is_video, get_file_name
from sinner.utilities import conditional_download, get_app_dir, is_image, is_video, get_file_name


class FaceSwapper(BaseFrameProcessor):
Expand Down Expand Up @@ -65,7 +66,7 @@ def suggest_output_path(self) -> str:
@property
def source_face(self) -> Face | None:
if self._source_face is None:
self._source_face = self.face_analyser.get_one_face(read_image(self.source_path))
self._source_face = self.face_analyser.get_one_face(CV2VideoHandler.read_image(self.source_path))
return self._source_face

@property
Expand Down
32 changes: 0 additions & 32 deletions sinner/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@
import urllib
from typing import List, Literal, Any, get_type_hints

import cv2
import onnxruntime
import psutil
import tensorflow
from numpy import uint8, fromfile
from tqdm import tqdm

from sinner.typing import Frame

TEMP_DIRECTORY = 'temp'


Expand Down Expand Up @@ -77,25 +73,6 @@ def resolve_relative_path(path: str, from_file: str | None = None) -> str:
return os.path.abspath(os.path.join(os.path.dirname(from_file), path)) # type: ignore[arg-type]


def read_image(path: str) -> Frame:
if platform.system().lower() == 'windows': # issue #511
image = cv2.imdecode(fromfile(path, dtype=uint8), cv2.IMREAD_UNCHANGED)
if image.shape[2] == 4: # fixes the alpha-channel issue
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
else:
return cv2.imread(path)


def write_image(image: Frame, path: str) -> bool:
if platform.system().lower() == 'windows': # issue #511
is_success, im_buf_arr = cv2.imencode(".png", image)
im_buf_arr.tofile(path)
return is_success
else:
return cv2.imwrite(path, image)


def get_mem_usage(param: Literal['rss', 'vms', 'shared', 'text', 'lib', 'data', 'dirty'] = 'rss', size: Literal['b', 'k', 'm', 'g'] = 'm') -> int:
"""
The `memory_info()` method of the `psutil.Process` class provides information about the memory usage of a process. It returns a named tuple containing the following attributes:
Expand Down Expand Up @@ -209,12 +186,3 @@ def declared_attr_type(obj: object, attribute: str) -> Any:
if attribute in declared_typed_variables:
return declared_typed_variables[attribute]
return None


def suggest_output_path(target_path: str, output_path: str | None = None) -> str:
target_name, target_extension = os.path.splitext(os.path.basename(target_path))
if output_path is None:
return os.path.join(os.path.dirname(target_path), 'result-' + target_name + target_extension)
if os.path.isdir(output_path):
return os.path.join(output_path, 'result-' + target_name + target_extension)
return output_path
4 changes: 2 additions & 2 deletions tests/processors/test_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from sinner.Parameters import Parameters
from sinner.handlers.frame.BaseFrameHandler import BaseFrameHandler
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.handlers.frame.VideoHandler import VideoHandler
from sinner.processors.frame.BaseFrameProcessor import BaseFrameProcessor
from sinner.processors.frame.DummyProcessor import DummyProcessor
from sinner.State import State
from sinner.typing import Frame
from sinner.utilities import read_image
from tests.constants import source_jpg, target_png, IMAGE_SHAPE, target_mp4, tmp_dir, TARGET_FC

parameters: Namespace = Parameters(f'--frame-processor=DummyProcessor --execution-provider=cpu --execution-threads={multiprocessing.cpu_count()} --source-path="{source_jpg}" --target-path="{target_mp4}" --output-path="{tmp_dir}"').parameters
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_init():


def test_process_frame():
processed_frame = get_test_object().process_frame(read_image(target_png))
processed_frame = get_test_object().process_frame(CV2VideoHandler.read_image(target_png))
assert (processed_frame, Frame)
assert processed_frame.shape == IMAGE_SHAPE

Expand Down
4 changes: 2 additions & 2 deletions tests/processors/test_face_enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from sinner.Parameters import Parameters

from sinner.FaceAnalyser import FaceAnalyser
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.processors.frame.FaceEnhancer import FaceEnhancer
from sinner.State import State
from sinner.typing import Frame
from sinner.utilities import read_image
from tests.constants import target_png, IMAGE_SHAPE, tmp_dir

parameters: Namespace = Parameters(f'--execution-provider=cpu --execution-threads={multiprocessing.cpu_count()} --max-memory=12 --target-path="{target_png}" --output-path="{tmp_dir}"').parameters
Expand Down Expand Up @@ -36,6 +36,6 @@ def test_init():


def test_process_frame():
processed_frame = get_test_object().process_frame(read_image(target_png))
processed_frame = get_test_object().process_frame(CV2VideoHandler.read_image(target_png))
assert (processed_frame, Frame)
assert processed_frame.shape == IMAGE_SHAPE
4 changes: 2 additions & 2 deletions tests/processors/test_face_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from sinner.Parameters import Parameters
from sinner.FaceAnalyser import FaceAnalyser
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.processors.frame.FaceSwapper import FaceSwapper
from sinner.typing import Frame, FaceSwapperType, Face
from sinner.utilities import read_image
from tests.constants import source_jpg, target_png, IMAGE_SHAPE, tmp_dir

parameters: Namespace = Parameters(f'--execution-provider=cpu --execution-threads={multiprocessing.cpu_count()} --max-memory=12 --source-path="{source_jpg}" --target-path="{target_png}" --output-path="{tmp_dir}"').parameters
Expand All @@ -30,6 +30,6 @@ def test_face_analysis():


def test_process_frame():
processed_frame = get_test_object().process_frame(read_image(target_png))
processed_frame = get_test_object().process_frame(CV2VideoHandler.read_image(target_png))
assert (processed_frame, Frame)
assert processed_frame.shape == IMAGE_SHAPE
8 changes: 4 additions & 4 deletions tests/test_face_analyser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List

from sinner.FaceAnalyser import FaceAnalyser
from sinner.handlers.frame.CV2VideoHandler import CV2VideoHandler
from sinner.typing import Face
from sinner.utilities import read_image
from tests.constants import source_jpg, target_faces


Expand All @@ -12,23 +12,23 @@ def get_test_object() -> FaceAnalyser:

def test_one_face():
analyser = get_test_object()
face = analyser.get_one_face(read_image(source_jpg))
face = analyser.get_one_face(CV2VideoHandler.read_image(source_jpg))
assert (face, Face)
assert face.age == 31
assert face.sex == 'F'


def test_one_face_from_many():
analyser = get_test_object()
face = analyser.get_one_face(read_image(target_faces))
face = analyser.get_one_face(CV2VideoHandler.read_image(target_faces))
assert (face, Face)
assert face.age == 47
assert face.sex == 'M'


def test_many_faces():
analyser = get_test_object()
faces = analyser.get_many_faces(read_image(target_faces))
faces = analyser.get_many_faces(CV2VideoHandler.read_image(target_faces))
assert (faces, List)
assert len(faces) == 2
assert faces[0].age == 28
Expand Down

0 comments on commit 8595261

Please sign in to comment.