package payload import ( "bufio" "encoding/base64" "io" "io/ioutil" "github.com/googollee/go-socket.io/engineio/frame" "github.com/googollee/go-socket.io/engineio/packet" ) type byteReader interface { ReadByte() (byte, error) io.Reader } type readerFeeder interface { getReader() (io.Reader, bool, error) putReader(error) error } type decoder struct { b64Reader io.Reader limitReader io.LimitedReader rawReader byteReader feeder readerFeeder ft frame.Type pt packet.Type supportBinary bool } func (d *decoder) NextReader() (frame.Type, packet.Type, io.ReadCloser, error) { if d.rawReader == nil { r, supportBinary, err := d.feeder.getReader() if err != nil { return 0, 0, nil, err } br, ok := r.(byteReader) if !ok { br = bufio.NewReader(r) } if err := d.setNextReader(br, supportBinary); err != nil { return 0, 0, nil, d.sendError(err) } } return d.ft, d.pt, d, nil } func (d *decoder) Read(p []byte) (int, error) { if d.b64Reader != nil { return d.b64Reader.Read(p) } return d.limitReader.Read(p) } func (d *decoder) Close() error { if _, err := io.Copy(ioutil.Discard, d); err != nil { return d.sendError(err) } err := d.setNextReader(d.rawReader, d.supportBinary) if err != nil { if err != io.EOF { return d.sendError(err) } d.rawReader = nil d.limitReader.R = nil d.limitReader.N = 0 d.b64Reader = nil err = d.sendError(nil) } return err } func (d *decoder) setNextReader(r byteReader, supportBinary bool) error { var read func(byteReader) (frame.Type, packet.Type, int64, error) if supportBinary { read = d.binaryRead } else { read = d.textRead } ft, pt, l, err := read(r) if err != nil { return err } d.ft = ft d.pt = pt d.rawReader = r d.limitReader.R = r d.limitReader.N = l d.supportBinary = supportBinary if !supportBinary && ft == frame.Binary { d.b64Reader = base64.NewDecoder(base64.StdEncoding, &d.limitReader) } else { d.b64Reader = nil } return nil } func (d *decoder) sendError(err error) error { if e := d.feeder.putReader(err); e != nil { return e } return err } func (d *decoder) textRead(r byteReader) (frame.Type, packet.Type, int64, error) { l, err := readTextLen(r) if err != nil { return 0, 0, 0, err } ft := frame.String b, err := r.ReadByte() if err != nil { return 0, 0, 0, err } l-- if b == 'b' { ft = frame.Binary b, err = r.ReadByte() if err != nil { return 0, 0, 0, err } l-- } pt := packet.ByteToPacketType(b, frame.String) return ft, pt, l, nil } func (d *decoder) binaryRead(r byteReader) (frame.Type, packet.Type, int64, error) { b, err := r.ReadByte() if err != nil { return 0, 0, 0, err } if b > 1 { return 0, 0, 0, errInvalidPayload } ft := frame.ByteToFrameType(b) l, err := readBinaryLen(r) if err != nil { return 0, 0, 0, err } b, err = r.ReadByte() if err != nil { return 0, 0, 0, err } pt := packet.ByteToPacketType(b, ft) l-- return ft, pt, l, nil }