-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
287 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# 1.1.0 | ||
|
||
Support for deleting potential CSAM from pict-rs Local Storage | ||
|
||
# 1.0.0 | ||
|
||
Support for deleting potential CSAM from pict-rs Object Storage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,11 @@ | ||
## Make a copy of this file into .env and change the below fields | ||
OBJECT_STORAGE_ENDPOINT="https://eu2.example.com" | ||
PICTRS_BUCKET="pictrs" | ||
AWS_ACCESS_KEY_ID=1234asdf5678zxxcvb890qwerty | ||
AWS_SECRET_ACCESS_KEY=1234567890qwertyuiopasdfghjkl | ||
AWS_DEFAULT_REGION=auto | ||
OBJECT_STORAGE_ENDPOINT="https://eu2.example.com" # Fill in when using object storage | ||
PICTRS_BUCKET="pictrs" # Fill in when using object storage | ||
AWS_ACCESS_KEY_ID=1234asdf5678zxxcvb890qwerty # Fill in when using object storage | ||
AWS_SECRET_ACCESS_KEY=1234567890qwertyuiopasdfghjkl # Fill in when using object storage | ||
AWS_DEFAULT_REGION=auto # Fill in when using object storage | ||
SSH_HOSTNAME="127.0.0.1" # Fill in when using filesystem storage | ||
SSH_PORT=22 # Fill in when using filesystem storage | ||
SSH_USERNAME="root" # This user should have read/write access to your pict-rs files | ||
SSH_PRIVKEY="/home/username/.ssh/id_rsa" # Path to your private key file | ||
SSH_PICTRS_FILES_DIRECTORY="/lemmy/lemmy.example.com/volumes/pictrs/files" # Path to your pictrs files directory |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import os | ||
import datetime | ||
import paramiko | ||
import sys | ||
from getpass import getpass | ||
from PIL import Image | ||
from io import BytesIO | ||
import stat | ||
from pathlib import Path | ||
from loguru import logger | ||
import pytz | ||
|
||
hostname = os.getenv("SSH_HOSTNAME") | ||
if hostname is None: | ||
logger.error("You need to provide an SSH_HOSTNAME var in your .env file") | ||
sys.exit(1) | ||
port = os.getenv("SSH_PORT") | ||
if hostname is None: | ||
logger.error("You need to provide an SSH_PORT var in your .env file") | ||
sys.exit(1) | ||
port = int(port) | ||
username = os.getenv("SSH_USERNAME") | ||
if hostname is None: | ||
logger.error("You need to provide an SSH_USERNAME var in your .env file") | ||
sys.exit(1) | ||
private_key_path = os.getenv("SSH_PRIVKEY") | ||
if hostname is None: | ||
logger.error("You need to provide an SSH_PRIVKEY var in your .env file") | ||
sys.exit(1) | ||
remote_base_directory = os.getenv("SSH_PICTRS_FILES_DIRECTORY") | ||
if hostname is None: | ||
logger.error("You need to provide an SSH_PICTRS_FILES_DIRECTORY var in your .env file") | ||
sys.exit(1) | ||
|
||
private_key_passphrase = getpass(prompt="Enter passphrase for private key: ") | ||
private_key = paramiko.RSAKey(filename=private_key_path, password=private_key_passphrase) | ||
|
||
def get_connection(): | ||
# I can't re-use the same connection when using threading | ||
# So we have to initiate a new connection per thread | ||
transport = paramiko.Transport((hostname, port)) | ||
transport.connect(username=username, pkey=private_key) | ||
sftp = paramiko.SFTPClient.from_transport(transport) | ||
return sftp | ||
|
||
def get_all_images(min_date=None): | ||
sftp = get_connection() | ||
filelist = [] | ||
|
||
def list_files_recursively(remote_directory): | ||
files = sftp.listdir_attr(remote_directory) | ||
for file_info in files: | ||
file_path = os.path.join(remote_directory, file_info.filename) | ||
if stat.S_ISREG(file_info.st_mode): # Check if it's a regular file | ||
modify_time = datetime.datetime.fromtimestamp(file_info.st_mtime, tz=pytz.UTC) | ||
if min_date is None or modify_time >= min_date: | ||
filelist.append( | ||
{ | ||
"key": str(Path(file_path).relative_to(Path(remote_base_directory))), | ||
"filepath": Path(file_path), | ||
"mtime": modify_time, | ||
} | ||
) | ||
elif stat.S_ISDIR(file_info.st_mode): # Check if it's a directory | ||
list_files_recursively(file_path) | ||
|
||
list_files_recursively(remote_base_directory) | ||
return filelist | ||
|
||
def download_image(remote_path): | ||
sftp = get_connection() | ||
remote_file = sftp.open(remote_path, "rb") | ||
image_bytes = remote_file.read() | ||
remote_file.close() | ||
image_pil = Image.open(BytesIO(image_bytes)) | ||
return image_pil | ||
|
||
|
||
def delete_image(remote_path): | ||
sftp = get_connection() | ||
try: | ||
sftp.remove(remote_path) | ||
except FileNotFoundError: | ||
logger.error(f"File not found: {remote_path}") | ||
except Exception as e: | ||
logger.error(f"Error deleting file {remote_path}: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import time | ||
import logging | ||
from datetime import datetime, timedelta, timezone | ||
from concurrent.futures import ThreadPoolExecutor | ||
import argparse | ||
import PIL.Image | ||
|
||
from loguru import logger | ||
import sys | ||
|
||
from lemmy_safety.check import check_image | ||
from lemmy_safety import local_storage | ||
from lemmy_safety import database | ||
from PIL import UnidentifiedImageError | ||
|
||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s', level=logging.WARNING) | ||
|
||
|
||
arg_parser = argparse.ArgumentParser() | ||
arg_parser.add_argument('--all', action="store_true", required=False, default=False, help="Check all images in the storage account") | ||
arg_parser.add_argument('-t', '--threads', action="store", required=False, default=100, type=int, help="How many threads to use. The more threads, the more VRAM requirements, but the faster the processing.") | ||
arg_parser.add_argument('-m', '--minutes', action="store", required=False, default=20, type=int, help="The images of the past how many minutes to check.") | ||
arg_parser.add_argument('--dry_run', action="store_true", required=False, default=False, help="Will check and reprt but will not delete") | ||
args = arg_parser.parse_args() | ||
|
||
|
||
def check_and_delete_filename(file_details): | ||
try: | ||
image: PIL.Image.Image = local_storage.download_image(str(file_details["filepath"])) | ||
except UnidentifiedImageError: | ||
logger.warning("Image could not be read. Returning it as CSAM to be sure.") | ||
is_csam = True | ||
if not image: | ||
is_csam = None | ||
else: | ||
is_csam = check_image(image) | ||
if is_csam and not args.dry_run: | ||
local_storage.delete_image(str(file_details["filepath"])) | ||
return is_csam, file_details | ||
|
||
def run_cleanup(cutoff_time = None): | ||
with ThreadPoolExecutor(max_workers=10) as executor: | ||
futures = [] | ||
for file_details in local_storage.get_all_images(cutoff_time): | ||
if not database.is_image_checked(file_details["key"]): | ||
futures.append(executor.submit(check_and_delete_filename, file_details)) | ||
if len(futures) >= args.threads: | ||
for future in futures: | ||
result, fdetails = future.result() | ||
database.record_image(fdetails["key"],csam=result) | ||
logger.info(f"Safety Checked Images: {len(futures)}") | ||
futures = [] | ||
for future in futures: | ||
result, fdetails = future.result() | ||
database.record_image(fdetails["key"],csam=result) | ||
logger.info(f"Safety Checked Images: {len(futures)}") | ||
|
||
if __name__ == "__main__": | ||
if args.all: | ||
run_cleanup() | ||
else: | ||
while True: | ||
try: | ||
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=args.minutes) | ||
run_cleanup(cutoff_time) | ||
time.sleep(30) | ||
except: | ||
time.sleep(30) | ||
|
Oops, something went wrong.