Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose endpoint for category filtering #70

Merged
merged 8 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 167 additions & 65 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ rand = "0.8"
reqwest = "0.11"
serde = { version = "1.0", features = ["derive"] }
sha2 = "0.10"
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "time"] }
snapd = { git = "https://github.com/ZoopOTheGoop/snapd-rs", branch = "framework" }
sqlx = { version = "0.7", features = [
"runtime-tokio-rustls",
"postgres",
"time",
] }
strum = { version = "0.26.1", features = ["derive"] }
thiserror = "1.0"
time = { version = "0.3", features = ["macros"] }
tokio = { version = "1.36", features = ["full"] }
Expand Down
5 changes: 5 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.build_server(true)
.file_descriptor_set_path(descriptor_set_path)
.out_dir(out_dir)
.type_attribute("Category", "#[derive(sqlx::Type, strum::EnumString)]")
.type_attribute(
"Category",
r#"#[strum(serialize_all = "kebab_case", ascii_case_insensitive)]"#,
)
.compile(files, &["proto"])?;

Ok(())
Expand Down
28 changes: 28 additions & 0 deletions proto/ratings_features_chart.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ service Chart {

message GetChartRequest {
Timeframe timeframe = 1;
optional Category category = 2;
}

message GetChartResponse {
Timeframe timeframe = 1;
repeated ChartData ordered_chart_data = 2;
optional Category category = 3;
}

message ChartData {
Expand All @@ -27,3 +29,29 @@ enum Timeframe {
TIMEFRAME_WEEK = 1;
TIMEFRAME_MONTH = 2;
}

// The categories that can be selected, these
// are taken directly from `curl -sS -X GET --unix-socket /run/snapd.socket "http://localhost/v2/categories"`
// On 2024-02-03, it may need to be kept in sync.
enum Category {
ART_AND_DESIGN = 0;
BOOK_AND_REFERENCE = 1;
DEVELOPMENT = 2;
DEVICES_AND_IOT = 3;
EDUCATION = 4;
ENTERTAINMENT = 5;
FEATURED = 6;
FINANCE = 7;
GAMES = 8;
HEALTH_AND_FITNESS = 9;
MUSIC_AND_AUDIO = 10;
NEWS_AND_WEATHER = 11;
PERSONALISATION = 12;
PHOTO_AND_VIDEO = 13;
PRODUCTIVITY = 14;
SCIENCE = 15;
SECURITY = 16;
SERVER_AND_CLOUD = 17;
SOCIAL = 18;
UTILITIES = 19;
}
8 changes: 8 additions & 0 deletions sql/migrations/20240208071037_categories.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Add up migration script here

CREATE TABLE snap_categories (
id SERIAL PRIMARY KEY,
snap_id CHAR(32) NOT NULL,
category INTEGER NOT NULL,
CONSTRAINT category CHECK (category BETWEEN 0 AND 19)
);
58 changes: 37 additions & 21 deletions src/features/chart/infrastructure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
//! [`Chart`]: crate::features::chart::entities::Chart
use crate::{
app::AppContext,
features::{chart::errors::ChartError, common::entities::VoteSummary, pb::chart::Timeframe},
features::{
chart::errors::ChartError,
common::entities::VoteSummary,
pb::chart::{Category, Timeframe},
},
};
use sqlx::QueryBuilder;
use tracing::error;

/// Retrieves the vote summary in the given [`AppContext`] over a given [`Timeframe`]
/// from the database.
pub(crate) async fn get_votes_summary_by_timeframe(
pub(crate) async fn get_votes_summary(
app_ctx: &AppContext,
timeframe: Timeframe,
category: Option<Category>,
) -> Result<Vec<VoteSummary>, ChartError> {
let mut pool = app_ctx
.infrastructure()
Expand All @@ -22,28 +28,38 @@ pub(crate) async fn get_votes_summary_by_timeframe(
ChartError::FailedToGetChart
})?;

// Generate WHERE clause based on timeframe
let where_clause = match timeframe {
Timeframe::Week => "WHERE votes.created >= NOW() - INTERVAL '1 week'",
Timeframe::Month => "WHERE votes.created >= NOW() - INTERVAL '1 month'",
Timeframe::Unspecified => "", // Adjust as needed for Unspecified case
};

let query = format!(
let mut builder = QueryBuilder::new(
r#"
SELECT
votes.snap_id,
COUNT(*) AS total_votes,
COUNT(*) FILTER (WHERE votes.vote_up) AS positive_votes
FROM
votes
{}
GROUP BY votes.snap_id
"#,
where_clause
SELECT
votes.snap_id,
COUNT(*) AS total_votes,
COUNT(*) FILTER (WHERE votes.vote_up) AS positive_votes
FROM
votes"#,
);

let result = sqlx::query_as::<_, VoteSummary>(&query)
builder.push(match timeframe {
Timeframe::Week => " WHERE votes.created >= NOW() - INTERVAL '1 week'",
Timeframe::Month => " WHERE votes.created >= NOW() - INTERVAL '1 month'",
Timeframe::Unspecified => "", // Adjust as needed for Unspecified case
});

if let Some(category) = category {
builder
.push(
r#"
WHERE votes.snap_id IN (
SELECT snap_categories.snap_id FROM snap_categories
WHERE snap_categories.category = "#,
)
.push_bind(category)
.push(")");
}

builder.push(" GROUP BY votes.snap_id");

let result = builder
.build_query_as()
.fetch_all(&mut *pool)
.await
.map_err(|error| {
Expand Down
23 changes: 15 additions & 8 deletions src/features/chart/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
app::AppContext,
features::{
chart::{errors::ChartError, service::ChartService, use_cases},
pb::chart::{chart_server::Chart, GetChartRequest, GetChartResponse, Timeframe},
pb::chart::{chart_server::Chart, Category, GetChartRequest, GetChartResponse, Timeframe},
},
};
use tonic::{Request, Response, Status};
Expand All @@ -17,16 +17,22 @@ impl Chart for ChartService {
) -> Result<Response<GetChartResponse>, Status> {
let app_ctx = request.extensions().get::<AppContext>().unwrap().clone();

let GetChartRequest { timeframe } = request.into_inner();
let GetChartRequest {
timeframe,
category,
} = request.into_inner();

let timeframe = match timeframe {
0 => Timeframe::Unspecified,
1 => Timeframe::Week,
2 => Timeframe::Month,
_ => Timeframe::Unspecified,
let category = match category {
Some(category) => Some(
Category::try_from(category)
.map_err(|_| Status::invalid_argument("invalid category value"))?,
),
None => None,
};

let result = use_cases::get_chart(&app_ctx, timeframe).await;
let timeframe = Timeframe::try_from(timeframe).unwrap_or(Timeframe::Unspecified);

let result = use_cases::get_chart(&app_ctx, timeframe, category).await;

match result {
Ok(result) => {
Expand All @@ -38,6 +44,7 @@ impl Chart for ChartService {

let payload = GetChartResponse {
timeframe: timeframe.into(),
category: category.map(|v| v.into()),
ordered_chart_data,
};
Ok(Response::new(payload))
Expand Down
14 changes: 8 additions & 6 deletions src/features/chart/use_cases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
use crate::{
app::AppContext,
features::{
chart::{
entities::Chart, errors::ChartError, infrastructure::get_votes_summary_by_timeframe,
},
pb::chart::Timeframe,
chart::{entities::Chart, errors::ChartError, infrastructure::get_votes_summary},
pb::chart::{Category, Timeframe},
},
};
use tracing::error;

/// Gets a chart over the given [`Timeframe`] within the given [`AppContext`]. Either ends up returning
/// a [`Chart`] or else one of the many [`ChartError`]s in case the timeframe is invalid or another database error
/// happens.
pub async fn get_chart(app_ctx: &AppContext, timeframe: Timeframe) -> Result<Chart, ChartError> {
let votes = get_votes_summary_by_timeframe(app_ctx, timeframe)
pub async fn get_chart(
app_ctx: &AppContext,
timeframe: Timeframe,
category: Option<Category>,
) -> Result<Chart, ChartError> {
let votes = get_votes_summary(app_ctx, timeframe, category)
.await
.map_err(|error| {
error!("{error:?}");
Expand Down
2 changes: 1 addition & 1 deletion src/features/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Contains various feature implementations for the ratings backend.

pub mod chart;
pub mod pb;
pub mod rating;
pub mod user;

mod common;
mod pb;
7 changes: 7 additions & 0 deletions src/features/user/errors.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Errors related to user voting
use snapd::SnapdClientError;
use thiserror::Error;

/// Errors that can occur when a user votes.
Expand All @@ -16,6 +17,12 @@ pub enum UserError {
/// The user was unable to cast a vote
#[error("failed to cast vote")]
FailedToCastVote,
/// Errors from `snapd-rs`
#[error("an error occurred when calling snapd: {0}")]
SnapdError(#[from] SnapdClientError),
/// An error that occurred in category updating
#[error("an error occurred with the DB when getting categories: {0}")]
CategoryDBError(#[from] sqlx::Error),
/// Anything else that can go wrong
#[error("unknown user error")]
Unknown,
Expand Down
92 changes: 88 additions & 4 deletions src/features/user/infrastructure.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
//! Infrastructure for user handling
use sqlx::Row;
use snapd::{
api::{convenience::SnapNameFromId, find::FindSnapByName},
SnapdClient,
};
use sqlx::{Acquire, Executor, Row};
use tracing::error;

use crate::{
app::AppContext,
features::user::{
entities::{User, Vote},
errors::UserError,
features::{
pb::chart::Category,
user::{
entities::{User, Vote},
errors::UserError,
},
},
};

Expand Down Expand Up @@ -179,6 +186,58 @@ pub(crate) async fn save_vote_to_db(app_ctx: &AppContext, vote: Vote) -> Result<
Ok(result.rows_affected())
}

/// Convenience function for getting categories by their snap ID, since it takes multiple API calls
async fn snapd_categories_by_snap_id(
client: &SnapdClient,
snap_id: &str,
) -> Result<Vec<Category>, UserError> {
let snap_name = SnapNameFromId::get_name(snap_id.into(), client).await?;

Ok(FindSnapByName::get_categories(snap_name, client)
.await?
.into_iter()
.map(|v| Category::try_from(v.name.as_ref()).expect("got unknown category?"))
.collect())
}

/// Update the category (we do this every time we get a vote for the time being)
pub(crate) async fn update_category(app_ctx: &AppContext, snap_id: &str) -> Result<(), UserError> {
let mut pool = app_ctx
.infrastructure()
.repository()
.await
.map_err(|error| {
error!("{error:?}");
UserError::Unknown
})?;

let snapd_client = &app_ctx.infrastructure().snapd_client;

let categories = snapd_categories_by_snap_id(snapd_client, snap_id).await?;

// Do a transaction because bulk querying doesn't seem to work cleanly
let mut tx = pool.begin().await?;

// Reset the categories since we're refreshing all of them
tx.execute(
sqlx::query("DELETE FROM snap_categories WHERE snap_categories.snap_id = $1;")
.bind(snap_id),
)
.await?;

for category in categories.iter() {
tx.execute(
sqlx::query("INSERT INTO snap_categories (snap_id, category) VALUES ($1, $2); ")
.bind(snap_id)
.bind(category),
)
.await?;
}

tx.commit().await?;
Ok(())
}

/// Retrieve all votes for a given [`User`], within the current [`AppContext`].
///
/// May be filtered for a given snap ID.
Expand Down Expand Up @@ -238,3 +297,28 @@ pub(crate) async fn find_user_votes(

Ok(votes)
}

#[cfg(test)]
mod test {
use std::collections::HashSet;

use snapd::SnapdClient;

use crate::features::pb::chart::Category;

use super::snapd_categories_by_snap_id;
const TESTING_SNAP_ID: &str = "3Iwi803Tk3KQwyD6jFiAJdlq8MLgBIoD";
const TESTING_SNAP_CATEGORIES: [Category; 2] = [Category::Utilities, Category::Development];

#[tokio::test]
async fn get_categories() {
let categories = snapd_categories_by_snap_id(&SnapdClient::default(), TESTING_SNAP_ID)
.await
.unwrap();

assert_eq!(
TESTING_SNAP_CATEGORIES.into_iter().collect::<HashSet<_>>(),
categories.into_iter().collect::<HashSet<_>>()
)
}
}
Loading
Loading