// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package zstd provides a decompressor for zstd streams, // described in RFC 8878. It does not support dictionaries.
package zstd import ( ) // fuzzing is a fuzzer hook set to true when fuzzing. // This is used to reject cases where we don't match zstd. var fuzzing = false // Reader implements [io.Reader] to read a zstd compressed stream. type Reader struct { // The underlying Reader. r io.Reader // Whether we have read the frame header. // This is of interest when buffer is empty. // If true we expect to see a new block. sawFrameHeader bool // Whether the current frame expects a checksum. hasChecksum bool // Whether we have read at least one frame. readOneFrame bool // True if the frame size is not known. frameSizeUnknown bool // The number of uncompressed bytes remaining in the current frame. // If frameSizeUnknown is true, this is not valid. remainingFrameSize uint64 // The number of bytes read from r up to the start of the current // block, for error reporting. blockOffset int64 // Buffered decompressed data. buffer []byte // Current read offset in buffer. off int // The current repeated offsets. repeatedOffset1 uint32 repeatedOffset2 uint32 repeatedOffset3 uint32 // The current Huffman tree used for compressing literals. huffmanTable []uint16 huffmanTableBits int // The window for back references. window window // A buffer available to hold a compressed block. compressedBuf []byte // A buffer for literals. literals []byte // Sequence decode FSE tables. seqTables [3][]fseBaselineEntry seqTableBits [3]uint8 // Buffers for sequence decode FSE tables. seqTableBuffers [3][]fseBaselineEntry // Scratch space used for small reads, to avoid allocation. scratch [16]byte // A scratch table for reading an FSE. Only temporarily valid. fseScratch []fseEntry // For checksum computation. checksum xxhash64 } // NewReader creates a new Reader that decompresses data from the given reader. func ( io.Reader) *Reader { := new(Reader) .Reset() return } // Reset discards the current state and starts reading a new stream from r. // This permits reusing a Reader rather than allocating a new one. func ( *Reader) ( io.Reader) { .r = // Several fields are preserved to avoid allocation. // Others are always set before they are used. .sawFrameHeader = false .hasChecksum = false .readOneFrame = false .frameSizeUnknown = false .remainingFrameSize = 0 .blockOffset = 0 .buffer = .buffer[:0] .off = 0 // repeatedOffset1 // repeatedOffset2 // repeatedOffset3 // huffmanTable // huffmanTableBits // window // compressedBuf // literals // seqTables // seqTableBits // seqTableBuffers // scratch // fseScratch } // Read implements [io.Reader]. func ( *Reader) ( []byte) (int, error) { if := .refillIfNeeded(); != nil { return 0, } := copy(, .buffer[.off:]) .off += return , nil } // ReadByte implements [io.ByteReader]. func ( *Reader) () (byte, error) { if := .refillIfNeeded(); != nil { return 0, } := .buffer[.off] .off++ return , nil } // refillIfNeeded reads the next block if necessary. func ( *Reader) () error { for .off >= len(.buffer) { if := .refill(); != nil { return } .off = 0 } return nil } // refill reads and decompresses the next block. func ( *Reader) () error { if !.sawFrameHeader { if := .readFrameHeader(); != nil { return } } return .readBlock() } // readFrameHeader reads the frame header and prepares to read a block. func ( *Reader) () error { : := 0 // Read magic number. RFC 3.1.1. if , := io.ReadFull(.r, .scratch[:4]); != nil { // We require that the stream contains at least one frame. if == io.EOF && !.readOneFrame { = io.ErrUnexpectedEOF } return .wrapError(, ) } if := binary.LittleEndian.Uint32(.scratch[:4]); != 0xfd2fb528 { if >= 0x184d2a50 && <= 0x184d2a5f { // This is a skippable frame. .blockOffset += int64() + 4 if := .skipFrame(); != nil { return } .readOneFrame = true goto } return .makeError(, "invalid magic number") } += 4 // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. if , := io.ReadFull(.r, .scratch[:1]); != nil { return .wrapNonEOFError(, ) } := .scratch[0] := &(1<<5) != 0 := 1 << ( >> 6) if == 1 && ! { = 0 } var int if { = 0 } else { = 1 } if &(1<<3) != 0 { return .makeError(, "reserved bit set in frame header descriptor") } .hasChecksum = &(1<<2) != 0 if .hasChecksum { .checksum.reset() } // Dictionary_ID_Flag. RFC 3.1.1.1.1.6. := 0 if := & 3; != 0 { = 1 << ( - 1) } ++ := + + if , := io.ReadFull(.r, .scratch[:]); != nil { return .wrapNonEOFError(, ) } // Figure out the maximum amount of data we need to retain // for backreferences. var uint64 if ! { // Window descriptor. RFC 3.1.1.1.2. := .scratch[0] := uint64( >> 3) := uint64( & 7) := + 10 := uint64(1) << := ( / 8) * = + // Default zstd sets limits on the window size. if fuzzing && ( > 31 || > 1<<27) { return .makeError(, "windowSize too large") } } // Dictionary_ID. RFC 3.1.1.1.3. if != 0 { := .scratch[ : +] // Allow only zero Dictionary ID. for , := range { if != 0 { return .makeError(, "dictionaries are not supported") } } } // Frame_Content_Size. RFC 3.1.1.1.4. .frameSizeUnknown = false .remainingFrameSize = 0 := .scratch[+:] switch { case 0: .frameSizeUnknown = true case 1: .remainingFrameSize = uint64([0]) case 2: .remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16()) case 4: .remainingFrameSize = uint64(binary.LittleEndian.Uint32()) case 8: .remainingFrameSize = binary.LittleEndian.Uint64() default: panic("unreachable") } // RFC 3.1.1.1.2. // When Single_Segment_Flag is set, Window_Descriptor is not present. // In this case, Window_Size is Frame_Content_Size. if { = .remainingFrameSize } // RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size. const = 8 << 20 if > { = } += .sawFrameHeader = true .readOneFrame = true .blockOffset += int64() // Prepare to read blocks from the frame. .repeatedOffset1 = 1 .repeatedOffset2 = 4 .repeatedOffset3 = 8 .huffmanTableBits = 0 .window.reset(int()) .seqTables[0] = nil .seqTables[1] = nil .seqTables[2] = nil return nil } // skipFrame skips a skippable frame. RFC 3.1.2. func ( *Reader) () error { := 0 if , := io.ReadFull(.r, .scratch[:4]); != nil { return .wrapNonEOFError(, ) } += 4 := binary.LittleEndian.Uint32(.scratch[:4]) if == 0 { .blockOffset += int64() return nil } if , := .r.(io.Seeker); { .blockOffset += int64() // Implementations of Seeker do not always detect invalid offsets, // so check that the new offset is valid by comparing to the end. , := .Seek(0, io.SeekCurrent) if != nil { return .wrapError(0, ) } , := .Seek(0, io.SeekEnd) if != nil { return .wrapError(0, ) } if > -int64() { .blockOffset += - return .makeEOFError(0) } // The new offset is valid, so seek to it. _, = .Seek(+int64(), io.SeekStart) if != nil { return .wrapError(0, ) } .blockOffset += int64() return nil } var []byte const = 1 << 20 // 1M for >= { if len() == 0 { = make([]byte, ) } if , := io.ReadFull(.r, ); != nil { return .wrapNonEOFError(, ) } += -= } if > 0 { if len() == 0 { = make([]byte, ) } if , := io.ReadFull(.r, ); != nil { return .wrapNonEOFError(, ) } += int() } .blockOffset += int64() return nil } // readBlock reads the next block from a frame. func ( *Reader) () error { := 0 // Read Block_Header. RFC 3.1.1.2. if , := io.ReadFull(.r, .scratch[:3]); != nil { return .wrapNonEOFError(, ) } += 3 := uint32(.scratch[0]) | (uint32(.scratch[1]) << 8) | (uint32(.scratch[2]) << 16) := &1 != 0 := ( >> 1) & 3 := int( >> 3) // Maximum block size is smaller of window size and 128K. // We don't record the window size for a single segment frame, // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. if > 128<<10 || (.window.size > 0 && > .window.size) { return .makeError(, "block size too large") } // Handle different block types. RFC 3.1.1.2.2. switch { case 0: .setBufferSize() if , := io.ReadFull(.r, .buffer); != nil { return .wrapNonEOFError(, ) } += .blockOffset += int64() case 1: .setBufferSize() if , := io.ReadFull(.r, .scratch[:1]); != nil { return .wrapNonEOFError(, ) } ++ := .scratch[0] for := range .buffer { .buffer[] = } .blockOffset += int64() case 2: .blockOffset += int64() if := .compressedBlock(); != nil { return } .blockOffset += int64() case 3: return .makeError(, "invalid block type") } if !.frameSizeUnknown { if uint64(len(.buffer)) > .remainingFrameSize { return .makeError(, "too many uncompressed bytes in frame") } .remainingFrameSize -= uint64(len(.buffer)) } if .hasChecksum { .checksum.update(.buffer) } if ! { .window.save(.buffer) } else { if !.frameSizeUnknown && .remainingFrameSize != 0 { return .makeError(, "not enough uncompressed bytes for frame") } // Check for checksum at end of frame. RFC 3.1.1. if .hasChecksum { if , := io.ReadFull(.r, .scratch[:4]); != nil { return .wrapNonEOFError(0, ) } := binary.LittleEndian.Uint32(.scratch[:4]) := uint32(.checksum.digest()) if != { return .wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", , )) } .blockOffset += 4 } .sawFrameHeader = false } return nil } // setBufferSize sets the decompressed buffer size. // When this is called the buffer is empty. func ( *Reader) ( int) { if cap(.buffer) < { := - cap(.buffer) .buffer = append(.buffer[:cap(.buffer)], make([]byte, )...) } .buffer = .buffer[:] } // zstdError is an error while decompressing. type zstdError struct { offset int64 err error } func ( *zstdError) () string { return fmt.Sprintf("zstd decompression error at %d: %v", .offset, .err) } func ( *zstdError) () error { return .err } func ( *Reader) ( int) error { return .wrapError(, io.ErrUnexpectedEOF) } func ( *Reader) ( int, error) error { if == io.EOF { = io.ErrUnexpectedEOF } return .wrapError(, ) } func ( *Reader) ( int, string) error { return .wrapError(, errors.New()) } func ( *Reader) ( int, error) error { if == io.EOF { return } return &zstdError{.blockOffset + int64(), } }