玩轉 Go 生態|Hertz WebSocket 擴充套件簡析

2022-12-14 18:00:36

WebSocket 是一種可以在單個 TCP 連線上進行全雙工通訊,位於 OSI 模型的應用層。WebSocket 使得使用者端和伺服器之間的資料交換變得更加簡單,允許伺服器端主動向使用者端推播資料。在 WebSocket API 中,瀏覽器和伺服器只需要完成一次握手,兩者之間就可以建立永續性的連線,並進行雙向資料傳輸。

Hertz 提供了 WebSocket 的支援,參考 gorilla/websocket 庫使用 hijack 的方式在 Hertz 進行了適配,用法和引數基本保持一致。

安裝

go get github.com/hertz-contrib/websocket

範例程式碼

package main
​
import (
    "context"
    "flag"
    "html/template"
    "log"
​
    "github.com/cloudwego/hertz/pkg/app"
    "github.com/cloudwego/hertz/pkg/app/server"
    "github.com/hertz-contrib/websocket"
)
​
var addr = flag.String("addr", "localhost:8080", "http service address")
​
var upgrader = websocket.HertzUpgrader{} // use default options
​
func echo(_ context.Context, c *app.RequestContext) {
    err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
        for {
            mt, message, err := conn.ReadMessage()
            if err != nil {
                log.Println("read:", err)
                break
            }
            log.Printf("recv: %s", message)
            err = conn.WriteMessage(mt, message)
            if err != nil {
                log.Println("write:", err)
                break
            }
        }
    })
    if err != nil {
        log.Print("upgrade:", err)
        return
    }
}
​
func home(_ context.Context, c *app.RequestContext) {
    c.SetContentType("text/html; charset=utf-8")
    homeTemplate.Execute(c, "ws://"+string(c.Host())+"/echo")
}
​
func main() {
    flag.Parse()
    h := server.Default(server.WithHostPorts(*addr))
    // https://github.com/cloudwego/hertz/issues/121
    h.NoHijackConnPool = true
    h.GET("/", home)
    h.GET("/echo", echo)
    h.Spin()
}
​
var homeTemplate = template.Must(template.New("").Parse(`
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script>  
window.addEventListener("load", function(evt) {
​
    var output = document.getElementById("output");
    var input = document.getElementById("input");
    var ws;
​
    var print = function(message) {
        var d = document.createElement("div");
        d.textContent = message;
        output.appendChild(d);
        output.scroll(0, output.scrollHeight);
    };
​
    document.getElementById("open").onclick = function(evt) {
        if (ws) {
            return false;
        }
        ws = new WebSocket("{{.}}");
        ws.onopen = function(evt) {
            print("OPEN");
        }
        ws.onclose = function(evt) {
            print("CLOSE");
            ws = null;
        }
        ws.onmessage = function(evt) {
            print("RESPONSE: " + evt.data);
        }
        ws.onerror = function(evt) {
            print("ERROR: " + evt.data);
        }
        return false;
    };
​
    document.getElementById("send").onclick = function(evt) {
        if (!ws) {
            return false;
        }
        print("SEND: " + input.value);
        ws.send(input.value);
        return false;
    };
​
    document.getElementById("close").onclick = function(evt) {
        if (!ws) {
            return false;
        }
        ws.close();
        return false;
    };
​
});
</script>
</head>
<body>
<table>
<tr><td valign="top" width="50%">
<p>Click "Open" to create a connection to the server, 
"Send" to send a message to the server and "Close" to close the connection. 
You can change the message and send multiple times.
<p>
<form>
<button id="open">Open</button>
<button id="close">Close</button>
<p><input id="input" type="text" value="Hello world!">
<button id="send">Send</button>
</form>
</td><td valign="top" width="50%">
<div id="output" style="max-height: 70vh;overflow-y: scroll;"></div>
</td></tr></table>
</body>
</html>
`))

執行 server:

go run server.go

上述範例程式碼中,伺服器包括一個簡單的網路使用者端。要使用該使用者端,在瀏覽器中開啟 http://127.0.0.1:8080,並按照頁面上的指示操作。

Upgrade

websocket.Conn 型別代表一個 WebSocket 連線。伺服器應用程式從 HTTP 請求處理程式中呼叫 HertzUpgrader.Upgrade 方法,將 HTTP 協定的連線請求升級為 WebSocket 協定的連線請求。

這部分邏輯對應著範例程式碼的 echo() 函數,此處著重介紹 HertzUpgrader.Upgrade

函數簽名:

func (u *HertzUpgrader) Upgrade(ctx *app.RequestContext, handler HertzHandler) error

內部處理邏輯:

func (u *HertzUpgrader) Upgrade(ctx *app.RequestContext, handler HertzHandler) error {
    if !ctx.IsGet() {
        return u.returnError(ctx, consts.StatusMethodNotAllowed, fmt.Sprintf("%s request method is not GET", badHandshake))
    }
    // 校驗 requsetHeader 中與 websocket 相關的欄位(此處省略部分邏輯程式碼)
​
    subprotocol := u.selectSubprotocol(ctx)
    compress := u.isCompressionEnable(ctx)
​
    ctx.SetStatusCode(consts.StatusSwitchingProtocols)
    // 構造協定升級後的響應頭部資訊
    ctx.Response.Header.Set("Upgrade", "websocket")
    ctx.Response.Header.Set("Connection", "Upgrade")
    ctx.Response.Header.Set("Sec-WebSocket-Accept", computeAcceptKeyBytes(challengeKey))
    // 「無上下文接管」模式
    if compress {
        ctx.Response.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
    }
    if subprotocol != nil {
        ctx.Response.Header.SetBytesV("Sec-WebSocket-Protocol", subprotocol)
    }
​
    // 通過 Hijack 的方式,實現 websocket 全雙工的通訊
    ctx.Hijack(func(netConn network.Conn) {
        writeBuf := poolWriteBuffer.Get().([]byte)
        c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, nil, writeBuf)
        if subprotocol != nil {
            c.subprotocol = b2s(subprotocol)
        }
​
        if compress {
            c.newCompressionWriter = compressNoContextTakeover
            c.newDecompressionReader = decompressNoContextTakeover
        }
​
        netConn.SetDeadline(time.Time{})
​
        handler(c)
​
        writeBuf = writeBuf[0:0]
        poolWriteBuffer.Put(writeBuf)
    })
​
    return nil
}

HertzHandler

HertzHandler 是上述 HertzUpgrader.Upgrade 函數的第二個引數。HertzHandler 在握手完成後接收一個 websocket 連線,通過劫持這個連線,完成全雙工的通訊。

HertzHandler 必須由使用者提供,內部定義了 WebSocket 請求和響應的具體流程。

函數簽名:

type HertzHandler func(*Conn)

上述 echo 伺服器的 websocket 處理流程:

err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
    for {
        // 讀取使用者端傳送的資訊
        mt, message, err := conn.ReadMessage()
        if err != nil {
            log.Println("read:", err)
            break
        }
        log.Printf("recv: %s", message)
        // 向用戶端傳送資訊
        err = conn.WriteMessage(mt, message)
        if err != nil {
            log.Println("write:", err)
            break
        }
    }
})

設定

上述檔案已經講述了Hertz WebSocket 最核心的協定升級連線劫持的邏輯,下面將羅列 Hertz WebSocket 使用過程中可選的設定引數。

這部分將圍繞 websocket.HertzUpgrader 結構展開說明。

引數 介紹
ReadBufferSize 用於設定輸入緩衝區的大小,單位為位元組。如果緩衝區大小為零,那麼就使用 HTTP 伺服器分配的大小。輸入緩衝區大小並不限制可以接收的資訊的大小。
WriteBufferSize 用於設定輸出緩衝區的大小,單位為位元組。如果緩衝區大小為零,那麼就使用 HTTP 伺服器分配的大小。輸出緩衝區大小並不限制可以傳送的資訊的大小。
WriteBufferPool 用於設定寫操作的緩衝池。
Subprotocols 用於按優先順序設定伺服器支援的協定。如果這個欄位不是 nil,那麼 Upgrade 方法通過選擇這個列表中與使用者端請求的協定的第一個匹配來協商一個子協定。如果沒有匹配,那麼就不協商協定(Sec-Websocket-Protocol 頭不包括在握手響應中)。
Error 用於設定生成 HTTP 錯誤響應的函數。
CheckOrigin 用於設定針對請求的 Origin 頭的校驗函數, 如果請求的 Origin 頭是可接受的,CheckOrigin 返回 true。
EnableCompression 用於設定伺服器是否應該嘗試協商每個訊息的壓縮(RFC 7692)。將此值設定為 true 並不能保證壓縮會被支援。

WriteBufferPool

如果該值沒有被設定,則額外初始化寫緩衝區,並在當前生命週期內分配給該連線。當應用程式在大量的連線上有適度的寫入量時,緩衝池是最有用的。

應用程式應該使用一個單一的緩衝池來為不同的連線分配緩衝區。

介面簽名:

// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
// interface.  The type of the value stored in a pool is not specified.
type BufferPool interface {
    // Get gets a value from the pool or returns nil if the pool is empty.
    Get() interface{}
    // Put adds a value to the pool.
    Put(interface{})
}

範例程式碼:

type simpleBufferPool struct {
    v interface{}
}
​
func (p *simpleBufferPool) Get() interface{} {
    v := p.v
    p.v = nil
    return v
}
​
func (p *simpleBufferPool) Put(v interface{}) {
    p.v = v
}
​
var upgrader = websocket.HertzUpgrader{
    WriteBufferPool: &simpleBufferPool{},
}

Subprotocols

WebSocket 只是定義了一種交換任意訊息的機制。這些訊息是什麼意思,使用者端在任何特定的時間點可以期待什麼樣的訊息,或者他們被允許傳送什麼樣的訊息,完全取決於實現應用程式。

所以你需要在伺服器和使用者端之間就這些事情達成協定。子協定引數只是讓使用者端和伺服器端正式地交換這些資訊。你可以為你想要的任何協定編造任何名字。伺服器可以簡單地檢查客戶在握手過程中是否遵守了該協定。

Error

如果 Error 為 nil,則使用 Hertz 提供的 API 來生成 HTTP 錯誤響應。

函數簽名:

func(ctx *app.RequestContext, status int, reason error)

範例程式碼:

var upgrader = websocket.HertzUpgrader{
    Error: func(ctx *app.RequestContext, status int, reason error) {
        ctx.Response.Header.Set("Sec-Websocket-Version", "13")
        ctx.AbortWithMsg(reason.Error(), status)
    },
}

CheckOrigin

如果 CheckOrigin 為nil,則使用一個安全的預設值:如果Origin請求頭存在,並且源主機不等於請求主機頭,則返回false。CheckOrigin 函數應該仔細驗證請求的來源,以防止跨站請求偽造。

函數簽名:

func(ctx *app.RequestContext) bool

預設實現:

func fastHTTPCheckSameOrigin(ctx *app.RequestContext) bool {
    origin := ctx.Request.Header.Peek("Origin")
    if len(origin) == 0 {
        return true
    }
    u, err := url.Parse(b2s(origin))
    if err != nil {
        return false
    }
    return equalASCIIFold(u.Host, b2s(ctx.Host()))
}

EnableCompression

伺服器端接受一個或者多個擴充套件欄位,這些擴充套件欄位是包含使用者端請求的 Sec-WebSocket-Extensions 頭欄位擴充套件中的。當 EnableCompression 為 true 時,伺服器端根據當前自身支援的擴充套件與其進行匹配,如果匹配成功則支援壓縮。

校驗邏輯:

var strPermessageDeflate = []byte("permessage-deflate")
​
func (u *HertzUpgrader) isCompressionEnable(ctx *app.RequestContext) bool {
    extensions := parseDataHeader(ctx.Request.Header.Peek("Sec-WebSocket-Extensions"))
​
    // Negotiate PMCE
    if u.EnableCompression {
        for _, ext := range extensions {
            if bytes.HasPrefix(ext, strPermessageDeflate) {
                return true
            }
        }
    }
​
    return false
}

目前僅支援「無上下文接管」模式,詳見上述 HertzUpgrader.Upgrade 程式碼部分。

Set Deadline

當使用 websocket 進行讀寫的時候,可以通過類似如下方式設定超時時間(在每次讀寫過程中都會生效)。

範例程式碼:

func echo(_ context.Context, c *app.RequestContext) {
    err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
        defer conn.Close()
        // "github.com/cloudwego/hertz/pkg/network"
        conn.NetConn().(network.Conn).SetReadTimeout(1 * time.Second)
        ...
    })
    if err != nil {
        log.Print("upgrade:", err)
        return
    }
}

更多用法範例詳見 examples