Unit testing gRPC streams in Go

Unit testing gRPC streams in Go
Photo by Lee Campbell / Unsplash

One of the powerful features of gRPC is their ability to stream data from the server, client or bidirectional. This opens up a lot of use cases, but also more asynchronous communication which allows you to speed up your application.

However, they add another complexity when unit testing, as you need to have a good understanding on what happens behind the scenes in your client or server.

In Golang, the official grpc library uses channels to handle the message streams. In a unit test we do not have the actual server and/or client available, so we need to mock and simulate the other interaction from the unit test.

Mocking the stream

To prevent us from needing the use the real stream, we will mock the stream using mockery. Say we have the following gRPC handler that allows us to stream Todo items to the client, the grpc types are defined in the generated package rpc.

type TodoServer struct {
    rpc.TodoServiceServer // our todo server implements the grpc server methods

    TodoService TodoService // we also dependency inject our TodoService that handles retrieving the todo's from our storage.
}

func (s *TodoServer) ListTodos(stream *rpc.TodoServiceServer_ListTodosServer) error {
   // we will handle sending messages here
}

Let's take a closer look at the signature of ListTodos. It accepts the stream argument, which is the interface TodoServiceServer_ListTodosServer. We want to have mockery create a mock class specifically for this interface. We can use go:generate to have mockery auto generate the mock for us. However, we can not add the go generate command without our generated rpc package from overriding this.

To solve this, let's create a copy of the interface in our server class, by making it inherit the package interface.

type ListTodosStream interface {
    rpc.TodoServiceServer_ListTodosServer
}

type TodoServer struct {
    rpc.TodoServiceServer // our todo server implements the grpc server methods
}

...

Now let's add our mockery generate command:

//go:generate mockery --name ListTodosStream
type ListTodosStream interface {
    rpc.TodoServiceServer_ListTodosServer
}

...

Awesome! We now have the mock available in the mocks package next to our server go file Let's finish our code by implementing the ListTodos method:

func (s *TodoServer) ListTodos(stream *rpc.TodoServiceServer_ListTodosServer) error {

    ctx := stream.Context()

    todos, err := s.TodoService.ListTodos(ctx)
    if err != nil {
        return err
    }

    for _, todo := range todos {
        todoPb := &rpc.Todo{
            // map our application Todo to our protobuf model Todo
        })
        if err := stream.Send(todoPb); err != nil {
            return err
        }
    }

    return nil
}

Done! Or actually not done, because we have not written our test yet for this function.

Writing the tests

Let's open up our server_test.go file.

func TestTodoServer_ListTodos(t *testing.T) {
    t.Run("Retrieve Todos", func (t *testing.T) {
        // setup our TodoServer
        todoService := mocks.TodoService{}
         
        todoServer := TodoServer{
            TodoService: todoService
        }

        stream := mocks.ListTodosStream{}

        err := todoServer.ListTodos(stream)
        assert.NoError(err)
    }
}

Use mockery we can have our mocks expect certain parameters and return something too, or even run the function to give it a dynamic result using RunAndReturn.

First we use mockery to have our service return two Todo items. We need to use our mock twice now, our stream needs to mock the context and our TodoService will need to return 2 Todo items.

func TestTodoServer_ListTodos(t *testing.T) {
    t.Run("Retrieve Todos", func (t *testing.T) {
        // setup our TodoServer
        todoService := mocks.TodoService{}

        todos := []TodoItem{
            {
                ID: 1,
            },
            {
                ID: 2
            },
        }
        
        todoService.EXPECT().ListTodos(context.Background()).Return(todos)
        
        stream := mocks.ListTodosStream{}
        stream.EXPECT().Context().Return(context.Background())
         
        todoServer := TodoServer{
            TodoService: todoService
        }

        // Run the function we want to test
        err := todoServer.ListTodos(stream)
        assert.NoError(err)
    }
}

We now have our base test setup. Let's get back to how gRPC works and how we proceed in our test.

In our test, we are testing the server. This means that our test is acting as the client. We will receive messages from the server, which are submitted using the Send() method on the Server stream.

Because we will listen for messages, we should spawn a separate go routine to both run the server function in 1 routine, and act as the client in the other routine.

We should also setup a channel where messages can be added everything the server calls the Send() method on the stream. Once the test has started, we can start a loop that listens for messages added to the channel until we have reached the final message.

Let's update our test first to create a separate routine to run our server function:

func TestTodoServer_ListTodos(t *testing.T) {
    t.Run("Retrieve Todos", func (t *testing.T) {
        // setup our TodoServer
        todoService := mocks.TodoService{}

        todos := []TodoItem{
            {
                ID: 1,
            },
            {
                ID: 2
            },
        }
        
        todoService.EXPECT().ListTodos(context.Background()).Return(todos)
        
        stream := mocks.ListTodosStream{}
        stream.EXPECT().Context().Return(context.Background())
         
        todoServer := TodoServer{
            TodoService: todoService
        }

        // Run the function we want to test 
        go func() {
            err := todoServer.ListTodos(stream)
            assert.NoError(err)
        }()
    }
}

Using go func() {}() we create an anonymous new Go routine to handle the server code, allowing us to read the channel asynchronously. Let's create that channel now:

func TestTodoServer_ListTodos(t *testing.T) {
    t.Run("Retrieve Todos", func (t *testing.T) {
        // setup our TodoServer
        todoService := mocks.TodoService{}

        streamData := make(chan *rpc.Todo, 2)

        todos := []TodoItem{
            {
                ID: 1,
            },
            {
                ID: 2
            },
        }
        
        todoService.EXPECT().ListTodos(context.Background()).Return(todos)
        
        stream := mocks.ListTodosStream{}
        stream.EXPECT().Context().Return(context.Background())
         
        todoServer := TodoServer{
            TodoService: todoService
        }

        // Run the function we want to test 
        go func() {
            err := todoServer.ListTodos(stream)
            assert.NoError(err)

            // after the server is done, close the channel
            close(streamData)
        }()

        // read the data from the channel until there is no new todo item
        var todosPb []rpc.Todo
        for {
            todo, ok := <-streamData
            if !ok {
                break
            }
            todosPb = append(todosPb, todo)
		}

        // assert we received 2 todo's
        assert.Len(t, todosPb, 2)
    }
}

We are almost there! Running the test now will throw an error as the Send method on our streams Mock is being called, but we have not defined the behaviour for this method.

Now we do not only return a value (error is nil), but we also want to run the method and use it to add whatever our server sends to the channel we created on our test "client".

Let's finish up our test case by having it write to the channel:

func TestTodoServer_ListTodos(t *testing.T) {
    t.Run("Retrieve Todos", func (t *testing.T) {
        // setup our TodoServer
        todoService := mocks.TodoService{}

        streamData := make(chan rpc.Todo, 2)

        todos := []TodoItem{
            {
                ID: 1,
            },
            {
                ID: 2
            },
        }
        
        todoService.EXPECT().ListTodos(context.Background()).Return(todos)
        
        stream := mocks.ListTodosStream{}
        stream.EXPECT().Context().Return(context.Background())
         
        todoServer := TodoServer{
            TodoService: todoService
        }

        stream.EXPECT().Send(mock.AnythingOfType("rpc.Todo")).RunAndReturn(func(todo rpc.Todo) error {
			streamData <- todo
			return nil
		})

        // Run the function we want to test 
        go func() {
            err := todoServer.ListTodos(stream)
            assert.NoError(err)

            // after the server is done, close the channel
            close(streamData)
        }()

        // read the data from the channel until there is no new todo item
        var todosPb []rpc.Todo
        for {
            todo, ok := <-streamData
            if !ok {
                break
            }
            todosPb = append(todosPb, todo)
		}

        // assert we received 2 todo's
        assert.Len(t, todosPb, 2)
    }
}

That's it! We wrote our first test for our streaming function on our gRPC server. Using this example, you can also reverse it to write a test for a gRPC client by simulating the server from the stream. Or even combine both for bidirectional streams.

If you are not familiar yet with gRPC streams I really encourage you to start experimenting with them, as they open up a lot of new possibilities for unique use cases and better performance optimization.