Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Hide internal statuses and fix submitted status filtering for BCeID users - 2050 #2075

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backend/lcfs/db/models/compliance/ComplianceReportStatus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ class ComplianceReportStatusEnum(enum.Enum):
Reassessed = "Reassessed"
Rejected = "Rejected"

def underscore_value(self) -> str:
"""
Return the status as an underscored string.
"""
return self.value.replace(" ", "_")


class ComplianceReportStatus(BaseModel, EffectiveDates):
__tablename__ = "compliance_report_status"
Expand Down
61 changes: 30 additions & 31 deletions backend/lcfs/web/api/compliance_report/repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections import defaultdict
from datetime import datetime
from typing import List, Optional, Dict, Union, Tuple
from typing import List, Optional, Dict, Union

import structlog
from fastapi import Depends
from sqlalchemy import func, select, and_, asc, desc, update, or_, String, cast
from sqlalchemy import func, select, and_, asc, desc, update, String, cast
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, contains_eager, aliased
from sqlalchemy.orm import joinedload, aliased
from sqlalchemy.inspection import inspect

from lcfs.db.dependencies import get_async_db_session
Expand Down Expand Up @@ -73,22 +73,33 @@ def apply_filters(self, pagination, conditions):
filter_type = filter.filter_type
if filter.field == "status":
field = cast(
get_field_for_filter(
ComplianceReportListView, "report_status"),
get_field_for_filter(ComplianceReportListView, "report_status"),
String,
)
# Check if filter_value is a comma-separated string
if isinstance(filter_value, str) and "," in filter_value:
filter_value = filter_value.split(",") # Convert to list

if isinstance(filter_value, list):
filter_value = [value.replace(" ", "_")
for value in filter_value]

def underscore_string(val):
"""
If the item is an enum member, get its `.value`
Then do .replace(" ", "_") so we get underscores
"""
if isinstance(val, ComplianceReportStatusEnum):
val = val.value # convert enum to string
return val.replace(" ", "_")

filter_value = [underscore_string(val) for val in filter_value]
filter_type = "set"
else:
if isinstance(filter_value, ComplianceReportStatusEnum):
filter_value = filter_value.value
filter_value = filter_value.replace(" ", "_")

elif filter.field == "type":
field = get_field_for_filter(
ComplianceReportListView, "report_type")
field = get_field_for_filter(ComplianceReportListView, "report_type")
elif filter.field == "organization":
field = get_field_for_filter(
ComplianceReportListView, "organization_name"
Expand All @@ -98,12 +109,10 @@ def apply_filters(self, pagination, conditions):
ComplianceReportListView, "compliance_period"
)
else:
field = get_field_for_filter(
ComplianceReportListView, filter.field)
field = get_field_for_filter(ComplianceReportListView, filter.field)

conditions.append(
apply_filter_conditions(
field, filter_value, filter_option, filter_type)
apply_filter_conditions(field, filter_value, filter_option, filter_type)
)

@repo_handler
Expand Down Expand Up @@ -158,8 +167,7 @@ async def get_compliance_period(self, period: str) -> CompliancePeriod:
Retrieve a compliance period from the database
"""
result = await self.db.scalar(
select(CompliancePeriod).where(
CompliancePeriod.description == period)
select(CompliancePeriod).where(CompliancePeriod.description == period)
)
return result

Expand Down Expand Up @@ -198,8 +206,7 @@ async def get_compliance_report_status_by_desc(
Retrieve the compliance report status ID from the database based on the description.
Replaces spaces with underscores in the status description.
"""
status_enum = status.replace(
" ", "_") # frontend sends status with spaces
status_enum = status.replace(" ", "_") # frontend sends status with spaces
result = await self.db.execute(
select(ComplianceReportStatus).where(
ComplianceReportStatus.status
Expand Down Expand Up @@ -386,17 +393,15 @@ async def get_reports_paginated(
self.apply_filters(pagination, conditions)

# Pagination and offset setup
offset = 0 if (pagination.page < 1) else (
pagination.page - 1) * pagination.size
offset = 0 if (pagination.page < 1) else (pagination.page - 1) * pagination.size
limit = pagination.size

# Build the main query
query = query.where(and_(*conditions))

# Apply sorting from pagination
if len(pagination.sort_orders) < 1:
field = get_field_for_filter(
ComplianceReportListView, "update_date")
field = get_field_for_filter(ComplianceReportListView, "update_date")
query = query.order_by(desc(field))
for order in pagination.sort_orders:
sort_method = asc if order.direction == "asc" else desc
Expand Down Expand Up @@ -731,15 +736,13 @@ def aggregate_quantities(
isinstance(record, FuelSupply)
and record.fuel_type.fossil_derived == fossil_derived
):
fuel_category = self._format_category(
record.fuel_category.category)
fuel_category = self._format_category(record.fuel_category.category)
fuel_quantities[fuel_category] += record.quantity
elif (
isinstance(record, OtherUses)
and record.fuel_type.fossil_derived == fossil_derived
):
fuel_category = self._format_category(
record.fuel_category.category)
fuel_category = self._format_category(record.fuel_category.category)
fuel_quantities[fuel_category] += record.quantity_supplied

return dict(fuel_quantities)
Expand Down Expand Up @@ -891,15 +894,11 @@ async def get_compliance_report_group_id(self, report_id):

@repo_handler
async def get_changelog_data(
self,
pagination: PaginationRequestSchema,
compliance_report_id: int,
selection
self, pagination: PaginationRequestSchema, compliance_report_id: int, selection
):

conditions = [selection.compliance_report_id == compliance_report_id]
offset = 0 if pagination.page < 1 else (
pagination.page - 1) * pagination.size
offset = 0 if pagination.page < 1 else (pagination.page - 1) * pagination.size
limit = pagination.size

# Create an alias for the previous version row.
Expand Down
37 changes: 19 additions & 18 deletions backend/lcfs/web/api/compliance_report/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Union, Type

import structlog
from fastapi import Depends, Request
from fastapi import Depends

from lcfs.db.models.compliance.ComplianceReport import (
ComplianceReport,
Expand All @@ -29,7 +29,6 @@
from lcfs.web.api.organization_snapshot.services import OrganizationSnapshotService
from lcfs.web.core.decorators import service_handler
from lcfs.web.exception.exceptions import DataNotFoundException, ServiceException
from lcfs.db.base import ActionTypeEnum

logger = structlog.get_logger(__name__)

Expand All @@ -39,7 +38,6 @@ def __init__(
self,
repo: ComplianceReportRepository = Depends(),
snapshot_services: OrganizationSnapshotService = Depends(),

) -> None:
self.repo = repo
self.snapshot_services = snapshot_services
Expand All @@ -66,8 +64,7 @@ async def create_compliance_report(
report_data.status
)
if not draft_status:
raise DataNotFoundException(
f"Status '{report_data.status}' not found.")
raise DataNotFoundException(f"Status '{report_data.status}' not found.")

# Generate a new group_uuid for the new report series
group_uuid = str(uuid.uuid4())
Expand Down Expand Up @@ -173,7 +170,10 @@ async def create_supplemental_report(

@service_handler
async def get_compliance_reports_paginated(
self, pagination, organization_id: int = None, bceid_user: bool = False
self,
pagination,
organization_id: int = None,
bceid_user: bool = False,
):
"""Fetches all compliance reports"""
if bceid_user:
Expand Down Expand Up @@ -208,8 +208,8 @@ async def get_compliance_reports_paginated(

def _mask_report_status(self, reports: List) -> List:
recommended_statuses = {
ComplianceReportStatusEnum.Recommended_by_analyst.value,
ComplianceReportStatusEnum.Recommended_by_manager.value,
ComplianceReportStatusEnum.Recommended_by_analyst.underscore_value(),
ComplianceReportStatusEnum.Recommended_by_manager.underscore_value(),
}

masked_reports = []
Expand Down Expand Up @@ -263,8 +263,7 @@ async def get_compliance_report_by_id(

if apply_masking:
# Apply masking to each report in the chain
masked_chain = self._mask_report_status(
compliance_report_chain)
masked_chain = self._mask_report_status(compliance_report_chain)
# Apply history masking to each report in the chain
masked_chain = [
self._mask_report_status_for_history(report, apply_masking)
Expand Down Expand Up @@ -317,7 +316,7 @@ def _model_to_dict(self, record) -> dict:
"""Safely convert a model to a dict, skipping lazy-loaded attributes that raise errors."""
result = {}
for key, value in record.__dict__.items():
if key == '_sa_instance_state':
if key == "_sa_instance_state":
continue
try:
result[key] = value
Expand All @@ -330,10 +329,11 @@ async def get_changelog_data(
self,
pagination: PaginationResponseSchema,
compliance_report_id: int,
selection: Type[Union[FuelSupply, OtherUses,
NotionalTransfer, FuelExport]]
selection: Type[Union[FuelSupply, OtherUses, NotionalTransfer, FuelExport]],
):
changelog, total_count = await self.repo.get_changelog_data(pagination, compliance_report_id, selection)
changelog, total_count = await self.repo.get_changelog_data(
pagination, compliance_report_id, selection
)

groups = {}
for record in changelog:
Expand All @@ -359,12 +359,13 @@ async def get_changelog_data(
changelog = [record for group in groups.values() for record in group]

return {
'pagination': PaginationResponseSchema(
"pagination": PaginationResponseSchema(
total=total_count,
page=pagination.page,
size=pagination.size,
total_pages=math.ceil(
total_count / pagination.size) if pagination.size else 0,
total_pages=(
math.ceil(total_count / pagination.size) if pagination.size else 0
),
),
'changelog': changelog,
"changelog": changelog,
}