go の middleware で status code を取得する

はじめに

go の net/http の middleware 内で status code を取得する方法をまとめます。
go の ResponseWriter という interface を status code 断面で深掘りするお話となります。

golang.org

go の middleware で status code を取得する

go の net/http の middleware として rs/zerolog を使う で書いた通り logger などは status code に応じて処理を変えたいことがあります。

net/http の middleware は以下のような型を持っている必要があるので http.ResponseWriter を Wrap して status code を取得してみます。

func Middleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        next.ServeHTTP(w, r)
    })
}

chi/middleware の NewWrapResponseWriter()

chi/middleware には NewWrapResponseWriter() という関数があり ResponseWriter を引数にとって status code を含む WrapResponseWriter という構造体を返してくれます。

この関数を使って status code を取得する流れを見ていきます。

ww := middleware.NewWrapResponseWriter(http.ResponseWriter, *http.Request.ProtoMajor)
fmt.Println("status code: ", ww.Status())

ResponseWriter

そもそも http.ResponseWriter には status code というプロパティを持ってません。
ResponseWriter は以下のような interface になってます。

type ResponseWriter interface {
    Header() Header
    Write([]byte) (int, error)
    WriteHeader(statusCode int)
}

以下のような ResponseWriter を Wrap して取り出しやすい位置に status code を置いた構造体 basicWriter があります。

type basicWriter struct {
    http.ResponseWriter
    wroteHeader bool
    code        int
    bytes       int
    tee         io.Writer
}

この構造体に以下のように WriteHeader() を実装してあげると Wrap した構造体側にも status code を格納することができます。

func (b *basicWriter) WriteHeader(code int) {
    if !b.wroteHeader {
        b.code = code
        b.wroteHeader = true
        b.ResponseWriter.WriteHeader(code)
    }
}

通常は WriteHeader は以下のように response に対して生えているメソッドになっていて w.status = code として response へ status code を格納するような形で使われているようです。

func (w *response) WriteHeader(code int) {
    if w.conn.hijacked() {
        caller := relevantCaller()
        w.conn.server.logf("http: response.WriteHeader on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
        return
    }
    if w.wroteHeader {
        caller := relevantCaller()
        w.conn.server.logf("http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
        return
    }
    checkWriteHeaderCode(code)
    w.wroteHeader = true
    w.status = code

    if w.calledHeader && w.cw.header == nil {
        w.cw.header = w.handlerHeader.Clone()
    }

    if cl := w.handlerHeader.get("Content-Length"); cl != "" {
        v, err := strconv.ParseInt(cl, 10, 64)
        if err == nil && v >= 0 {
            w.contentLength = v
        } else {
            w.conn.server.logf("http: invalid Content-Length of %q", cl)
            w.handlerHeader.Del("Content-Length")
        }
    }
}

chi/middlewarebasicWriter では Status() というメソッドが生えているので、先ほど格納した status code が簡単に取得できるという流れでした。

func (b *basicWriter) Status() int {
    return b.code
}

まとめ

go の middleware で status code を取得する流れについてまとめました。
net/http は何度読んでも勉強になります。