Skip to content

Commit

Permalink
Rossum DMv2 analyze: add export functionality
Browse files Browse the repository at this point in the history
This allows downloading of the new results from DMv2.
  • Loading branch information
mrtnzlml committed Jul 19, 2023
1 parent 4441d26 commit bbe043c
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 87 deletions.
23 changes: 23 additions & 0 deletions src/rossum-dmv2-analyze/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/rossum-dmv2-analyze/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ edition = "2021"
anyhow = { version = "1.0.72" }
clap = { version = "4.3.16", default-features = false, features = ["std", "cargo", "color", "deprecated", "env", "error-context", "help", "suggestions", "unicode", "usage"] }
colored = { version = "2.0.4" }
csv = { version = "1.2.2" }
indicatif = { version = "0.17.5" }
reqwest = { version = "0.11.18", features = ["json"] }
serde = { version = "1.0.171", features = ["derive"] }
serde_json = { version = "1.0.103" }
tempfile = { version = "3.6.0" }
tokio = { version = "1.29.1", features = ["full"] }
2 changes: 1 addition & 1 deletion src/rossum-dmv2-analyze/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ brew install adeira/universe/rossum-dmv2-analyze
## Run

```bash
rossum-dmv2-analyze --config-file=./dmv2_config.json --dm-hook-id=252259 --queue-id=852015 --api-token=XXXXX
rossum-dmv2-analyze --dm-config-file=./dmv2_config.json --dm-hook-id=252259 --queue-id=852015 --api-token=XXXXX
```

Try running `rossum-dmv2-analyze --help` for more information.
Expand Down
64 changes: 64 additions & 0 deletions src/rossum-dmv2-analyze/src/api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use reqwest::{Client, Url};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Document {
pub annotations: Vec<String>,
pub id: u64,
pub mime_type: String,
pub original_file_name: String,
pub s3_name: String,
pub url: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Annotations {
pub results: Vec<Annotation>,
pub pagination: Pagination,
}

/// See: https://elis.rossum.ai/api/docs/#annotation
#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct Annotation {
pub id: i32,
pub document: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Pagination {
pub total: i32,
pub total_pages: i32,
pub next: Option<String>,
pub previous: Option<String>,
}

pub(crate) async fn get<T: DeserializeOwned>(
http_client: &Client,
url: &String,
) -> reqwest::Result<T> {
http_client.get(url).send().await?.json().await
}

pub(crate) async fn annotations_get(
http_client: &Client,
queue_id: &str,
current_page: &mut i32,
concurrency: &usize,
annotations_status: &Vec<String>,
) -> anyhow::Result<Annotations> {
Ok(http_client
.get(Url::parse_with_params(
"https://elis.rossum.ai/api/v1/annotations",
&[
("queue", queue_id.to_string()),
("page", current_page.to_string()),
("page_size", concurrency.to_string()),
("status", annotations_status.join(",").to_string()),
],
)?)
.send()
.await?
.json::<Annotations>()
.await?)
}
38 changes: 29 additions & 9 deletions src/rossum-dmv2-analyze/src/clap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@ use clap::{Arg, Command};

pub fn generate_clap_app() -> Command {
clap::command!()
.arg(
Arg::new("annotations-status")
.long("annotations-status")
.help("Comma-separated list of annotation statuses to analyze.")
.num_args(1)
.value_delimiter(',')
.required(false)
.default_value("confirmed,exported,exporting,reviewing,to_review")
)
.arg(
Arg::new("api-token")
.long("api-token")
.num_args(1)
.long_help(
"API token belonging to the organization that will be used for calling Rossum API endpoints.",
)
.required(true),
)
.arg(
Arg::new("concurrency")
.short('c')
Expand All @@ -12,8 +30,8 @@ pub fn generate_clap_app() -> Command {
.default_value("100"),
)
.arg(
Arg::new("config-file")
.long("config-file")
Arg::new("dm-config-file")
.long("dm-config-file")
.num_args(1)
.long_help(
"JSON config file with the new DMv2 configuration. Copy-paste here the whole \
Expand All @@ -29,20 +47,22 @@ pub fn generate_clap_app() -> Command {
.required(true),
)
.arg(
Arg::new("queue-id")
.long("queue-id")
Arg::new("export-results-file")
.long("export-results-file")
.num_args(1)
.long_help(
"Queue ID that will be used for the analysis of all relevant annotations.",
"Path to the file where DMv2 match (new) results should be exported. \
The only supported format is currently CSV. \
If the file already exists, it will be overwritten."
)
.required(true),
.required(false),
)
.arg(
Arg::new("api-token")
.long("api-token")
Arg::new("queue-id")
.long("queue-id")
.num_args(1)
.long_help(
"API token belonging to the organization that will be used for calling Rossum API endpoints.",
"Queue ID that will be used for the analysis of all relevant annotations.",
)
.required(true),
)
Expand Down
33 changes: 33 additions & 0 deletions src/rossum-dmv2-analyze/src/dmv2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,39 @@ use crate::score::MessageCounts;
use anyhow::Result;
use reqwest::{Client, Response};
use serde::{Deserialize, Serialize};
use serde_json::json;

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Message {
/// Datapoint ID
id: u64,
content: String,
r#type: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Operation {
/// Datapoint ID
pub(crate) id: u64,
pub(crate) op: String,
pub(crate) value: OperationValue,
}

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct OperationValue {
pub(crate) content: OperationValueContent,
pub(crate) options: Vec<OperationValueOptionValue>,
}

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct OperationValueContent {
pub(crate) value: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct OperationValueOptionValue {
label: String,
pub(crate) value: String,
}

fn replace_result_actions(
Expand All @@ -29,7 +58,9 @@ fn replace_result_actions(
// If this array item is an object:
if let Some(config_obj) = config_obj.as_object_mut() {
// Replace "result_actions" with `new_result_actions`:
// TODO: wrong (remove `result_actions` overwrite)
config_obj.insert("result_actions".to_owned(), new_result_actions.clone());
config_obj.insert("additional_mappings".to_owned(), json!([]));
}
}
}
Expand Down Expand Up @@ -60,6 +91,8 @@ pub(crate) async fn match_request(
http_client: &Client,
payload: &serde_json::Value,
) -> Result<Response> {
// TODO: remove this because overwriting result_actions with for example `"select": "best_match"`
// could skew the results
let new_result_actions: serde_json::Value =
serde_json::from_str(include_str!("result_actions_overwrite.json")).unwrap();

Expand Down
39 changes: 39 additions & 0 deletions src/rossum-dmv2-analyze/src/export.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use csv;
use serde::Serialize;
use std::io;
use tempfile::NamedTempFile;

#[derive(Serialize)]
pub struct CsvRecord {
pub operation: String,
pub document_original_file_name: String,
pub document_mime_type: String,
pub datapoint_id: String,
pub datapoint_value_content: String,
pub datapoint_value_options: String,
}

pub struct CsvWriter {
writer: csv::Writer<std::fs::File>,
}

impl CsvWriter {
pub fn new(file_path: &Option<&String>) -> Result<Self, csv::Error> {
let writer = match file_path {
Some(path) => csv::Writer::from_path(path)?,
None => {
let tmpfile = NamedTempFile::new()?;
csv::Writer::from_path(tmpfile)?
}
};
Ok(Self { writer })
}

pub fn write_record(&mut self, record: &CsvRecord) -> Result<(), csv::Error> {
self.writer.serialize(record)
}

pub fn flush(&mut self) -> Result<(), io::Error> {
self.writer.flush()
}
}
Loading

0 comments on commit bbe043c

Please sign in to comment.