Type-erasing trait parameters in Rust

2023-01-11 22:22:00 +00:00Freyja

Rust traits are great, but sometimes it's useful to erase compile-time trait parameters and associated types.

In this blog post, I'll show how to take a trait with type parameters, and build a container type that can contain multiple implementations with different type parameters and associated types.

Basic type erasure in Rust

Let's start out with the basics: how do you erase types known at compile time in your program, so that you can store values of different types in the same location?

There are two common approaches in Rust:

  1. Using enums, with one variant for each potential type.
  2. Using trait objects (dyn Trait).

The former is useful for a number of cases, especially when there's a limited number of known potential types (such as JSON types), but for this post I'll be talking about the second case.

A trait Trait has an associated trait object dyn Trait if it satisfies a property called object safety. There are a lot of specific details around this, but in particular this means only having methods that have a self receiver (or a variant like &self, &mut self, self: Box<Self>, etc) and not having the Self type appear anywhere else in the method signature.

For example, the std::io::Read trait is an example of an object safe trait:

trait Read {
    fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error>;
}

If we have a pointer (including references) to a type that implements Read, we can cast it into a pointer to the trait object dyn Read instead. In this way we're "erasing" the compile-time type of our reader, and only retaining the fact that it implements Read. This means that we can use the same trait object type for multiple underlying reader implementations:

use std::fs;
use std::io;
use std::path::PathBuf;

/// If `path` is specified, read from a file; otherwise, read from a byte array.
fn read_or_default(path: Option<PathBuf>) -> Result<Box<dyn io::Read>, io::Error> {
    match path {
        // Box<File>  ->  Box<dyn Read>
        Some(path) => Ok(Box::new(fs::File::open(path)?)), 
        // Box<&'static [u8]>  ->  Box<dyn Read>
        None => Ok(Box::new(b"default data".as_slice())),  
    }
}

Type erasure with type parameters

Types with type parameters or associated types can also be object safe. However, all of the type parameters and associated types must be specified as part of the trait object.

For example, for the Service trait defined below (a simplified version of the tower::Service trait), the trait object must be specified as dyn Service<Request, Response = Response, Future = Future>.

trait Service<Request> {
    type Response;
    type Future: Future<Output = Self::Response>;

    fn call(&mut self, req: Request) -> Self::Future;
}

Sometimes we also want to erase some or all of the type parameters and associated types, and this quickly becomes a bit more tricky. To start with, when we make a trait object of Service, it may be that we care about the types of Request and Service::Response, but not about the type of Service::Future. After all, we'll likely call this as svc.call(req).await anyways, in which case the .await keyword won't care about the type of Service::Future beyond the fact that it implements the Future trait1.

Since Future is an object safe trait, we can use Pin<Box<dyn 'a + Future<Output = Service::Response>>>2 to erase the future type itself, though we still need to specify the output type of the future (the resulting type when evaluating future.await). The lifetime 'a indicates the minimum lifetime of the underlying type (which will be 'static if it's not specified), and we wrap the boxed future in Pin for convenience, since futures must be pinned to be awaitable anyways.

Based on this, we can start making a UnfuturedService type that will erased the type of Service::Future of a service. The type itself will just be a simple wrapper around the underlying service type:

struct UnfuturedService<'a, S, Request>
where
    S: Service<Request>,
    S::Future: 'a,
{
    svc: S,
    _phantom: PhantomData<&'a fn(Request)>
}

The _phantom field is just needed to indicate that the Request type parameter is "used"; if omitted, the Rust compiler will complain about this.

The next step is to implement Service for the wrapper type, where we replace the Future associated type with the type-erased Pin<Box<dyn 'a + Future<Output = S::Response>>>:

impl<'a, S, Request> Service<Request> for UnfuturedService<'a, S, Request>
where
    S: Service<Request>,
    S::Future: 'a,
{
    type Response = S::Response;
    type Future = Pin<Box<dyn 'a + Future<Output = S::Response>>>;

    fn call(&mut self, req: Request) -> Self::Future {
        Box::pin(self.svc.call(req))
    }
}

Box::pin is a convenient function that lets us construct a Pin<Box<_>> without needing any unsafe code, which is often the case when dealing with pinned types.

Finally, we can turn UnfuturedService itself into a trait object to erase the information about the underlying type S, only leaving the request and response types:

type BoxUnfuturedService<'a, Request, Response> =
    Box<dyn 'a + Service<
        Request,
        Response = Response,
        Future = Pin<Box<dyn 'a + Future<Output = Response>>>,
    >>;

fn boxed_unfutured_service<'a, S, Request>(svc: S) -> BoxUnfuturedService<'a, Request, S::Response>
where
    S: 'a + Service<Request>,
    S::Future: 'a,
    Request: 'a,
{
    Box::new(UnfuturedService { svc, _phantom: PhantomData })
}

Erasing all type parameters

Going one step further, we may want to erase all of the type parameters and associated types for a service. This may be the case if we want to keep a collection of services which can have many different type parameters.

For demonstration purposes, we'll consider a map type ServiceMap where services can be registered for different request types, and afterwards the ServiceMap can be called with those different service maps and will print the response using its Debug implementation. So in pseudo-code, our ServiceMap would look something like this:

struct ServiceMap { ... }

impl ServiceMap {
    fn register<S, Request>(&mut self, svc: S)
    where
        S: Service<Request>,
        S::Response: Debug,
    {
        ...
    }

    fn call<Request>(&mut self, req: Request) {
        if let Some(svc) = ... /* impl Service<Request> */ {
            let resp = svc.call(req).await;
            println!("{resp:?}");
        }
    }
}

We can start out by identifying the constraints on each of our type parameters:

  1. Request needs to implement Any, because we need to be able to convert it into a trait object and back. Box<dyn Any> has a downcast method that can be used to obtain the original value when the type is known at compile-time.
  2. Response needs to implement Debug, since we will need to print it but we don't need anything else.
  3. Future needs to implement Future, from the constraints of the Service trait itself.

Since Any, Debug and Future are all object safe traits, we can use Box<dyn Any>, Box<dyn Debug> and Pin<Box<dyn Future>> to represent these with the types erased.

As before, we need to build a wrapper service, but this time with a bit more code:

struct UntypedService<'a, S, Request, Response>
where
    S: Service<Request, Response = Response>,
    S::Future: 'a,
    Request: Any,
    Response: Debug,
{
    svc: S,
    _phantom: PhantomData<&'a fn(Request) -> Response>,
}

impl<'a, S, Request, Response> Service<Box<dyn Any>> for UntypedService<'a, S, Request, Response>
where
    S: Service<Request, Response = Response>,
    S::Future: 'a,
    Request: Any,
    Response: 'a + Debug,
{
    type Response = Box<dyn 'a + Debug>;
    type Future = Pin<Box<dyn 'a + Future<Output = Self::Response>>>;

    fn call(&mut self, req: Box<dyn Any>) -> Self::Future {
        let req: Request = *req.downcast::<Request>()
            .ok()
            .expect("wrong request type");

        let future = self.svc.call(req);

        Box::pin(async move { Box::new(future.await) as Box<dyn '_ + Debug> })
    }
}

The most significant thing of note here is the use of the aforementioned Box<dyn Any>::downcast method to convert from a Box<dyn Any> into a Box<Request> value, which is then dereferenced with the * operator. The downcast method returns Result<Box<Request>, Box<dyn Any>>, but in this case we panic if it fails.

As before we can turn UntypedService itself into a trait object, this time erasing all of the type information:

type BoxUntypedService<'a> =
    Box<dyn 'a + Service<
        Box<dyn Any>,
        Response = Box<dyn 'a + Debug>,
        Future = Pin<Box<dyn 'a + Future<Output = Box<dyn 'a + Debug>>>>,
    >>;

fn boxed_untyped_service<'a, S, Request>(svc: S) -> BoxUntypedService<'a>
where
    S: 'a + Service<Request>,
    S::Future: 'a,
    S::Response: 'a + Debug,
    Request: Any,
{
    Box::new(UntypedService { svc, _phantom: PhantomData })
}

Building the ServiceMap

The final remaining step is to build the ServiceMap type itself.

The type itself will be quite simple, containing only a HashMap with TypeIds as keys, which will refer to the types of Request at runtime, and with the BoxUntypedService type we defined in the last section as values:

#[derive(Default)]
struct ServiceMap {
    map: HashMap<TypeId, BoxUntypedService<'static>>,
}

The first method we required our type to have was the register method that adds a new service to the map. This is actually quite straight-forward:

impl ServiceMap {
    fn register<S, Request>(&mut self, svc: S)
    where
        S: 'static + Service<Request>,
        S::Future: 'static,
        S::Response: 'static + Debug,
        Request: Any,
    {
        self.map.insert(
            TypeId::of::<Request>(),
            boxed_untyped_service(svc),
        );
    }
}

To call this function, we just need to look it up in the map and convert the request into Box<dyn Any> that our "untyped" service can accept:

impl ServiceMap {
    async fn call<Request>(&mut self, req: Request)
    where
        Request: Any,
    {
        let Some(svc) = self.map.get_mut(&TypeId::of::<Request>()) else { return };
        
        let resp = svc.call(Box::new(req)).await;
        println!("{resp:?}");
    }
}

Example

Here's an example that demonstrates the code we wrote above. You can run this in the Rust playground by clicking here.

#[tokio::main]
async fn main() {
    let mut service_map = ServiceMap::default();
    service_map.register(Foo {});
    service_map.register(Bar {});

    // prints: "got i64: 123"
    service_map.call(123i64).await;

    // prints: Point(123.0, 456.0)
    service_map.call((123.0f64, 456.0f64)).await;

    // prints nothing, since no service has been registered for `Request = &str`
    service_map.call("str").await;
}

struct Foo {}

impl Service<i64> for Foo {
    type Response = String;
    type Future = Ready<String>;

    fn call(&mut self, req: i64) -> Ready<String> {
        ready(format!("got i64: {req:?}"))
    }
}

#[derive(Debug)]
struct Point(f64, f64);

struct Bar {}

impl Service<(f64, f64)> for Bar {
    type Response = Point;
    type Future = Ready<Point>;

    fn call(&mut self, req: (f64, f64)) -> Ready<Point> {
        ready(Point(req.0, req.1))
    }
}

Footnotes

  1. Or, as of Rust 1.64, the IntoFuture trait, which is automatically implemented for all Future types.

  2. Or BoxFuture, which is a conventient alias for Pin<Box<dyn Future>> from the futures crate.