use crate::{
    internal_events::{GrpcServerRequestReceived, GrpcServerResponseSent},
    shutdown::{ShutdownSignal, ShutdownSignalToken},
    tls::MaybeTlsSettings,
};
use futures::FutureExt;
use http::{Request, Response};
use hyper::Body;
use std::{convert::Infallible, net::SocketAddr, time::Duration};
use tonic::transport::server::Routes;
use tonic::{body::BoxBody, server::NamedService, transport::server::Server};
use tower::Service;
use tower_http::{
    classify::{GrpcErrorsAsFailures, SharedClassifier},
    trace::TraceLayer,
};
use tracing::Span;
mod decompression;
pub use self::decompression::{DecompressionAndMetrics, DecompressionAndMetricsLayer};
pub async fn run_grpc_server<S>(
    address: SocketAddr,
    tls_settings: MaybeTlsSettings,
    service: S,
    shutdown: ShutdownSignal,
) -> crate::Result<()>
where
    S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible>
        + NamedService
        + Clone
        + Send
        + 'static,
    S::Future: Send + 'static,
{
    let span = Span::current();
    let (tx, rx) = tokio::sync::oneshot::channel::<ShutdownSignalToken>();
    let listener = tls_settings.bind(&address).await?;
    let stream = listener.accept_stream();
    info!(%address, "Building gRPC server.");
    Server::builder()
        .layer(build_grpc_trace_layer(span.clone()))
        .layer(DecompressionAndMetricsLayer)
        .add_service(service)
        .serve_with_incoming_shutdown(stream, shutdown.map(|token| tx.send(token).unwrap()))
        .await?;
    drop(rx.await);
    Ok(())
}
pub async fn run_grpc_server_with_routes(
    address: SocketAddr,
    tls_settings: MaybeTlsSettings,
    routes: Routes,
    shutdown: ShutdownSignal,
) -> crate::Result<()> {
    let span = Span::current();
    let (tx, rx) = tokio::sync::oneshot::channel::<ShutdownSignalToken>();
    let listener = tls_settings.bind(&address).await?;
    let stream = listener.accept_stream();
    info!(%address, "Building gRPC server.");
    Server::builder()
        .layer(build_grpc_trace_layer(span.clone()))
        .layer(DecompressionAndMetricsLayer)
        .add_routes(routes)
        .serve_with_incoming_shutdown(stream, shutdown.map(|token| tx.send(token).unwrap()))
        .await?;
    drop(rx.await);
    Ok(())
}
pub fn build_grpc_trace_layer(
    span: Span,
) -> TraceLayer<
    SharedClassifier<GrpcErrorsAsFailures>,
    impl Fn(&Request<Body>) -> Span + Clone,
    impl Fn(&Request<Body>, &Span) + Clone,
    impl Fn(&Response<BoxBody>, Duration, &Span) + Clone,
    (),
    (),
    (),
> {
    TraceLayer::new_for_grpc()
        .make_span_with(move |request: &Request<Body>| {
            let mut path = request.uri().path().split('/');
            let service = path.nth(1).unwrap_or("_unknown");
            let method = path.next().unwrap_or("_unknown");
            error_span!(
               parent: &span,
               "grpc-request",
               grpc_service = service,
               grpc_method = method,
            )
        })
        .on_request(Box::new(|_request: &Request<Body>, _span: &Span| {
            emit!(GrpcServerRequestReceived);
        }))
        .on_response(
            |response: &Response<BoxBody>, latency: Duration, _span: &Span| {
                emit!(GrpcServerResponseSent { response, latency });
            },
        )
        .on_failure(())
        .on_body_chunk(())
        .on_eos(())
}