aboutsummaryrefslogtreecommitdiff
path: root/aero-proto/src/dav/middleware.rs
blob: 8964699064a73cd079e4d5e06fcdb1f9ad009b4b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
use anyhow::{anyhow, Result};
use base64::Engine;
use hyper::body::Incoming;
use hyper::{Request, Response};

use aero_collections::user::User;
use aero_user::login::ArcLoginProvider;

use super::codec::text_body;
use super::controller::HttpResponse;

type ArcUser = std::sync::Arc<User>;

pub(super) async fn auth<'a>(
    login: ArcLoginProvider,
    req: Request<Incoming>,
    next: impl Fn(ArcUser, Request<Incoming>) -> futures::future::BoxFuture<'a, Result<HttpResponse>>,
) -> Result<HttpResponse> {
    let auth_val = match req.headers().get(hyper::header::AUTHORIZATION) {
        Some(hv) => hv.to_str()?,
        None => {
            tracing::info!("Missing authorization field");
            return Ok(Response::builder()
                .status(401)
                .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
                .body(text_body("Missing Authorization field"))?);
        }
    };

    let b64_creds_maybe_padded = match auth_val.split_once(" ") {
        Some(("Basic", b64)) => b64,
        _ => {
            tracing::info!("Unsupported authorization field");
            return Ok(Response::builder()
                .status(400)
                .body(text_body("Unsupported Authorization field"))?);
        }
    };

    // base64urlencoded may have trailing equals, base64urlsafe has not
    // theoretically authorization is padded but "be liberal in what you accept"
    let b64_creds_clean = b64_creds_maybe_padded.trim_end_matches('=');

    // Decode base64
    let creds = base64::engine::general_purpose::STANDARD_NO_PAD.decode(b64_creds_clean)?;
    let str_creds = std::str::from_utf8(&creds)?;

    // Split username and password
    let (username, password) = str_creds.split_once(':').ok_or(anyhow!(
        "Missing colon in Authorization, can't split decoded value into a username/password pair"
    ))?;

    // Call login provider
    let creds = match login.login(username, password).await {
        Ok(c) => c,
        Err(_) => {
            tracing::info!(user = username, "Wrong credentials");
            return Ok(Response::builder()
                .status(401)
                .header("WWW-Authenticate", "Basic realm=\"Aerogramme\"")
                .body(text_body("Wrong credentials"))?);
        }
    };

    // Build a user
    let user = User::new(username.into(), creds).await?;

    // Call router with user
    next(user, req).await
}