Implementing Middlewares with Go

Why use middlewares?

Middlewares are helpful when you have duplicated code between multiple endpoints. For example, if you want to implement request logging on a server. Instead of adding this code on each endpoint, you could create a LoggingMiddleware responsible for that. Middlewares are perfect for maintaining the Single Responsibility principle when creating an API. In addition to code reuse and maintaining the Single Responsibility principle, middlewares also provide flexibility in extending and modifying the request/response pipeline. For example, they allow you to inject custom logic such as authentication, data validation, error handling, and caching.

Design

The Router has middlewares that are common to the whole server. But I need some endpoints to have different functionalities; the/login endpoint shouldn't be authenticated, for example. And at the same time, I want to avoid repeating the same middleware across different endpoints. To have flexibility and ease of use, I used the concept of Groups. The Router has a list of Groups, and a Group has a list of Routes. Here is an image that illustrates this concept. middleware-architecture.png

Implementation

Route

Let's start with the lowest abstraction of our server, the Route.

type Route struct {
	pattern string
	handler http.HandlerFunc
}

func NewRoute(pattern string, handler http.HandlerFunc) Route {
	return Route{pattern: pattern, handler: handler}
}

Group

Now let's create our groups and the Middleware type:

type Middleware func(http.HandlerFunc) http.HandlerFunc

type Group struct {
	Name        string
	middlewares []Middleware
	routes      []*Route
}

func (g *Group) Use(middleware Middleware) {
	g.middlewares = append(g.middlewares, middleware)
}

func (g *Group) AddRoute(route Route) {
	g.routes = append(g.routes, &route)
}

func (g *Group) chainMiddlewares() {
	// Chain all group middlewares on each group route
	for _, route := range g.routes {
		for i := len(g.middlewares) - 1; i >= 0; i-- {
			route.handler = g.middlewares[i](route.handler)
		}
	}
}

As its name says, the chainMiddleware method chains each route with all group middlewares. If the route handler is de function foo and the Group middleware is the function bar. The chained handler will become bar(foo). Note that we are iterating on the middleware list backward. This ensures that the middlewares are executed in the same order they were added, which can be important if the order of execution matters for the desired functionality or behavior of the API.

Router

And finally our server, the Router:

type Router struct {
	mux         *http.ServeMux
	middlewares []Middleware
	groups      []*Group
}

func NewRouter() *Router {
	return &Router{
		mux: http.NewServeMux(),
	}
}

func (r *Router) Use(middleware Middleware) {
	r.middlewares = append(r.middlewares, middleware)
}

func (r *Router) AddGroup(group *Group) {
	r.groups = append(r.groups, group)
}

func (r *Router) registerHandlers() {
	for _, group := range r.groups {
		log.Println("Group:", group.Name)
		group.chainMiddlewares()
		for _, route := range group.routes {
			// Add common middlewares
			for i := len(r.middlewares) - 1; i >= 0; i-- {
				route.handler = r.middlewares[i](route.handler)
			}
			// Register the handler
			log.Println(route.pattern)
			r.mux.HandleFunc(route.pattern, route.handler)
		}
	}
}

func (r *Router) Start(addr string) error {
	r.registerHandlers()
	log.Println("Listening...")
	return http.ListenAndServe(addr, r.mux)
}

Similarly, the registerHandlers method chains the chained handlers with the Router's middlewares and register the product on the http.ServeMux instance. Phew, now we are ready to use this code and implement some middlewares.

Usage

I've created a middlewares package. This package contains the logic behind each middleware. Let's start with the Logging middleware since it's more straightforward. Since we want to log the HTTP Status Code of each request, we need a way to access it. We don't have access to it on our Logging middleware, so we need to create an wrapper for it.

type StatusRecorder struct {
	http.ResponseWriter
	Status int
}

func (r *StatusRecorder) WriteHeader(status int) {
	r.Status = status
	r.ResponseWriter.WriteHeader(status)
}

func LoggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		recorder := &StatusRecorder{
			ResponseWriter: w,
			Status:         200,
		}

		next(recorder, r)
		log.Printf("%s - %d", r.URL.Path, recorder.Status)
	}
}

The StatusRecorder type defines a method called WriteHeader(status int). This method overrides the WriteHeader method of the embedded http.ResponseWriter interface. The approach used in the given code snippet, is known as "embedding" or "Type embedding" in Go.

Now we can use the Created middleware in our main.go file

func main() {
	router := server.NewRouter()

	// Common middlewares
	router.Use(middlewares.LoggingMiddleware)

	// Public group
	publicGroup := server.Group{Name: "public"}
	publicGroup.AddRoute(server.NewRoute("/login", userHandler.LoginHandler))
	router.AddGroup(&publicGroup)

	router.Start(":8080")
}

Now we have in our Router the LoggingMiddleware and a Group called public. This group does not have any specific middleware. Later we can implement a new private Group with an Authentication Middleware. But I'll leave that for a separate post.

Thanks for reading!