Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Everything working but printing on the lua side #7

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified artifacts/libffi.dylib
Binary file not shown.
8 changes: 5 additions & 3 deletions benches/benches/build_index.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::ffi::c_char;

use criterion::{criterion_group, criterion_main, Criterion};

extern "C" fn cb(_progress: f64) {}
extern "C" fn cb(_msg: *const c_char) {}

fn build_index() {
let mut index = rfsee_tf_idf::TfIdf::default();
index.par_load_rfcs(cb, cb).unwrap();
index.finish();
index.par_load_rfcs(cb).unwrap();
index.finish(cb);
let path = std::path::PathBuf::from("/tmp/bench_index.json");
index.save(&path)
}
Expand Down
25 changes: 17 additions & 8 deletions crates/cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{fs::File, path::PathBuf, time::Instant};
use std::{
ffi::{c_char, CStr},
fs::File,
path::PathBuf,
time::Instant,
};

use clap::{Parser, Subcommand};
use rfsee_tf_idf::{
Expand Down Expand Up @@ -27,12 +32,15 @@ enum Command {
},
}

extern "C" fn fetch_progress_cb(progress: f64) {
println!("Fetching RFCs progress: {progress:.2}%")
}
extern "C" fn print_c_char(ptr: *const c_char) {
if ptr.is_null() {
return;
}

extern "C" fn parse_progress_cb(progress: f64) {
println!("Parsing RFCs progress: {progress:.2}%")
let msg = unsafe { CStr::from_ptr(ptr) };
if let Ok(msg) = msg.to_str() {
println!("{msg}")
}
}

fn handle_command(args: Args) -> RFSeeResult<()> {
Expand All @@ -42,13 +50,14 @@ fn handle_command(args: Args) -> RFSeeResult<()> {
println!("Indexing RFCs");
let start = Instant::now();
let mut index = TfIdf::default();
index.par_load_rfcs(fetch_progress_cb, parse_progress_cb)?;
index.par_load_rfcs(print_c_char)?;
println!("Loading RFCs took {:?}", start.elapsed());
let building_index_start = Instant::now();
index.finish();
index.finish(print_c_char);
println!("Building index took {:?}", building_index_start.elapsed());
let saving_start = Instant::now();
let index_path = get_index_path(path)?;
println!("Saving index");
index.save(&index_path);
println!("Saving index took {:?}", saving_start.elapsed());
}
Expand Down
14 changes: 6 additions & 8 deletions crates/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,14 @@ struct RfcSearchResultsContainer {
}

#[no_mangle]
pub extern "C" fn build_index(
fetch_progress_cb: extern "C" fn(progress: f64),
parse_progress_cb: extern "C" fn(progress: f64),
) {
pub extern "C" fn build_index(progress_cb: extern "C" fn(msg: *const c_char)) {
let path = rfsee_tf_idf::get_index_path(None).unwrap();
let mut index = rfsee_tf_idf::TfIdf::default();
index
.par_load_rfcs(fetch_progress_cb, parse_progress_cb)
.unwrap();
index.finish();
index.par_load_rfcs(progress_cb).unwrap();
index.finish(progress_cb);
if let Ok(cstr) = CString::new("Saving index to disk") {
progress_cb(cstr.as_ptr())
}
index.save(&path);
}

Expand Down
40 changes: 29 additions & 11 deletions crates/tf_idf/src/index.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
collections::HashMap,
ffi::{c_char, CString},
path::{Path, PathBuf},
sync::{Arc, Mutex},
time::Duration,
Expand Down Expand Up @@ -142,8 +143,7 @@ impl TfIdf {
/// Load the RFCs in parallel using a threadpool
pub fn par_load_rfcs(
&mut self,
fetch_progress_cb: extern "C" fn(progress: f64),
parse_progress_cb: extern "C" fn(progress: f64),
progress_cb: extern "C" fn(progress: *const c_char),
) -> RFSeeResult<()> {
let pool = threadpool::ThreadPool::new(12);
let raw_rfc_index = fetch_rfc_index()?;
Expand All @@ -164,11 +164,6 @@ impl TfIdf {
if let Ok(r) = fetch_rfc(&string) {
let mut guard = parsed_rfcs.lock().unwrap();
guard.push(r);
let processed = guard.len();
if processed % 100 == 0 {
let progress = (processed as f64 / rfcs_count as f64) * 100_f64;
fetch_progress_cb(progress)
}
};
let mut guard = remaining.lock().unwrap();
*guard -= 1;
Expand All @@ -179,10 +174,16 @@ impl TfIdf {
while !finished {
let remaining = remaining.clone();
let guard = remaining.lock().unwrap();
// Need to log here, and not in the thread pool because we cant have different threads
// call the callback
if let Ok(msg) = CString::new(format!("{} remaining RFCs to fetch", *guard)) {
progress_cb(msg.as_ptr())
}
if *guard == 0 {
finished = true
} else {
drop(guard);
// Don't want to go crazy locking the Mutex, so we only check every 5 seconds
std::thread::sleep(Duration::from_secs(5));
}
}
Expand All @@ -194,7 +195,11 @@ impl TfIdf {
self.add_rfc_entry(rfc);
if i % 100 == 0 {
let progress = (i as f64 / rfcs_count as f64) * 100_f64;
parse_progress_cb(progress)
if let Ok(msg) =
CString::new(format!("Parse progress: {progress:0.0}%"))
{
progress_cb(msg.as_ptr())
}
}
}
}
Expand Down Expand Up @@ -248,7 +253,10 @@ impl TfIdf {

/// Take all the processed documents and their term frequencies to compute the final term
/// scores
pub fn finish(&mut self) {
pub fn finish(&mut self, progress_cb: extern "C" fn(*const c_char)) {
if let Ok(msg) = CString::new("Collecting terms") {
progress_cb(msg.as_ptr())
}
// First, we collect all terms and the number of docs they appear in
let mut term_counts: HashMap<&String, usize> = HashMap::new();
for indexed_rfc in self.processed_rfcs.values() {
Expand All @@ -261,6 +269,9 @@ impl TfIdf {
}
}

if let Ok(msg) = CString::new("Computing inverse document frequencies") {
progress_cb(msg.as_ptr())
}
// Then we compute the inverse document frequency for each term
let total_docs = self.processed_rfcs.len();
for (term, docs_with_term) in term_counts {
Expand All @@ -269,6 +280,9 @@ impl TfIdf {
self.idfs.insert(term.clone(), scaled);
}

if let Ok(msg) = CString::new("Scoring documents") {
progress_cb(msg.as_ptr())
}
// Then we compute the score for each term in all documents
self.processed_rfcs.iter().for_each(|(_doc, rfc)| {
for (doc_term, freq) in &rfc.term_freqs {
Expand Down Expand Up @@ -357,8 +371,12 @@ pub fn search_index(search: String, index: Index) -> Vec<RfcSearchResult> {

#[cfg(test)]
mod tests {
use std::ffi::c_char;

use super::{parse_rfc_index, RfcEntry, TfIdf};

extern "C" fn dummy_cb(_msg: *const c_char) {}

#[test]
fn test_parse_index() {
let index_contents = std::fs::read_to_string("../../data/rfc_index.txt").unwrap();
Expand All @@ -377,7 +395,7 @@ mod tests {
url: "https://www.rfsee.com/1".to_string(),
};
tf_idf.add_rfc_entry(entry);
tf_idf.finish();
tf_idf.finish(dummy_cb);

assert_eq!(tf_idf.index.rfc_details.len(), 1);
assert_eq!(tf_idf.index.term_scores.len(), 2);
Expand All @@ -398,7 +416,7 @@ mod tests {
url: "https://www.rfsee.com/1".to_string(),
};
tf_idf.add_rfc_entry(entry);
tf_idf.finish();
tf_idf.finish(dummy_cb);

assert_eq!(tf_idf.index.rfc_details.len(), 1);
// This should be 1 once we update parsing to treat "Hello" and "hello" the same
Expand Down
5 changes: 3 additions & 2 deletions lua/rfsee/ffi.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
local ffi = require("ffi")

ffi.cdef([[
typedef void (*progress_callback_t)(double progress);
void build_index(progress_callback_t fetch_cb, progress_callback_t parse_cb);
typedef void (*progress_callback_t)(const char* msg);
void build_index(progress_callback_t progress_cb);
void test_print(progress_callback_t progress_cb);

struct RfcSearchResult {
const char* url;
Expand Down
15 changes: 4 additions & 11 deletions lua/rfsee/index.lua
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,14 @@ function M.refresh()
local buf, win = window.create_progress_window()
window.update_progress_window(buf, "Building RFC index")

local function fetch_progress_cb(pct)
local msg = string.format("Downloading RFCs progress: %.1f%%", pct)
local function progress_cb(ptr)
local msg = ffi.string(ptr)
window.update_progress_window(buf, msg)
end

local progress_cb_c = ffi.cast("progress_callback_t", progress_cb)

local function parse_progress_cb(pct)
local msg = string.format("Parsing RFCs progress: %.1f%%", pct)
window.update_progress_window(buf, msg)
end

local fetch_progress_cb_c = ffi.cast("progress_callback_t", fetch_progress_cb)
local parse_progress_cb_c = ffi.cast("progress_callback_t", parse_progress_cb)

lib.build_index(fetch_progress_cb_c, parse_progress_cb_c)
lib.build_index(progress_cb_c)
local end_time = os.clock()
window.update_progress_window(buf, string.format("Built RFC index", end_time - start_time))
-- Brief pause before closing
Expand Down
6 changes: 5 additions & 1 deletion tests/generate-data/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::ffi::c_char;

use rfsee_tf_idf::{RfcEntry, TfIdf};

extern "C" fn dummy_cb(_msg: *const c_char) {}

fn main() {
let mut tf_idf = TfIdf::default();
let rfc1 = RfcEntry {
Expand All @@ -18,7 +22,7 @@ fn main() {
tf_idf.add_rfc_entry(rfc1);
tf_idf.add_rfc_entry(rfc2);

tf_idf.finish();
tf_idf.finish(dummy_cb);
let path = rfsee_tf_idf::get_index_path(None).unwrap();
tf_idf.save(&path);
}
Loading