1use std::{env, fs::File, io::BufReader, path::PathBuf, time::Duration};
39
40use log::{debug, error, info, warn};
41use tonic::transport::{Channel, Endpoint};
42#[cfg(feature = "mtls")]
43use rustls::ClientConfig;
44#[cfg(feature = "mtls")]
45use rustls::RootCertStore;
46
47pub const DEFAULT_MOUNTAIN_ADDRESS:&str = "[::1]:50051";
54
55pub const DEFAULT_CONNECTION_TIMEOUT_SECS:u64 = 5;
57
58pub const DEFAULT_REQUEST_TIMEOUT_SECS:u64 = 30;
60
61#[cfg(feature = "mtls")]
66#[derive(Debug, Clone)]
67pub struct TlsConfig {
68 pub ca_cert_path:Option<PathBuf>,
71
72 pub client_cert_path:Option<PathBuf>,
74
75 pub client_key_path:Option<PathBuf>,
77
78 pub server_name:Option<String>,
80
81 pub verify_certs:bool,
83}
84
85#[cfg(feature = "mtls")]
86impl Default for TlsConfig {
87 fn default() -> Self {
88 Self {
89 ca_cert_path:None,
90 client_cert_path:None,
91 client_key_path:None,
92 server_name:None,
93 verify_certs:true,
94 }
95 }
96}
97
98#[cfg(feature = "mtls")]
99impl TlsConfig {
100 pub fn server_auth(ca_cert_path:PathBuf) -> Self {
108 Self {
109 ca_cert_path:Some(ca_cert_path),
110 client_cert_path:None,
111 client_key_path:None,
112 server_name:Some("localhost".to_string()),
113 verify_certs:true,
114 }
115 }
116
117 pub fn mtls(ca_cert_path:PathBuf, client_cert_path:PathBuf, client_key_path:PathBuf) -> Self {
127 Self {
128 ca_cert_path:Some(ca_cert_path),
129 client_cert_path:Some(client_cert_path),
130 client_key_path:Some(client_key_path),
131 server_name:Some("localhost".to_string()),
132 verify_certs:true,
133 }
134 }
135}
136
137#[cfg(feature = "mtls")]
148pub fn create_tls_client_config(tls_config:&TlsConfig) -> Result<ClientConfig, Box<dyn std::error::Error>> {
149 info!("Creating TLS client configuration");
150
151 let mut root_store = RootCertStore::empty();
153
154 if let Some(ca_path) = &tls_config.ca_cert_path {
155 debug!("Loading CA certificate from {:?}", ca_path);
157 let ca_file = File::open(ca_path).map_err(|e| format!("Failed to open CA certificate file: {}", e))?;
158 let mut reader = BufReader::new(ca_file);
159
160 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut reader).collect();
161 let certs = certs.map_err(|e| format!("Failed to parse CA certificate: {}", e))?;
162
163 if certs.is_empty() {
164 return Err("No CA certificates found in file".into());
165 }
166
167 for cert in certs {
168 root_store
169 .add(cert)
170 .map_err(|e| format!("Failed to add CA certificate to root store: {}", e))?;
171 }
172
173 info!("Loaded CA certificate from {:?}", ca_path);
174 } else {
175 debug!("Loading system root certificates");
177 let cert_result = rustls_native_certs::load_native_certs();
178
179 if !cert_result.errors.is_empty() {
181 warn!("Encountered errors loading system certificates: {:?}", cert_result.errors);
182 }
183
184 let native_certs = cert_result.certs;
185
186 if native_certs.is_empty() {
187 warn!("No system root certificates found");
188 }
189
190 for cert in native_certs {
191 root_store
192 .add(cert)
193 .map_err(|e| format!("Failed to add system certificate to root store: {}", e))?;
194 }
195
196 info!("Loaded {} system root certificates", root_store.len());
197 }
198
199 let client_certs = if tls_config.client_cert_path.is_some() && tls_config.client_key_path.is_some() {
201 let cert_path = tls_config.client_cert_path.as_ref().unwrap();
202 let key_path = tls_config.client_key_path.as_ref().unwrap();
203
204 debug!("Loading client certificate from {:?}", cert_path);
205 let cert_file = File::open(cert_path).map_err(|e| format!("Failed to open client certificate file: {}", e))?;
206 let mut cert_reader = BufReader::new(cert_file);
207
208 let certs:Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert_reader).collect();
209 let certs = certs.map_err(|e| format!("Failed to parse client certificate: {}", e))?;
210
211 if certs.is_empty() {
212 return Err("No client certificates found in file".into());
213 }
214
215 debug!("Loading client private key from {:?}", key_path);
216 let key_file = File::open(key_path).map_err(|e| format!("Failed to open private key file: {}", e))?;
217 let mut key_reader = BufReader::new(key_file);
218
219 let key = rustls_pemfile::private_key(&mut key_reader)
220 .map_err(|e| format!("Failed to parse private key: {}", e))?
221 .ok_or("No private key found in file")?;
222
223 Some((certs, key))
224 } else {
225 None
226 };
227
228 let mut config = match client_certs {
230 Some((certs, key)) => {
231 let client_config = ClientConfig::builder()
233 .with_root_certificates(root_store)
234 .with_client_auth_cert(certs, key)
235 .map_err(|e| format!("Failed to configure client authentication: {}", e))?;
236
237 info!("Configured mTLS with client certificate");
238
239 client_config
240 },
241 None => {
242 let client_config = ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth();
245
246 info!("Configured TLS with server authentication only");
247
248 client_config
249 },
250 };
251
252 config.alpn_protocols = vec![b"h2".to_vec()];
254
255 if !tls_config.verify_certs {
260 warn!("Certificate verification disabled - this is NOT secure for production!");
261 }
264
265 info!("TLS client configuration created successfully");
266
267 Ok(config)
268}
269
270#[derive(Debug, Clone)]
272pub struct MountainClientConfig {
273 pub address:String,
275
276 pub connection_timeout_secs:u64,
278
279 pub request_timeout_secs:u64,
281
282 #[cfg(feature = "mtls")]
284 pub tls_config:Option<TlsConfig>,
285}
286
287impl Default for MountainClientConfig {
288 fn default() -> Self {
289 Self {
290 address:DEFAULT_MOUNTAIN_ADDRESS.to_string(),
291 connection_timeout_secs:DEFAULT_CONNECTION_TIMEOUT_SECS,
292 request_timeout_secs:DEFAULT_REQUEST_TIMEOUT_SECS,
293 #[cfg(feature = "mtls")]
294 tls_config:None,
295 }
296 }
297}
298
299impl MountainClientConfig {
300 pub fn new(address:impl Into<String>) -> Self { Self { address:address.into(), ..Default::default() } }
308
309 pub fn from_env() -> Self {
329 let address = env::var("MOUNTAIN_ADDRESS").unwrap_or_else(|_| DEFAULT_MOUNTAIN_ADDRESS.to_string());
330
331 let connection_timeout_secs = env::var("MOUNTAIN_CONNECTION_TIMEOUT_SECS")
332 .ok()
333 .and_then(|s| s.parse().ok())
334 .unwrap_or(DEFAULT_CONNECTION_TIMEOUT_SECS);
335
336 let request_timeout_secs = env::var("MOUNTAIN_REQUEST_TIMEOUT_SECS")
337 .ok()
338 .and_then(|s| s.parse().ok())
339 .unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS);
340
341 #[cfg(feature = "mtls")]
342 let tls_config = if env::var("MOUNTAIN_TLS_ENABLED")
343 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
344 .unwrap_or(false)
345 {
346 Some(TlsConfig {
347 ca_cert_path:env::var("MOUNTAIN_CA_CERT").ok().map(PathBuf::from),
348 client_cert_path:env::var("MOUNTAIN_CLIENT_CERT").ok().map(PathBuf::from),
349 client_key_path:env::var("MOUNTAIN_CLIENT_KEY").ok().map(PathBuf::from),
350 server_name:env::var("MOUNTAIN_SERVER_NAME").ok(),
351 verify_certs:env::var("MOUNTAIN_VERIFY_CERTS")
352 .map(|v| v != "0" && !v.eq_ignore_ascii_case("false"))
353 .unwrap_or(true),
354 })
355 } else {
356 None
357 };
358
359 #[cfg(not(feature = "mtls"))]
360 let tls_config = None;
361
362 Self {
363 address,
364 connection_timeout_secs,
365 request_timeout_secs,
366 #[cfg(feature = "mtls")]
367 tls_config,
368 }
369 }
370
371 pub fn with_connection_timeout(mut self, timeout_secs:u64) -> Self {
379 self.connection_timeout_secs = timeout_secs;
380 self
381 }
382
383 pub fn with_request_timeout(mut self, timeout_secs:u64) -> Self {
391 self.request_timeout_secs = timeout_secs;
392 self
393 }
394
395 #[cfg(feature = "mtls")]
403 pub fn with_tls(mut self, tls_config:TlsConfig) -> Self {
404 self.tls_config = Some(tls_config);
405 self
406 }
407}
408
409#[derive(Debug, Clone)]
415pub struct MountainClient {
416 channel:Channel,
418
419 config:MountainClientConfig,
421}
422
423impl MountainClient {
424 pub async fn connect(config:MountainClientConfig) -> Result<Self, Box<dyn std::error::Error>> {
435 info!("Connecting to Mountain at {}", config.address);
436
437 let endpoint = Endpoint::from_shared(config.address.clone())?
438 .connect_timeout(Duration::from_secs(config.connection_timeout_secs));
439
440 #[cfg(feature = "mtls")]
442 if let Some(tls_config) = &config.tls_config {
443 info!("TLS configuration provided, configuring secure connection");
444
445 let _client_config = create_tls_client_config(tls_config).map_err(|e| {
446 error!("Failed to create TLS client configuration: {}", e);
447 format!("TLS configuration error: {}", e)
448 })?;
449
450 let domain_name = tls_config.server_name.clone().unwrap_or_else(|| "localhost".to_string());
452 info!("Setting server name for SNI: {}", domain_name);
453
454 let tls = tonic::transport::ClientTlsConfig::new().domain_name(domain_name.clone());
456 let channel = endpoint
457 .tcp_keepalive(Some(Duration::from_secs(60)))
458 .tls_config(tls)?
459 .connect()
460 .await
461 .map_err(|e| format!("Failed to connect with TLS: {}", e))?;
462
463 info!("Successfully connected to Mountain at {} with TLS", config.address);
464 return Ok(Self { channel, config });
465 }
466
467 debug!("Using unencrypted connection");
469 let channel = endpoint.connect().await?;
470 info!("Successfully connected to Mountain at {}", config.address);
471
472 Ok(Self { channel, config })
473 }
474
475 pub fn channel(&self) -> &Channel { &self.channel }
480
481 pub fn config(&self) -> &MountainClientConfig { &self.config }
486
487 pub async fn health_check(&self) -> Result<bool, Box<dyn std::error::Error>> {
494 debug!("Checking Mountain health");
495
496 match tokio::time::timeout(Duration::from_secs(self.config.request_timeout_secs), async {
498 Ok::<(), Box<dyn std::error::Error>>(())
501 })
502 .await
503 {
504 Ok(Ok(())) => {
505 debug!("Mountain health check: healthy");
506 Ok(true)
507 },
508 Ok(Err(e)) => {
509 warn!("Mountain health check: disconnected - {}", e);
510 Ok(false)
511 },
512 Err(_) => {
513 warn!("Mountain health check: timeout");
514 Ok(false)
515 },
516 }
517 }
518
519 pub async fn get_status(&self) -> Result<String, Box<dyn std::error::Error>> {
527 debug!("Getting Mountain status");
528
529 Ok("connected".to_string())
532 }
533
534 pub async fn get_config(&self, key:&str) -> Result<Option<String>, Box<dyn std::error::Error>> {
545 debug!("Getting Mountain config: {}", key);
546
547 Ok(None)
550 }
551
552 pub async fn set_config(&self, key:&str, value:&str) -> Result<(), Box<dyn std::error::Error>> {
564 debug!("Setting Mountain config: {} = {}", key, value);
565
566 Ok(())
569 }
570}
571
572pub async fn connect_to_mountain() -> Result<MountainClient, Box<dyn std::error::Error>> {
577 MountainClient::connect(MountainClientConfig::default()).await
578}
579
580pub async fn connect_to_mountain_at(address:impl Into<String>) -> Result<MountainClient, Box<dyn std::error::Error>> {
588 MountainClient::connect(MountainClientConfig::new(address)).await
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594
595 #[test]
596 fn test_default_config() {
597 let config = MountainClientConfig::default();
598 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
599 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
600 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
601 }
602
603 #[test]
604 fn test_config_builder() {
605 let config = MountainClientConfig::new("[::1]:50060")
606 .with_connection_timeout(10)
607 .with_request_timeout(60);
608
609 assert_eq!(config.address, "[::1]:50060");
610 assert_eq!(config.connection_timeout_secs, 10);
611 assert_eq!(config.request_timeout_secs, 60);
612 }
613
614 #[cfg(feature = "mtls")]
615 #[test]
616 fn test_tls_config_server_auth() {
617 let tls = TlsConfig::server_auth(std::path::PathBuf::from("/path/to/ca.pem"));
618 assert_eq!(tls.server_name, Some("localhost".to_string()));
619 assert!(tls.client_cert_path.is_none());
620 assert!(tls.client_key_path.is_none());
621 assert!(tls.ca_cert_path.is_some());
622 assert!(tls.verify_certs);
623 }
624
625 #[cfg(feature = "mtls")]
626 #[test]
627 fn test_tls_config_mtls() {
628 let tls = TlsConfig::mtls(
629 std::path::PathBuf::from("/path/to/ca.pem"),
630 std::path::PathBuf::from("/path/to/cert.pem"),
631 std::path::PathBuf::from("/path/to/key.pem"),
632 );
633 assert!(tls.client_cert_path.is_some());
634 assert!(tls.client_key_path.is_some());
635 assert!(tls.ca_cert_path.is_some());
636 assert!(tls.verify_certs);
637 assert_eq!(tls.server_name, Some("localhost".to_string()));
638 }
639
640 #[cfg(feature = "mtls")]
641 #[test]
642 fn test_tls_config_default() {
643 let tls = TlsConfig::default();
644 assert!(tls.ca_cert_path.is_none());
645 assert!(tls.client_cert_path.is_none());
646 assert!(tls.client_key_path.is_none());
647 assert!(tls.server_name.is_none());
648 assert!(tls.verify_certs);
649 }
650
651 #[test]
652 fn test_from_env_default() {
653 unsafe {
655 env::remove_var("MOUNTAIN_ADDRESS");
656 }
657 unsafe {
658 env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS");
659 }
660 unsafe {
661 env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS");
662 }
663 unsafe {
664 env::remove_var("MOUNTAIN_TLS_ENABLED");
665 }
666
667 let config = MountainClientConfig::from_env();
668 assert_eq!(config.address, DEFAULT_MOUNTAIN_ADDRESS);
669 assert_eq!(config.connection_timeout_secs, DEFAULT_CONNECTION_TIMEOUT_SECS);
670 assert_eq!(config.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
671 }
672
673 #[test]
674 fn test_from_env_custom() {
675 unsafe {
676 env::set_var("MOUNTAIN_ADDRESS", "[::1]:50060");
677 }
678 unsafe {
679 env::set_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS", "10");
680 }
681 unsafe {
682 env::set_var("MOUNTAIN_REQUEST_TIMEOUT_SECS", "60");
683 }
684
685 let config = MountainClientConfig::from_env();
686 assert_eq!(config.address, "[::1]:50060");
687 assert_eq!(config.connection_timeout_secs, 10);
688 assert_eq!(config.request_timeout_secs, 60);
689
690 unsafe {
692 env::remove_var("MOUNTAIN_ADDRESS");
693 }
694 unsafe {
695 env::remove_var("MOUNTAIN_CONNECTION_TIMEOUT_SECS");
696 }
697 unsafe {
698 env::remove_var("MOUNTAIN_REQUEST_TIMEOUT_SECS");
699 }
700 }
701
702 #[cfg(feature = "mtls")]
703 #[test]
704 fn test_from_env_tls() {
705 unsafe {
706 env::set_var("MOUNTAIN_TLS_ENABLED", "1");
707 }
708 unsafe {
709 env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem");
710 }
711 unsafe {
712 env::set_var("MOUNTAIN_SERVER_NAME", "mymountain.com");
713 }
714
715 let config = MountainClientConfig::from_env();
716 assert!(config.tls_config.is_some());
717 let tls = config.tls_config.unwrap();
718 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
719 assert_eq!(tls.server_name, Some("mymountain.com".to_string()));
720 assert!(tls.verify_certs);
721
722 unsafe {
724 env::remove_var("MOUNTAIN_TLS_ENABLED");
725 }
726 unsafe {
727 env::remove_var("MOUNTAIN_CA_CERT");
728 }
729 unsafe {
730 env::remove_var("MOUNTAIN_SERVER_NAME");
731 }
732 }
733
734 #[cfg(feature = "mtls")]
735 #[test]
736 fn test_from_env_mtls() {
737 unsafe {
738 env::set_var("MOUNTAIN_TLS_ENABLED", "true");
739 }
740 unsafe {
741 env::set_var("MOUNTAIN_CA_CERT", "/path/to/ca.pem");
742 }
743 unsafe {
744 env::set_var("MOUNTAIN_CLIENT_CERT", "/path/to/cert.pem");
745 }
746 unsafe {
747 env::set_var("MOUNTAIN_CLIENT_KEY", "/path/to/key.pem");
748 }
749
750 let config = MountainClientConfig::from_env();
751 assert!(config.tls_config.is_some());
752 let tls = config.tls_config.unwrap();
753 assert_eq!(tls.ca_cert_path, Some(std::path::PathBuf::from("/path/to/ca.pem")));
754 assert_eq!(tls.client_cert_path, Some(std::path::PathBuf::from("/path/to/cert.pem")));
755 assert_eq!(tls.client_key_path, Some(std::path::PathBuf::from("/path/to/key.pem")));
756 assert!(tls.verify_certs);
757
758 unsafe {
760 env::remove_var("MOUNTAIN_TLS_ENABLED");
761 }
762 unsafe {
763 env::remove_var("MOUNTAIN_CA_CERT");
764 }
765 unsafe {
766 env::remove_var("MOUNTAIN_CLIENT_CERT");
767 }
768 unsafe {
769 env::remove_var("MOUNTAIN_CLIENT_KEY");
770 }
771 }
772}