From bbe043c094913b079abfa8b804ee094d58c33f17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Zl=C3=A1mal?= Date: Wed, 19 Jul 2023 10:11:35 +0200 Subject: [PATCH] Rossum DMv2 analyze: add export functionality This allows downloading of the new results from DMv2. --- src/rossum-dmv2-analyze/Cargo.lock | 23 +++++ src/rossum-dmv2-analyze/Cargo.toml | 2 + src/rossum-dmv2-analyze/README.md | 2 +- src/rossum-dmv2-analyze/src/api.rs | 64 ++++++++++++ src/rossum-dmv2-analyze/src/clap.rs | 38 +++++-- src/rossum-dmv2-analyze/src/dmv2.rs | 33 +++++++ src/rossum-dmv2-analyze/src/export.rs | 39 ++++++++ src/rossum-dmv2-analyze/src/main.rs | 120 +++++++++++++---------- src/rossum-dmv2-analyze/src/processor.rs | 60 +++++++----- 9 files changed, 294 insertions(+), 87 deletions(-) create mode 100644 src/rossum-dmv2-analyze/src/api.rs create mode 100644 src/rossum-dmv2-analyze/src/export.rs diff --git a/src/rossum-dmv2-analyze/Cargo.lock b/src/rossum-dmv2-analyze/Cargo.lock index d0df4fb387..b461c3aebc 100644 --- a/src/rossum-dmv2-analyze/Cargo.lock +++ b/src/rossum-dmv2-analyze/Cargo.lock @@ -224,6 +224,27 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "csv" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "626ae34994d3d8d668f4269922248239db4ae42d538b14c398b74a52208e8086" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +dependencies = [ + "memchr", +] + [[package]] name = "encode_unicode" version = "0.3.6" @@ -826,10 +847,12 @@ dependencies = [ "anyhow", "clap", "colored", + "csv", "indicatif", "reqwest", "serde", "serde_json", + "tempfile", "tokio", ] diff --git a/src/rossum-dmv2-analyze/Cargo.toml b/src/rossum-dmv2-analyze/Cargo.toml index ebf4ff8dff..011c8ccbb0 100644 --- a/src/rossum-dmv2-analyze/Cargo.toml +++ b/src/rossum-dmv2-analyze/Cargo.toml @@ -9,8 +9,10 @@ edition = "2021" anyhow = { version = "1.0.72" } clap = { version = "4.3.16", default-features = false, features = ["std", "cargo", "color", "deprecated", "env", "error-context", "help", "suggestions", "unicode", "usage"] } colored = { version = "2.0.4" } +csv = { version = "1.2.2" } indicatif = { version = "0.17.5" } reqwest = { version = "0.11.18", features = ["json"] } serde = { version = "1.0.171", features = ["derive"] } serde_json = { version = "1.0.103" } +tempfile = { version = "3.6.0" } tokio = { version = "1.29.1", features = ["full"] } diff --git a/src/rossum-dmv2-analyze/README.md b/src/rossum-dmv2-analyze/README.md index 695c75e454..f38071ab14 100644 --- a/src/rossum-dmv2-analyze/README.md +++ b/src/rossum-dmv2-analyze/README.md @@ -9,7 +9,7 @@ brew install adeira/universe/rossum-dmv2-analyze ## Run ```bash -rossum-dmv2-analyze --config-file=./dmv2_config.json --dm-hook-id=252259 --queue-id=852015 --api-token=XXXXX +rossum-dmv2-analyze --dm-config-file=./dmv2_config.json --dm-hook-id=252259 --queue-id=852015 --api-token=XXXXX ``` Try running `rossum-dmv2-analyze --help` for more information. diff --git a/src/rossum-dmv2-analyze/src/api.rs b/src/rossum-dmv2-analyze/src/api.rs new file mode 100644 index 0000000000..e456cd62ef --- /dev/null +++ b/src/rossum-dmv2-analyze/src/api.rs @@ -0,0 +1,64 @@ +use reqwest::{Client, Url}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Document { + pub annotations: Vec, + pub id: u64, + pub mime_type: String, + pub original_file_name: String, + pub s3_name: String, + pub url: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Annotations { + pub results: Vec, + pub pagination: Pagination, +} + +/// See: https://elis.rossum.ai/api/docs/#annotation +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct Annotation { + pub id: i32, + pub document: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Pagination { + pub total: i32, + pub total_pages: i32, + pub next: Option, + pub previous: Option, +} + +pub(crate) async fn get( + http_client: &Client, + url: &String, +) -> reqwest::Result { + http_client.get(url).send().await?.json().await +} + +pub(crate) async fn annotations_get( + http_client: &Client, + queue_id: &str, + current_page: &mut i32, + concurrency: &usize, + annotations_status: &Vec, +) -> anyhow::Result { + Ok(http_client + .get(Url::parse_with_params( + "https://elis.rossum.ai/api/v1/annotations", + &[ + ("queue", queue_id.to_string()), + ("page", current_page.to_string()), + ("page_size", concurrency.to_string()), + ("status", annotations_status.join(",").to_string()), + ], + )?) + .send() + .await? + .json::() + .await?) +} diff --git a/src/rossum-dmv2-analyze/src/clap.rs b/src/rossum-dmv2-analyze/src/clap.rs index d5e76fba7c..891d8418c3 100644 --- a/src/rossum-dmv2-analyze/src/clap.rs +++ b/src/rossum-dmv2-analyze/src/clap.rs @@ -2,6 +2,24 @@ use clap::{Arg, Command}; pub fn generate_clap_app() -> Command { clap::command!() + .arg( + Arg::new("annotations-status") + .long("annotations-status") + .help("Comma-separated list of annotation statuses to analyze.") + .num_args(1) + .value_delimiter(',') + .required(false) + .default_value("confirmed,exported,exporting,reviewing,to_review") + ) + .arg( + Arg::new("api-token") + .long("api-token") + .num_args(1) + .long_help( + "API token belonging to the organization that will be used for calling Rossum API endpoints.", + ) + .required(true), + ) .arg( Arg::new("concurrency") .short('c') @@ -12,8 +30,8 @@ pub fn generate_clap_app() -> Command { .default_value("100"), ) .arg( - Arg::new("config-file") - .long("config-file") + Arg::new("dm-config-file") + .long("dm-config-file") .num_args(1) .long_help( "JSON config file with the new DMv2 configuration. Copy-paste here the whole \ @@ -29,20 +47,22 @@ pub fn generate_clap_app() -> Command { .required(true), ) .arg( - Arg::new("queue-id") - .long("queue-id") + Arg::new("export-results-file") + .long("export-results-file") .num_args(1) .long_help( - "Queue ID that will be used for the analysis of all relevant annotations.", + "Path to the file where DMv2 match (new) results should be exported. \ + The only supported format is currently CSV. \ + If the file already exists, it will be overwritten." ) - .required(true), + .required(false), ) .arg( - Arg::new("api-token") - .long("api-token") + Arg::new("queue-id") + .long("queue-id") .num_args(1) .long_help( - "API token belonging to the organization that will be used for calling Rossum API endpoints.", + "Queue ID that will be used for the analysis of all relevant annotations.", ) .required(true), ) diff --git a/src/rossum-dmv2-analyze/src/dmv2.rs b/src/rossum-dmv2-analyze/src/dmv2.rs index bc72991dda..11a0202b2f 100644 --- a/src/rossum-dmv2-analyze/src/dmv2.rs +++ b/src/rossum-dmv2-analyze/src/dmv2.rs @@ -2,10 +2,39 @@ use crate::score::MessageCounts; use anyhow::Result; use reqwest::{Client, Response}; use serde::{Deserialize, Serialize}; +use serde_json::json; #[derive(Debug, Serialize, Deserialize)] pub(crate) struct Message { + /// Datapoint ID + id: u64, content: String, + r#type: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct Operation { + /// Datapoint ID + pub(crate) id: u64, + pub(crate) op: String, + pub(crate) value: OperationValue, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct OperationValue { + pub(crate) content: OperationValueContent, + pub(crate) options: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct OperationValueContent { + pub(crate) value: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct OperationValueOptionValue { + label: String, + pub(crate) value: String, } fn replace_result_actions( @@ -29,7 +58,9 @@ fn replace_result_actions( // If this array item is an object: if let Some(config_obj) = config_obj.as_object_mut() { // Replace "result_actions" with `new_result_actions`: + // TODO: wrong (remove `result_actions` overwrite) config_obj.insert("result_actions".to_owned(), new_result_actions.clone()); + config_obj.insert("additional_mappings".to_owned(), json!([])); } } } @@ -60,6 +91,8 @@ pub(crate) async fn match_request( http_client: &Client, payload: &serde_json::Value, ) -> Result { + // TODO: remove this because overwriting result_actions with for example `"select": "best_match"` + // could skew the results let new_result_actions: serde_json::Value = serde_json::from_str(include_str!("result_actions_overwrite.json")).unwrap(); diff --git a/src/rossum-dmv2-analyze/src/export.rs b/src/rossum-dmv2-analyze/src/export.rs new file mode 100644 index 0000000000..80eae29cdb --- /dev/null +++ b/src/rossum-dmv2-analyze/src/export.rs @@ -0,0 +1,39 @@ +use csv; +use serde::Serialize; +use std::io; +use tempfile::NamedTempFile; + +#[derive(Serialize)] +pub struct CsvRecord { + pub operation: String, + pub document_original_file_name: String, + pub document_mime_type: String, + pub datapoint_id: String, + pub datapoint_value_content: String, + pub datapoint_value_options: String, +} + +pub struct CsvWriter { + writer: csv::Writer, +} + +impl CsvWriter { + pub fn new(file_path: &Option<&String>) -> Result { + let writer = match file_path { + Some(path) => csv::Writer::from_path(path)?, + None => { + let tmpfile = NamedTempFile::new()?; + csv::Writer::from_path(tmpfile)? + } + }; + Ok(Self { writer }) + } + + pub fn write_record(&mut self, record: &CsvRecord) -> Result<(), csv::Error> { + self.writer.serialize(record) + } + + pub fn flush(&mut self) -> Result<(), io::Error> { + self.writer.flush() + } +} diff --git a/src/rossum-dmv2-analyze/src/main.rs b/src/rossum-dmv2-analyze/src/main.rs index f67029f048..15770decc4 100644 --- a/src/rossum-dmv2-analyze/src/main.rs +++ b/src/rossum-dmv2-analyze/src/main.rs @@ -1,32 +1,20 @@ use crate::clap::generate_clap_app; +use crate::export::{CsvRecord, CsvWriter}; use crate::http::get_http_client; use crate::score::MessageCounts; use colored::*; use indicatif::{ProgressBar, ProgressStyle}; -use reqwest::{Client, Url}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; +mod api; mod clap; mod dmv2; +mod export; mod http; mod processor; mod score; -#[derive(Debug, Serialize, Deserialize)] -struct AnnotationResponse { - results: Vec, - pagination: AnnotationResponsePagination, -} - -#[derive(Debug, Serialize, Deserialize)] -struct AnnotationResponsePagination { - total: i32, - total_pages: i32, - next: Option, - previous: Option, -} - #[derive(Debug, Serialize, Deserialize)] struct DmConfig { todo: bool, @@ -39,10 +27,18 @@ struct Configurations { #[tokio::main] async fn main() -> anyhow::Result<()> { - let cli_matches = generate_clap_app().get_matches(); + let mut cli_matches = generate_clap_app().get_matches(); + let annotations_status: Vec = cli_matches + .remove_many("annotations-status") + .unwrap() + .collect(); let config_file = cli_matches - .get_one::("config-file") + .get_one::("dm-config-file") + .unwrap() + .clone(); + let export_file = cli_matches + .try_get_one::("export-results-file") .unwrap() .clone(); let api_token = cli_matches.get_one::("api-token").unwrap().clone(); @@ -61,8 +57,14 @@ async fn main() -> anyhow::Result<()> { let mut current_page = 1; // 1. fetch all relevant annotations (with pagination) - let mut annotations = - get_annotations(&queue_id, &http_client, &mut current_page, &concurrency).await?; + let mut annotations = api::annotations_get( + &http_client, + &queue_id, + &mut current_page, + &concurrency, + &annotations_status, + ) + .await?; println!( "Analyzing {} annotations…", @@ -73,6 +75,8 @@ async fn main() -> anyhow::Result<()> { pb.set_style(ProgressStyle::with_template("{wide_bar} {pos}/{len}\n{wide_msg}").unwrap()); pb.tick(); // to display empty bar + let mut csv_wtr = CsvWriter::new(&export_file)?; + loop { let mut handles = Vec::new(); @@ -82,6 +86,7 @@ async fn main() -> anyhow::Result<()> { let config_file = config_file.clone(); let api_token = api_token.clone(); let dm_hook_id = dm_hook_id.clone(); + let annotation = annotation.clone(); let tx = tx.clone(); handles.push(tokio::spawn(async move { let res = processor::process(config_file, api_token, dm_hook_id, annotation).await; @@ -93,16 +98,24 @@ async fn main() -> anyhow::Result<()> { let mut received_messages = 0; while received_messages < annotations_results_len { match rx.recv().await.unwrap() { - Ok((before, after)) => { + Ok(processor_result) => { + let response_document = api::get::( + &http_client, + &processor_result.annotation.document, + ) + .await?; + received_messages += 1; - before_total.one_match_found += before.one_match_found; - before_total.multiple_matches_found += before.multiple_matches_found; - before_total.no_match_found += before.no_match_found; + before_total.one_match_found += processor_result.before.one_match_found; + before_total.multiple_matches_found += + processor_result.before.multiple_matches_found; + before_total.no_match_found += processor_result.before.no_match_found; - after_total.one_match_found += after.one_match_found; - after_total.multiple_matches_found += after.multiple_matches_found; - after_total.no_match_found += after.no_match_found; + after_total.one_match_found += processor_result.after.one_match_found; + after_total.multiple_matches_found += + processor_result.after.multiple_matches_found; + after_total.no_match_found += processor_result.after.no_match_found; pb.inc(1); pb.set_message(format!( @@ -123,6 +136,25 @@ async fn main() -> anyhow::Result<()> { after_total.no_match_found.to_string().red().bold(), score::compare_solutions(&before_total, &after_total) )); + + for operation in processor_result.after_result.operations { + csv_wtr.write_record(&CsvRecord { + operation: operation.op, + datapoint_id: operation.id.to_string(), + datapoint_value_content: operation.value.content.value, + datapoint_value_options: operation + .value + .options + .iter() + .map(|option| option.value.clone()) + .collect::>() + .join("|"), + document_original_file_name: response_document + .original_file_name + .to_string(), + document_mime_type: response_document.mime_type.to_string(), + })?; + } } Err(_e) => { break; @@ -133,38 +165,20 @@ async fn main() -> anyhow::Result<()> { // 4. if there is a next page, increment the page number, otherwise, exit the loop if annotations.pagination.next.is_some() { current_page += 1; - annotations = - get_annotations(&queue_id, &http_client, &mut current_page, &concurrency).await?; + annotations = api::annotations_get( + &http_client, + &queue_id, + &mut current_page, + &concurrency, + &annotations_status, + ) + .await?; } else { break; } } + csv_wtr.flush()?; pb.finish(); Ok(()) } - -async fn get_annotations( - queue_id: &str, - http_client: &Client, - current_page: &mut i32, - concurrency: &usize, -) -> anyhow::Result { - Ok(http_client - .get(Url::parse_with_params( - "https://elis.rossum.ai/api/v1/annotations", - &[ - ("queue", queue_id.to_string()), - ("page", current_page.to_string()), - ("page_size", concurrency.to_string()), - ( - "status", - "confirmed,exported,exporting,reviewing,to_review".to_string(), - ), - ], - )?) - .send() - .await? - .json::() - .await?) -} diff --git a/src/rossum-dmv2-analyze/src/processor.rs b/src/rossum-dmv2-analyze/src/processor.rs index b651009755..b863b0c74f 100644 --- a/src/rossum-dmv2-analyze/src/processor.rs +++ b/src/rossum-dmv2-analyze/src/processor.rs @@ -1,5 +1,6 @@ +use crate::api::Annotation; use crate::dmv2; -use crate::dmv2::Message; +use crate::dmv2::{Message, Operation}; use crate::http::get_http_client; use crate::score::MessageCounts; use serde::{Deserialize, Serialize}; @@ -7,13 +8,17 @@ use serde_json::json; use std::fs::File; #[derive(Debug, Serialize, Deserialize)] -pub(crate) struct Annotation { - id: i32, +pub(crate) struct MatchResult { + messages: Vec, + pub(crate) operations: Vec, } -#[derive(Debug, Serialize, Deserialize)] -struct MatchResult { - messages: Vec, +pub(crate) struct ProcessorResult { + pub(crate) before: MessageCounts, + pub(crate) after: MessageCounts, + pub(crate) before_result: MatchResult, + pub(crate) after_result: MatchResult, + pub(crate) annotation: Annotation, } fn replace_settings( @@ -34,7 +39,7 @@ pub(crate) async fn process( api_token: String, dm_hook_id: String, annotation: Annotation, -) -> anyhow::Result<(MessageCounts, MessageCounts)> { +) -> anyhow::Result { let new_dm_config: serde_json::Value = serde_json::from_reader( File::open(config_file).expect("configuration file should open read only"), ) @@ -42,23 +47,24 @@ pub(crate) async fn process( let http_client = get_http_client(&api_token); + // TODO: preferably, we should remove the following /start and /cancel calls // 1.1. start the annotation (otherwise /generate_payload returns null datapoints) - let _ = http_client - .post(format!( - "https://elis.rossum.ai/api/v1/annotations/{}/start", - annotation.id - )) - .send() - .await?; - - // 1.2. cancel the annotation so users are not locked out of it - let _ = http_client - .post(format!( - "https://elis.rossum.ai/api/v1/annotations/{}/cancel", - annotation.id - )) - .send() - .await?; + // let _ = http_client + // .post(format!( + // "https://elis.rossum.ai/api/v1/annotations/{}/start", + // annotation.id + // )) + // .send() + // .await?; + // + // // 1.2. cancel the annotation so users are not locked out of it + // let _ = http_client + // .post(format!( + // "https://elis.rossum.ai/api/v1/annotations/{}/cancel", + // annotation.id + // )) + // .send() + // .await?; // 2. generate testing payload let generate_payload_res = http_client @@ -90,5 +96,11 @@ pub(crate) async fn process( let before = dmv2::process_dmv2_messages(&dm_result1.messages); let after = dmv2::process_dmv2_messages(&dm_result2.messages); - Ok((before, after)) + Ok(ProcessorResult { + before, + after, + before_result: dm_result1, + after_result: dm_result2, + annotation, + }) }