Blog coding article

Dancing Links In Rust

Aleksey
Article

Dancing Links In Rust

Published on 23 min read

    Hi!

    This will be a rather technical article — we will implement Dancing Links algorithm in Rust. I became aware of this beautiful problem when working with Calendly — they are using Dancing Links (and Rust!) to find the best time slot for an event to satisfy participants' preferences.

    In this post, we will present a simplified implementation of DLX, which is an excellent material to learn some advanced Rust. Specifically, we’ll learn about arenas, linked lists, ECS, and data structure design. This will be hands-on learning — expect to read quite a bit of dense Rust!

    If you find that Learn Rust With Entirely Too Many Linked Lists contains too few linked lists, this post is for you!

    The code for the blog post is available at https://github.com/matklad/dlx.

    If you want to challenge yourself and also get the most out of this post, write your own implementation of DLX and compare the notes afterwards!

    Problem Statement

    Before diving into Rust-specific implementation details, let’s discuss the overall problem.

    The input to the problem is a matrix with zeros and ones:

    . 0 0 1 0 1 1 0
      1 0 0 1 0 0 1
      0 1 1 0 0 1 0
    . 1 0 0 1 0 0 0
    . 0 1 0 0 0 0 1
      0 0 0 1 1 0 1

    The task is to find a subset of rows such that each column in the selected subset has exactly one 1. In the example above, rows 0, 3, and 4 give such a subset:

      0 0 1 0 1 1 0 \
      1 0 0 1 0 0 0  +> 1 1 1 1 1 1 1
      0 1 0 0 0 0 1 /

    The solution to the problem is not really exciting algorithmically — a boring trial and error. Specifically, we:

    Choose a column to eliminate
    For each row that covers this column:
      Try removing the row
        Remove the columns it covers
        Remove the rows that are covered by the removed columns
      Recursively solve the residual problem
        If we find the solution, then we are done
        Otherwise try the next row

    For the example problem, if we try removing the first row,

    . 0 0 1 0 1 1 0
      1 0 0 1 0 0 1
      0 1 1 0 0 1 0
      1 0 0 1 0 0 0
      0 1 0 0 0 0 1
      0 0 0 1 1 0 1

    we need to remove the following rows and columns as well,

      x x x x x x x
      1 0 x 1 x x 1
      - - x - x x -
      1 0 x 1 x x 0
      0 1 x 0 x x 1
      - - x - x x -

    leaving the following residual matrix:

      1 0 1 1
    . 1 0 1 0
    . 0 1 0 1

    That’s all there is to the algorithm!

    Efficient Implementation

    The trick to handling it efficiently is fast rollback: restoring deleted rows and columns after realizing that the selected row doesn’t lead to a solution.

    The naïve solution is to represent the matrix as Vec<Vec<bool>> and just clone it before removal of rows and columns. This involves a lot of moving vectors around, and is expected to be slow.

    Instead, the dancing links approach suggests using sparse, doubly-linked-lists based representation for the matrix. Specifically, each 1 in a matrix will be a node, which has links to the neighbors in the row and in the column.

    So a matrix like this

    0 0 1 0 0
    
    1 0 1 0 1
    
    1 0 1 0 1

    will look like this (each - is a bidirectional link, a pair of pointers):

        1
        |
    1---1---1
    |   |   |
    1---1---1

    In other words, we arrange all ones in a two-dimensional sparse grid, linked across both dimensions (yes, this is the tricky bit). Note how this representation already allows for faster removal than nested vectors: to remove a row, we traverse all the row’s ones, and unlink them from vertical lists. The work we need to do is proportional to the number of ones in the row and is independent of the overall size of the matrix.

    Finally, there is an additional trick. When we unlink the node b from a doubly-linked-list,

    a <-> b <-> c

    We will only fix a and c pointers. We will leave pointers from b to a and c intact:

       b
     /   \
    a <-> c

    By not zeroing them out, we gain an ability to cheaply reverse the operation — knowing only b, we can get to a and c and make them to point back at b! That is, in this representation we don’t need to .clone the matrix before trial deletion!

    Finally, there’s one more algorithmic/implementation twist. We have certain liberty when choosing the order in which to try columns. In general, it makes sense to start with columns which have fewer ones, to keep the recursion tree smaller. But that means that, when covering rows and columns, we also need to maintain column sizes. To track those, we add an extra header row which tracks sizes.

    Rust Implementation

    Let’s get to the meat of the post — how to express all this in Rust. This section will have a bit of a literal programming vibe to it.

    So far, it sounds like we need some kind of Cell object:

    struct Cell {
      right: &Cell,
      left: &Cell,
    
      up: &Cell,
      down: &Cell,
    
      column: &Cell,
    
      // `Some` for header cells.
      size: Option<usize>,
    }

    Seasoned Rusteceans will see that this won’t end up good. Bidirectional linked lists require that all Cells are aliased, which would make mutating links harder. Specifically, writing a function like this will be a problem:

    fn link(a: &mut Cell, b: &mut Cell) {
      a.right = &*b;
      b.left = &*a;
    }

    We need exclusive references to Cells to change right and left fields, but there might be other cells pointing to (aliasing) a and b.

    Instead, we can use a setup like this, with using indices instead of references:

    use std::ops;
    
    struct Cell(usize);
    
    struct CellData {
      right: Cell,
      left: Cell,
    
      up: Cell,
      down: Cell,
    
      column: Cell,
    
      // `Some` for header cells.
      size: Option<usize>,
    }
    
    impl ops::Index<Cell> for Vec<CellData> {
      type Output = CellData;
      fn index(&self, index: Cell) -> &CellData {
        &self[index.0]
      }
    }
    
    impl ops::IndexMut<Cell> for Vec<CellData> {
      type Output = CellData;
      fn index_mut(&mut self, index: Cell) -> &mut CellData {
        &mut self[index.0]
      }
    }

    With this setup, the above function becomes possible:

    fn link(cells: &mut Vec<CellData>, a: Cell, b: Cell) {
      cells[a].right = b;
      cells[b].left = a;
    }

    Here are some nice little details about the setup above I like and typically use:

    • Cell / CellData naming. An alternative, CellIdx / Cell turns out to be more noisy down the line.

    • Semi-qualified impl ops::Index for implementing ops traits. Usually, you want to implement several traits from that module, so importing a whole module is tidier and less typing. An additional nice detail is that, because we don’t have Index trait in scope, we don’t unlock .index method call. This emphasizes that we implement, rather than use a trait.

    • impl ops::Index<Cell> for Vec<CellData> works. At the first sight, this should fail the coherence check. Both the Index trait and the Vec type are from stdlib, how come we can write this impl in our crate? The key is that Index is generic — it is a sort of a template for making traits. Only when we apply it to a type parameter, we get a real trait, Index<Cell>. And that trait is “ours”, because the Cell type is local.

    Now, let’s zoom in on that size: Option<usize> field. Remember, for cells that represent column headers (and only for them), we store column sizes. In an OOP language, we’d use inheritance and downcasting to add this field only to headers. In C, we’d do the same, but manually:

    struct cell {
      cell* right; cell* left;
      cell* up; cell* down;
      cell* column;
    }
    
    struct header_cell {
      cell base;
      size_t size;
    }

    In Rust, as we are allocating everything in a Vec, we need to carry this semi-useless Option. There’s a way to get rid of it though while preserving Vec-ness — the Entity Component System pattern:

    struct Matrix {
      data: Vec<CellData>,
      size: Vec<usize>
    }
    
    impl ops::Index<Cell> for Vec<usize> {
      type Output = usize;
      fn index(&self, index: Cell) -> &usize {
        &self[index.0]
      }
    }

    We will make sure that header cells are assigned small indexes, and will store column sizes in a separate array. This is the core of the ECS — it is a flexible way to extend entities with additional data by using a side table. In some sense, this isn’t better — we are just trading options for index checks. But it is more flexible — now, data and size vectors, for example, could be defined in separate modules, instead of being fields of the single struct.

    And, while we are at it, let’s also flatten other fields in a similar way:

    struct Matrix {
      /// Links along the horizontal dimension.
      x: LinkedList,
      /// Links along the vertical dimension.
      y: LinkedList,
      /// Pointer to column headers.
      c: Vec<Cell>,
      /// For column headers, the size of the column.
      size: Vec<u32>,
    }

    After this, we can generically express linked lists across both dimensions:

    struct Link {
      prev: Cell,
      next: Cell,
    }
    
    #[derive(Default, Debug)]
    struct LinkedList {
      data: Vec<Link>,
    }
    
    impl ops::Index<Cell> for LinkedList {
      type Output = Link;
      fn index(&self, index: Cell) -> &Link {
        &self.data[index.0]
      }
    }
    
    impl ops::IndexMut<Cell> for LinkedList {
      fn index_mut(&mut self, index: Cell) -> &mut Link {
        &mut self.data[index.0]
      }
    }
    
    impl LinkedList {
      fn with_capacity(cap: usize) -> LinkedList {
        LinkedList { data: Vec::with_capacity(cap) }
      }
      fn alloc(&mut self) -> Cell {
        let cell = Cell(self.data.len());
        self.data.push(Link { prev: cell, next: cell });
        cell
      }
      /// Inserts `b` into `a <-> c` to get `a <-> b <-> c`
      fn insert(&mut self, a: Cell, b: Cell) {
        let c = self[a].next;
    
        self[b].prev = a;
        self[b].next = c;
    
        self[a].next = b;
        self[c].prev = b;
      }
      /// Removes `b` from `a <-> b <-> c` to get `a <-> c`
      fn remove(&mut self, b: Cell) {
        let a = self[b].prev;
        let c = self[b].next;
    
        self[a].next = self[b].next;
        self[c].prev = self[b].prev;
      }
      /// Restores previously removed `b` to get `a <-> b <-> c`
      fn restore(&mut self, b: Cell) {
        let a = self[b].prev;
        let c = self[b].next;
        self[a].next = b;
        self[c].prev = b;
      }
    }

    Things to note:

    • next / prev is a nice symmetric pair of names.

    • alloc is common name for a factory function in similar arena setups.

    • In alloc, each cell stats with pointing to itself. This is a common pattern when implementing linked lists. For circular linked lists, there are no edge cases for first/last nodes.

    • a, b, c naming gives intuitive sense about relative positions of the nodes.

    • The remove method does not change the next and prev fields of b. This is the main idea of dancing links, which allows to implement restore.

    Now we can also mount a generic iteration scheme:

    impl LinkedList {
      fn cursor(&self, head: Cell) -> Cursor {
        Cursor { head, curr: head }
      }
    }
    
    struct Cursor {
      head: Cell,
      curr: Cell,
    }
    
    impl Cursor {
      fn next(&mut self, list: &LinkedList) -> Option<Cell> {
        self.curr = list[self.curr].next;
        if self.curr == self.head {
          return None;
        }
        Some(self.curr)
      }
      fn prev(&mut self, list: &LinkedList) -> Option<Cell> {
        self.curr = list[self.curr].prev;
        if self.curr == self.head {
          return None;
        }
        Some(self.curr)
      }
    }

    Notice some duplication between next and prev methods. Usually, it can be removed by changing definition of Link to

    type Link = [Cell; 2];

    This representation allows writing algorithms generic over direction (0 or 1, d ^ 1 flips direction), but it doesn’t buy us much in this case.

    Now, let’s write routines to create a matrix with n columns. We will populate the matrix with rows separately, here we only add the header row.

    const H: Cell = Cell(0);
    
    impl Matrix {
      fn new(n_cols: usize) -> Matrix {
        let mut res = Matrix {
          size: Vec::with_capacity(n_cols + 1),
          c: Vec::with_capacity(n_cols + 1),
          x: LinkedList::with_capacity(n_cols + 1),
          y: LinkedList::with_capacity(n_cols + 1),
        };
        assert_eq!(res.alloc_column(), H);
        for _ in 0..n_cols {
          res.add_column();
        }
        res
      }
      fn add_column(&mut self) {
        let new_col = self.alloc_column();
        self.x.insert(self.x[H].prev, new_col);
      }
      fn alloc_column(&mut self) -> Cell {
        let cell = self.alloc(H);
        self.c[cell] = cell;
        self.size.push(0);
        cell
      }
      fn alloc(&mut self, c: Cell) -> Cell {
        self.c.push(c);
        let cell = self.x.alloc();
        assert_eq!(self.y.alloc(), cell);
        cell
      }
    }

    Things to note:

    • We set aside the special cell, H, to serve as a header cell of a header row.

    • n_ is a nice naming convention to use to mean “number of things”.

    • res is a nice conventional name for the result variable.

    • We use asserts to verify that parallel indices in two lists match.

    Now let’s add a function to populate the matrix with rows:

    impl Matrix {
      fn add_row(&mut self, row: &[bool]) {
        assert_eq!(row.len(), self.size.len() - 1);
        let mut c = H;
        let mut prev = None;
        for &is_filled in row {
          c = self.x[c].next;
          if is_filled {
            self.size[c] += 1;
            let new_cell = self.alloc(c);
            self.y.insert(self.y[c].prev, new_cell);
            if let Some(prev) = prev {
              self.x.insert(prev, new_cell);
            }
            prev = Some(new_cell);
          }
        }
      }
    }

    Now that we have all the code to construct the two dimensional linked lists, it’s useful to visualize the state of the matrix. In general, when implementing complex algorithms, putting time into visualization code pays-off well.

    impl fmt::Display for Matrix {
      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "s: ")?;
        for s in &self.size {
          write!(f, "{:^5}", s)?;
        }
        writeln!(f)?;
    
        write!(f, "c: ")?;
        for &Cell(c) in &self.c {
          write!(f, "{:^5}", c.saturating_sub(1))?;
        }
        writeln!(f)?;
    
        write!(f, "x: ")?;
        for link in &self.x.data {
          write!(f, " {:>1}|{:<1} ", link.prev.0, link.next.0)?
        }
        writeln!(f)?;
    
        write!(f, "y: ")?;
        for link in &self.y.data {
          write!(f, " {:>1}|{:<1} ", link.prev.0, link.next.0)?
        }
        writeln!(f)?;
    
        write!(f, "i: ")?;
        for i in 0..self.x.data.len() {
          write!(f, "{:^5}", i)?;
        }
        writeln!(f)?;
    
        Ok(())
      }
    }

    For matrix like this:

    let mut m = Matrix::new(3);
    m.add_row(&[true, false, true]);

    we’ll get the following representation:

    s:   0    1    0    1
    c:   0    0    1    2    0    2
    x:  3|1  0|2  1|3  2|0  5|5  4|4
    y:  0|0  4|4  2|2  5|5  1|1  3|3
    i:   0    1    2    3    4    5

    i = 0 is the H, matrix root. i = 1, 2, 3 are the column headers with sizes 1, 0, 1. i = 4, 5 are the two filled cells in the matrix.

    At this stage, we’ve fully designed and implemented the data structure underpinning the algorithm. Let’s get to solving the problem.

    We start with code to cover and uncover the columns. The cover operation takes a cell for a column header, and unlinks this column and all of it’s rows from the matrix. The uncover operation reverses this, taking advantage of the fact that we don’t zero out links when removing the node from the linked list.

    impl Matrix {
      fn cover(&mut self, c: Cell) {
        self.x.remove(c);
        let mut i = self.y.cursor(c);
        while let Some(i) = i.next(&self.y) {
          let mut j = self.x.cursor(i);
          while let Some(j) = j.next(&self.x) {
            self.y.remove(j);
            self.size[self.c[j]] -= 1;
          }
        }
      }
      fn uncover(&mut self, c: Cell) {
        let mut i = self.y.cursor(c);
        while let Some(i) = i.prev(&self.y) {
          let mut j = self.x.cursor(i);
          while let Some(j) = j.prev(&self.x) {
            self.size[self.c[j]] += 1;
            self.y.restore(j);
          }
        }
        self.x.restore(c);
      }
    }

    Finally, let’s implement the recursive trial and error algorithm which counts the number of solutions (see the full source for code to recover the solution):

    pub fn solve(mut m: Matrix) -> usize {
      let mut n_answers = 0;
      go(&mut m, &mut n_answers);
      n_answers
    }
    
    fn go(m: &mut Matrix, n_answers: &mut usize) {
      let c = {
        let mut i = m.x.cursor(H);
        let mut c = match i.next(&m.x) {
          Some(it) => it,
          None => {
            *n_answers += 1;
            return;
          }
        };
        while let Some(next_c) = i.next(&m.x) {
          if m.size[next_c] < m.size[c] {
            c = next_c;
          }
        }
        c
      };
    
      m.cover(c);
      let mut r = m.y.cursor(c);
      while let Some(r) = r.next(&m.y) {
        let mut j = m.x.cursor(r);
        while let Some(j) = j.next(&m.x) {
          m.cover(m.c[j]);
        }
        go(m, n_answers);
        let mut j = m.x.cursor(r);
        while let Some(j) = j.prev(&m.x) {
          m.uncover(m.c[j]);
        }
      }
      m.uncover(c);
    }

    Things to note:

    • As usual with recursive algorithms, it’s better to provide a public non-recursive API which calls into a private helper.

    • A convenient name for the recursive helper is just go.

    • The helper accepts context arguments by reference. If there are to many context arguments, it makes sense to introduce a Ctx struct, and make go a method of that.

    • Inside go, we first try to find the smallest colum. This is also the place where we handle the base case of empty matrix and increment n_answers.

    • After we picked the column, we do trial deletion of rows and recur. After recursion, we take care to restore the old state of the matrix, repeated the steps in the reverse order.

    So it seems we are done? How do we test our implementation to make sure that it is correct? We surely can start with throwing the example from the paper:

    #[test]
    fn sample_problem() {
      let f = false;
      let t = true;
    
      let mut m = Matrix::new(7);
      m.add_row(&[f, f, t, f, t, t, f]);
      m.add_row(&[t, f, f, t, f, f, t]);
      m.add_row(&[f, t, t, f, f, t, f]);
      m.add_row(&[t, f, f, t, f, f, f]);
      m.add_row(&[f, t, f, f, f, f, t]);
      m.add_row(&[f, f, f, t, t, f, t]);
    
      let solutions = solve(m);
      assert_eq!(solutions, 1);
    }

    Given the complexity of the algorithm though, a single test is not really encouraging. Indeed, we can do much better! Let’s just check that the algorithm correctly handles all 4x4 matrices. There are only 65536 of them, so this should be relatively fast:

    #[test]
    fn exhaustive_test() {
      'matrix: for bits in 0..=0b1111_1111_1111_1111 {
        let mut rows = [0u32; 4];
        for (i, row) in rows.iter_mut().enumerate() {
          *row = (bits >> (i * 4)) & 0b1111;
          if *row == 0 {
            continue 'matrix;
          }
        }
    
        let brute_force = {
          let mut n_solutions = 0;
          for mask in 0..=0b1111 {
            let mut or = 0;
            let mut n_ones = 0;
            for (i, &row) in rows.iter().enumerate() {
              if mask & (1 << i) != 0 {
                or |= row;
                n_ones += row.count_ones()
              }
            }
            if or == 0b1111 && n_ones == 4 {
              n_solutions += 1;
            }
          }
          n_solutions
        };
    
        let dlx = {
          let mut m = Matrix::new(4);
          for row_bits in rows.iter() {
            let mut row = [false; 4];
            for i in 0..4 {
              row[i] = row_bits & (1 << i) != 0;
            }
            m.add_row(&row);
          }
          solve(m)
        };
        assert_eq!(brute_force, dlx)
      }
    }

    Things to note:

    • We heavily rely on bit vector representation of vectors and matrices. For example, the for bits in 0..=0b1111_1111_1111_1111 enumerates all of the matrices. This is a nice hack to enumerate "all subsets" in a few lines of code, without writing a recursion by hand.

    • We need to skip matrices with empty rows, as the dlx algorithm doesn’t see them at all.

    • For brute force solution, we use bit operations again — we check that bit or (|) of rows gives all ones. We also use popcount to check that each column is covered only once.

    Running this test gives us much more confidence in the correctness of the implementation! Full source, including the code to restore the answer, can be found at https://github.com/matklad/dlx.

    It’s also interesting to reflect on the unusual effectiveness of linked list for this problem. Remember that on the modern hardware, a Vec beats LinkedList for the overwhelming majority of the problems. While linked lists have a better theoretical complexity for the insertion and removal from the middle, most benchmarks are dominated by the traversal time to get to this middle. Linked list only wins if you have some additional structure in place to quickly get to the interesting element. And this is exactly what happens in this case — because we have a two-dimensional web of interleaving lists, we can do a bunch of removals from the middle in one dimension, while doing traversal in the other dimension.