Websockets的升级版:Rust WebTransport库

522次阅读  |  发布于10月以前

WebTransport是一种新的协议,用于实现客户端和服务器之间通过web进行低延迟、双向通信。它旨在通过提供更高效、更灵活的传输层来解决WebSocket协议的局限性。

WebTransport的优点:

下面我们来看一下使用WebTransport的例子。

构建项目

首先,创建一个新的Rust项目:

cargo new webtransport-example

然后,在Cargo.toml文件中加入依赖项:

[dependencies]
tokio = { version = "1.28.1", default-features = false, features = ["macros", "rt-multi-thread"] }
wtransport = {version = "0.1.8", features = ["dangerous-configuration"]}
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
anyhow = "1.0.71"
wtransport-proto = "0.1.8"
base64 = "0.21.0"
rcgen = "0.11.1"
ring = "0.17.0"
time = "0.3.21"

接下来,在项目根目录下创建一个bin目录,在bin目录下创建两个文件:server.rs和client.rs。

server.rs的代码如下:

use std::time::Duration;

use anyhow::Result;
use tracing::error;
use tracing::info;
use tracing::info_span;
use tracing::Instrument;
use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::EnvFilter;
use wtransport::endpoint::IncomingSession;
use wtransport::tls::Certificate;
use wtransport::Endpoint;
use wtransport::ServerConfig;

#[tokio::main]
async fn main() -> Result<()> {
    init_logging();
    // 配置Server端
    let config = ServerConfig::builder()
        // 设置监听端口为 4433
        .with_bind_default(4433)
        // 设置证书和密钥
        .with_certificate(Certificate::load("cert.pem", "key.pem").await?)
        .keep_alive_interval(Some(Duration::from_secs(3)))
        .build();

    let server = Endpoint::server(config)?;

    info!("Server ready!");

    for id in 0.. {
        let incoming_session = server.accept().await;
        tokio::spawn(handle_connection(incoming_session).instrument(info_span!("Connection", id)));
    }

    Ok(())
}

async fn handle_connection(incoming_session: IncomingSession) {
    let result = handle_connection_impl(incoming_session).await;
    error!("{:?}", result);
}

async fn handle_connection_impl(incoming_session: IncomingSession) -> Result<()> {
    let mut buffer = vec![0; 65536].into_boxed_slice();

    info!("Waiting for session request...");
    // 等待会话请求
    let session_request = incoming_session.await?;

    info!(
        "New session: Authority: '{}', Path: '{}'",
        session_request.authority(),
        session_request.path()
    );

    let connection = session_request.accept().await?;

    info!("Waiting for data from client...");

    // 等待来自客户端的数据
    // 多路复用
    loop {
        tokio::select! {
            // 接收双向通信流消息
            stream = connection.accept_bi() => {
                let mut stream = stream?;
                info!("Accepted BI stream");

                let bytes_read = match stream.1.read(&mut buffer).await? {
                    Some(bytes_read) => bytes_read,
                    None => continue,
                };

                let str_data = std::str::from_utf8(&buffer[..bytes_read])?;

                info!("Received (bi) '{str_data}' from client");

                stream.0.write_all(b"ACK").await?;
            }
            // 接收单向通信流消息
            stream = connection.accept_uni() => {
                let mut stream = stream?;
                info!("Accepted UNI stream");

                let bytes_read = match stream.read(&mut buffer).await? {
                    Some(bytes_read) => bytes_read,
                    None => continue,
                };

                let str_data = std::str::from_utf8(&buffer[..bytes_read])?;

                info!("Received (uni) '{str_data}' from client");

                let mut stream = connection.open_uni().await?.await?;
                stream.write_all(b"ACK").await?;
            }
            // 接收数据报消息
            dgram = connection.receive_datagram() => {
                let dgram = dgram?;
                let str_data = std::str::from_utf8(&dgram)?;

                info!("Received (dgram) '{str_data}' from client");
                connection.send_datagram(b"ACK")?;
            }
        }
    }
}

fn init_logging() {
    let env_filter = EnvFilter::builder()
        .with_default_directive(LevelFilter::INFO.into())
        .from_env_lossy();

    tracing_subscriber::fmt()
        .with_target(true)
        .with_level(true)
        .with_env_filter(env_filter)
        .init();
}

client.rs的代码如下:

use wtransport::ClientConfig;
use wtransport::Endpoint;

#[tokio::main]
async fn main() {
    // 配置 client 端
    let config = ClientConfig::builder()
        .with_bind_default()
        .with_no_cert_validation()
        .build();

    // 连接服务器的4433端口
    let connection = Endpoint::client(config)
        .unwrap()
        .connect("https://[::1]:4433")
        .await
        .unwrap();

    // 使用双向通信流发送消息
    let mut stream = connection.open_bi().await.unwrap().await.unwrap();
    stream.0.write_all(b"HELLO").await.unwrap();
    stream.0.finish().await.unwrap();

    // 使用单向通信流发消息
    let mut stream = connection.open_uni().await.unwrap().await.unwrap();
    stream.write_all(b"WORLD").await.unwrap();
    stream.finish().await.unwrap();

    // 发送数据报消息
    connection.send_datagram(b"Hello, world!").unwrap();

    tokio::time::sleep(Duration::from_secs(3)).await;
}

在bin目录下创建gencert.rs文件,用于生成证书。代码如下:

use base64::engine::general_purpose::STANDARD as Base64Engine;
use base64::Engine;
use rcgen::CertificateParams;
use rcgen::DistinguishedName;
use rcgen::DnType;
use rcgen::KeyPair;
use rcgen::PKCS_ECDSA_P256_SHA256;
use ring::digest::digest;
use ring::digest::SHA256;
use std::fs;
use std::io::Write;
use time::Duration;
use time::OffsetDateTime;

fn main() {
    const COMMON_NAME: &str = "localhost";

    let mut dname = DistinguishedName::new();
    dname.push(DnType::CommonName, COMMON_NAME);

    let keypair = KeyPair::generate(&PKCS_ECDSA_P256_SHA256).unwrap();

    let digest = digest(&SHA256, &keypair.public_key_der());

    let mut cert_params = CertificateParams::new(vec![COMMON_NAME.to_string()]);
    cert_params.distinguished_name = dname;
    cert_params.alg = &PKCS_ECDSA_P256_SHA256;
    cert_params.key_pair = Some(keypair);
    cert_params.not_before = OffsetDateTime::now_utc()
        .checked_sub(Duration::days(2))
        .unwrap();
    cert_params.not_after = OffsetDateTime::now_utc()
        .checked_add(Duration::days(2))
        .unwrap();

    let certificate = rcgen::Certificate::from_params(cert_params).unwrap();

    fs::File::create("cert.pem")
        .unwrap()
        .write_all(certificate.serialize_pem().unwrap().as_bytes())
        .unwrap();

    fs::File::create("key.pem")
        .unwrap()
        .write_all(certificate.serialize_private_key_pem().as_bytes())
        .unwrap();

    println!("Certificate generated");
    println!("Fingerprint: {}", Base64Engine.encode(digest));
}

测试

执行如下命令生成证书:

cargo run --bin gencert

在项目根目录下生成cert.pem和key.pem两个文件。

执行如下命令启动server:

cargo run --bin server

然后打开一个新的终端,执行如下命令运行client:

cargo run --bin client

这时,服务器端的输出信息如下:

INFO server: Server ready!
INFO Connection{id=0}: server: Waiting for session request...
INFO Connection{id=0}: server: New session: Authority: '[::1]:4433', Path: '/'
INFO Connection{id=0}: server: Waiting for data from client...
INFO Connection{id=0}: server: Accepted BI stream
INFO Connection{id=0}: server: Received (bi) 'HELLO' from client
INFO Connection{id=0}: server: Accepted UNI stream
INFO Connection{id=0}: server: Received (uni) 'WORLD' from client
INFO Connection{id=0}: server: Received (dgram) 'Hello, world!' from client

Copyright© 2013-2020

All Rights Reserved 京ICP备2023019179号-8