diff --git a/difftastic/src/lib.rs b/difftastic/src/lib.rs index 758122539..824855805 100644 --- a/difftastic/src/lib.rs +++ b/difftastic/src/lib.rs @@ -129,6 +129,7 @@ pub fn generate_sidecar_diff( &DiffOptions::default(), &[], ); + dbg!(&diff_result); crate::display::side_by_side::print( diff_result.hunks.as_slice(), &display_options, @@ -763,7 +764,6 @@ fn diff_directories<'a>( } fn print_diff_result(display_options: &DisplayOptions, summary: &DiffResult) { - dbg!(&summary.hunks); match (&summary.lhs_src, &summary.rhs_src) { (FileContent::Text(lhs_src), FileContent::Text(rhs_src)) => { let hunks = &summary.hunks; diff --git a/llm_client/src/bin/codestory_provider.rs b/llm_client/src/bin/codestory_provider.rs index 8db23e6e3..232ca81b3 100644 --- a/llm_client/src/bin/codestory_provider.rs +++ b/llm_client/src/bin/codestory_provider.rs @@ -15,7 +15,7 @@ async fn main() { let codestory_client = CodeStoryClient::new("http://localhost:8080"); let (sender, _receiver) = tokio::sync::mpsc::unbounded_channel(); let request = LLMClientCompletionRequest::new( - LLMType::ClaudeOpus, + LLMType::Gpt4, vec![ LLMClientMessage::system("you are a python expert".to_owned()), LLMClientMessage::user("Can you write 1 to 300 in a new line for me".to_owned()), diff --git a/sidecar/src/bin/testing.rs b/sidecar/src/bin/testing.rs new file mode 100644 index 000000000..e018139bc --- /dev/null +++ b/sidecar/src/bin/testing.rs @@ -0,0 +1,404 @@ +use std::path::PathBuf; + +use regex::Regex; +use tracing::debug; + +use crate::repo::types::RepoRef; + +use super::{ + javascript::javascript_language_config, + languages::TSLanguageConfig, + python::python_language_config, + rust::rust_language_config, + text_document::{DocumentSymbol, Position, Range, TextDocument}, + types::FunctionInformation, + typescript::typescript_language_config, +}; + +/// Here we will parse the document we get from the editor using symbol level +/// information, as its very fast + +#[derive(Debug, Clone)] +pub struct EditorParsing { + configs: Vec, +} + +impl Default for EditorParsing { + fn default() -> Self { + Self { + configs: vec![ + rust_language_config(), + javascript_language_config(), + typescript_language_config(), + python_language_config(), + ], + } + } +} + +impl EditorParsing { + pub fn ts_language_config(&self, language: &str) -> Option<&TSLanguageConfig> { + self.configs + .iter() + .find(|config| config.language_ids.contains(&language)) + } + + pub fn for_file_path(&self, file_path: &str) -> Option<&TSLanguageConfig> { + let file_path = PathBuf::from(file_path); + let file_extension = file_path + .extension() + .map(|extension| extension.to_str()) + .map(|extension| extension.to_owned()) + .flatten(); + match file_extension { + Some(extension) => self + .configs + .iter() + .find(|config| config.file_extensions.contains(&extension)), + None => None, + } + } + + fn is_node_identifier( + &self, + node: &tree_sitter::Node, + language_config: &TSLanguageConfig, + ) -> bool { + match language_config + .language_ids + .first() + .expect("language_id to be present") + .to_lowercase() + .as_ref() + { + "typescript" | "typescriptreact" | "javascript" | "javascriptreact" => { + Regex::new(r"(definition|declaration|declarator|export_statement)") + .unwrap() + .is_match(node.kind()) + } + "golang" => Regex::new(r"(definition|declaration|declarator|var_spec)") + .unwrap() + .is_match(node.kind()), + "cpp" => Regex::new(r"(definition|declaration|declarator|class_specifier)") + .unwrap() + .is_match(node.kind()), + "ruby" => Regex::new(r"(module|class|method|assignment)") + .unwrap() + .is_match(node.kind()), + "rust" => Regex::new(r"(item)").unwrap().is_match(node.kind()), + _ => Regex::new(r"(definition|declaration|declarator)") + .unwrap() + .is_match(node.kind()), + } + } + + /** + * This function aims to process nodes from a tree sitter parsed structure + * based on their intersection with a given range and identify nodes that + * represent declarations or definitions specific to a programming language. + * + * @param {Object} t - The tree sitter node. + * @param {Object} e - The range (or point structure) with which intersections are checked. + * @param {string} r - The programming language (e.g., "typescript", "golang"). + * + * @return {Object|undefined} - Returns the most relevant node or undefined. + */ + // function KX(t, e, r) { + // // Initial setup with the root node and an empty list for potential matches + // let n = [t.rootNode], i = []; + + // while (true) { + // // For each node in 'n', calculate its intersection size with 'e' + // let o = n.map(s => [s, rs.intersectionSize(s, e)]) + // .filter(([s, a]) => a > 0) + // .sort(([s, a], [l, c]) => c - a); // sort in decreasing order of intersection size + + // // If there are no intersections, either return undefined or the most relevant node from 'i' + // if (o.length === 0) return i.length === 0 ? void 0 : tX(i, ([s, a], [l, c]) => a - c)[0]; + + // // For the nodes in 'o', calculate a relevance score and filter the ones that are declarations or definitions for language 'r' + // let s = o.map(([a, l]) => { + // let c = rs.len(a), // Length of the node + // u = Math.abs(rs.len(e) - l), // Difference between length of 'e' and its intersection size + // p = (l - u) / c; // Relevance score + // return [a, p]; + // }); + + // // Filter nodes based on the ZL function and push to 'i' + // i.push(...s.filter(([a, l]) => ZL(a, r))); + + // // Prepare for the next iteration by setting 'n' to the children of the nodes in 'o' + // n = []; + // n.push(...s.flatMap(([a, l]) => a.children)); + // } + // } + fn get_identifier_node_fully_contained<'a>( + &'a self, + tree_sitter_node: tree_sitter::Node<'a>, + range: &'a Range, + language_config: &'a TSLanguageConfig, + source_str: &str, + ) -> Option> { + let mut nodes = vec![tree_sitter_node]; + let mut identifier_nodes: Vec<(tree_sitter::Node, f64)> = vec![]; + loop { + // Here we take the nodes in [nodes] which have an intersection + // with the range we are interested in + let mut intersecting_nodes = nodes + .into_iter() + .map(|tree_sitter_node| { + ( + tree_sitter_node, + Range::for_tree_node(&tree_sitter_node).intersection_size(range) as f64, + ) + }) + .filter(|(_, intersection_size)| intersection_size > &0.0) + .collect::>(); + // we sort the nodes by their intersection size + // we want to keep the biggest size here on the top + intersecting_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("partial_cmp to work")); + + // if there are no nodes, then we return none or the most relevant nodes + // from i, which is the biggest node here + if intersecting_nodes.is_empty() { + return if identifier_nodes.is_empty() { + None + } else { + Some({ + let mut current_node = identifier_nodes[0]; + for identifier_node in &identifier_nodes[1..] { + if identifier_node.1 - current_node.1 > 0.0 { + current_node = identifier_node.clone(); + } + } + current_node.0 + }) + }; + } + + // For the nodes in intersecting_nodes, calculate a relevance score and filter the ones that are declarations or definitions for language 'language_config' + let identifier_nodes_sorted = intersecting_nodes + .iter() + .map(|(tree_sitter_node, intersection_size)| { + let len = Range::for_tree_node(&tree_sitter_node).len(); + let diff = ((range.len() as f64 - intersection_size) as f64).abs(); + let relevance_score = (intersection_size - diff) as f64 / len as f64; + (tree_sitter_node.clone(), relevance_score) + }) + .collect::>(); + + // now we filter out the nodes which are here based on the identifier function and set it to identifier nodes + // which we want to find for documentation + identifier_nodes.extend( + identifier_nodes_sorted + .into_iter() + .filter(|(tree_sitter_node, _)| { + self.is_node_identifier(tree_sitter_node, language_config) + }) + .map(|(tree_sitter_node, score)| (tree_sitter_node, score)) + .collect::>(), + ); + + // Now we prepare for the next iteration by setting nodes to the children of the nodes + // in intersecting_nodes + nodes = intersecting_nodes + .into_iter() + .map(|(tree_sitter_node, _)| { + let mut cursor = tree_sitter_node.walk(); + tree_sitter_node.children(&mut cursor).collect::>() + }) + .flatten() + .collect::>(); + } + } + + fn get_identifier_node_by_expanding<'a>( + &'a self, + tree_sitter_node: tree_sitter::Node<'a>, + range: &Range, + language_config: &TSLanguageConfig, + ) -> Option> { + let tree_sitter_range = range.to_tree_sitter_range(); + let mut expanding_node = tree_sitter_node + .descendant_for_byte_range(tree_sitter_range.start_byte, tree_sitter_range.end_byte); + loop { + // Here we expand this node until we hit a identifier node, this is + // a very easy way to get to the best node we are interested in by + // bubbling up +<<<<<<< + if expanding_node.is_none() { + return None; + } + match expanding_node { + Some(expanding_node_val) => { + // if this is not a identifier and the parent is there, we keep + // going up +======= + while let Some(expanding_node_val) = expanding_node { +>>>>>>> +<<<<<<< + if !self.is_node_identifier(&expanding_node_val, &language_config) +======= + if self.is_node_identifier(&expanding_node_val, &language_config) { +>>>>>>> +<<<<<<< + && expanding_node_val.parent().is_some() + { +======= + return Some(expanding_node_val.clone()); +>>>>>>> +<<<<<<< + expanding_node = expanding_node_val.parent() + // if we have a node identifier, return right here! + } else if self.is_node_identifier(&expanding_node_val, &language_config) { +======= + } else if let Some(parent) = expanding_node_val.parent() { +>>>>>>> +<<<<<<< + return Some(expanding_node_val.clone()); +======= + expanding_node = Some(parent); +>>>>>>> + } else { + // so we don't have a node identifier and neither a parent, so + // just return None + return None; + } + } + None => { + return None; + } + } + } +<<<<<<< +======= + None +>>>>>>> + } + + pub fn get_documentation_node( + &self, + text_document: &TextDocument, + language_config: &TSLanguageConfig, + range: Range, + ) -> Vec { + let language = language_config.grammar; + let source = text_document.get_content_buffer(); + let mut parser = tree_sitter::Parser::new(); + parser.set_language(language()).unwrap(); + let tree = parser + .parse(text_document.get_content_buffer().as_bytes(), None) + .unwrap(); + if let Some(identifier_node) = self.get_identifier_node_fully_contained( + tree.root_node(), + &range, + &language_config, + source, + ) { + // we have a identifier node right here, so lets get the document symbol + // for this and return it back + return DocumentSymbol::from_tree_node( + &identifier_node, + language_config, + text_document.get_content_buffer(), + ) + .into_iter() + .collect(); + } + // or else we try to expand the node out so we can get a symbol back + if let Some(expanded_node) = + self.get_identifier_node_by_expanding(tree.root_node(), &range, &language_config) + { + // we get the expanded node here again + return DocumentSymbol::from_tree_node( + &expanded_node, + language_config, + text_document.get_content_buffer(), + ) + .into_iter() + .collect(); + } + // or else we return nothing here + vec![] + } + + pub fn get_documentation_node_for_range( + &self, + source_str: &str, + language: &str, + relative_path: &str, + fs_file_path: &str, + start_position: &Position, + end_position: &Position, + repo_ref: &RepoRef, + ) -> Vec { + // First we need to find the language config which matches up with + // the language we are interested in + let language_config = self.ts_language_config(&language); + if let None = language_config { + return vec![]; + } + // okay now we have a language config, lets parse it + self.get_documentation_node( + &TextDocument::new( + source_str.to_owned(), + repo_ref.clone(), + fs_file_path.to_owned(), + relative_path.to_owned(), + ), + &language_config.expect("if let None check above to work"), + Range::new(start_position.clone(), end_position.clone()), + ) + } + + pub fn function_information_nodes( + &self, + source_code: &[u8], + language: &str, + ) -> Vec { + let language_config = self.ts_language_config(&language); + if let None = language_config { + return vec![]; + } + language_config + .expect("if let None check above") + .function_information_nodes(source_code) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + chunking::{ + languages::TSLanguageParsing, + text_document::{Position, Range, TextDocument}, + }, + repo::types::RepoRef, + }; + + use super::EditorParsing; + + #[test] + fn rust_selection_parsing() { + let editor_parsing = EditorParsing::default(); + // This is from the configuration file + let source_str = "use std::{\n num::NonZeroUsize,\n path::{Path, PathBuf},\n};\n\nuse clap::Parser;\nuse serde::{Deserialize, Serialize};\n\nuse crate::repo::state::StateSource;\n\n#[derive(Serialize, Deserialize, Parser, Debug, Clone)]\n#[clap(author, version, about, long_about = None)]\npub struct Configuration {\n #[clap(short, long, default_value_os_t = default_index_dir())]\n #[serde(default = \"default_index_dir\")]\n /// Directory to store all persistent state\n pub index_dir: PathBuf,\n\n #[clap(long, default_value_t = default_port())]\n #[serde(default = \"default_port\")]\n /// Bind the webserver to ``\n pub port: u16,\n\n #[clap(long)]\n /// Path to the embedding model directory\n pub model_dir: PathBuf,\n\n #[clap(long, default_value_t = default_host())]\n #[serde(default = \"default_host\")]\n /// Bind the webserver to ``\n pub host: String,\n\n #[clap(flatten)]\n #[serde(default)]\n pub state_source: StateSource,\n\n #[clap(short, long, default_value_t = default_parallelism())]\n #[serde(default = \"default_parallelism\")]\n /// Maximum number of parallel background threads\n pub max_threads: usize,\n\n #[clap(short, long, default_value_t = default_buffer_size())]\n #[serde(default = \"default_buffer_size\")]\n /// Size of memory to use for file indexes\n pub buffer_size: usize,\n\n /// Qdrant url here can be mentioned if we are running it remotely or have\n /// it running on its own process\n #[clap(long)]\n #[serde(default = \"default_qdrant_url\")]\n pub qdrant_url: String,\n\n /// The folder where the qdrant binary is present so we can start the server\n /// and power the qdrant client\n #[clap(short, long)]\n pub qdrant_binary_directory: Option,\n\n /// The location for the dylib directory where we have the runtime binaries\n /// required for ort\n #[clap(short, long)]\n pub dylib_directory: PathBuf,\n\n /// Qdrant allows us to create collections and we need to provide it a default\n /// value to start with\n #[clap(short, long, default_value_t = default_collection_name())]\n #[serde(default = \"default_collection_name\")]\n pub collection_name: String,\n\n #[clap(long, default_value_t = interactive_batch_size())]\n #[serde(default = \"interactive_batch_size\")]\n /// Batch size for batched embeddings\n pub embedding_batch_len: NonZeroUsize,\n\n #[clap(long, default_value_t = default_user_id())]\n #[serde(default = \"default_user_id\")]\n user_id: String,\n\n /// If we should poll the local repo for updates auto-magically. Disabled\n /// by default, until we figure out the delta sync method where we only\n /// reindex the files which have changed\n #[clap(long)]\n pub enable_background_polling: bool,\n}\n\nimpl Configuration {\n /// Directory where logs are written to\n pub fn log_dir(&self) -> PathBuf {\n self.index_dir.join(\"logs\")\n }\n\n pub fn index_path(&self, name: impl AsRef) -> impl AsRef {\n self.index_dir.join(name)\n }\n\n pub fn qdrant_storage(&self) -> PathBuf {\n self.index_dir.join(\"qdrant_storage\")\n }\n}\n\nfn default_index_dir() -> PathBuf {\n match directories::ProjectDirs::from(\"ai\", \"codestory\", \"sidecar\") {\n Some(dirs) => dirs.data_dir().to_owned(),\n None => \"codestory_sidecar\".into(),\n }\n}\n\nfn default_port() -> u16 {\n 42424\n}\n\nfn default_host() -> String {\n \"127.0.0.1\".to_owned()\n}\n\npub fn default_parallelism() -> usize {\n std::thread::available_parallelism().unwrap().get()\n}\n\nconst fn default_buffer_size() -> usize {\n 100_000_000\n}\n\nfn default_collection_name() -> String {\n \"codestory\".to_owned()\n}\n\nfn interactive_batch_size() -> NonZeroUsize {\n NonZeroUsize::new(1).unwrap()\n}\n\nfn default_qdrant_url() -> String {\n \"http://127.0.0.1:6334\".to_owned()\n}\n\nfn default_user_id() -> String {\n \"codestory\".to_owned()\n}\n"; + let range = Range::new(Position::new(134, 7, 3823), Position::new(137, 0, 3878)); + let ts_lang_parsing = TSLanguageParsing::init(); + let rust_config = ts_lang_parsing.for_lang("rust"); + let mut documentation_nodes = editor_parsing.get_documentation_node( + &TextDocument::new( + source_str.to_owned(), + RepoRef::local("/Users/skcd/testing/").expect("test to work"), + "".to_owned(), + "".to_owned(), + ), + &rust_config.expect("rust config to be present"), + range, + ); + assert!(!documentation_nodes.is_empty()); + let first_entry = documentation_nodes.remove(0); + assert_eq!(first_entry.start_position, Position::new(134, 0, 3816)); + assert_eq!(first_entry.end_position, Position::new(136, 1, 3877)); + } +} \ No newline at end of file diff --git a/sidecar/src/chunking/languages.rs b/sidecar/src/chunking/languages.rs index 4fb3cc7a4..ab2f92df5 100644 --- a/sidecar/src/chunking/languages.rs +++ b/sidecar/src/chunking/languages.rs @@ -1724,6 +1724,7 @@ mod tests { use crate::chunking::types::FunctionNodeType; use super::naive_chunker; + use super::TSLanguageConfig; use super::TSLanguageParsing; fn get_naive_chunking_test_string<'a>() -> &'a str { @@ -2357,4 +2358,398 @@ trait SomethingTrait { let outline = ts_language_config.generate_outline(source_code.as_bytes(), &tree); assert_eq!(outline.len(), 6); } + + #[test] + fn test_class_with_functions_parsing() { + let source_code = r#" + use std::path::PathBuf; + + use regex::Regex; + use tracing::debug; + + use crate::repo::types::RepoRef; + + use super::{ + javascript::javascript_language_config, + languages::TSLanguageConfig, + python::python_language_config, + rust::rust_language_config, + text_document::{DocumentSymbol, Position, Range, TextDocument}, + types::FunctionInformation, + typescript::typescript_language_config, + }; + + /// Here we will parse the document we get from the editor using symbol level + /// information, as its very fast + + #[derive(Debug, Clone)] + pub struct EditorParsing { + configs: Vec, + } + + impl Default for EditorParsing { + fn default() -> Self { + Self { + configs: vec![ + rust_language_config(), + javascript_language_config(), + typescript_language_config(), + python_language_config(), + ], + } + } + } + + impl EditorParsing { + pub fn ts_language_config(&self, language: &str) -> Option<&TSLanguageConfig> { + self.configs + .iter() + .find(|config| config.language_ids.contains(&language)) + } + + pub fn for_file_path(&self, file_path: &str) -> Option<&TSLanguageConfig> { + let file_path = PathBuf::from(file_path); + let file_extension = file_path + .extension() + .map(|extension| extension.to_str()) + .map(|extension| extension.to_owned()) + .flatten(); + match file_extension { + Some(extension) => self + .configs + .iter() + .find(|config| config.file_extensions.contains(&extension)), + None => None, + } + } + + fn is_node_identifier( + &self, + node: &tree_sitter::Node, + language_config: &TSLanguageConfig, + ) -> bool { + match language_config + .language_ids + .first() + .expect("language_id to be present") + .to_lowercase() + .as_ref() + { + "typescript" | "typescriptreact" | "javascript" | "javascriptreact" => { + Regex::new(r"(definition|declaration|declarator|export_statement)") + .unwrap() + .is_match(node.kind()) + } + "golang" => Regex::new(r"(definition|declaration|declarator|var_spec)") + .unwrap() + .is_match(node.kind()), + "cpp" => Regex::new(r"(definition|declaration|declarator|class_specifier)") + .unwrap() + .is_match(node.kind()), + "ruby" => Regex::new(r"(module|class|method|assignment)") + .unwrap() + .is_match(node.kind()), + "rust" => Regex::new(r"(item)").unwrap().is_match(node.kind()), + _ => Regex::new(r"(definition|declaration|declarator)") + .unwrap() + .is_match(node.kind()), + } + } + + /** + * This function aims to process nodes from a tree sitter parsed structure + * based on their intersection with a given range and identify nodes that + * represent declarations or definitions specific to a programming language. + * + * @param {Object} t - The tree sitter node. + * @param {Object} e - The range (or point structure) with which intersections are checked. + * @param {string} r - The programming language (e.g., "typescript", "golang"). + * + * @return {Object|undefined} - Returns the most relevant node or undefined. + */ + // function KX(t, e, r) { + // // Initial setup with the root node and an empty list for potential matches + // let n = [t.rootNode], i = []; + + // while (true) { + // // For each node in 'n', calculate its intersection size with 'e' + // let o = n.map(s => [s, rs.intersectionSize(s, e)]) + // .filter(([s, a]) => a > 0) + // .sort(([s, a], [l, c]) => c - a); // sort in decreasing order of intersection size + + // // If there are no intersections, either return undefined or the most relevant node from 'i' + // if (o.length === 0) return i.length === 0 ? void 0 : tX(i, ([s, a], [l, c]) => a - c)[0]; + + // // For the nodes in 'o', calculate a relevance score and filter the ones that are declarations or definitions for language 'r' + // let s = o.map(([a, l]) => { + // let c = rs.len(a), // Length of the node + // u = Math.abs(rs.len(e) - l), // Difference between length of 'e' and its intersection size + // p = (l - u) / c; // Relevance score + // return [a, p]; + // }); + + // // Filter nodes based on the ZL function and push to 'i' + // i.push(...s.filter(([a, l]) => ZL(a, r))); + + // // Prepare for the next iteration by setting 'n' to the children of the nodes in 'o' + // n = []; + // n.push(...s.flatMap(([a, l]) => a.children)); + // } + // } + fn get_identifier_node_fully_contained<'a>( + &'a self, + tree_sitter_node: tree_sitter::Node<'a>, + range: &'a Range, + language_config: &'a TSLanguageConfig, + source_str: &str, + ) -> Option> { + let mut nodes = vec![tree_sitter_node]; + let mut identifier_nodes: Vec<(tree_sitter::Node, f64)> = vec![]; + loop { + // Here we take the nodes in [nodes] which have an intersection + // with the range we are interested in + let mut intersecting_nodes = nodes + .into_iter() + .map(|tree_sitter_node| { + ( + tree_sitter_node, + Range::for_tree_node(&tree_sitter_node).intersection_size(range) as f64, + ) + }) + .filter(|(_, intersection_size)| intersection_size > &0.0) + .collect::>(); + // we sort the nodes by their intersection size + // we want to keep the biggest size here on the top + intersecting_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("partial_cmp to work")); + + // if there are no nodes, then we return none or the most relevant nodes + // from i, which is the biggest node here + if intersecting_nodes.is_empty() { + return if identifier_nodes.is_empty() { + None + } else { + Some({ + let mut current_node = identifier_nodes[0]; + for identifier_node in &identifier_nodes[1..] { + if identifier_node.1 - current_node.1 > 0.0 { + current_node = identifier_node.clone(); + } + } + current_node.0 + }) + }; + } + + // For the nodes in intersecting_nodes, calculate a relevance score and filter the ones that are declarations or definitions for language 'language_config' + let identifier_nodes_sorted = intersecting_nodes + .iter() + .map(|(tree_sitter_node, intersection_size)| { + let len = Range::for_tree_node(&tree_sitter_node).len(); + let diff = ((range.len() as f64 - intersection_size) as f64).abs(); + let relevance_score = (intersection_size - diff) as f64 / len as f64; + (tree_sitter_node.clone(), relevance_score) + }) + .collect::>(); + + // now we filter out the nodes which are here based on the identifier function and set it to identifier nodes + // which we want to find for documentation + identifier_nodes.extend( + identifier_nodes_sorted + .into_iter() + .filter(|(tree_sitter_node, _)| { + self.is_node_identifier(tree_sitter_node, language_config) + }) + .map(|(tree_sitter_node, score)| (tree_sitter_node, score)) + .collect::>(), + ); + + // Now we prepare for the next iteration by setting nodes to the children of the nodes + // in intersecting_nodes + nodes = intersecting_nodes + .into_iter() + .map(|(tree_sitter_node, _)| { + let mut cursor = tree_sitter_node.walk(); + tree_sitter_node.children(&mut cursor).collect::>() + }) + .flatten() + .collect::>(); + } + } + + fn get_identifier_node_by_expanding<'a>( + &'a self, + tree_sitter_node: tree_sitter::Node<'a>, + range: &Range, + language_config: &TSLanguageConfig, + ) -> Option> { + let tree_sitter_range = range.to_tree_sitter_range(); + let mut expanding_node = tree_sitter_node + .descendant_for_byte_range(tree_sitter_range.start_byte, tree_sitter_range.end_byte); + loop { + // Here we expand this node until we hit a identifier node, this is + // a very easy way to get to the best node we are interested in by + // bubbling up + if expanding_node.is_none() { + return None; + } + match expanding_node { + Some(expanding_node_val) => { + // if this is not a identifier and the parent is there, we keep + // going up + if !self.is_node_identifier(&expanding_node_val, &language_config) + && expanding_node_val.parent().is_some() + { + expanding_node = expanding_node_val.parent() + // if we have a node identifier, return right here! + } else if self.is_node_identifier(&expanding_node_val, &language_config) { + return Some(expanding_node_val.clone()); + } else { + // so we don't have a node identifier and neither a parent, so + // just return None + return None; + } + } + None => { + return None; + } + } + } + } + + pub fn get_documentation_node( + &self, + text_document: &TextDocument, + language_config: &TSLanguageConfig, + range: Range, + ) -> Vec { + let language = language_config.grammar; + let source = text_document.get_content_buffer(); + let mut parser = tree_sitter::Parser::new(); + parser.set_language(language()).unwrap(); + let tree = parser + .parse(text_document.get_content_buffer().as_bytes(), None) + .unwrap(); + if let Some(identifier_node) = self.get_identifier_node_fully_contained( + tree.root_node(), + &range, + &language_config, + source, + ) { + // we have a identifier node right here, so lets get the document symbol + // for this and return it back + return DocumentSymbol::from_tree_node( + &identifier_node, + language_config, + text_document.get_content_buffer(), + ) + .into_iter() + .collect(); + } + // or else we try to expand the node out so we can get a symbol back + if let Some(expanded_node) = + self.get_identifier_node_by_expanding(tree.root_node(), &range, &language_config) + { + // we get the expanded node here again + return DocumentSymbol::from_tree_node( + &expanded_node, + language_config, + text_document.get_content_buffer(), + ) + .into_iter() + .collect(); + } + // or else we return nothing here + vec![] + } + + pub fn get_documentation_node_for_range( + &self, + source_str: &str, + language: &str, + relative_path: &str, + fs_file_path: &str, + start_position: &Position, + end_position: &Position, + repo_ref: &RepoRef, + ) -> Vec { + // First we need to find the language config which matches up with + // the language we are interested in + let language_config = self.ts_language_config(&language); + if let None = language_config { + return vec![]; + } + // okay now we have a language config, lets parse it + self.get_documentation_node( + &TextDocument::new( + source_str.to_owned(), + repo_ref.clone(), + fs_file_path.to_owned(), + relative_path.to_owned(), + ), + &language_config.expect("if let None check above to work"), + Range::new(start_position.clone(), end_position.clone()), + ) + } + + pub fn function_information_nodes( + &self, + source_code: &[u8], + language: &str, + ) -> Vec { + let language_config = self.ts_language_config(&language); + if let None = language_config { + return vec![]; + } + language_config + .expect("if let None check above") + .function_information_nodes(source_code) + } + } + + #[cfg(test)] + mod tests { + use crate::{ + chunking::{ + languages::TSLanguageParsing, + text_document::{Position, Range, TextDocument}, + }, + repo::types::RepoRef, + }; + + use super::EditorParsing; + + #[test] + fn rust_selection_parsing() { + let editor_parsing = EditorParsing::default(); + // This is from the configuration file + let source_str = "use std::{\n num::NonZeroUsize,\n path::{Path, PathBuf},\n};\n\nuse clap::Parser;\nuse serde::{Deserialize, Serialize};\n\nuse crate::repo::state::StateSource;\n\n#[derive(Serialize, Deserialize, Parser, Debug, Clone)]\n#[clap(author, version, about, long_about = None)]\npub struct Configuration {\n #[clap(short, long, default_value_os_t = default_index_dir())]\n #[serde(default = \"default_index_dir\")]\n /// Directory to store all persistent state\n pub index_dir: PathBuf,\n\n #[clap(long, default_value_t = default_port())]\n #[serde(default = \"default_port\")]\n /// Bind the webserver to ``\n pub port: u16,\n\n #[clap(long)]\n /// Path to the embedding model directory\n pub model_dir: PathBuf,\n\n #[clap(long, default_value_t = default_host())]\n #[serde(default = \"default_host\")]\n /// Bind the webserver to ``\n pub host: String,\n\n #[clap(flatten)]\n #[serde(default)]\n pub state_source: StateSource,\n\n #[clap(short, long, default_value_t = default_parallelism())]\n #[serde(default = \"default_parallelism\")]\n /// Maximum number of parallel background threads\n pub max_threads: usize,\n\n #[clap(short, long, default_value_t = default_buffer_size())]\n #[serde(default = \"default_buffer_size\")]\n /// Size of memory to use for file indexes\n pub buffer_size: usize,\n\n /// Qdrant url here can be mentioned if we are running it remotely or have\n /// it running on its own process\n #[clap(long)]\n #[serde(default = \"default_qdrant_url\")]\n pub qdrant_url: String,\n\n /// The folder where the qdrant binary is present so we can start the server\n /// and power the qdrant client\n #[clap(short, long)]\n pub qdrant_binary_directory: Option,\n\n /// The location for the dylib directory where we have the runtime binaries\n /// required for ort\n #[clap(short, long)]\n pub dylib_directory: PathBuf,\n\n /// Qdrant allows us to create collections and we need to provide it a default\n /// value to start with\n #[clap(short, long, default_value_t = default_collection_name())]\n #[serde(default = \"default_collection_name\")]\n pub collection_name: String,\n\n #[clap(long, default_value_t = interactive_batch_size())]\n #[serde(default = \"interactive_batch_size\")]\n /// Batch size for batched embeddings\n pub embedding_batch_len: NonZeroUsize,\n\n #[clap(long, default_value_t = default_user_id())]\n #[serde(default = \"default_user_id\")]\n user_id: String,\n\n /// If we should poll the local repo for updates auto-magically. Disabled\n /// by default, until we figure out the delta sync method where we only\n /// reindex the files which have changed\n #[clap(long)]\n pub enable_background_polling: bool,\n}\n\nimpl Configuration {\n /// Directory where logs are written to\n pub fn log_dir(&self) -> PathBuf {\n self.index_dir.join(\"logs\")\n }\n\n pub fn index_path(&self, name: impl AsRef) -> impl AsRef {\n self.index_dir.join(name)\n }\n\n pub fn qdrant_storage(&self) -> PathBuf {\n self.index_dir.join(\"qdrant_storage\")\n }\n}\n\nfn default_index_dir() -> PathBuf {\n match directories::ProjectDirs::from(\"ai\", \"codestory\", \"sidecar\") {\n Some(dirs) => dirs.data_dir().to_owned(),\n None => \"codestory_sidecar\".into(),\n }\n}\n\nfn default_port() -> u16 {\n 42424\n}\n\nfn default_host() -> String {\n \"127.0.0.1\".to_owned()\n}\n\npub fn default_parallelism() -> usize {\n std::thread::available_parallelism().unwrap().get()\n}\n\nconst fn default_buffer_size() -> usize {\n 100_000_000\n}\n\nfn default_collection_name() -> String {\n \"codestory\".to_owned()\n}\n\nfn interactive_batch_size() -> NonZeroUsize {\n NonZeroUsize::new(1).unwrap()\n}\n\nfn default_qdrant_url() -> String {\n \"http://127.0.0.1:6334\".to_owned()\n}\n\nfn default_user_id() -> String {\n \"codestory\".to_owned()\n}\n"; + let range = Range::new(Position::new(134, 7, 3823), Position::new(137, 0, 3878)); + let ts_lang_parsing = TSLanguageParsing::init(); + let rust_config = ts_lang_parsing.for_lang("rust"); + let mut documentation_nodes = editor_parsing.get_documentation_node( + &TextDocument::new( + source_str.to_owned(), + RepoRef::local("/Users/skcd/testing/").expect("test to work"), + "".to_owned(), + "".to_owned(), + ), + &rust_config.expect("rust config to be present"), + range, + ); + assert!(!documentation_nodes.is_empty()); + let first_entry = documentation_nodes.remove(0); + assert_eq!(first_entry.start_position, Position::new(134, 0, 3816)); + assert_eq!(first_entry.end_position, Position::new(136, 1, 3877)); + } + } + "#; + let language = "rust"; + let tree_sitter_parsing = TSLanguageParsing::init(); + let ts_language_config = tree_sitter_parsing + .for_lang(language) + .expect("test to work"); + let file_symbols = ts_language_config.generate_file_symbols(source_code.as_bytes()); + dbg!(&file_symbols); + assert!(false); + } } diff --git a/sidecar/src/repo/types.rs b/sidecar/src/repo/types.rs index 117ec4936..bc880f436 100644 --- a/sidecar/src/repo/types.rs +++ b/sidecar/src/repo/types.rs @@ -45,9 +45,11 @@ impl RepoRef { pub fn new(backend: Backend, name: &(impl AsRef + ?Sized)) -> Result { let path = Path::new(name.as_ref()); - if !path.is_absolute() { - return Err(RepoError::NonAbsoluteLocal); - } + // disabling this for now, it should start working later on + // but on windows this check might not be valid + // if !path.is_absolute() { + // return Err(RepoError::NonAbsoluteLocal); + // } for component in path.components() { use std::path::Component::*; @@ -121,7 +123,7 @@ impl FromStr for RepoRef { type Err = RepoError; fn from_str(refstr: &str) -> Result { - match refstr.trim_start_matches('/').split_once('/') { + match dbg!(refstr.trim_start_matches('/').split_once('/')) { // // github.com/... // Some(("github.com", name)) => RepoRef::new(Backend::Github, name), // local/... @@ -259,3 +261,16 @@ fn get_unix_time(time: SystemTime) -> u64 { .expect("system time error") .as_secs() } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::RepoRef; + + #[test] + fn test_repo_ref_parsing_windows() { + let repo_ref = RepoRef::from_str("local/c:\\Users\\someone\\pifuhd"); + assert!(repo_ref.is_ok()); + } +} diff --git a/sidecar/src/webserver/agent.rs b/sidecar/src/webserver/agent.rs index 628d5df24..1976721af 100644 --- a/sidecar/src/webserver/agent.rs +++ b/sidecar/src/webserver/agent.rs @@ -603,15 +603,24 @@ pub async fn followup_chat( #[cfg(test)] mod tests { + use crate::webserver::model_selection::LLMClientConfig; + use super::FollowupChatRequest; use serde_json; #[test] fn test_parsing() { let input_string = r#" - {"repo_ref":"local//Users/skcd/scratch/website","query":"whats happenign here","thread_id":"7cb05252-1bb8-4d5e-a942-621ab5d5e114","deep_context":{"repoRef":"local//Users/skcd/scratch/website","preciseContext":[{"symbol":{"fuzzyName":"Author"},"fsFilePath":"/Users/skcd/scratch/website/interfaces/author.ts","relativeFilePath":"interfaces/author.ts","range":{"startLine":0,"startCharacter":0,"endLine":6,"endCharacter":1},"hoverText":["\n```typescript\n(alias) type Author = {\n name: string;\n picture: string;\n twitter: string;\n linkedin: string;\n github: string;\n}\nimport Author\n```\n",""],"definitionSnippet":"type Author = {\n name: string\n picture: string\n twitter: string\n linkedin: string\n github: string\n}"}],"cursorPosition":{"startPosition":{"line":16,"character":0},"endPosition":{"line":16,"character":0}},"currentViewPort":{"startPosition":{"line":0,"character":0},"endPosition":{"line":16,"character":0},"fsFilePath":"/Users/skcd/scratch/website/interfaces/post.ts","relativePath":"interfaces/post.ts","textOnScreen":"import type Author from './author'\n\ntype PostType = {\n slug: string\n title: string\n date: string\n coverImage: string\n author: Author\n excerpt: string\n ogImage: {\n url: string\n }\n content: string\n}\n\nexport default PostType\n"}}} +{"repo_ref":"local/c:\\Users\\keert\\pifuhd","query":"tell me","thread_id":"b265857b-9bf5-4db4-897c-a07d1c4c3b67","user_context":{"variables":[],"file_content_map":[]},"project_labels":["python","pip"],"active_window_data":{"file_path":"c:\\Users\\keert\\pifuhd\\lib\\colab_util.py","file_content":"'''\r\nMIT License\r\n\r\nCopyright (c) 2019 Shunsuke Saito, Zeng Huang, and Ryota Natsume\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE.\r\n'''\r\nimport io\r\nimport os\r\nimport torch\r\nfrom skimage.io import imread\r\nimport numpy as np\r\nimport cv2\r\nfrom tqdm import tqdm_notebook as tqdm\r\nimport base64\r\nfrom IPython.display import HTML\r\n\r\n# Util function for loading meshes\r\nfrom pytorch3d.io import load_objs_as_meshes\r\n\r\nfrom IPython.display import HTML\r\nfrom base64 import b64encode\r\n\r\n# Data structures and functions for rendering\r\nfrom pytorch3d.structures import Meshes\r\nfrom pytorch3d.renderer import (\r\n look_at_view_transform,\r\n OpenGLOrthographicCameras, \r\n PointLights, \r\n DirectionalLights, \r\n Materials, \r\n RasterizationSettings, \r\n MeshRenderer, \r\n MeshRasterizer, \r\n HardPhongShader,\r\n TexturesVertex\r\n)\r\n\r\ndef set_renderer():\r\n # Setup\r\n device = torch.device(\"cuda:0\")\r\n torch.cuda.set_device(device)\r\n\r\n # Initialize an OpenGL perspective camera.\r\n R, T = look_at_view_transform(2.0, 0, 180) \r\n cameras = OpenGLOrthographicCameras(device=device, R=R, T=T)\r\n\r\n raster_settings = RasterizationSettings(\r\n image_size=512, \r\n blur_radius=0.0, \r\n faces_per_pixel=1, \r\n bin_size = None, \r\n max_faces_per_bin = None\r\n )\r\n\r\n lights = PointLights(device=device, location=((2.0, 2.0, 2.0),))\r\n\r\n renderer = MeshRenderer(\r\n rasterizer=MeshRasterizer(\r\n cameras=cameras, \r\n raster_settings=raster_settings\r\n ),\r\n shader=HardPhongShader(\r\n device=device, \r\n cameras=cameras,\r\n lights=lights\r\n )\r\n )\r\n return renderer\r\n\r\ndef get_verts_rgb_colors(obj_path):\r\n rgb_colors = []\r\n\r\n f = open(obj_path)\r\n lines = f.readlines()\r\n for line in lines:\r\n ls = line.split(' ')\r\n if len(ls) == 7:\r\n rgb_colors.append(ls[-3:])\r\n\r\n return np.array(rgb_colors, dtype='float32')[None, :, :]\r\n\r\ndef generate_video_from_obj(obj_path, image_path, video_path, renderer):\r\n input_image = cv2.imread(image_path)\r\n input_image = input_image[:,:input_image.shape[1]//3]\r\n input_image = cv2.resize(input_image, (512,512))\r\n\r\n # Setup\r\n device = torch.device(\"cuda:0\")\r\n torch.cuda.set_device(device)\r\n\r\n # Load obj file\r\n verts_rgb_colors = get_verts_rgb_colors(obj_path)\r\n verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device)\r\n textures = TexturesVertex(verts_features=verts_rgb_colors)\r\n # wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75)\r\n\r\n # Load obj\r\n mesh = load_objs_as_meshes([obj_path], device=device)\r\n\r\n # Set mesh\r\n vers = mesh._verts_list\r\n faces = mesh._faces_list\r\n mesh_w_tex = Meshes(vers, faces, textures)\r\n # mesh_wo_tex = Meshes(vers, faces, wo_textures)\r\n\r\n # create VideoWriter\r\n fourcc = cv2. VideoWriter_fourcc(*'MP4V')\r\n out = cv2.VideoWriter(video_path, fourcc, 20.0, (1024,512))\r\n\r\n for i in tqdm(range(90)):\r\n R, T = look_at_view_transform(1.8, 0, i*4, device=device)\r\n images_w_tex = renderer(mesh_w_tex, R=R, T=T)\r\n images_w_tex = np.clip(images_w_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255\r\n # images_wo_tex = renderer(mesh_wo_tex, R=R, T=T)\r\n # images_wo_tex = np.clip(images_wo_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255\r\n image = np.concatenate([input_image, images_w_tex], axis=1)\r\n out.write(image.astype('uint8'))\r\n out.release()\r\n\r\ndef video(path):\r\n mp4 = open(path,'rb').read()\r\n data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\r\n return HTML('' % data_url)\r\n","visible_range_content":" raster_settings=raster_settings\r\n ),\r\n shader=HardPhongShader(\r\n device=device, \r\n cameras=cameras,\r\n lights=lights\r\n )\r\n )\r\n return renderer\r\n\r\ndef get_verts_rgb_colors(obj_path):\r\n rgb_colors = []\r\n\r\n f = open(obj_path)\r\n lines = f.readlines()\r\n for line in lines:\r\n ls = line.split(' ')\r\n if len(ls) == 7:\r\n rgb_colors.append(ls[-3:])\r\n\r\n return np.array(rgb_colors, dtype='float32')[None, :, :]\r\n\r\ndef generate_video_from_obj(obj_path, image_path, video_path, renderer):\r\n input_image = cv2.imread(image_path)\r\n input_image = input_image[:,:input_image.shape[1]//3]\r\n input_image = cv2.resize(input_image, (512,512))\r\n\r\n # Setup\r\n device = torch.device(\"cuda:0\")\r\n torch.cuda.set_device(device)\r\n\r\n # Load obj file\r\n verts_rgb_colors = get_verts_rgb_colors(obj_path)\r\n verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device)\r\n textures = TexturesVertex(verts_features=verts_rgb_colors)\r\n # wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75)\r\n\r\n # Load obj\r\n mesh = load_objs_as_meshes([obj_path], device=device)","start_line":77,"end_line":115,"language":"python"},"openai_key":null,"model_config":{"slow_model":"Gpt4","fast_model":"Gpt4","models":{"Gpt4":{"context_length":8192,"temperature":0.2,"provider":{"CodeStory":{"llm_type":null}}},"GPT3_5_16k":{"context_length":16385,"temperature":0.2,"provider":{"CodeStory":{"llm_type":null}}},"DeepSeekCoder1.3BInstruct":{"context_length":16384,"temperature":0.2,"provider":"Ollama"},"DeepSeekCoder6BInstruct":{"context_length":16384,"temperature":0.2,"provider":"Ollama"},"ClaudeOpus":{"context_length":200000,"temperature":0.2,"provider":"Anthropic"},"ClaudeSonnet":{"context_length":200000,"temperature":0.2,"provider":"Anthropic"}},"providers":["CodeStory",{"Ollama":{}},{"Anthropic":{"api_key":"soemthing"}}]},"user_id":"keert"} "#; let parsed_response = serde_json::from_str::(&input_string); + let model_config = r#" + {"slow_model":"Gpt4","fast_model":"Gpt4","models":{"Gpt4":{"context_length":8192,"temperature":0.2,"provider":{"CodeStory":{"llm_type":null}}},"GPT3_5_16k":{"context_length":16385,"temperature":0.2,"provider":{"CodeStory":{"llm_type":null}}},"DeepSeekCoder1.3BInstruct":{"context_length":16384,"temperature":0.2,"provider":"Ollama"},"DeepSeekCoder6BInstruct":{"context_length":16384,"temperature":0.2,"provider":"Ollama"},"ClaudeOpus":{"context_length":200000,"temperature":0.2,"provider":"Anthropic"},"ClaudeSonnet":{"context_length":200000,"temperature":0.2,"provider":"Anthropic"}},"providers":["CodeStory",{"Ollama":{}},{"Anthropic":{"api_key":"soemthing"}}]} + "#.to_owned(); + let parsed_model_config = serde_json::from_str::(&model_config); + dbg!(&parsed_response); + dbg!(&parsed_model_config); + assert!(parsed_model_config.is_ok()); assert!(parsed_response.is_ok()); } } diff --git a/sidecar/src/webserver/file_edit.rs b/sidecar/src/webserver/file_edit.rs index e86cb47ae..c2767a027 100644 --- a/sidecar/src/webserver/file_edit.rs +++ b/sidecar/src/webserver/file_edit.rs @@ -39,7 +39,6 @@ pub struct EditFileRequest { pub user_query: String, pub session_id: String, pub code_block_index: usize, - pub openai_key: Option, pub model_config: LLMClientConfig, } @@ -145,7 +144,6 @@ pub async fn file_edit( user_query, session_id, code_block_index, - openai_key, model_config, }): Json, ) -> Result { @@ -163,6 +161,10 @@ pub async fn file_edit( app.language_parsing.clone(), ) .await; + file_diff_content.clone().and_then(|file_diff_content| { + dbg!(file_diff_content.join("\n")); + Some("".to_owned()) + }); if let None = file_diff_content { let cloned_session_id = session_id.clone(); @@ -210,6 +212,7 @@ pub async fn file_edit( app.language_parsing.clone(), ) .await; + dbg!("nearest_range_of_symbols", &nearest_range_for_symbols); // Now we apply the edits and send it over to the user // After generating the git diff we want to send back the responses to the @@ -323,6 +326,8 @@ async fn find_nearest_position_for_code_edit( } let class_with_funcs_llm = language_parser.generate_file_symbols(new_content.as_bytes()); let class_with_funcs = language_parser.generate_file_symbols(file_content.as_bytes()); + dbg!(&class_with_funcs); + dbg!(&class_with_funcs_llm); let types_llm = language_parser.capture_type_data(new_content.as_bytes()); let types_file = language_parser.capture_type_data(file_content.as_bytes()); // First we want to try and match all the classes as much as possible @@ -380,6 +385,10 @@ async fn find_nearest_position_for_code_edit( .flatten() .collect::>(); // These are the independent functions which are present in the file + // TODO(skcd): Pick up from here, for some reason the functions are not matching + // up properly in the new content and the file content, we want to get the proper + // function matches so we can ask the llm to rewrite it, and also difftastic is not required + // as a dependency anymore (yay?) so we can skip it completely :))) let independent_functions_from_file = class_with_funcs .into_iter() .filter_map(|class_with_func| { @@ -564,6 +573,7 @@ pub async fn generate_file_diff( if language_parser.is_none() { return None; } + dbg!("the file path", &file_path); // we can get the extension from the file path let file_extension = PathBuf::from(file_path) .extension()