|
- package nats
- import (
- "bufio"
- "bytes"
- "compress/flate"
- "crypto/rand"
- "crypto/sha1"
- "encoding/base64"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- mrand "math/rand"
- "net/http"
- "net/url"
- "strings"
- "time"
- "unicode/utf8"
- )
- type wsOpCode int
- const (
-
- wsTextMessage = wsOpCode(1)
- wsBinaryMessage = wsOpCode(2)
- wsCloseMessage = wsOpCode(8)
- wsPingMessage = wsOpCode(9)
- wsPongMessage = wsOpCode(10)
- wsFinalBit = 1 << 7
- wsRsv1Bit = 1 << 6
- wsRsv2Bit = 1 << 5
- wsRsv3Bit = 1 << 4
- wsMaskBit = 1 << 7
- wsContinuationFrame = 0
- wsMaxFrameHeaderSize = 14
- wsMaxControlPayloadSize = 125
-
- wsCloseStatusNormalClosure = 1000
- wsCloseStatusNoStatusReceived = 1005
- wsCloseStatusAbnormalClosure = 1006
- wsCloseStatusInvalidPayloadData = 1007
- wsScheme = "ws"
- wsSchemeTLS = "wss"
- wsPMCExtension = "permessage-deflate"
- wsPMCSrvNoCtx = "server_no_context_takeover"
- wsPMCCliNoCtx = "client_no_context_takeover"
- wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx
- )
- var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
- var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}
- type websocketReader struct {
- r io.Reader
- pending [][]byte
- ib []byte
- ff bool
- fc bool
- dc io.ReadCloser
- nc *Conn
- }
- type websocketWriter struct {
- w io.Writer
- compress bool
- compressor *flate.Writer
- ctrlFrames [][]byte
- cm []byte
- cmDone bool
- noMoreSend bool
- }
- type decompressorBuffer struct {
- buf []byte
- rem int
- off int
- final bool
- }
- func newDecompressorBuffer(buf []byte) *decompressorBuffer {
- return &decompressorBuffer{buf: buf, rem: len(buf)}
- }
- func (d *decompressorBuffer) Read(p []byte) (int, error) {
- if d.buf == nil {
- return 0, io.EOF
- }
- lim := d.rem
- if len(p) < lim {
- lim = len(p)
- }
- n := copy(p, d.buf[d.off:d.off+lim])
- d.off += n
- d.rem -= n
- d.checkRem()
- return n, nil
- }
- func (d *decompressorBuffer) checkRem() {
- if d.rem != 0 {
- return
- }
- if !d.final {
- d.buf = compressFinalBlock
- d.off = 0
- d.rem = len(d.buf)
- d.final = true
- } else {
- d.buf = nil
- }
- }
- func (d *decompressorBuffer) ReadByte() (byte, error) {
- if d.buf == nil {
- return 0, io.EOF
- }
- b := d.buf[d.off]
- d.off++
- d.rem--
- d.checkRem()
- return b, nil
- }
- func wsNewReader(r io.Reader) *websocketReader {
- return &websocketReader{r: r, ff: true}
- }
- func (r *websocketReader) Read(p []byte) (int, error) {
- var err error
- var buf []byte
- if l := len(r.ib); l > 0 {
- buf = r.ib
- r.ib = nil
- } else {
- if len(r.pending) > 0 {
- return r.drainPending(p), nil
- }
-
- n, err := r.r.Read(p)
- if err != nil {
- return 0, err
- }
- buf = p[:n]
- }
-
-
- var (
- tmpBuf []byte
- pos int
- max = len(buf)
- rem = 0
- )
- for pos < max {
- b0 := buf[pos]
- frameType := wsOpCode(b0 & 0xF)
- final := b0&wsFinalBit != 0
- compressed := b0&wsRsv1Bit != 0
- pos++
- tmpBuf, pos, err = wsGet(r.r, buf, pos, 1)
- if err != nil {
- return 0, err
- }
- b1 := tmpBuf[0]
-
- rem = int(b1 & 0x7F)
- switch frameType {
- case wsPingMessage, wsPongMessage, wsCloseMessage:
- if rem > wsMaxControlPayloadSize {
- return 0, fmt.Errorf(
- fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes",
- wsMaxControlPayloadSize))
- }
- if compressed {
- return 0, errors.New("control frame should not be compressed")
- }
- if !final {
- return 0, errors.New("control frame does not have final bit set")
- }
- case wsTextMessage, wsBinaryMessage:
- if !r.ff {
- return 0, errors.New("new message started before final frame for previous message was received")
- }
- r.ff = final
- r.fc = compressed
- case wsContinuationFrame:
-
- if r.ff || compressed {
- return 0, errors.New("invalid continuation frame")
- }
- r.ff = final
- default:
- return 0, fmt.Errorf("unknown opcode %v", frameType)
- }
-
-
-
- switch rem {
- case 126:
- tmpBuf, pos, err = wsGet(r.r, buf, pos, 2)
- if err != nil {
- return 0, err
- }
- rem = int(binary.BigEndian.Uint16(tmpBuf))
- case 127:
- tmpBuf, pos, err = wsGet(r.r, buf, pos, 8)
- if err != nil {
- return 0, err
- }
- rem = int(binary.BigEndian.Uint64(tmpBuf))
- }
-
- if wsIsControlFrame(frameType) {
- pos, err = r.handleControlFrame(frameType, buf, pos, rem)
- if err != nil {
- return 0, err
- }
- rem = 0
- continue
- }
- var b []byte
- b, pos, err = wsGet(r.r, buf, pos, rem)
- if err != nil {
- return 0, err
- }
- rem = 0
- if r.fc {
- br := newDecompressorBuffer(b)
- if r.dc == nil {
- r.dc = flate.NewReader(br)
- } else {
- r.dc.(flate.Resetter).Reset(br, nil)
- }
-
- b, err = ioutil.ReadAll(r.dc)
- if err != nil {
- return 0, err
- }
- r.fc = false
- }
- r.pending = append(r.pending, b)
- }
-
- return r.drainPending(p), nil
- }
- func (r *websocketReader) drainPending(p []byte) int {
- var n int
- var max = len(p)
- for i, buf := range r.pending {
- if n+len(buf) <= max {
- copy(p[n:], buf)
- n += len(buf)
- } else {
-
- if n < max {
-
- rem := max - n
- copy(p[n:], buf[:rem])
- n += rem
- r.pending[i] = buf[rem:]
- }
-
-
- r.pending = r.pending[i:]
- return n
- }
- }
- r.pending = r.pending[:0]
- return n
- }
- func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) {
- avail := len(buf) - pos
- if avail >= needed {
- return buf[pos : pos+needed], pos + needed, nil
- }
- b := make([]byte, needed)
- start := copy(b, buf[pos:])
- for start != needed {
- n, err := r.Read(b[start:cap(b)])
- start += n
- if err != nil {
- return b, start, err
- }
- }
- return b, pos + avail, nil
- }
- func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos, rem int) (int, error) {
- var payload []byte
- var err error
- statusPos := pos
- if rem > 0 {
- payload, pos, err = wsGet(r.r, buf, pos, rem)
- if err != nil {
- return pos, err
- }
- }
- switch frameType {
- case wsCloseMessage:
- status := wsCloseStatusNoStatusReceived
- body := ""
-
-
- if len(payload) >= 2 {
- status = int(binary.BigEndian.Uint16(buf[statusPos : statusPos+2]))
- body = string(buf[statusPos+2 : statusPos+len(payload)])
- if body != "" && !utf8.ValidString(body) {
-
-
- status = wsCloseStatusInvalidPayloadData
- body = "invalid utf8 body in close frame"
- }
- }
- r.nc.wsEnqueueCloseMsg(status, body)
-
-
- return pos, io.EOF
- case wsPingMessage:
- r.nc.wsEnqueueControlMsg(wsPongMessage, payload)
- case wsPongMessage:
-
- }
- return pos, nil
- }
- func (w *websocketWriter) Write(p []byte) (int, error) {
- if w.noMoreSend {
- return 0, nil
- }
- var total int
- var n int
- var err error
-
-
- if len(w.ctrlFrames) > 0 {
- n, err = w.writeCtrlFrames()
- if err != nil {
- return n, err
- }
- total += n
- }
-
-
- if len(p) > 0 {
- if w.compress {
- buf := &bytes.Buffer{}
- if w.compressor == nil {
- w.compressor, _ = flate.NewWriter(buf, flate.BestSpeed)
- } else {
- w.compressor.Reset(buf)
- }
- w.compressor.Write(p)
- w.compressor.Close()
- b := buf.Bytes()
- p = b[:len(b)-4]
- }
- fh, key := wsCreateFrameHeader(w.compress, wsBinaryMessage, len(p))
- wsMaskBuf(key, p)
- n, err = w.w.Write(fh)
- total += n
- if err == nil {
- n, err = w.w.Write(p)
- total += n
- }
- }
- if err == nil && w.cm != nil {
- n, err = w.writeCloseMsg()
- total += n
- }
- return total, err
- }
- func (w *websocketWriter) writeCtrlFrames() (int, error) {
- var (
- n int
- total int
- i int
- err error
- )
- for ; i < len(w.ctrlFrames); i++ {
- buf := w.ctrlFrames[i]
- n, err = w.w.Write(buf)
- total += n
- if err != nil {
- break
- }
- }
- if i != len(w.ctrlFrames) {
- w.ctrlFrames = w.ctrlFrames[i+1:]
- } else {
- w.ctrlFrames = w.ctrlFrames[:0]
- }
- return total, err
- }
- func (w *websocketWriter) writeCloseMsg() (int, error) {
- n, err := w.w.Write(w.cm)
- w.cm, w.noMoreSend = nil, true
- return n, err
- }
- func wsMaskBuf(key, buf []byte) {
- for i := 0; i < len(buf); i++ {
- buf[i] ^= key[i&3]
- }
- }
- func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) ([]byte, []byte) {
- fh := make([]byte, wsMaxFrameHeaderSize)
- n, key := wsFillFrameHeader(fh, compressed, frameType, l)
- return fh[:n], key
- }
- func wsFillFrameHeader(fh []byte, compressed bool, frameType wsOpCode, l int) (int, []byte) {
- var n int
- b := byte(frameType)
- b |= wsFinalBit
- if compressed {
- b |= wsRsv1Bit
- }
- b1 := byte(wsMaskBit)
- switch {
- case l <= 125:
- n = 2
- fh[0] = b
- fh[1] = b1 | byte(l)
- case l < 65536:
- n = 4
- fh[0] = b
- fh[1] = b1 | 126
- binary.BigEndian.PutUint16(fh[2:], uint16(l))
- default:
- n = 10
- fh[0] = b
- fh[1] = b1 | 127
- binary.BigEndian.PutUint64(fh[2:], uint64(l))
- }
- var key []byte
- var keyBuf [4]byte
- if _, err := io.ReadFull(rand.Reader, keyBuf[:4]); err != nil {
- kv := mrand.Int31()
- binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv))
- }
- copy(fh[n:], keyBuf[:4])
- key = fh[n : n+4]
- n += 4
- return n, key
- }
- func (nc *Conn) wsInitHandshake(u *url.URL) error {
- compress := nc.Opts.Compression
- tlsRequired := u.Scheme == wsSchemeTLS || nc.Opts.Secure || nc.Opts.TLSConfig != nil
-
- if tlsRequired {
- if err := nc.makeTLSConn(); err != nil {
- return err
- }
- } else {
- nc.bindToNewConn()
- }
- var err error
-
- scheme := "http"
- if tlsRequired {
- scheme = "https"
- }
- ustr := fmt.Sprintf("%s://%s", scheme, u.Host)
- u, err = url.Parse(ustr)
- if err != nil {
- return err
- }
- req := &http.Request{
- Method: "GET",
- URL: u,
- Proto: "HTTP/1.1",
- ProtoMajor: 1,
- ProtoMinor: 1,
- Header: make(http.Header),
- Host: u.Host,
- }
- wsKey, err := wsMakeChallengeKey()
- if err != nil {
- return err
- }
- req.Header["Upgrade"] = []string{"websocket"}
- req.Header["Connection"] = []string{"Upgrade"}
- req.Header["Sec-WebSocket-Key"] = []string{wsKey}
- req.Header["Sec-WebSocket-Version"] = []string{"13"}
- if compress {
- req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue)
- }
- if err := req.Write(nc.conn); err != nil {
- return err
- }
- var resp *http.Response
- br := bufio.NewReaderSize(nc.conn, 4096)
- nc.conn.SetReadDeadline(time.Now().Add(nc.Opts.Timeout))
- resp, err = http.ReadResponse(br, req)
- if err == nil &&
- (resp.StatusCode != 101 ||
- !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
- !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
- resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) {
- err = fmt.Errorf("invalid websocket connection")
- }
-
- if err == nil && compress {
-
-
- srvCompress, noCtxTakeover := wsPMCExtensionSupport(resp.Header)
-
- if !srvCompress {
- compress = false
- } else if !noCtxTakeover {
- err = fmt.Errorf("compression negotiation error")
- }
- }
- if resp != nil {
- resp.Body.Close()
- }
- nc.conn.SetReadDeadline(time.Time{})
- if err != nil {
- return err
- }
- wsr := wsNewReader(nc.br.r)
- wsr.nc = nc
-
- if n := br.Buffered(); n != 0 {
- wsr.ib, _ = br.Peek(n)
- }
- nc.br.r = wsr
- nc.bw.w = &websocketWriter{w: nc.bw.w, compress: compress}
- nc.ws = true
- return nil
- }
- func (nc *Conn) wsClose() {
- nc.mu.Lock()
- defer nc.mu.Unlock()
- if !nc.ws {
- return
- }
- nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_)
- }
- func (nc *Conn) wsEnqueueCloseMsg(status int, payload string) {
-
- if nc == nil {
- return
- }
- nc.mu.Lock()
- nc.wsEnqueueCloseMsgLocked(status, payload)
- nc.mu.Unlock()
- }
- func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) {
- wr, ok := nc.bw.w.(*websocketWriter)
- if !ok || wr.cmDone {
- return
- }
- statusAndPayloadLen := 2 + len(payload)
- frame := make([]byte, 2+4+statusAndPayloadLen)
- n, key := wsFillFrameHeader(frame, false, wsCloseMessage, statusAndPayloadLen)
-
- binary.BigEndian.PutUint16(frame[n:], uint16(status))
-
- if len(payload) > 0 {
- copy(frame[n+2:], payload)
- }
-
- wsMaskBuf(key, frame[n:n+statusAndPayloadLen])
- wr.cm = frame
- wr.cmDone = true
- nc.bw.flush()
- }
- func (nc *Conn) wsEnqueueControlMsg(frameType wsOpCode, payload []byte) {
-
- if nc == nil {
- return
- }
- fh, key := wsCreateFrameHeader(false, frameType, len(payload))
- nc.mu.Lock()
- wr, ok := nc.bw.w.(*websocketWriter)
- if !ok {
- nc.mu.Unlock()
- return
- }
- wr.ctrlFrames = append(wr.ctrlFrames, fh)
- if len(payload) > 0 {
- wsMaskBuf(key, payload)
- wr.ctrlFrames = append(wr.ctrlFrames, payload)
- }
- nc.bw.flush()
- nc.mu.Unlock()
- }
- func wsPMCExtensionSupport(header http.Header) (bool, bool) {
- for _, extensionList := range header["Sec-Websocket-Extensions"] {
- extensions := strings.Split(extensionList, ",")
- for _, extension := range extensions {
- extension = strings.Trim(extension, " \t")
- params := strings.Split(extension, ";")
- for i, p := range params {
- p = strings.Trim(p, " \t")
- if strings.EqualFold(p, wsPMCExtension) {
- var snc bool
- var cnc bool
- for j := i + 1; j < len(params); j++ {
- p = params[j]
- p = strings.Trim(p, " \t")
- if strings.EqualFold(p, wsPMCSrvNoCtx) {
- snc = true
- } else if strings.EqualFold(p, wsPMCCliNoCtx) {
- cnc = true
- }
- if snc && cnc {
- return true, true
- }
- }
- return true, false
- }
- }
- }
- }
- return false, false
- }
- func wsMakeChallengeKey() (string, error) {
- p := make([]byte, 16)
- if _, err := io.ReadFull(rand.Reader, p); err != nil {
- return "", err
- }
- return base64.StdEncoding.EncodeToString(p), nil
- }
- func wsAcceptKey(key string) string {
- h := sha1.New()
- h.Write([]byte(key))
- h.Write(wsGUID)
- return base64.StdEncoding.EncodeToString(h.Sum(nil))
- }
- func wsIsControlFrame(frameType wsOpCode) bool {
- return frameType >= wsCloseMessage
- }
- func isWebsocketScheme(u *url.URL) bool {
- return u.Scheme == wsScheme || u.Scheme == wsSchemeTLS
- }
|