Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upload organisation logo in S3. #1008

Merged
merged 9 commits into from
Nov 26, 2023
93 changes: 62 additions & 31 deletions src/backend/app/organization/organization_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,31 @@
# You should have received a copy of the GNU General Public License
# along with FMTM. If not, see <https:#www.gnu.org/licenses/>.
#
import os
import random
"""Logic for organization management."""

import re
import string
from io import BytesIO

from fastapi import HTTPException, UploadFile
from loguru import logger as log
from sqlalchemy import func
from sqlalchemy.orm import Session

from ..db import db_models

IMAGEDIR = "app/images/"
from app.config import settings
from app.db import db_models
from app.s3 import add_obj_to_bucket


def get_organisations(
db: Session,
):
"""Get all orgs."""
db_organisation = db.query(db_models.DbOrganisation).all()
return db_organisation


def generate_slug(text: str) -> str:
"""Sanitise the organization name for use in a URL."""
# Remove special characters and replace spaces with hyphens
slug = re.sub(r"[^\w\s-]", "", text).strip().lower().replace(" ", "-")
# Remove consecutive hyphens
Expand All @@ -46,6 +48,10 @@ def generate_slug(text: str) -> str:


async def get_organisation_by_name(db: Session, name: str):
"""Get org by name.

This function is used to check if a org exists with the same name.
"""
# Use SQLAlchemy's query-building capabilities
db_organisation = (
db.query(db_models.DbOrganisation)
Expand All @@ -55,63 +61,87 @@ async def get_organisation_by_name(db: Session, name: str):
return db_organisation


async def upload_image(db: Session, file: UploadFile(None)):
# Check if file with the same name exists
filename = file.filename
file_path = f"{IMAGEDIR}{filename}"
while os.path.exists(file_path):
# Generate a random character
random_char = "".join(random.choices(string.ascii_letters + string.digits, k=3))
async def upload_logo_to_s3(
db_org: db_models.DbOrganisation, logo_file: UploadFile(None)
) -> str:
"""Upload logo using standardised /{org_id}/logo.png format.

# Add the random character to the filename
logo_name, extension = os.path.splitext(filename)
filename = f"{logo_name}_{random_char}{extension}"
file_path = f"{IMAGEDIR}{filename}"
Browsers treat image mimetypes the same, regardless of extension,
so it should not matter if a .jpg is renamed .png.

# Read the file contents
contents = await file.read()
Args:
db_org(db_models.DbOrganisation): The organization database object.
logo_file(UploadFile): The logo image uploaded to FastAPI.

# Save the file
with open(file_path, "wb") as f:
f.write(contents)
Returns:
logo_path(str): The file path in S3.
"""
logo_path = f"/{db_org.id}/logo.png"

return filename
file_bytes = await logo_file.read()
file_obj = BytesIO(file_bytes)

add_obj_to_bucket(
settings.S3_BUCKET_NAME,
file_obj,
logo_path,
content_type=logo_file.content_type,
)

return logo_path


async def create_organization(
db: Session, name: str, description: str, url: str, logo: UploadFile(None)
):
"""Creates a new organization with the given name, description, url, type, and logo.
Saves the logo file to the app/images folder.

Saves the logo file S3 bucket under /{org_id}/logo.png.

Args:
db (Session): database session
name (str): name of the organization
description (str): description of the organization
url (str): url of the organization
type (int): type of the organization
logo (UploadFile, optional): logo file of the organization. Defaults to File(...).
logo (UploadFile, optional): logo file of the organization.
Defaults to File(...).

Returns:
bool: True if organization was created successfully
"""
# create new organization
try:
logo_name = await upload_image(db, logo) if logo else None

# Create new organization without logo set
db_organization = db_models.DbOrganisation(
name=name,
slug=generate_slug(name),
description=description,
url=url,
logo=logo_name,
)

db.add(db_organization)
db.commit()
# Refresh to get the assigned org id
db.refresh(db_organization)

logo_path = await upload_logo_to_s3(db_organization, logo)

# Update the logo field in the database with the correct path
db_organization.logo = (
f"{settings.S3_DOWNLOAD_ROOT}/{settings.S3_BUCKET_NAME}{logo_path}"
)
db.commit()

except Exception as e:
log.error(e)
log.exception(e)
log.debug("Rolling back changes to db organization")
# Rollback any changes
db.rollback()
# Delete the failed organization entry
if db_organization:
log.debug(f"Deleting created organisation ID {db_organization.id}")
db.delete(db_organization)
db.commit()
raise HTTPException(
status_code=400, detail=f"Error creating organization: {e}"
) from e
Expand Down Expand Up @@ -145,6 +175,7 @@ async def update_organization_info(
url: str,
logo: UploadFile,
):
"""Update an existing organisation database entry."""
organization = await get_organisation_by_id(db, organization_id)
if not organization:
raise HTTPException(status_code=404, detail="Organization not found")
Expand All @@ -156,7 +187,7 @@ async def update_organization_info(
if url:
organization.url = url
if logo:
organization.logo = await upload_image(db, logo) if logo else None
organization.logo = await upload_logo_to_s3(organization, logo)

db.commit()
db.refresh(organization)
Expand Down
16 changes: 16 additions & 0 deletions src/backend/app/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def add_file_to_bucket(bucket_name: str, file_path: str, s3_path: str):
file_path (str): The path to the file on the local filesystem.
s3_path (str): The path in the S3 bucket where the file will be stored.
"""
# Ensure s3_path starts with a forward slash
if not s3_path.startswith("/"):
s3_path = f"/{s3_path}"

client = s3_client()
client.fput_object(bucket_name, file_path, s3_path)

Expand All @@ -55,6 +59,10 @@ def add_obj_to_bucket(
kwargs (dict[str, Any]): Any other arguments to pass to client.put_object.

"""
# Ensure s3_path starts with a forward slash
if not s3_path.startswith("/"):
s3_path = f"/{s3_path}"

client = s3_client()
# Set BytesIO object to start, prior to .read()
file_obj.seek(0)
Expand All @@ -77,6 +85,10 @@ def get_file_from_bucket(bucket_name: str, s3_path: str, file_path: str):
file_path (str): The path on the local filesystem where the S3
file will be saved.
"""
# Ensure s3_path starts with a forward slash
if not s3_path.startswith("/"):
s3_path = f"/{s3_path}"

client = s3_client()
client.fget_object(bucket_name, s3_path, file_path)

Expand All @@ -91,6 +103,10 @@ def get_obj_from_bucket(bucket_name: str, s3_path: str) -> BytesIO:
Returns:
BytesIO: A BytesIO object containing the content of the downloaded S3 object.
"""
# Ensure s3_path starts with a forward slash
if not s3_path.startswith("/"):
s3_path = f"/{s3_path}"

client = s3_client()
try:
response = client.get_object(bucket_name, s3_path)
Expand Down