// Copyright 2021 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nats import ( "bufio" "bytes" "compress/flate" "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/binary" "errors" "fmt" "io" "io/ioutil" mrand "math/rand" "net/http" "net/url" "strings" "time" "unicode/utf8" ) type wsOpCode int const ( // From https://tools.ietf.org/html/rfc6455#section-5.2 wsTextMessage = wsOpCode(1) wsBinaryMessage = wsOpCode(2) wsCloseMessage = wsOpCode(8) wsPingMessage = wsOpCode(9) wsPongMessage = wsOpCode(10) wsFinalBit = 1 << 7 wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6 wsRsv2Bit = 1 << 5 wsRsv3Bit = 1 << 4 wsMaskBit = 1 << 7 wsContinuationFrame = 0 wsMaxFrameHeaderSize = 14 wsMaxControlPayloadSize = 125 // From https://tools.ietf.org/html/rfc6455#section-11.7 wsCloseStatusNormalClosure = 1000 wsCloseStatusNoStatusReceived = 1005 wsCloseStatusAbnormalClosure = 1006 wsCloseStatusInvalidPayloadData = 1007 wsScheme = "ws" wsSchemeTLS = "wss" wsPMCExtension = "permessage-deflate" // per-message compression wsPMCSrvNoCtx = "server_no_context_takeover" wsPMCCliNoCtx = "client_no_context_takeover" wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx ) // From https://tools.ietf.org/html/rfc6455#section-1.3 var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader // does not report unexpected EOF. var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} type websocketReader struct { r io.Reader pending [][]byte ib []byte ff bool fc bool dc io.ReadCloser nc *Conn } type websocketWriter struct { w io.Writer compress bool compressor *flate.Writer ctrlFrames [][]byte // pending frames that should be sent at the next Write() cm []byte // close message that needs to be sent when everything else has been sent cmDone bool // a close message has been added or sent (never going back to false) noMoreSend bool // if true, even if there is a Write() call, we should not send anything } type decompressorBuffer struct { buf []byte rem int off int final bool } func newDecompressorBuffer(buf []byte) *decompressorBuffer { return &decompressorBuffer{buf: buf, rem: len(buf)} } func (d *decompressorBuffer) Read(p []byte) (int, error) { if d.buf == nil { return 0, io.EOF } lim := d.rem if len(p) < lim { lim = len(p) } n := copy(p, d.buf[d.off:d.off+lim]) d.off += n d.rem -= n d.checkRem() return n, nil } func (d *decompressorBuffer) checkRem() { if d.rem != 0 { return } if !d.final { d.buf = compressFinalBlock d.off = 0 d.rem = len(d.buf) d.final = true } else { d.buf = nil } } func (d *decompressorBuffer) ReadByte() (byte, error) { if d.buf == nil { return 0, io.EOF } b := d.buf[d.off] d.off++ d.rem-- d.checkRem() return b, nil } func wsNewReader(r io.Reader) *websocketReader { return &websocketReader{r: r, ff: true} } func (r *websocketReader) Read(p []byte) (int, error) { var err error var buf []byte if l := len(r.ib); l > 0 { buf = r.ib r.ib = nil } else { if len(r.pending) > 0 { return r.drainPending(p), nil } // Get some data from the underlying reader. n, err := r.r.Read(p) if err != nil { return 0, err } buf = p[:n] } // Now parse this and decode frames. We will possibly read more to // ensure that we get a full frame. var ( tmpBuf []byte pos int max = len(buf) rem = 0 ) for pos < max { b0 := buf[pos] frameType := wsOpCode(b0 & 0xF) final := b0&wsFinalBit != 0 compressed := b0&wsRsv1Bit != 0 pos++ tmpBuf, pos, err = wsGet(r.r, buf, pos, 1) if err != nil { return 0, err } b1 := tmpBuf[0] // Store size in case it is < 125 rem = int(b1 & 0x7F) switch frameType { case wsPingMessage, wsPongMessage, wsCloseMessage: if rem > wsMaxControlPayloadSize { return 0, fmt.Errorf( fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes", wsMaxControlPayloadSize)) } if compressed { return 0, errors.New("control frame should not be compressed") } if !final { return 0, errors.New("control frame does not have final bit set") } case wsTextMessage, wsBinaryMessage: if !r.ff { return 0, errors.New("new message started before final frame for previous message was received") } r.ff = final r.fc = compressed case wsContinuationFrame: // Compressed bit must be only set in the first frame if r.ff || compressed { return 0, errors.New("invalid continuation frame") } r.ff = final default: return 0, fmt.Errorf("unknown opcode %v", frameType) } // If the encoded size is <= 125, then `rem` is simply the remainder size of the // frame. If it is 126, then the actual size is encoded as a uint16. For larger // frames, `rem` will initially be 127 and the actual size is encoded as a uint64. switch rem { case 126: tmpBuf, pos, err = wsGet(r.r, buf, pos, 2) if err != nil { return 0, err } rem = int(binary.BigEndian.Uint16(tmpBuf)) case 127: tmpBuf, pos, err = wsGet(r.r, buf, pos, 8) if err != nil { return 0, err } rem = int(binary.BigEndian.Uint64(tmpBuf)) } // Handle control messages in place... if wsIsControlFrame(frameType) { pos, err = r.handleControlFrame(frameType, buf, pos, rem) if err != nil { return 0, err } rem = 0 continue } var b []byte b, pos, err = wsGet(r.r, buf, pos, rem) if err != nil { return 0, err } rem = 0 if r.fc { br := newDecompressorBuffer(b) if r.dc == nil { r.dc = flate.NewReader(br) } else { r.dc.(flate.Resetter).Reset(br, nil) } // TODO: When Go 1.15 support is dropped, replace with io.ReadAll() b, err = ioutil.ReadAll(r.dc) if err != nil { return 0, err } r.fc = false } r.pending = append(r.pending, b) } // At this point we should have pending slices. return r.drainPending(p), nil } func (r *websocketReader) drainPending(p []byte) int { var n int var max = len(p) for i, buf := range r.pending { if n+len(buf) <= max { copy(p[n:], buf) n += len(buf) } else { // Is there room left? if n < max { // Write the partial and update this slice. rem := max - n copy(p[n:], buf[:rem]) n += rem r.pending[i] = buf[rem:] } // These are the remaining slices that will need to be used at // the next Read() call. r.pending = r.pending[i:] return n } } r.pending = r.pending[:0] return n } func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { avail := len(buf) - pos if avail >= needed { return buf[pos : pos+needed], pos + needed, nil } b := make([]byte, needed) start := copy(b, buf[pos:]) for start != needed { n, err := r.Read(b[start:cap(b)]) start += n if err != nil { return b, start, err } } return b, pos + avail, nil } func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos, rem int) (int, error) { var payload []byte var err error statusPos := pos if rem > 0 { payload, pos, err = wsGet(r.r, buf, pos, rem) if err != nil { return pos, err } } switch frameType { case wsCloseMessage: status := wsCloseStatusNoStatusReceived body := "" // If there is a payload, it should contain 2 unsigned bytes // that represent the status code and then optional payload. if len(payload) >= 2 { status = int(binary.BigEndian.Uint16(buf[statusPos : statusPos+2])) body = string(buf[statusPos+2 : statusPos+len(payload)]) if body != "" && !utf8.ValidString(body) { // https://tools.ietf.org/html/rfc6455#section-5.5.1 // If body is present, it must be a valid utf8 status = wsCloseStatusInvalidPayloadData body = "invalid utf8 body in close frame" } } r.nc.wsEnqueueCloseMsg(status, body) // Return io.EOF so that readLoop will close the connection as ClientClosed // after processing pending buffers. return pos, io.EOF case wsPingMessage: r.nc.wsEnqueueControlMsg(wsPongMessage, payload) case wsPongMessage: // Nothing to do.. } return pos, nil } func (w *websocketWriter) Write(p []byte) (int, error) { if w.noMoreSend { return 0, nil } var total int var n int var err error // If there are control frames, they can be sent now. Actually spec says // that they should be sent ASAP, so we will send before any application data. if len(w.ctrlFrames) > 0 { n, err = w.writeCtrlFrames() if err != nil { return n, err } total += n } // Do the following only if there is something to send. // We will end with checking for need to send close message. if len(p) > 0 { if w.compress { buf := &bytes.Buffer{} if w.compressor == nil { w.compressor, _ = flate.NewWriter(buf, flate.BestSpeed) } else { w.compressor.Reset(buf) } w.compressor.Write(p) w.compressor.Close() b := buf.Bytes() p = b[:len(b)-4] } fh, key := wsCreateFrameHeader(w.compress, wsBinaryMessage, len(p)) wsMaskBuf(key, p) n, err = w.w.Write(fh) total += n if err == nil { n, err = w.w.Write(p) total += n } } if err == nil && w.cm != nil { n, err = w.writeCloseMsg() total += n } return total, err } func (w *websocketWriter) writeCtrlFrames() (int, error) { var ( n int total int i int err error ) for ; i < len(w.ctrlFrames); i++ { buf := w.ctrlFrames[i] n, err = w.w.Write(buf) total += n if err != nil { break } } if i != len(w.ctrlFrames) { w.ctrlFrames = w.ctrlFrames[i+1:] } else { w.ctrlFrames = w.ctrlFrames[:0] } return total, err } func (w *websocketWriter) writeCloseMsg() (int, error) { n, err := w.w.Write(w.cm) w.cm, w.noMoreSend = nil, true return n, err } func wsMaskBuf(key, buf []byte) { for i := 0; i < len(buf); i++ { buf[i] ^= key[i&3] } } // Create the frame header. // Encodes the frame type and optional compression flag, and the size of the payload. func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { fh := make([]byte, wsMaxFrameHeaderSize) n, key := wsFillFrameHeader(fh, compressed, frameType, l) return fh[:n], key } func wsFillFrameHeader(fh []byte, compressed bool, frameType wsOpCode, l int) (int, []byte) { var n int b := byte(frameType) b |= wsFinalBit if compressed { b |= wsRsv1Bit } b1 := byte(wsMaskBit) switch { case l <= 125: n = 2 fh[0] = b fh[1] = b1 | byte(l) case l < 65536: n = 4 fh[0] = b fh[1] = b1 | 126 binary.BigEndian.PutUint16(fh[2:], uint16(l)) default: n = 10 fh[0] = b fh[1] = b1 | 127 binary.BigEndian.PutUint64(fh[2:], uint64(l)) } var key []byte var keyBuf [4]byte if _, err := io.ReadFull(rand.Reader, keyBuf[:4]); err != nil { kv := mrand.Int31() binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv)) } copy(fh[n:], keyBuf[:4]) key = fh[n : n+4] n += 4 return n, key } func (nc *Conn) wsInitHandshake(u *url.URL) error { compress := nc.Opts.Compression tlsRequired := u.Scheme == wsSchemeTLS || nc.Opts.Secure || nc.Opts.TLSConfig != nil // Do TLS here as needed. if tlsRequired { if err := nc.makeTLSConn(); err != nil { return err } } else { nc.bindToNewConn() } var err error // For http request, we need the passed URL to contain either http or https scheme. scheme := "http" if tlsRequired { scheme = "https" } ustr := fmt.Sprintf("%s://%s", scheme, u.Host) u, err = url.Parse(ustr) if err != nil { return err } req := &http.Request{ Method: "GET", URL: u, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Host: u.Host, } wsKey, err := wsMakeChallengeKey() if err != nil { return err } req.Header["Upgrade"] = []string{"websocket"} req.Header["Connection"] = []string{"Upgrade"} req.Header["Sec-WebSocket-Key"] = []string{wsKey} req.Header["Sec-WebSocket-Version"] = []string{"13"} if compress { req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue) } if err := req.Write(nc.conn); err != nil { return err } var resp *http.Response br := bufio.NewReaderSize(nc.conn, 4096) nc.conn.SetReadDeadline(time.Now().Add(nc.Opts.Timeout)) resp, err = http.ReadResponse(br, req) if err == nil && (resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) { err = fmt.Errorf("invalid websocket connection") } // Check compression extension... if err == nil && compress { // Check that not only permessage-deflate extension is present, but that // we also have server and client no context take over. srvCompress, noCtxTakeover := wsPMCExtensionSupport(resp.Header) // If server does not support compression, then simply disable it in our side. if !srvCompress { compress = false } else if !noCtxTakeover { err = fmt.Errorf("compression negotiation error") } } if resp != nil { resp.Body.Close() } nc.conn.SetReadDeadline(time.Time{}) if err != nil { return err } wsr := wsNewReader(nc.br.r) wsr.nc = nc // We have to slurp whatever is in the bufio reader and copy to br.r if n := br.Buffered(); n != 0 { wsr.ib, _ = br.Peek(n) } nc.br.r = wsr nc.bw.w = &websocketWriter{w: nc.bw.w, compress: compress} nc.ws = true return nil } func (nc *Conn) wsClose() { nc.mu.Lock() defer nc.mu.Unlock() if !nc.ws { return } nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_) } func (nc *Conn) wsEnqueueCloseMsg(status int, payload string) { // In some low-level unit tests it will happen... if nc == nil { return } nc.mu.Lock() nc.wsEnqueueCloseMsgLocked(status, payload) nc.mu.Unlock() } func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) { wr, ok := nc.bw.w.(*websocketWriter) if !ok || wr.cmDone { return } statusAndPayloadLen := 2 + len(payload) frame := make([]byte, 2+4+statusAndPayloadLen) n, key := wsFillFrameHeader(frame, false, wsCloseMessage, statusAndPayloadLen) // Set the status binary.BigEndian.PutUint16(frame[n:], uint16(status)) // If there is a payload, copy if len(payload) > 0 { copy(frame[n+2:], payload) } // Mask status + payload wsMaskBuf(key, frame[n:n+statusAndPayloadLen]) wr.cm = frame wr.cmDone = true nc.bw.flush() } func (nc *Conn) wsEnqueueControlMsg(frameType wsOpCode, payload []byte) { // In some low-level unit tests it will happen... if nc == nil { return } fh, key := wsCreateFrameHeader(false, frameType, len(payload)) nc.mu.Lock() wr, ok := nc.bw.w.(*websocketWriter) if !ok { nc.mu.Unlock() return } wr.ctrlFrames = append(wr.ctrlFrames, fh) if len(payload) > 0 { wsMaskBuf(key, payload) wr.ctrlFrames = append(wr.ctrlFrames, payload) } nc.bw.flush() nc.mu.Unlock() } func wsPMCExtensionSupport(header http.Header) (bool, bool) { for _, extensionList := range header["Sec-Websocket-Extensions"] { extensions := strings.Split(extensionList, ",") for _, extension := range extensions { extension = strings.Trim(extension, " \t") params := strings.Split(extension, ";") for i, p := range params { p = strings.Trim(p, " \t") if strings.EqualFold(p, wsPMCExtension) { var snc bool var cnc bool for j := i + 1; j < len(params); j++ { p = params[j] p = strings.Trim(p, " \t") if strings.EqualFold(p, wsPMCSrvNoCtx) { snc = true } else if strings.EqualFold(p, wsPMCCliNoCtx) { cnc = true } if snc && cnc { return true, true } } return true, false } } } } return false, false } func wsMakeChallengeKey() (string, error) { p := make([]byte, 16) if _, err := io.ReadFull(rand.Reader, p); err != nil { return "", err } return base64.StdEncoding.EncodeToString(p), nil } func wsAcceptKey(key string) string { h := sha1.New() h.Write([]byte(key)) h.Write(wsGUID) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // Returns true if the op code corresponds to a control frame. func wsIsControlFrame(frameType wsOpCode) bool { return frameType >= wsCloseMessage } func isWebsocketScheme(u *url.URL) bool { return u.Scheme == wsScheme || u.Scheme == wsSchemeTLS }