Skip to content

Commit

Permalink
fix: centralized payload and attributes validation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
evilsocket committed Jun 22, 2024
1 parent c9e4679 commit 51d9f3a
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 141 deletions.
120 changes: 56 additions & 64 deletions src/agent/namespaces/filesystem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,49 +75,45 @@ impl Action for ReadFolder {
_attributes: Option<HashMap<String, String>>,
payload: Option<String>,
) -> Result<Option<String>> {
if let Some(folder) = payload {
// adapted from https://gist.github.com/mre/91ebb841c34df69671bd117ead621a8b
let ret = fs::read_dir(&folder);
if let Ok(paths) = ret {
let mut output = format!("Contents of {} :\n\n", &folder);

for path in paths {
if let Ok(entry) = path {
let full_path = entry.path().canonicalize().unwrap();
let metadata = entry.metadata().unwrap();
let size = metadata.len();
let modified: DateTime<Local> =
DateTime::from(metadata.modified().unwrap());
let mode = metadata.permissions().mode();

output += &format!(
"{} {:>5} {} [{}] {}\n",
parse_permissions(mode),
size,
modified.format("%_d %b %H:%M"),
parse_type(metadata.file_type()),
full_path.display()
);
} else {
eprintln!("ERROR: {:?}", path);
}
// adapted from https://gist.github.com/mre/91ebb841c34df69671bd117ead621a8b
let folder = payload.unwrap();
let ret = fs::read_dir(&folder);
if let Ok(paths) = ret {
let mut output = format!("Contents of {} :\n\n", &folder);

for path in paths {
if let Ok(entry) = path {
let full_path = entry.path().canonicalize().unwrap();
let metadata = entry.metadata().unwrap();
let size = metadata.len();
let modified: DateTime<Local> = DateTime::from(metadata.modified().unwrap());
let mode = metadata.permissions().mode();

output += &format!(
"{} {:>5} {} [{}] {}\n",
parse_permissions(mode),
size,
modified.format("%_d %b %H:%M"),
parse_type(metadata.file_type()),
full_path.display()
);
} else {
eprintln!("ERROR: {:?}", path);
}

println!(
"<{}> {} -> {} bytes",
self.name().bold(),
folder.yellow(),
output.len()
);

return Ok(Some(output));
} else {
eprintln!("<{}> {} -> {:?}", self.name().bold(), folder.red(), &ret);
return Err(anyhow!("can't read {}: {:?}", folder, ret));
}
}

Err(anyhow!("no content specified for read-file"))
println!(
"<{}> {} -> {} bytes",
self.name().bold(),
folder.yellow(),
output.len()
);

Ok(Some(output))
} else {
eprintln!("<{}> {} -> {:?}", self.name().bold(), folder.red(), &ret);
Err(anyhow!("can't read {}: {:?}", folder, ret))
}
}
}

Expand All @@ -143,31 +139,27 @@ impl Action for ReadFile {
_attributes: Option<HashMap<String, String>>,
payload: Option<String>,
) -> Result<Option<String>> {
if let Some(filepath) = payload {
let ret = std::fs::read_to_string(&filepath);
if let Ok(contents) = ret {
println!(
"<{}> {} -> {} bytes",
self.name().bold(),
filepath.yellow(),
contents.len()
);
return Ok(Some(contents));
} else {
let err = ret.err().unwrap();
println!(
"<{}> {} -> {:?}",
self.name().bold(),
filepath.yellow(),
&err
);

return Err(anyhow!(err));
}
let filepath = payload.unwrap();
let ret = std::fs::read_to_string(&filepath);
if let Ok(contents) = ret {
println!(
"<{}> {} -> {} bytes",
self.name().bold(),
filepath.yellow(),
contents.len()
);
Ok(Some(contents))
} else {
let err = ret.err().unwrap();
println!(
"<{}> {} -> {:?}",
self.name().bold(),
filepath.yellow(),
&err
);

Err(anyhow!(err))
}

// TODO: check for mandatory payload and attributes while parsing
Err(anyhow!("no content specified for read-file"))
}
}

Expand Down
59 changes: 21 additions & 38 deletions src/agent/namespaces/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,14 @@ impl Action for SaveMemory {
attributes: Option<HashMap<String, String>>,
payload: Option<String>,
) -> Result<Option<String>> {
if let Some(attrs) = attributes {
if let Some(key) = attrs.get("key") {
if let Some(data) = payload {
state.get_storage("memories")?.add_tagged(key, &data);
return Ok(Some("memory saved".to_string()));
}
let attrs = attributes.unwrap();
let key = attrs.get("key").unwrap();

return Err(anyhow!("no content specified for save-memory"));
}
state
.get_storage("memories")?
.add_tagged(key, payload.unwrap().as_str());

return Err(anyhow!("no key attribute specified for save-memory"));
}

Err(anyhow!("no attributes specified for save-memory"))
Ok(Some("memory saved".to_string()))
}
}

Expand Down Expand Up @@ -79,19 +73,13 @@ impl Action for DeleteMemory {
attributes: Option<HashMap<String, String>>,
_: Option<String>,
) -> Result<Option<String>> {
if let Some(attrs) = attributes {
if let Some(key) = attrs.get("key") {
return if state.get_storage("memories")?.del_tagged(key).is_some() {
return Ok(Some("memory deleted".to_string()));
} else {
Err(anyhow!("memory '{}' not found", key))
};
}

return Err(anyhow!("no key attribute specified for delete-memory"));
let attrs = attributes.unwrap();
let key = attrs.get("key").unwrap();
if state.get_storage("memories")?.del_tagged(key).is_some() {
Ok(Some("memory deleted".to_string()))
} else {
Err(anyhow!("memory '{}' not found", key))
}

Err(anyhow!("no attributes specified for delete-memory"))
}
}

Expand Down Expand Up @@ -121,21 +109,16 @@ impl Action for RecallMemory {
attributes: Option<HashMap<String, String>>,
_: Option<String>,
) -> Result<Option<String>> {
if let Some(attrs) = attributes {
if let Some(key) = attrs.get("key") {
return if let Some(memory) = state.get_storage("memories")?.get_tagged(key) {
println!("<{}> recalling {}", "memories".bold(), key);
return Ok(Some(memory));
} else {
eprintln!("<{}> memory {} does not exist", "memories".bold(), key);
Err(anyhow!("memory '{}' not found", key))
};
}

return Err(anyhow!("no key attribute specified for delete-memory"));
let attrs = attributes.unwrap();
let key = attrs.get("key").unwrap();

if let Some(memory) = state.get_storage("memories")?.get_tagged(key) {
println!("<{}> recalling {}", "memories".bold(), key);
Ok(Some(memory))
} else {
eprintln!("<{}> memory {} does not exist", "memories".bold(), key);
Err(anyhow!("memory '{}' not found", key))
}

Err(anyhow!("no attributes specified for delete-memory"))
}
}

Expand Down
20 changes: 2 additions & 18 deletions src/agent/namespaces/planning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@ impl Action for AddStep {
_: Option<HashMap<String, String>>,
payload: Option<String>,
) -> Result<Option<String>> {
if payload.is_none() {
Err(anyhow!("no step description provided"))
} else {
state.get_storage("plan")?.add_completion(&payload.unwrap());
Ok(Some("step added to the plan".to_string()))
}
state.get_storage("plan")?.add_completion(&payload.unwrap());
Ok(Some("step added to the plan".to_string()))
}
}

Expand All @@ -58,10 +54,6 @@ impl Action for DeleteStep {
_: Option<HashMap<String, String>>,
payload: Option<String>,
) -> Result<Option<String>> {
if payload.is_none() {
return Err(anyhow!("no position provided"));
}

state
.get_storage("plan")?
.del_completion(payload.unwrap().parse::<usize>()?);
Expand Down Expand Up @@ -91,10 +83,6 @@ impl Action for SetComplete {
_: Option<HashMap<String, String>>,
payload: Option<String>,
) -> Result<Option<String>> {
if payload.is_none() {
return Err(anyhow!("no position provided"));
}

let pos = payload.unwrap().parse::<usize>()?;
if state.get_storage("plan")?.set_complete(pos).is_some() {
Ok(Some(format!("step {} marked as completed", pos)))
Expand Down Expand Up @@ -126,10 +114,6 @@ impl Action for SetIncomplete {
_: Option<HashMap<String, String>>,
payload: Option<String>,
) -> Result<Option<String>> {
if payload.is_none() {
return Err(anyhow!("no position provided"));
}

let pos = payload.unwrap().parse::<usize>()?;
if state.get_storage("plan")?.set_incomplete(pos).is_some() {
Ok(Some(format!("step {} marked as incomplete", pos)))
Expand Down
96 changes: 75 additions & 21 deletions src/agent/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ impl State {

pub fn add_error_to_history(&self, invocation: Invocation, error: String) {
if let Ok(mut guard) = self.history.lock() {
// eprintln!("[{}] -> {}", &invocation.action, error.red());
guard.push(Execution::with_error(invocation, error));
}
}
Expand All @@ -222,33 +223,86 @@ impl State {
}
}

pub async fn execute(&self, invocation: Invocation) -> Result<()> {
// println!("[INVOKE]");

#[allow(clippy::borrowed_box)]
fn get_action(&self, name: &str) -> Option<&Box<dyn namespaces::Action>> {
for group in &self.namespaces {
for action in &group.actions {
if invocation.action == action.name() {
// execute the action
let inv = invocation.clone();
let ret = action.run(self, invocation.attributes, invocation.payload);
if let Err(error) = ret {
// tell the model about the error
self.add_error_to_history(inv, error.to_string());
} else {
// tell the model about the output
self.add_success_to_history(inv, ret.unwrap());
}

return Ok(());
if name == action.name() {
return Some(action);
}
}
}

// tell the model that the action name is wrong
self.add_error_to_history(
invocation.clone(),
format!("'{}' is not a valid action name", invocation.action),
);
None
}

pub async fn execute(&self, invocation: Invocation) -> Result<()> {
if let Some(action) = self.get_action(&invocation.action) {
// validate prerequisites
let payload_required = action.example_payload().is_some();
let attrs_required = action.attributes().is_some();

if payload_required && invocation.payload.is_none() {
// payload required and not specified
self.add_error_to_history(
invocation.clone(),
format!("no content specified for '{}'", invocation.action),
);
return Ok(());
} else if attrs_required && invocation.attributes.is_none() {
// attributes required and not specified at all
self.add_error_to_history(
invocation.clone(),
format!("no attributes specified for '{}'", invocation.action),
);
return Ok(());
} else if attrs_required {
// validate each required attribute
let required_attrs: Vec<String> = action
.attributes()
.unwrap()
.keys()
.map(|s| s.to_owned())
.collect();
let passed_attrs: Vec<String> = invocation
.clone()
.attributes
.unwrap()
.keys()
.map(|s| s.to_owned())
.collect();

for required in required_attrs {
if !passed_attrs.contains(&required) {
self.add_error_to_history(
invocation.clone(),
format!(
"no '{}' attribute specified for '{}'",
required, invocation.action
),
);
return Ok(());
}
}
}

// execute the action
let inv = invocation.clone();
let ret = action.run(self, invocation.attributes, invocation.payload);
if let Err(error) = ret {
// tell the model about the error
self.add_error_to_history(inv, error.to_string());
} else {
// tell the model about the output
self.add_success_to_history(inv, ret.unwrap());
}
} else {
// tell the model that the action name is wrong
self.add_error_to_history(
invocation.clone(),
format!("'{}' is not a valid action name", invocation.action),
);
}

Ok(())
}
Expand Down

0 comments on commit 51d9f3a

Please sign in to comment.