use crate::config_file;
use sha2::Digest;
use rusoto_credential;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, SystemTime};
use thiserror::*;
use tracing::*;
use {thrussh, thrussh_keys};
pub struct Cached {
pub body: bytes::Bytes,
pub content_type: Option<hyper::header::HeaderValue>,
}
pub struct Config {
pub csrf: axum_csrf::CsrfConfig,
pub repository_cache_size: usize,
pub max_body_length: u64,
pub hard_max_body_length: u64,
pub host: String,
pub hostname: String,
pub http_timeout: std::time::Duration,
pub ssh: Arc<thrussh::server::Config>,
pub ssh_timeout: std::time::Duration,
pub hmac_secret: [u8; SHA512_OUTPUT_LEN],
pub repositories_path: PathBuf,
pub version_time: SystemTime,
pub version_time_str: String,
pub max_relative_days: usize,
pub yesterday_threshold_hours: usize,
pub failed_auth_timeout: Duration,
pub email: rusoto_credential::StaticProvider,
pub email_source: String,
pub cache: RwLock<HashMap<String, Cached>>,
pub cache_gzip: RwLock<HashMap<String, Cached>>,
pub cache_br: RwLock<HashMap<String, Cached>>,
pub db: Db,
pub partial_change_size: u64,
pub ws_timeout: std::time::Duration,
pub change_cache: Arc<
std::sync::Mutex<
lru_cache::LruCache<
(uuid::Uuid, libpijul::ChangeId),
Arc<std::sync::Mutex<libpijul::change::ChangeFile>>,
>,
>,
>,
pub hash_cache: Arc<
std::sync::Mutex<
lru_cache::LruCache<
(uuid::Uuid, libpijul::Hash),
Arc<std::sync::Mutex<libpijul::change::ChangeFile>>,
>,
>,
>,
pub max_password_attempts: i64,
pub syntax_set: syntect::parsing::SyntaxSet,
pub svelte_socket: Option<String>,
pub render_cache: Mutex<lru_cache::LruCache<CachedItem, Resp>>,
pub size_limit: i64,
}
#[derive(Debug, Hash, PartialEq, Eq)]
pub enum CachedItem {
Change(libpijul::Hash),
}
#[derive(Debug, Clone)]
pub struct Resp {
pub time: std::time::SystemTime,
pub body: Vec<u8>,
}
use clap::*;
#[derive(Debug, Parser)]
pub struct App {
#[arg(short, long)]
config: PathBuf,
#[arg(short, long)]
replication: PathBuf,
}
pub fn from_app() -> (config_file::ConfigFile, replication::ConfigFile) {
let matches = App::parse();
let config = toml::from_str(&std::fs::read_to_string(&matches.config).unwrap()).unwrap();
let repl = toml::from_str(&std::fs::read_to_string(&matches.replication).unwrap()).unwrap();
(config, repl)
}
pub fn drop_privileges(c: &config_file::ConfigFile) {
if let (&Some(ref user), &Some(ref group)) = (&c.user, &c.group) {
println!("Dropping privileges");
privdrop::PrivDrop::default()
.user(&user)
.group(&group)
.apply()
.unwrap();
}
}
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
OpenSSL(#[from] openssl::error::ErrorStack),
}
pub async fn make_tls() -> Result<axum_server::tls_rustls::RustlsConfig, TlsError> {
let key = std::env::var("tls_key").unwrap().into_bytes();
let certs = std::env::var("tls_cert").unwrap().into_bytes();
Ok(axum_server::tls_rustls::RustlsConfig::from_pem(certs, key)
.await
.unwrap())
}
const SHA512_OUTPUT_LEN: usize = 512 / 8;
pub type Db = diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>;
pub async fn from_file(config_file: &config_file::ConfigFile) -> Config {
debug!("connecting to dbr");
let config = diesel_async::pooled_connection::AsyncDieselConnectionManager::<
diesel_async::AsyncPgConnection,
>::new(
config_file
.postgres
.as_deref()
.unwrap_or(&std::env::var("DATABASE_URL").unwrap()),
);
let db = diesel_async::pooled_connection::deadpool::Pool::builder(config)
.build()
.unwrap();
let ssh_keys = thrussh_keys::decode_secret_key(
&std::env::var("ssh_secret").expect("missing ssh_secret"),
None,
)
.unwrap();
let hmac_secret = {
let mut hasher = sha2::Sha512::new();
hasher.update(std::env::var("hmac_secret")
.expect("missing hmac_secret")
.as_bytes());
let mut out = [0; SHA512_OUTPUT_LEN];
out.clone_from_slice(&hasher.finalize());
out
};
let ssh_config = {
let mut config = thrussh::server::Config::default();
use thrussh::MethodSet;
config.methods =
MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE;
config.keys.push(ssh_keys);
config.maximum_packet_size = 10_000_000;
config
};
Config {
csrf: axum_csrf::CsrfConfig::new(),
repository_cache_size: config_file.repository_cache_size,
max_body_length: config_file.max_body_length,
hard_max_body_length: config_file.hard_max_body_length,
http_timeout: std::time::Duration::from_secs(config_file.http.timeout_secs),
ws_timeout: std::time::Duration::from_secs(config_file.http.ws_timeout_secs),
email: rusoto_credential::StaticProvider::new(
std::env::var("aws_access_key_id").expect("missing aws_access_key_id"),
std::env::var("aws_access_key").expect("missing aws_access_key"),
None,
None,
),
email_source: config_file.email.source.clone(),
hostname: config_file.host.clone(),
host: {
let mut h = config_file.host.clone();
if config_file.http.https_port != 443 {
h.push_str(&format!(":{}", config_file.http.https_port));
}
h
},
ssh: Arc::new(ssh_config),
ssh_timeout: std::time::Duration::from_secs(config_file.ssh.timeout_secs),
hmac_secret,
cache: RwLock::new(HashMap::new()),
cache_br: RwLock::new(HashMap::new()),
cache_gzip: RwLock::new(HashMap::new()),
repositories_path: PathBuf::from(&config_file.repositories_path),
version_time: httpdate::parse_http_date(&config_file.version_time).unwrap(),
version_time_str: config_file.version_time.clone(),
max_relative_days: config_file.time.max_relative_days,
yesterday_threshold_hours: config_file.time.yesterday_threshold_hours,
failed_auth_timeout: Duration::from_millis(config_file.failed_auth_timeout_millis),
db,
partial_change_size: config_file.partial_change_size,
change_cache: Arc::new(std::sync::Mutex::new(lru_cache::LruCache::new(
config_file.change_cache_size,
))),
hash_cache: Arc::new(std::sync::Mutex::new(lru_cache::LruCache::new(
config_file.change_cache_size,
))),
max_password_attempts: config_file.max_password_attempts,
syntax_set: syntect::parsing::SyntaxSet::load_defaults_newlines(),
svelte_socket: config_file.svelte_socket.clone(),
render_cache: Mutex::new(lru_cache::LruCache::new(1024)),
size_limit: config_file.size_limit.unwrap_or(0),
}
}
pub struct AcceptJson(pub bool);
use http::HeaderValue;
impl headers::Header for AcceptJson {
fn name() -> &'static headers::HeaderName {
&http::header::ACCEPT
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
I: Iterator<Item = &'i HeaderValue>,
{
Ok(AcceptJson(
values
.filter_map(|x| x.to_str().ok())
.flat_map(|x| x.split(","))
.any(|x| x.split(";").next() == Some("application/json")),
))
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<headers::HeaderValue>,
{
if self.0 {
values.extend(std::iter::once(headers::HeaderValue::from_static(
"application/json",
)))
}
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct Color(pub i32);
impl Color {
pub fn to_string(&self) -> String {
format!(
"#{:02x}{:02x}{:02x}",
(self.0 >> 16) & 0xff,
(self.0 >> 8) & 0xff,
self.0 & 0xff
)
}
pub fn fg(&self) -> Color {
let l = |c: f64| {
if c <= 0.03928 {
c / 12.92
} else {
((c + 0.055f64) / 1.055f64).powf(2.4)
}
};
let r: f64 = l((self.0 >> 16) as f64 / 255.);
let g = l(((self.0 >> 8) & 0xff) as f64 / 255.);
let b = l((self.0 & 0xff) as f64 / 255.);
let lum = 0.2126 * r + 0.7152 * g + 0.0722 * b;
let black_contrast = (lum + 0.05) / 0.05;
let white_contrast = 1.1 / (lum + 0.05);
if white_contrast > black_contrast {
Color(0xffffff)
} else {
Color(0)
}
}
}