use crate::config_file;
use maxminddb;
use openssl;
use openssl::hash::MessageDigest;
use rusoto_credential;
use std::collections::HashMap;
use std::net::IpAddr;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, SystemTime};
use thiserror::*;
use tracing::*;
use {thrussh, thrussh_keys};
pub const RFC1123: &str = "%a, %d %b %Y %H:%M:%S GMT";
pub struct Cached {
pub body: bytes::Bytes,
pub content_type: Option<hyper::header::HeaderValue>,
}
pub struct Config {
pub csrf: axum_csrf::CsrfConfig,
pub etcd_server: String,
pub repository_cache_size: usize,
pub max_body_length: u64,
pub hard_max_body_length: u64,
pub host: String,
pub hostname: String,
pub http_timeout: std::time::Duration,
pub ssh: Arc<thrussh::server::Config>,
pub ssh_timeout: std::time::Duration,
pub hmac_secret: [u8; SHA512_OUTPUT_LEN],
pub repositories_path: PathBuf,
pub version_time: SystemTime,
pub version_time_str: String,
pub max_relative_days: usize,
pub yesterday_threshold_hours: usize,
pub maxmind: maxminddb::Reader<Vec<u8>>,
pub failed_auth_timeout: Duration,
pub email: Option<rusoto_credential::StaticProvider>,
pub email_source: String,
pub aws_ses_zone: rusoto_core::Region,
pub cache: RwLock<HashMap<String, Cached>>,
pub cache_gzip: RwLock<HashMap<String, Cached>>,
pub cache_br: RwLock<HashMap<String, Cached>>,
pub db: Db,
pub partial_change_size: u64,
pub ws_timeout: std::time::Duration,
pub change_cache: Arc<
std::sync::Mutex<
lru_cache::LruCache<
(uuid::Uuid, libpijul::ChangeId),
Arc<std::sync::Mutex<libpijul::change::ChangeFile>>,
>,
>,
>,
pub hash_cache: Arc<
std::sync::Mutex<
lru_cache::LruCache<
(uuid::Uuid, libpijul::Hash),
Arc<std::sync::Mutex<libpijul::change::ChangeFile>>,
>,
>,
>,
pub max_password_attempts: i64,
pub editor: config_file::Editor,
pub stripe_config: Option<Stripe>,
pub webauthn: webauthn_rs::Webauthn,
pub pro_prix_euros: u32,
pub syntax_set: syntect::parsing::SyntaxSet,
pub svelte_socket: Option<String>,
pub render_cache: Mutex<lru_cache::LruCache<CachedItem, Resp>>,
pub basic_size_limit: i64,
pub pro_size_limit: i64,
}
pub struct Stripe {
pub publishable_key: String,
pub webhook_secret: String,
pub pro: String,
}
#[derive(Debug, Hash, PartialEq, Eq)]
pub enum CachedItem {
Change(libpijul::Hash),
}
#[derive(Debug, Clone)]
pub struct Resp {
pub time: std::time::SystemTime,
pub body: Vec<u8>,
}
use clap::*;
#[derive(Debug, Parser)]
pub struct App {
#[arg(short, long)]
config: PathBuf,
#[arg(short, long)]
replication: PathBuf,
}
pub fn from_app() -> (config_file::ConfigFile, replication::ConfigFile) {
let matches = App::parse();
let config = toml::from_str(&std::fs::read_to_string(&matches.config).unwrap()).unwrap();
let repl = toml::from_str(&std::fs::read_to_string(&matches.replication).unwrap()).unwrap();
(config, repl)
}
pub fn drop_privileges(c: &config_file::ConfigFile) {
if let (&Some(ref user), &Some(ref group)) = (&c.user, &c.group) {
println!("Dropping privileges");
privdrop::PrivDrop::default()
.user(&user)
.group(&group)
.apply()
.unwrap();
}
}
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
OpenSSL(#[from] openssl::error::ErrorStack),
}
pub async fn make_tls() -> Result<Option<axum_server::tls_rustls::RustlsConfig>, TlsError> {
if let (Ok(key), Ok(cert)) = (std::env::var("tls_key"), std::env::var("tls_cert")) {
Ok(Some(
axum_server::tls_rustls::RustlsConfig::from_pem(cert.into_bytes(), key.into_bytes())
.await
.unwrap(),
))
} else {
Ok(None)
}
}
const SHA512_OUTPUT_LEN: usize = 512 / 8;
pub type Db = diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>;
pub async fn from_file(config_file: &config_file::ConfigFile) -> Config {
debug!("connecting to dbr");
let config = diesel_async::pooled_connection::AsyncDieselConnectionManager::<
diesel_async::AsyncPgConnection,
>::new(
config_file
.postgres
.as_deref()
.unwrap_or(&std::env::var("DATABASE_URL").unwrap()),
);
let db = diesel_async::pooled_connection::deadpool::Pool::builder(config)
.build()
.unwrap();
let ssh_keys = thrussh_keys::decode_secret_key(
&std::env::var("ssh_secret").expect("missing ssh_secret"),
None,
)
.unwrap();
let hmac_secret = {
let mut out = [0; SHA512_OUTPUT_LEN];
openssl::pkcs5::pbkdf2_hmac(
std::env::var("pbkdf2_password")
.expect("missing pbkdf2_password")
.as_bytes(),
std::env::var("pbkdf2_salt")
.expect("missing pbkdf2_salt")
.as_bytes(),
config_file.pbkdf2_iterations,
MessageDigest::sha512(),
&mut out,
)
.unwrap();
out
};
let ssh_config = {
let mut config = thrussh::server::Config::default();
use thrussh::MethodSet;
config.methods =
MethodSet::PUBLICKEY | MethodSet::PASSWORD | MethodSet::KEYBOARD_INTERACTIVE;
config.keys.push(ssh_keys);
config.maximum_packet_size = 10_000_000;
config
};
let static_time: u64 = std::fs::read_to_string(&config_file.http.time_file)
.unwrap()
.trim()
.parse()
.unwrap();
let version_time =
std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(static_time);
Config {
csrf: axum_csrf::CsrfConfig::new(),
etcd_server: config_file.etcd_server.clone(),
repository_cache_size: config_file.repository_cache_size,
max_body_length: config_file.max_body_length,
hard_max_body_length: config_file.hard_max_body_length,
http_timeout: std::time::Duration::from_secs(config_file.http.timeout_secs),
ws_timeout: std::time::Duration::from_secs(config_file.http.ws_timeout_secs),
email: email(),
email_source: config_file.email.source.clone(),
aws_ses_zone: config_file.email.region.parse().unwrap(),
hostname: config_file.hostname.clone(),
host: config_file.host.clone(),
ssh: Arc::new(ssh_config),
ssh_timeout: std::time::Duration::from_secs(config_file.ssh.timeout_secs),
hmac_secret,
cache: RwLock::new(HashMap::new()),
cache_br: RwLock::new(HashMap::new()),
cache_gzip: RwLock::new(HashMap::new()),
repositories_path: PathBuf::from(&config_file.repositories_path),
version_time,
version_time_str: httpdate::fmt_http_date(version_time),
max_relative_days: config_file.time.max_relative_days,
yesterday_threshold_hours: config_file.time.yesterday_threshold_hours,
maxmind: {
debug!("maxminddb config");
if let Ok(path) = std::env::var("GEOLITE2_PATH") {
maxminddb::Reader::open_readfile(&path).unwrap()
} else {
maxminddb::Reader::open_readfile(&config_file.geoip_database).unwrap()
}
},
failed_auth_timeout: Duration::from_millis(config_file.failed_auth_timeout_millis),
db,
partial_change_size: config_file.partial_change_size,
change_cache: Arc::new(std::sync::Mutex::new(lru_cache::LruCache::new(
config_file.change_cache_size,
))),
hash_cache: Arc::new(std::sync::Mutex::new(lru_cache::LruCache::new(
config_file.change_cache_size,
))),
max_password_attempts: config_file.max_password_attempts,
editor: config_file.editor.clone(),
stripe_config: stripe_config(),
webauthn: webauthn_rs::WebauthnBuilder::new(
&config_file.webauthn.rp_id,
&url::Url::parse(&config_file.webauthn.rp_origin).unwrap(),
)
.unwrap()
.build()
.unwrap(),
pro_prix_euros: config_file.pro_prix_euros,
syntax_set: syntect::parsing::SyntaxSet::load_defaults_newlines(),
svelte_socket: config_file.svelte_socket.clone(),
render_cache: Mutex::new(lru_cache::LruCache::new(1024)),
basic_size_limit: config_file.basic_size_limit,
pro_size_limit: config_file.pro_size_limit,
}
}
fn stripe_config() -> Option<Stripe> {
Some(Stripe {
publishable_key: std::env::var("stripe_publishable").ok()?,
webhook_secret: std::env::var("stripe_webhook_secret").ok()?,
pro: std::env::var("stripe_pro").ok()?,
})
}
fn email() -> Option<rusoto_credential::StaticProvider> {
Some(rusoto_credential::StaticProvider::new(
std::env::var("aws_access_key_id").ok()?,
std::env::var("aws_access_key").ok()?,
None,
None,
))
}
impl Config {
pub fn ip_lookup<'a>(
&'a self,
ip: IpAddr,
) -> Result<Option<maxminddb::geoip2::City<'a>>, maxminddb::MaxMindDbError> {
let ip = match ip {
IpAddr::V6(addr) => {
if let Some(ip) = addr.to_ipv4() {
IpAddr::V4(ip)
} else {
IpAddr::V6(addr)
}
}
ip => ip,
};
self.maxmind.lookup(ip)
}
}
pub struct AcceptJson(pub bool);
use http::HeaderValue;
impl headers::Header for AcceptJson {
fn name() -> &'static headers::HeaderName {
&http::header::ACCEPT
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
I: Iterator<Item = &'i HeaderValue>,
{
Ok(AcceptJson(
values
.filter_map(|x| x.to_str().ok())
.flat_map(|x| x.split(","))
.any(|x| x.split(";").next() == Some("application/json")),
))
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<headers::HeaderValue>,
{
if self.0 {
values.extend(std::iter::once(headers::HeaderValue::from_static(
"application/json",
)))
}
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct Color(pub i32);
impl Color {
pub fn to_string(&self) -> String {
format!(
"#{:02x}{:02x}{:02x}",
(self.0 >> 16) & 0xff,
(self.0 >> 8) & 0xff,
self.0 & 0xff
)
}
pub fn fg(&self) -> Color {
let l = |c: f64| {
if c <= 0.03928 {
c / 12.92
} else {
((c + 0.055f64) / 1.055f64).powf(2.4)
}
};
let r: f64 = l((self.0 >> 16) as f64 / 255.);
let g = l(((self.0 >> 8) & 0xff) as f64 / 255.);
let b = l((self.0 & 0xff) as f64 / 255.);
let lum = 0.2126 * r + 0.7152 * g + 0.0722 * b;
let black_contrast = (lum + 0.05) / 0.05;
let white_contrast = 1.1 / (lum + 0.05);
if white_contrast > black_contrast {
Color(0xffffff)
} else {
Color(0)
}
}
}