Skip to content

Commit

Permalink
Add the ability to save licenses for later use
Browse files Browse the repository at this point in the history
  • Loading branch information
zmb3 committed Oct 24, 2024
1 parent b5dc810 commit d5e7bd2
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 10 deletions.
36 changes: 33 additions & 3 deletions src/core/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use crate::model::link::{Link, Stream};
use crate::nla::ntlm::Ntlm;
use std::io::{Read, Write};

use super::license::MemoryLicenseStore;
use super::LicenseStore;

impl From<&str> for KeyboardLayout {
fn from(e: &str) -> Self {
match e {
Expand Down Expand Up @@ -133,7 +136,10 @@ impl<S: Read + Write> RdpClient<S> {
}
}

pub struct Connector {
pub struct Connector<L = MemoryLicenseStore>
where
L: LicenseStore,
{
/// Screen width
width: u16,
/// Screen height
Expand Down Expand Up @@ -164,9 +170,11 @@ pub struct Connector {
/// Use network level authentication
/// default TRUE
use_nla: bool,
/// Stores RDS licenses for reuse.
license_store: L,
}

impl Connector {
impl<L: LicenseStore> Connector<L> {
/// Create a new RDP client
/// You can configure your client
///
Expand All @@ -178,7 +186,7 @@ impl Connector {
/// .credentials("domain".to_string(), "username".to_string(), "password".to_string());
/// ```
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
pub fn new_with_license_store(license_store: L) -> Self {
Connector {
width: 800,
height: 600,
Expand All @@ -193,9 +201,29 @@ impl Connector {
check_certificate: false,
name: "rdp-rs".to_string(),
use_nla: true,
license_store,
}
}
}

impl<L: LicenseStore + Default> Default for Connector<L> {
fn default() -> Self {
Self::new_with_license_store(Default::default())
}
}

impl Connector<Box<dyn LicenseStore>> {
pub fn new() -> Self {
Self::new_with_license_store(Box::new(MemoryLicenseStore::new()))
}

pub fn use_license_store(mut self, license_store: Box<dyn LicenseStore>) -> Self {
self.license_store = license_store;
self
}
}

impl<L: LicenseStore> Connector<L> {
/// Connect to a target server
/// This function will produce a RdpClient object
/// use to interact with server
Expand Down Expand Up @@ -255,6 +283,7 @@ impl Connector {
self.auto_logon,
None,
None,
&mut self.license_store,
)?;
} else {
sec::connect(
Expand All @@ -266,6 +295,7 @@ impl Connector {
self.auto_logon,
None,
None,
&mut self.license_store,
)?;
}

Expand Down
97 changes: 92 additions & 5 deletions src/core/license.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::model::error::{Error, RdpError, RdpErrorKind, RdpResult};
use crate::model::rnd::random;
use crate::model::unicode;
use num_enum::TryFromPrimitive;
use std::collections::HashMap;
use std::ffi::CStr;
use std::ffi::CString;
use std::io::{self, Cursor, Read, Write};
Expand All @@ -22,6 +23,8 @@ use rsa::{PublicKeyParts, RsaPublicKey};
use uuid::Uuid;
use x509_parser::{certificate::X509Certificate, prelude::FromDer};

use super::LicenseStore;

const SIGNATURE_ALG_RSA: u32 = 0x00000001;
const KEY_EXCHANGE_ALG_RSA: u32 = 0x00000001;
const CERT_CHAIN_VERSION_1: u32 = 0x00000001;
Expand Down Expand Up @@ -1025,18 +1028,16 @@ fn license_response(message_type: MessageType, data: Vec<u8>) -> RdpResult<Vec<u
Ok(buf)
}

pub fn client_connect<T: Read + Write>(
pub fn client_connect<T: Read + Write, L: LicenseStore>(
mcs: &mut mcs::Client<T>,
client_machine: &str, // must be a UUID
username: &str,
mut license_store: L,
) -> RdpResult<()> {
// We use the UUID that identifies the client as both the client machine name,
// and (in binary form) the hardware identifier for the client.
let client_uuid = Uuid::try_parse(client_machine)?;

// TODO(zmb3): attempt to load an existing license
let existing_license: Option<Vec<u8>> = None;

let (channel, payload) = mcs.read()?;
let session_encryption_data = match LicenseMessage::new(payload)? {
// When we get the `NewLicense` message at the start of the
Expand All @@ -1051,6 +1052,21 @@ pub fn client_connect<T: Read + Write>(
request.certificate,
);

let mut existing_license: Option<Vec<u8>> = None;
for issuer in request.scopes {
let l = license_store.read_license(
request.version_major,
request.version_minor,
&request.company_name,
&issuer,
&request.product_id,
);
if l.is_some() {
existing_license.replace(l.unwrap());
break;
}
}

// we either send information about a previously obtained license
// or a new license request, depending on whether we have a license
// cached from a previous attempt
Expand Down Expand Up @@ -1133,11 +1149,82 @@ pub fn client_connect<T: Read + Write>(
}
};

// TODO(zmb3): save the license
license_store.write_license(
license.version_major,
license.version_minor,
&license.company_name,
&license.scope,
&license.product_id,
&license.cert_data,
);

Ok(())
}

#[derive(PartialEq, Eq, Hash)]
struct LicenseStoreKey {
major: u16,
minor: u16,
company: String,
issuer: String,
product_id: String,
}

/// MemoryLicenseStore stores licenses in memory.
/// It is not concurrency-safe.
#[derive(Default)]
pub struct MemoryLicenseStore {
licenses: HashMap<LicenseStoreKey, Vec<u8>>,
}

impl MemoryLicenseStore {
pub fn new() -> Self {
Default::default()
}
}

impl LicenseStore for MemoryLicenseStore {
fn write_license(
&mut self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
license: &[u8],
) {
self.licenses.insert(
LicenseStoreKey {
major,
minor,
company: company.to_owned(),
issuer: issuer.to_owned(),
product_id: product_id.to_owned(),
},
license.to_vec(),
);
}

fn read_license(
&self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
) -> Option<Vec<u8>> {
self.licenses
.get(&LicenseStoreKey {
major,
minor,
company: company.to_owned(),
issuer: issuer.to_owned(),
product_id: product_id.to_owned(),
})
.cloned()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
71 changes: 71 additions & 0 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,74 @@ pub mod per;
pub mod sec;
pub mod tpkt;
pub mod x224;

/// LicenseStore provides the ability to save (and later retrieve)
/// RDS licenses.
pub trait LicenseStore {
fn write_license(
&mut self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
license: &[u8],
);
fn read_license(
&self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
) -> Option<Vec<u8>>;
}

impl<L: LicenseStore + ?Sized> LicenseStore for &mut L {
fn write_license(
&mut self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
license: &[u8],
) {
(**self).write_license(major, minor, company, issuer, product_id, license)
}

fn read_license(
&self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
) -> Option<Vec<u8>> {
(**self).read_license(major, minor, company, issuer, product_id)
}
}

impl<T: LicenseStore + ?Sized> LicenseStore for Box<T> {
fn write_license(
&mut self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
license: &[u8],
) {
(**self).write_license(major, minor, company, issuer, product_id, license)
}
fn read_license(
&self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
) -> Option<Vec<u8>> {
(**self).read_license(major, minor, company, issuer, product_id)
}
}
7 changes: 5 additions & 2 deletions src/core/sec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use crate::model::error::RdpResult;
use crate::model::unicode::Unicode;
use std::io::{Read, Write};

use super::LicenseStore;

/// Security flag send as header flage in core ptotocol
/// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-rdpbcgr/e13405c5-668b-4716-94b2-1c2654ca1ad4?redirectedfrom=MSDN
#[repr(u16)]
Expand Down Expand Up @@ -151,7 +153,7 @@ fn rdp_infos(
/// sec::connect(&mut mcs).unwrap();
/// ```
#[allow(clippy::too_many_arguments)]
pub fn connect<T: Read + Write>(
pub fn connect<T: Read + Write, L: LicenseStore>(
mcs: &mut mcs::Client<T>,
agent_id: &str,
domain: &String,
Expand All @@ -160,6 +162,7 @@ pub fn connect<T: Read + Write>(
auto_logon: bool,
info_flags: Option<u32>,
extended_info_flags: Option<u32>,
license_store: L,
) -> RdpResult<()> {
let perf_flags = if mcs.is_rdp_version_5_plus() {
extended_info_flags
Expand All @@ -176,6 +179,6 @@ pub fn connect<T: Read + Write>(
],
)?;

license::client_connect(mcs, agent_id, username)?;
license::client_connect(mcs, agent_id, username, license_store)?;
Ok(())
}

0 comments on commit d5e7bd2

Please sign in to comment.