Article

Storing borrowed data in trait objects

This is the third – and last – post of a series where we go through the performance work we did in rustls, a modern TLS library in Rust. Today we avoid and/or defer heap allocations using Cow-like enums embedded in trait objects.

Published on 19 min read

    Recap of the second post

    In the second post, we finished working through the pipeline that decodes and decrypts incoming TLS records. To remove the heap allocations in the second decoding pass, we added a generic lifetime parameter to the Codec trait. This change lets us borrow data from the input bytes and put those slices in the decoded data structure returned by the decode method instead of copying those slices into fresh heap allocations.

    struct Reader<'a> {
        buffer: &'a [u8], 
        // ..
    }
    
    trait Codec<'r>: Sized {
        //      ^^ added this lifetime parameter
        fn read(_: &mut Reader<'r>) -> Result<Self, /* .. */>;
        //                     ^^ and used it here
    }
    
    // BEFORE
    // struct PayloadU8(Vec<u8>); 
    
    // AFTER
    struct PayloadU8<'a>(&'a [u8]); 
    
    // we can now tie the lifetime ..
    impl<'a> Codec<'a> for PayloadU8<'a> {
    //   ^^        ^^                ^^ .. in Self ..
        fn read(r: &mut Reader<'a>) -> Result<Self, /* ... */> { 
            // ..              ^^ .. to the lifetime in Reader
            let len = u8::read(r)? as usize;
            let mut sub = r.sub(len)?;
    
            // BEFORE
            // let body = sub.rest().to_vec(); 
    
            // AFTER: no allocation
            let body = sub.rest(); 
    
            Ok(Self(body))
        }
    }
    

    The TLS State machine

    Now, let's see how the Connection uses the TLS records after decoding and decrypting them. Let's look at the process_new_packets method:

    struct Connection {
        incoming_tls: MessageDeframer,
        state: Result<Box<dyn State>, /* .. */>,
        // ..
    }
    
    impl Connection {
        fn process_new_packets(
            &mut self,
            /* .. */,
        ) -> /* .. */  {
            // take out state machine
            let mut state =
                match mem::replace(&mut self.state, /* .. */) {
                // ..
            };
            
            while let Some(msg) =
                self.incoming_tls.pop(/* .. */)?
            {
                match self.process_msg(msg, state) {
                    // advance state machine
                    Ok(new) => state = new, // <-
    
                    Err(e) => {
                        self.state = Err(e.clone());
                        return Err(e);
                    }
                }
            }
            
            self.state = Ok(state); // <- put it back in `self`
            // ..
        }
    }
    

    The msg variable that comes out of message_deframer is a decoded and decrypted TLS record. This record is used to advance a state machine: process_msg takes the current state and the record and produces a new state.

    The state machine is stored in the Connection object. At the start of the function, the current state is moved out of the Connection object – this is not unlike our use of mem::take in the first post of this series. process_new_packets will use all the currently available TLS records to advance the state machine. When no more records are available in incoming_tls, the state gets put back in the Connection object.

    For the curious: This state machine contains the TLS handshake logic. The handshake is the part of the TLS protocol that needs to happen before encrypted application data can be exchanged between two peers.

    Discarding old records

    Something I have not mentioned so far is that the original MessageDeframer::pop method discards bytes from the front of its internal bytes buffer (Vec). That makes sense; once the handshake is complete there's no need to keep those initial handshake records in the buffer. By discarding them, we have more space available to receive application data and / or we can minimize the capacity of the Vec buffer.

    However, discarding data from buffer does not mesh well with the idea of representing all TLS records as borrowed data (references) because the discard operation will invalidate that borrowed data. To be able to represent TLS records as borrowed data, we'll need to modify MessageDeframer::pop.

    MessageDeframer refresher

    But first, if you don't quite remember the details about the MessageDeframer, which was the focus of the first post, here's a quick refresher:

    • The deframer is a newtype over a contiguous buffer of incoming TLS data (Vec<u8>)
    • The TLS data is made of runtime-sized records (think slices, not fixed-size arrays); the deframer's job is to yield one record at a time (pop method)
    • Some TLS records are encrypted; the deframer will decrypt them if needed as part of the pop operation

    Batched discard

    These are the constraints of the problem we want to solve:

    • MessageDeframer stores the incoming TLS data in a heap allocation.
    • To avoid new heap allocations, the pop method must return the decrypted record as a slice of the incoming TLS data – that was covered in the first post.
    • process_new_packets may pop more than one TLS record; therefore pop must not freeze the MessageDeframer.
    • Finally, when we are done processing all the records we popped, we want to remove all the used data from MessageDeframer.bytes and shrink it (in length, not capacity) – this is the "discard" operation.
    • That is, compared to the original logic, we want one explicit, big discard after several pops (more efficient), rather than one implicit, small discard inside each pop (more convenient).

    Given the constraints, I decided to create a borrowed variant of MessageDeframer and moved the pop method to the borrowed variant. The BorrowedMessageDeframer keeps track of how much TLS data has been yielded through its pop method. When we are done popping records, we get the number of bytes consumed from BorrowedMessageDeframer and explicitly discard those bytes from the front of MessageDeframer.bytes.

    A small example that uses the new API is shown below:

    // the incoming TLS buffer contains 2 complete 3-byte records and
    // one incomplete record
    let mut deframer = MessageDeframer {
        bytes: vec![0, 0, 0, 1, 1, 1, 2, 2],
    };
    
    let mut borrowed = deframer.borrow();
    
    let first = borrowed.pop();
    // we can call `pop` again while `first` is alive because
    // no freezing occurs
    let second = borrowed.pop();
    
    // these are complete records
    assert_eq!(Some(&[0, 0, 0][..]), first);
    assert_eq!(Some(&[1, 1, 1][..]), second);
    
    // the next record is incomplete
    assert_eq!(None, borrowed.pop());
    
    // this shrinks the internal `Vec` buffer
    let consumed = borrowed.done();
    deframer.discard(consumed);
    
    assert_eq!(2, deframer.bytes.len());
    

    A minimal (because the actual deframing logic is quite complex) implementation of the above semantics is shown below:

    struct MessageDeframer {
        bytes: Vec<u8>,
    }
    
    impl MessageDeframer {
        fn borrow(&mut self) -> BorrowedMessageDeframer {
            BorrowedMessageDeframer {
                bytes: &mut self.bytes,
                used: 0,
            }
        }
    
        // discards `num_bytes` from the front of `bytes`
        fn discard(&mut self, num_bytes: usize) {
            let len = self.bytes.len();
            self.bytes.copy_within(num_bytes..len, 0);
            self.bytes.resize(len - num_bytes, 0);
        }
    }
    
    struct BorrowedMessageDeframer<'a> {
        bytes: &'a mut [u8],
        used: usize,
    }
    
    impl<'a> BorrowedMessageDeframer<'a> {
        // NOTE: no freezing because the returned lifetime 'a is
        // unrelated to 's
        fn pop<'s>(&'s mut self) -> Option<&'a [u8]> {
            // for simplicity, we assume it's constant size
            // in the real world, the size will be read out from
            // the record header
            let record_size = 3;
    
            if self.bytes.len() < record_size {
                return None; // incomplete record
            }
    
            self.used += record_size;
            let (record, rest) =
                mem::take(&mut self.bytes).split_at_mut(record_size);
            self.bytes = rest;
            Some(decrypt(record))
        }
    
        fn done(self) -> usize {
            self.used
        }
    }
    
    // details omitted
    fn decrypt<'a>(bytes: &'a mut [u8]) -> &'a [u8] {
        // decryption happens "in-place" and mutates `bytes`
        // once decrypted, we can drop the mutability to
        // avoid further unintended mutation
        &*bytes
    }
    

    In practice, the actual API change was a bit more elaborate because we wanted to support "caller-side buffers" where the end-user manages the MessageDeframer.bytes buffer instead of it being managed by rustls. But, ultimately, the semantics were the same.

    What's important to highlight here is that all borrows from MessageDeframer.bytes must end before discard is called to avoid invalidation of references. The compiler will point out any such misuse:

    let mut borrowed = deframer.borrow();
    let record = borrowed.pop();
    
    let consumed = borrowed.done();
    // `record` can not be used beyond this statement
    deframer.discard(consumed);
    
    // do_stuff_with(record); //~ error
    

    Updated process_new_packets

    Now let's update process_new_packets to use the explicit discard API. We need to perform the discard operation in the exit points of process_new_packets:

    • In the error path of process_msg before we do an explicit return and
    • After we have processed all the available records, that is after the while let loop.
    impl Connection {
        fn process_new_packets(
            &mut self,
            /* .. */,
        ) -> /* .. */  {
            let mut incoming_tls = self.incoming_tls.borrow(); // <-
    
            let mut state =
                match mem::replace(&mut self.state, /* .. */) {
                    // ..
                };
            
            while let Some(msg) = incoming_tls.pop(/* .. */)? {
                match self.process_msg(msg, state) {
                    Ok(new) => state = new, 
                    Err(e) => {
                        self.state = Err(e.clone());
                        let used = incoming_tls.done();
                        self.incoming_tls.discard(used); // <-
                        return Err(e);
                    }
                }
            }
            
            self.state = Ok(state); 
            let used = incoming_tls.done();
            self.incoming_tls.discard(used); // <-
            // ..
        }
    }
    

    Persistent messages

    If you read through what process_msg does in the original code, you'll see that it decodes the payload of msg: PlainMessage using the Codec trait, which we covered in the second post. The decoded payload, a Message object, is used to advance the state machine using the State::handle method:

    // omitted associated type 
    trait State {
        fn handle(
            self: Box<Self>,
            message: Message,
            /* .. */
        ) -> Result<Box<dyn State>, /* .. */>;
        // ..
    }
    

    If you read through the many State::handle implementations, you'll see that some of them store parts of message into the State implementer they return. For example, the tls12::ExpectCertificate state stores the certificate chain included in message into the next state.

    impl State for ExpectCertificate {
        fn handle(
            mut self: Box<Self>,
            m: Message,
            // ..
        ) -> Result<Box<dyn State>, /* .. */> {
            // ..
            let server_cert_chain = require_handshake_msg_move!(
                m, // <-
                HandshakeType::Certificate,
                HandshakePayload::Certificate
            )?;
    
            if self.may_send_cert_status {
                Ok(Box::new(ExpectCertificateStatusOrServerKx {
                    server_cert_chain, // <- 
                    // ..
                }))
            } else {
                let server_cert = ServerCertDetails::new(
                    server_cert_chain, // <-
                    vec![],
                );
    
                Ok(Box::new(ExpectServerKx {
                    server_cert, // <-
                    // ..
                }))
            }
        }
    }
    

    In the second post, we converted most of these Message variants into borrowed types. The borrowed Message variants borrow from MessageDeframer.bytes so now we have a new borrow checker problem…

    In process_new_packets, the updated state may borrow from the bytes buffer, but in the previous section we said that all borrows of bytes must end before discard is invoked. Therefore, we need to "kill" the borrows inside the state machine before discard is called. The only way we have to end the borrow is to copy its contents into a fresh heap allocation.

    In other words, some fragments of the received TLS records need to be persisted in the state machine until sufficient TLS data is exchanged to complete the handshake process. In the common case, TLS implementations will pack as many TLS records as possible in a single TCP packet but even in that case the TLS handshake needs 1 or 2 "round trips" to complete, depending on the negotiated TLS version. In the worst-case scenario, each TCP packet will only contain one TLS record. This second scenario needs to persist more TLS records in the state machine between calls to process_new_packets.

    A flexible TLS implementation needs to handle both extremes as well as everything in between therefore we need to update the state machine to hold TLS records in both borrowed and owned variants.

    Cows everywhere

    Every Message fragment that may potentially be stored in a state needs to support switching from a borrowed variant into an owning variant. All those types will become Cow-like enums. Here's the Cow-like version of the PayloadU8 type we used before:

    enum PayloadU8<'a> {
        Borrowed(&'a [u8]),
        Owned(Vec<u8>),
    }
    

    And to go from the borrowed variant to the owning one, we'll add into_owned methods:

    impl<'a> PayloadU8<'a> {
        fn into_owned(self) -> PayloadU8<'static> {
            let vec = match self {
                Self::Borrowed(slice) => slice.to_vec(),
                Self::Owned(vec) => vec,
            };
            PayloadU8::Owned(vec)
        }
    }
    

    This method ends the 'a borrow and returns a new type with 'static lifetime, that is a type that does not contain any reference.

    Ultimately, we want this into_owned operation on the state itself and we'll use it just before we put the state back in the Connection object and perform the discard operation.

    impl Connection {
        fn process_new_packets(
            &mut self,
            /* .. */,
        ) -> /* .. */  {
            // ..
            while let Some(msg) = incoming_tls.pop(/* .. */)? {
                // ..
            }
            
            self.state = Ok(state.into_owned()); // <- here
            let used = incoming_tls.done();
            self.incoming_tls.discard(used);
            // ..
        }
    }
    

    Trait objects and lifetimes

    It's not obvious from the type signature of the field in the Connection struct but the trait object has an elided lifetime attached to it.

    struct Connection {
        incoming_tls: MessageDeframer,
        state: Result<Box<dyn State + 'static>, /* .. */>,
        //                            ^^^^^^^
        // ..
    }
    

    Box<dyn State + 'static> indicates that the State implementer behind the pointer does not contain any references (or that if it does contain references they are 'static references, but that's not the case here). If the State implementer contains references then that will surface in the trait object as a non-static trait object: Box<dyn State + 'some_lifetime>

    Let's illustrate that with an example:

    struct SomeState<'a> {
        payload: PayloadU8<'a>,
    }
    
    impl<'a> State for SomeState<'a> { /* .. */ }
    
    // type annotations in the body added for clarity
    fn create_non_static_trait_object<'a>(
       bytes: &'a [u8],
    ) -> Box<dyn State + 'a> {
       let payload: PayloadU8<'a> = PayloadU8::Borrowed(bytes);
       let state: SomeState<'a> = SomeState { payload };
       let boxed: Box<SomeState<'a>> = Box::new(state);
       // coercion
       let trait_object: Box<dyn State + 'a> = boxed;
       trait_object
    }
    

    If we want to create a 'static version of the trait object then we need to create a PayloadU8<'static> value first.

    // type annotations in the body added for clarity
    fn create_static_trait_object<'a>(
       bytes: &'a [u8],
    ) -> Box<dyn State + 'static> {
       let payload: PayloadU8<'static> =
           PayloadU8::Owned(bytes.to_vec()); // <- owned!
       let state: SomeState<'static> = SomeState { payload };
       let boxed: Box<SomeState<'static>> = Box::new(state);
       // coercion
       let trait_object: Box<dyn State + 'static> = boxed;
       trait_object
    }
    

    The State::into_owned method is going to encapsulate this conversion from a 'any version into a 'static version.

    Here's an example implementation using SomeState:

    trait State {
        fn into_owned(self: Box<Self>) -> Box<dyn<State + 'static>>;
        // ..
    }
    
    impl<'a> State for SomeState<'a> {
        fn into_owned(
            self: Box<Self>,
        ) -> Box<dyn<Static + 'static>> {
            let payload = self.payload.into_owned();
            Box::new(SomeState {
                payload,
            })
        }
        // ..
    }
    

    (I haven't figured out a no-unsafe way to remove the re-boxing from the SomeState::into_owned method. SomeState<'a> and SomeState<'static> should have the same memory layout so it should be possible to return the input heap allocation with an updated payload field at the machine code level. At the time of writing, the optimizer was unable to remove the seemingly unnecessary extra heap allocation)

    Problem: (lack of) lifetime constraints in handle

    To be able to store references in State implementers we need to modify the State trait. A first pass would be to tie the Message's lifetime to the return type.

    trait State {
        fn handle<'m>(
            self: Box<Self>,
            message: Message<'m>,
            //               ^^
            /* .. */
        ) -> Result<Box<dyn State + 'm>, /* .. */>;
        //                          ^^
        // ..
    }
    
    /// unveil this type because we'll use in an example
    enum Message<'m> {
        Certificate(PayloadU8<'m>),
        CertificateVerify(PayloadU8<'m>),
        // ..
    }
    

    This version works fine when transitioning from an owning state into a borrowed state…

    struct InitialState;
    
    struct HasCertificate<'a> {
        certificate: PayloadU8<'a>,
    }
    
    impl State for InitialState {
        fn handle<'m>(
            self: Box<Self>,
            message: Message<'m>,
            /* .. */
        ) -> Result<Box<dyn State + 'm>, /* .. */> {
             let new_state: Box<dyn<State>> = match message {
                 // advance state
                 Message::Certificate(certificate) =>
                     Box::new(HasCertificate { certificate }),
    
                 // no state change
                 _ => self,
             };
    
             Ok(new_state)
        }
    }
    

    … but breaks down when the initial state contains borrowed data.

    impl<'a> State for HasCertificate<'a> {
        fn handle<'m>(
            self: Box<Self>,
            message: Message<'m>,
            /* .. */
        ) -> Result<Box<dyn State + 'm>, /* .. */> {
             let new_state: Box<dyn State + 'm> = match message {
                 Message::CertificateVerify(dsd) => {
                     // advance state
                 },
    
                 // no state change
                 _ => self, //~ error
             };
    
             Ok(new_state)
        }
    }
    

    The compiler errors with "method was supposed to return data with lifetime 'm but it is returning data with lifetime 'a".

    That makes sense: The lifetimes 'a and 'm are unrelated in this impl block, so the compiler cannot coerce self, which has type Box<HasCertificate<'a>>, into Box<dyn State + 'm> or Box<HasCertificate<'m>> as there's no guarantee that 'a outlives 'm. Here's an example that shows why this doesn't work: If 'm happens to be 'static and 'a is a non-static lifetime then the coercion cannot occur.

    Solution: where Self: 'a

    How do we fix this error?

    I can tell you that adding a lifetime parameter to State, like we did with Codec, to connect the lifetime of HasCertificate to the lifetime of the Message parameter and to the return type of handle will NOT work. I went down that rabbit hole and it creates more problems: All trait objects need yet another lifetime (Box<dyn State<'a> + 'b>).

    There's a hint to the solution in the help message that accompanies the error message: "consider adding the following bound: 'a: 'm". That can be applied to the HasCertificate implementation because there's a 'a lifetime in scope, but there's no other lifetime parameter in the InitialState implementation.

    The answer is to put the lifetime constraint on the implementer itself, that is, on Self:

    trait State {
        fn handle<'m>(
            self: Box<Self>,
            message: Message<'m>,
            /* .. */
        ) -> Result<Box<dyn State + 'm>, /* .. */>
        where
            Self: 'm; // <-
        // ..
    }
    

    What the where Self: 'm constraint says is: To use this method, the Self type, e.g. HasCertificate must contain no references OR only contain references that live longer than 'm.

    Updating the two example State implementations shown above to use this new handle signature fixes the compiler in the HasCertificate case and does not break the InitialState implementation.

    Limitations

    where Self: 'm is perhaps a more strict constraint than what's absolutely needed. For instance, to call handle with a Message<'static>, the implementer value must contain no references or contain only 'static references. If the implementer value happens to contain references that will not be overwritten with the contents of message then handle cannot be called, even though it would be safe to do so. This over-constraint is not a problem in practice because only non-'static Messages are used in process_new_packets; after all the Messages borrow from message_deframer's buffer.

    The over-constraint issue also occurs with non-static messages (Message<'m>): It prevents the implementation from holding references with smaller lifetimes that won't be overwritten. That situation does not arise in process_new_packets because the initial state is always Box<dyn State + 'static> (see the [state field in the definition of struct Connection][state_in_struct_connection]{:target="_blank"}). The first call to handle in the while loop will change the type of state to Box<dyn State + 'm> where 'm is the lifetime of MessageDeframer's buffer. Once state becomes Box<dyn State + 'm>, it's always possible to call its handle with other Message<'m> values. [state_in_struct_connection]: https://ja-rustls-borrow-checker-p3.ferrous-systems.pages.dev/blog/rustls-borrow-checker-p3/#trait-objects-and-lifetimes

    Wrap-up

    That concludes the last blog post of this series.

    We have covered how to deal with borrow checker problems in different contexts:

    • in the implicit reborrow that occurs on struct field access;
    • in generic functions that contain where clauses;
    • and in trait methods that involve trait objects.

    Where possible, we have discussed available alternatives and we have included commentary on the limitations of each solution.

    I hope that this series has given you some new tools to deal with the borrow checker issues that may come your way.


    If the borrow checker has you in a bind or if you want to optimize the memory usage of your project by switching from heap allocated objects to references, our consulting service can help you with that! Get in touch through our contact form.