// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package zstd

import (
	
)

// fseEntry is one entry in an FSE table.
type fseEntry struct {
	sym  uint8  // value that this entry records
	bits uint8  // number of bits to read to determine next state
	base uint16 // add those bits to this state to get the next state
}

// readFSE reads an FSE table from data starting at off.
// maxSym is the maximum symbol value.
// maxBits is the maximum number of bits permitted for symbols in the table.
// The FSE is written into table, which must be at least 1<<maxBits in size.
// This returns the number of bits in the FSE table and the new offset.
// RFC 4.1.1.
func ( *Reader) ( block, , ,  int,  []fseEntry) (,  int,  error) {
	 := .makeBitReader(, )
	if  := .moreBits();  != nil {
		return 0, 0, 
	}

	 := int(.val(4)) + 5
	if  >  {
		return 0, 0, .makeError("FSE accuracy log too large")
	}

	// The number of remaining probabilities, plus 1.
	// This determines the number of bits to be read for the next value.
	 := (1 << ) + 1

	// The current difference between small and large values,
	// which depends on the number of remaining values.
	// Small values use 1 less bit.
	 := 1 << 

	// The number of bits needed to compute threshold.
	 :=  + 1

	// The next character value.
	 := 0

	// Whether the last count was 0.
	 := false

	var  [256]int16

	for  > 1 &&  <=  {
		if  := .moreBits();  != nil {
			return 0, 0, 
		}

		if  {
			// Previous count was 0, so there is a 2-bit
			// repeat flag. If the 2-bit flag is 0b11,
			// it adds 3 and then there is another repeat flag.
			 := 
			for (.bits & 0xfff) == 0xfff {
				 += 3 * 6
				.bits >>= 12
				.cnt -= 12
				if  := .moreBits();  != nil {
					return 0, 0, 
				}
			}
			for (.bits & 3) == 3 {
				 += 3
				.bits >>= 2
				.cnt -= 2
				if  := .moreBits();  != nil {
					return 0, 0, 
				}
			}

			// We have at least 14 bits here,
			// no need to call moreBits

			 += int(.val(2))

			if  >  {
				return 0, 0, .makeError("FSE symbol index overflow")
			}

			for ;  < ; ++ {
				[uint8()] = 0
			}

			 = false
			continue
		}

		 := (2* - 1) - 
		var  int
		if int(.bits&uint32(-1)) <  {
			// A small value.
			 = int(.bits & uint32(( - 1)))
			.bits >>=  - 1
			.cnt -= uint32( - 1)
		} else {
			// A large value.
			 = int(.bits & uint32((2* - 1)))
			if  >=  {
				 -= 
			}
			.bits >>= 
			.cnt -= uint32()
		}

		--
		if  >= 0 {
			 -= 
		} else {
			--
		}
		if  >= 256 {
			return 0, 0, .makeError("FSE sym overflow")
		}
		[uint8()] = int16()
		++

		 =  == 0

		for  <  {
			--
			 >>= 1
		}
	}

	if  != 1 {
		return 0, 0, .makeError("too many symbols in FSE table")
	}

	for ;  <= ; ++ {
		[uint8()] = 0
	}

	.backup()

	if  := .buildFSE(, [:+1], , );  != nil {
		return 0, 0, 
	}

	return , int(.off), nil
}

// buildFSE builds an FSE decoding table from a list of probabilities.
// The probabilities are in norm. next is scratch space. The number of bits
// in the table is tableBits.
func ( *Reader) ( int,  []int16,  []fseEntry,  int) error {
	 := 1 << 
	 :=  - 1

	var  [256]uint16

	for ,  := range  {
		if  >= 0 {
			[uint8()] = uint16()
		} else {
			[].sym = uint8()
			--
			[uint8()] = 1
		}
	}

	 := 0
	 := ( >> 1) + ( >> 3) + 3
	 :=  - 1
	for ,  := range  {
		for  := 0;  < int(); ++ {
			[].sym = uint8()
			 = ( + ) & 
			for  >  {
				 = ( + ) & 
			}
		}
	}
	if  != 0 {
		return .makeError(, "FSE count error")
	}

	for  := 0;  < ; ++ {
		 := [].sym
		 := []
		[]++

		if  == 0 {
			return .makeError(, "FSE state error")
		}

		 := 15 - bits.LeadingZeros16()

		 :=  - 
		[].bits = uint8()
		[].base = ( << ) - uint16()
	}

	return nil
}

// fseBaselineEntry is an entry in an FSE baseline table.
// We use these for literal/match/length values.
// Those require mapping the symbol to a baseline value,
// and then reading zero or more bits and adding the value to the baseline.
// Rather than looking these up in separate tables,
// we convert the FSE table to an FSE baseline table.
type fseBaselineEntry struct {
	baseline uint32 // baseline for value that this entry represents
	basebits uint8  // number of bits to read to add to baseline
	bits     uint8  // number of bits to read to determine next state
	base     uint16 // add the bits to this base to get the next state
}

// Given a literal length code, we need to read a number of bits and
// add that to a baseline. For states 0 to 15 the baseline is the
// state and the number of bits is zero. RFC 3.1.1.3.2.1.1.

const literalLengthOffset = 16

var literalLengthBase = []uint32{
	16 | (1 << 24),
	18 | (1 << 24),
	20 | (1 << 24),
	22 | (1 << 24),
	24 | (2 << 24),
	28 | (2 << 24),
	32 | (3 << 24),
	40 | (3 << 24),
	48 | (4 << 24),
	64 | (6 << 24),
	128 | (7 << 24),
	256 | (8 << 24),
	512 | (9 << 24),
	1024 | (10 << 24),
	2048 | (11 << 24),
	4096 | (12 << 24),
	8192 | (13 << 24),
	16384 | (14 << 24),
	32768 | (15 << 24),
	65536 | (16 << 24),
}

// makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
func ( *Reader) ( int,  []fseEntry,  []fseBaselineEntry) error {
	for ,  := range  {
		 := fseBaselineEntry{
			bits: .bits,
			base: .base,
		}
		if .sym < literalLengthOffset {
			.baseline = uint32(.sym)
			.basebits = 0
		} else {
			if .sym > 35 {
				return .makeError(, "FSE baseline symbol overflow")
			}
			 := .sym - literalLengthOffset
			 := literalLengthBase[]
			.baseline =  & 0xffffff
			.basebits = uint8( >> 24)
		}
		[] = 
	}
	return nil
}

// makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
func ( *Reader) ( int,  []fseEntry,  []fseBaselineEntry) error {
	for ,  := range  {
		 := fseBaselineEntry{
			bits: .bits,
			base: .base,
		}
		if .sym > 31 {
			return .makeError(, "FSE offset symbol overflow")
		}

		// The simple way to write this is
		//     be.baseline = 1 << e.sym
		//     be.basebits = e.sym
		// That would give us an offset value that corresponds to
		// the one described in the RFC. However, for offsets > 3
		// we have to subtract 3. And for offset values 1, 2, 3
		// we use a repeated offset.
		//
		// The baseline is always a power of 2, and is never 0,
		// so for those low values we will see one entry that is
		// baseline 1, basebits 0, and one entry that is baseline 2,
		// basebits 1. All other entries will have baseline >= 4
		// basebits >= 2.
		//
		// So we can check for RFC offset <= 3 by checking for
		// basebits <= 1. That means that we can subtract 3 here
		// and not worry about doing it in the hot loop.

		.baseline = 1 << .sym
		if .sym >= 2 {
			.baseline -= 3
		}
		.basebits = .sym
		[] = 
	}
	return nil
}

// Given a match length code, we need to read a number of bits and add
// that to a baseline. For states 0 to 31 the baseline is state+3 and
// the number of bits is zero. RFC 3.1.1.3.2.1.1.

const matchLengthOffset = 32

var matchLengthBase = []uint32{
	35 | (1 << 24),
	37 | (1 << 24),
	39 | (1 << 24),
	41 | (1 << 24),
	43 | (2 << 24),
	47 | (2 << 24),
	51 | (3 << 24),
	59 | (3 << 24),
	67 | (4 << 24),
	83 | (4 << 24),
	99 | (5 << 24),
	131 | (7 << 24),
	259 | (8 << 24),
	515 | (9 << 24),
	1027 | (10 << 24),
	2051 | (11 << 24),
	4099 | (12 << 24),
	8195 | (13 << 24),
	16387 | (14 << 24),
	32771 | (15 << 24),
	65539 | (16 << 24),
}

// makeMatchBaselineFSE converts the match length fseTable to baselineTable.
func ( *Reader) ( int,  []fseEntry,  []fseBaselineEntry) error {
	for ,  := range  {
		 := fseBaselineEntry{
			bits: .bits,
			base: .base,
		}
		if .sym < matchLengthOffset {
			.baseline = uint32(.sym) + 3
			.basebits = 0
		} else {
			if .sym > 52 {
				return .makeError(, "FSE baseline symbol overflow")
			}
			 := .sym - matchLengthOffset
			 := matchLengthBase[]
			.baseline =  & 0xffffff
			.basebits = uint8( >> 24)
		}
		[] = 
	}
	return nil
}

// predefinedLiteralTable is the predefined table to use for literal lengths.
// Generated from table in RFC 3.1.1.3.2.2.1.
// Checked by TestPredefinedTables.
var predefinedLiteralTable = [...]fseBaselineEntry{
	{0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
	{3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
	{7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
	{12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
	{20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
	{32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
	{128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
	{4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
	{2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
	{7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
	{18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
	{32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
	{64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
	{2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
	{2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
	{6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
	{11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
	{18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
	{28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
	{65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
	{8192, 13, 6, 0},
}

// predefinedOffsetTable is the predefined table to use for offsets.
// Generated from table in RFC 3.1.1.3.2.2.3.
// Checked by TestPredefinedTables.
var predefinedOffsetTable = [...]fseBaselineEntry{
	{1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
	{32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
	{125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
	{8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
	{16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
	{125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
	{4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
	{8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
	{61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
	{268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
	{33554429, 25, 5, 0}, {16777213, 24, 5, 0},
}

// predefinedMatchTable is the predefined table to use for match lengths.
// Generated from table in RFC 3.1.1.3.2.2.2.
// Checked by TestPredefinedTables.
var predefinedMatchTable = [...]fseBaselineEntry{
	{3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
	{6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
	{19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
	{28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
	{37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
	{59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
	{515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
	{6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
	{10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
	{18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
	{27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
	{35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
	{51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
	{259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
	{5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
	{10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
	{17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
	{26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
	{65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
	{8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
	{1027, 10, 6, 0},
}