Constructing tests for legacy code

Every company has legacy code sitting somewhere that isn't well understood. Maybe it was written long ago and the original authors have moved on. Here's a high ROI approach to getting that code tested.

The code I have in mind was written in a time when priorities were different, the code in question falls in the "if it ain't broke" bucket. In general teams avoid modifying it -- engineers are wary of breaking some edge case, and PMs know it's a time sink. But there will a come a time when you'll need to modify it...

Why add unit tests?

Unit tests remove one of the pain points. A well tested function saves you from breaking something on the other side of the codebase. Tests set you up for a future re-factor for similar reasons. There's one problem: there are no unit tests, not by a long-shot. And no one has the time to add them!

A partial solution

Since this code is used all the time, you can implement regression tests.

  1. Make the code as functional as possible
  2. Collect example calls from production (record input and output)
  3. Turn those into unit tests.

This can be handy, but you'll run into one of two problems. Either you'll collect far too many test cases to manage or you'll collect too few to be useful.

If you collect too many you'll be bogged down in test maintenance. Whenever you update the code you'll need to update every test (add mocks, change expected output, etc.).

If you collect too few you'll find you haven't really tested the function. It's difficult to see what side of the line you're on. While the function you're testing may be a garden of forking paths, perhaps 90% of the calls are exactly the same. Sure you've collected 100 tests, but your test coverage is still 5%.

Selecting the best tests

We need to be clever about which tests we add to our codebase. Here's one way to do that:

  1. Every n calls profile the execution on a per-line basis (deterministically)
  2. Record the profile, input, and output in some separate log
  3. Enable your change in production until you've collected some log data
  4. Perform a greedy max set coverage on the file and line numbers found in the profile data.
  5. Convert the k picked entries into unit tests.

The trick is to record the code coverage of each test, and see how much untested code each additional test brings to the table. For simplicity I'm using a greedy algorithm, but the sky's the limit.

A toy example

Suppose we have a poorly understood function that categorizes a number. It takes in a number and returns a string. It doesn't matter if it's 'correct', it's been running for years.

def complicated_function(e):
    """
    This function categories numbers
    """
    if e < 0:
        return "negative"
    if e == 42:
        return "answer"
    if is_even(e):
        return "even"
    if is_prime(e):
        return "prime"
    return "unknown"

I've setup our profiling logic in a decorator that will collect data 50% of the time.

@consider_input_for_tests(0.5)
def complicated_function(e):

In this example we'll call the function with a range of values, but in production this function could be called with all sorts of values.

for e in range(-10, 50):
    complicated_function(e)

We then generate test cases using the recorded calls. Here we generate a maximum of 10 tests:

selected = select_tests(10)

This gives us:

assert complicated_function(15) == "unknown" # Covers 11 lines
assert complicated_function(23) == "prime" # Covers 2 lines
assert complicated_function(-9) == "negative" # Covers 1 lines
assert complicated_function(0) == "even" # Covers 1 lines
assert complicated_function(1) == "unknown" # Covers 1 lines
assert complicated_function(42) == "answer" # Covers 1 lines

Notice we've only collected 6 test, that's because the other invocations didn't increase our test coverage. This particular set of tests gets us a 95% coverage on the toy function. Not bad 😎.

Productionizing

There are a couple questions to ask yourself before using this in a production environment:

  • What probability should you sample at?
  • Can your input and output can be serlialized? If not, add a layer that can be.
  • Is your function executed across multiple servers? If so, you may need to use a database instead of a file.
  • Are you worried about repeat calls? If so, convert observations to a set.
  • Do you need to mock out data for your test? If so, extend the observation class to include it.
  • Will measuring coverage affect performance too much? You may need to measure coverage in a separate thread.
  • Do you anticipate the line numbers changing? You may need to include the current commit hash with each observation.

Code

Selecting Observations for Tests

def select_tests(num_to_pick):
    """
    Selects observations that cover the most lines.
    """
    observations = FileObservationDAL("observations.jl").read()

    covered = set()
    picked = []
    num_uncovered = lambda o: len(o.elements() - covered)

    while len(picked) < num_to_pick:
        if len(observations) == 0:
            break
        pick = max(observations, key=num_uncovered)
        if num_uncovered(pick) == 0:
            print("No more progress to be made")
            break
        print(f"Lines covered: {num_uncovered(pick)}")
        observations.remove(pick)
        picked.append(pick)
        covered |= pick.elements()

    return picked

In the above example elements() returns a set of "file:line" strings, but in theory this could be any hashable element. coverage.py supports a concept of arcs, you can easily extend this approach to find untested arcs.

Collecting test data

Here's a decorator that measures the coverage of a particular function call and writes the observation to file.

def consider_input_for_tests(probability, dal=FileObservationDAL("observations.jl")):
    """
    A decorator for profiling and recording the execution of
    the decorated function. The decorator will record the execution
    with the given probability. A probability of 1 means every call is
    recorded. A probability of 0 means no call is recorded.
    """

    def decorator(fn):
        def decorated_fn(*args, **kwargs):
            x = random.random()
            do_profile = x < probability
            if do_profile:
                cov = coverage.Coverage()
                cov.start()
                res = fn(*args, **kwargs)
                cov.stop()
                cov_data = cov.get_data()
                obs = Observation.from_coverage_data(args, kwargs, res, cov_data)
                dal.append(obs)
                return res
            else:
                return fn(*args, **kwargs)

        return decorated_fn

    return decorator

Observation class

@attr.s
class Observation:
    args = attr.ib()
    kwargs = attr.ib()
    output = attr.ib()
    file_to_lines = attr.ib()

    @classmethod
    def from_coverage_data(cls, args, kwargs, output, coverage_data):
        lines_by_file = {
                filename: coverage_data.lines(filename)
                for filename in coverage_data.measured_files()
                }
        return cls(args, kwargs, output, lines_by_file)

    @classmethod
    def from_json(cls, json):
        return cls(json["args"], json["kwargs"], json["output"], json["coverage"])

    def serialize(self):
        """
        Returns a json dict that represents the test case
        """
        return {
            "args": self.args,
            "kwargs": self.kwargs,
            "output": self.output,
            "coverage": self.file_to_lines
        }

    def elements(self):
        return {
            f"{filename}:{line}"
            for filename, lines in self.file_to_lines.items()
            for line in lines
        }


class ObservationDAL:
    """
    DAL = Data Access Layer
    """

    def append(self, observation):
        """
        Writes the given observation somewhere.
        """
        raise NotImplemented("You need to implement this method.")

    def read(self):
        """
        Returns a list of all observations
        """
        raise NotImplemented("You need to implement this method.")

class FileObservationDAL(ObservationDAL):
    """ Appends each observation to a file as a json object """

    def __init__(self, filename):
        self.filename = filename

    def append(self, observation):
        with open(self.filename, "a") as fd:
            fd.write(json.dumps(observation.serialize()))
            fd.write("\n")

    def read(self):
        with open(self.filename, "r") as fd:
            return [
                    Observation.from_json(json.loads(line))
                    for line in fd.readlines()
                    ]

social