Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAPI Codegen #1103

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ reqwest = { version = "0.11", features = ["blocking", "multipart", "json"] }
criterion = { version = "0.5.1", features = ["async_tokio"] }
vcpkg = "0.2"
cc = "1.0"

oasgen = { version = "0.22.0", features = ["axum", "chrono"] }
once_cell = "1.20.2"
sentry = { version = "0.36.0", features = ["tracing"] }

Expand Down
1 change: 1 addition & 0 deletions screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ which = "7.0.0"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
oasgen = { workspace = true }

# Cross-platform audio capture
cpal = { git = "https://github.com/Kree0/cpal.git", branch = "master" }
Expand Down
6 changes: 4 additions & 2 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use std::sync::Arc;
use std::time::Duration;
use std::{fmt, thread};
use tokio::sync::{broadcast, oneshot};
use oasgen::OaSchema;

lazy_static! {
pub static ref LAST_AUDIO_CAPTURE: AtomicU64 = AtomicU64::new(
std::time::SystemTime::now()
Expand Down Expand Up @@ -51,13 +53,13 @@ pub struct DeviceControl {
pub is_paused: bool,
}

#[derive(Clone, Eq, PartialEq, Hash, Serialize, Debug, Deserialize)]
#[derive(OaSchema, Clone, Eq, PartialEq, Hash, Serialize, Debug, Deserialize)]
pub enum DeviceType {
Input,
Output,
}

#[derive(Clone, Eq, PartialEq, Hash, Serialize, Debug)]
#[derive(OaSchema, Clone, Eq, PartialEq, Hash, Serialize, Debug)]
pub struct AudioDevice {
pub name: String,
pub device_type: DeviceType,
Expand Down
1 change: 1 addition & 0 deletions screenpipe-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ edition = { workspace = true }
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
oasgen = { workspace = true }

screenpipe-events = { path = "../screenpipe-events" }
screenpipe-vision = { path = "../screenpipe-vision" }
Expand Down
4 changes: 2 additions & 2 deletions screenpipe-server/src/bin/screenpipe-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use screenpipe_server::{
create_migration_worker, handle_index_command,
pipe_manager::PipeInfo,
start_continuous_recording, watch_pid, DatabaseManager, MigrationCommand, MigrationConfig,
MigrationStatus, PipeManager, ResourceMonitor, Server,
MigrationStatus, PipeManager, ResourceMonitor, SCServer,
};
use screenpipe_vision::monitor::list_monitors;
#[cfg(target_os = "macos")]
Expand Down Expand Up @@ -796,7 +796,7 @@ async fn main() -> anyhow::Result<()> {
let (audio_devices_tx, _) = broadcast::channel(100);

// TODO: Add SSE stream for realtime audio transcription
let server = Server::new(
let server = SCServer::new(
db_server,
SocketAddr::from(([127, 0, 0, 1], cli.port)),
local_data_dir_clone_2,
Expand Down
39 changes: 20 additions & 19 deletions screenpipe-server/src/db_types.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use chrono::{DateTime, Utc};
use oasgen::OaSchema;
use screenpipe_audio::DeviceType;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use std::error::Error as StdError;
use std::fmt::{self, Display};

#[derive(Debug)]
#[derive(OaSchema, Debug)]
pub struct DatabaseError(pub String);

impl fmt::Display for DatabaseError {
Expand All @@ -16,7 +17,7 @@ impl fmt::Display for DatabaseError {

impl StdError for DatabaseError {}

#[derive(Debug, Serialize, Deserialize)]
#[derive(OaSchema, Debug, Serialize, Deserialize)]
pub enum SearchResult {
OCR(OCRResult),
Audio(AudioResult),
Expand Down Expand Up @@ -48,7 +49,7 @@ pub struct OCRResultRaw {
pub focused: Option<bool>,
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(OaSchema, Debug, Serialize, Deserialize)]
pub struct OCRResult {
pub frame_id: i64,
pub frame_name: String,
Expand All @@ -65,7 +66,7 @@ pub struct OCRResult {
pub focused: Option<bool>,
}

#[derive(Debug, Deserialize, PartialEq, Default, Clone)]
#[derive(OaSchema, Debug, Deserialize, PartialEq, Default, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ContentType {
#[default]
Expand Down Expand Up @@ -100,14 +101,14 @@ pub struct AudioResultRaw {
pub end_time: Option<f64>,
}

#[derive(Debug, Serialize, Deserialize, FromRow, Clone)]
#[derive(OaSchema, Debug, Serialize, Deserialize, FromRow, Clone)]
pub struct Speaker {
pub id: i64,
pub name: String,
pub metadata: String,
}

#[derive(Debug, Serialize, Deserialize)]
#[derive(OaSchema, Debug, Serialize, Deserialize)]
pub struct AudioResult {
pub audio_chunk_id: i64,
pub transcription: String,
Expand All @@ -123,14 +124,14 @@ pub struct AudioResult {
pub end_time: Option<f64>,
}

#[derive(Debug, Deserialize, PartialEq)]
#[derive(OaSchema, Debug, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum TagContentType {
Vision,
Audio,
}

#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
#[derive(OaSchema, Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct UiContent {
pub id: i64,
#[sqlx(rename = "text_output")]
Expand All @@ -147,7 +148,7 @@ pub struct UiContent {
pub browser_url: Option<String>,
}

#[derive(Debug, Clone)]
#[derive(OaSchema, Debug, Clone)]
pub struct FrameData {
pub frame_id: i64,
pub timestamp: DateTime<Utc>,
Expand All @@ -156,7 +157,7 @@ pub struct FrameData {
pub audio_entries: Vec<AudioEntry>,
}

#[derive(Debug, Clone)]
#[derive(OaSchema, Debug, Clone)]
pub struct OCREntry {
pub text: String,
pub app_name: String,
Expand All @@ -165,7 +166,7 @@ pub struct OCREntry {
pub video_file_path: String,
}

#[derive(Debug, Clone)]
#[derive(OaSchema, Debug, Clone)]
pub struct AudioEntry {
pub transcription: String,
pub device_name: String,
Expand All @@ -174,14 +175,14 @@ pub struct AudioEntry {
pub duration_secs: f64,
}

#[derive(Debug, Clone)]
#[derive(OaSchema, Debug, Clone)]
pub struct TimeSeriesChunk {
pub frames: Vec<FrameData>,
pub start_time: DateTime<Utc>,
pub end_time: DateTime<Utc>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(OaSchema, Debug, Clone, Copy, PartialEq, Eq)]
pub enum ContentSource {
Screen,
Audio,
Expand All @@ -196,14 +197,14 @@ impl Display for ContentSource {
}
}

#[derive(Debug, FromRow)]
#[derive(OaSchema, Debug, FromRow)]
pub struct AudioChunk {
pub id: i64,
pub file_path: String,
pub timestamp: DateTime<Utc>,
}

#[derive(Debug, FromRow)]
#[derive(OaSchema, Debug, FromRow)]
pub struct AudioChunksResponse {
pub audio_chunk_id: i64,
pub start_time: Option<f64>,
Expand All @@ -228,22 +229,22 @@ pub struct OcrTextBlock {
pub line_num: String,
}

#[derive(Debug, Serialize, Clone)]
#[derive(OaSchema, Debug, Serialize, Clone)]
pub struct TextPosition {
pub text: String,
pub confidence: f32,
pub bounds: TextBounds,
}

#[derive(Debug, Serialize, Clone)]
#[derive(OaSchema, Debug, Serialize, Clone)]
pub struct TextBounds {
pub left: f32,
pub top: f32,
pub width: f32,
pub height: f32,
}

#[derive(Serialize)]
#[derive(OaSchema, Serialize)]
pub struct SearchMatch {
pub frame_id: i64,
pub timestamp: DateTime<Utc>,
Expand All @@ -267,7 +268,7 @@ pub struct FrameRow {
pub text_json: String,
}

#[derive(Deserialize, PartialEq, Default)]
#[derive(Deserialize, OaSchema, PartialEq, Default)]
pub enum Order {
#[serde(rename = "ascending")]
Ascending,
Expand Down
12 changes: 7 additions & 5 deletions screenpipe-server/src/embedding/embedding_endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use axum::Json;
use oasgen::{oasgen, OaSchema};
use once_cell::sync::OnceCell;
use screenpipe_core::model::EmbeddingModel;
use serde::{Deserialize, Serialize};
Expand All @@ -9,7 +10,7 @@ use tracing::info;
static EMBEDDING_MODEL: OnceCell<Arc<Mutex<EmbeddingModel>>> = OnceCell::new();

// OpenAI-like request/response types
#[derive(Deserialize)]
#[derive(OaSchema, Deserialize)]
#[allow(dead_code)]
pub struct EmbeddingRequest {
model: String,
Expand All @@ -20,7 +21,7 @@ pub struct EmbeddingRequest {
user: Option<String>,
}

#[derive(Deserialize)]
#[derive(OaSchema, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
#[serde(rename = "single")]
Expand All @@ -29,22 +30,22 @@ pub enum EmbeddingInput {
Multiple(Vec<String>),
}

#[derive(Serialize)]
#[derive(OaSchema, Serialize)]
pub struct EmbeddingResponse {
object: String,
data: Vec<EmbeddingData>,
model: String,
usage: Usage,
}

#[derive(Serialize)]
#[derive(OaSchema, Serialize)]
pub struct EmbeddingData {
object: String,
embedding: Vec<f32>,
index: usize,
}

#[derive(Serialize)]
#[derive(OaSchema, Serialize)]
pub struct Usage {
prompt_tokens: usize,
total_tokens: usize,
Expand All @@ -71,6 +72,7 @@ pub async fn get_or_initialize_model() -> anyhow::Result<Arc<Mutex<EmbeddingMode
.map(|model| model.clone())
}

#[oasgen]
pub async fn create_embeddings(
Json(request): Json<EmbeddingRequest>,
) -> Result<Json<EmbeddingResponse>, (axum::http::StatusCode, String)> {
Expand Down
3 changes: 1 addition & 2 deletions screenpipe-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ pub use migration_worker::{
pub use pipe_manager::PipeManager;
pub use resource_monitor::{ResourceMonitor, RestartSignal};
pub use screenpipe_core::Language;
pub use server::create_router;
pub use server::health_check;
pub use server::AppState;
pub use server::ContentItem;
pub use server::HealthCheckResponse;
pub use server::PaginatedResponse;
pub use server::Server;
pub use server::SCServer;
pub use server::{api_list_monitors, MonitorInfo};
pub use video::VideoCapture;
Loading
Loading