// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mldsa

import (
	
	
	
	
	
	
	
)

type parameters struct {
	k, l int // dimensions of A
	η    int // bound for secret coefficients
	γ1   int // log₂(γ₁), where [-γ₁+1, γ₁] is the bound of y
	γ2   int // denominator of γ₂ = (q - 1) / γ2
	λ    int // collison strength
	τ    int // number of non-zero coefficients in challenge
	ω    int // max number of hints in MakeHint
}

var (
	params44 = parameters{k: 4, l: 4, η: 2, γ1: 17, γ2: 88, λ: 128, τ: 39, ω: 80}
	params65 = parameters{k: 6, l: 5, η: 4, γ1: 19, γ2: 32, λ: 192, τ: 49, ω: 55}
	params87 = parameters{k: 8, l: 7, η: 2, γ1: 19, γ2: 32, λ: 256, τ: 60, ω: 75}
)

func pubKeySize( parameters) int {
	// ρ + k × n × 10-bit coefficients of t₁
	return 32 + .k*n*10/8
}

func sigSize( parameters) int {
	// challenge + l × n × (γ₁+1)-bit coefficients of z + hint
	return (.λ / 4) + .l*n*(.γ1+1)/8 + .ω + .k
}

const (
	PrivateKeySize = 32

	PublicKeySize44 = 32 + 4*n*10/8
	PublicKeySize65 = 32 + 6*n*10/8
	PublicKeySize87 = 32 + 8*n*10/8

	SignatureSize44 = 128/4 + 4*n*(17+1)/8 + 80 + 4
	SignatureSize65 = 192/4 + 5*n*(19+1)/8 + 55 + 6
	SignatureSize87 = 256/4 + 7*n*(19+1)/8 + 75 + 8
)

const maxK, maxL, maxλ, maxγ1 = 8, 7, 256, 19
const maxPubKeySize = PublicKeySize87

type PrivateKey struct {
	seed [32]byte
	pub  PublicKey
	s1   [maxL]nttElement
	s2   [maxK]nttElement
	t0   [maxK]nttElement
	k    [32]byte
}

func ( *PrivateKey) ( *PrivateKey) bool {
	return .pub.p == .pub.p && subtle.ConstantTimeCompare(.seed[:], .seed[:]) == 1
}

func ( *PrivateKey) () []byte {
	 := .seed
	return [:]
}

func ( *PrivateKey) () *PublicKey {
	// Note that this is likely to keep the entire PrivateKey reachable for
	// the lifetime of the PublicKey, which may be undesirable.
	return &.pub
}

type PublicKey struct {
	raw [maxPubKeySize]byte
	p   parameters
	a   [maxK * maxL]nttElement
	t1  [maxK]nttElement // NTT(t₁ ⋅ 2ᵈ)
	tr  [64]byte         // public key hash
}

func ( *PublicKey) ( *PublicKey) bool {
	 := pubKeySize(.p)
	return .p == .p && subtle.ConstantTimeCompare(.raw[:], .raw[:]) == 1
}

func ( *PublicKey) () []byte {
	 := pubKeySize(.p)
	return bytes.Clone(.raw[:])
}

func ( *PublicKey) () string {
	switch .p {
	case params44:
		return "ML-DSA-44"
	case params65:
		return "ML-DSA-65"
	case params87:
		return "ML-DSA-87"
	default:
		panic("mldsa: internal error: unknown parameters")
	}
}

func () *PrivateKey {
	fipsSelfTest()
	fips140.RecordApproved()
	var  [32]byte
	drbg.Read([:])
	 := newPrivateKey(&, params44)
	fipsPCT()
	return 
}

func () *PrivateKey {
	fipsSelfTest()
	fips140.RecordApproved()
	var  [32]byte
	drbg.Read([:])
	 := newPrivateKey(&, params65)
	fipsPCT()
	return 
}

func () *PrivateKey {
	fipsSelfTest()
	fips140.RecordApproved()
	var  [32]byte
	drbg.Read([:])
	 := newPrivateKey(&, params87)
	fipsPCT()
	return 
}

var errInvalidSeedLength = errors.New("mldsa: invalid seed length")

func ( []byte) (*PrivateKey, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	if len() != 32 {
		return nil, errInvalidSeedLength
	}
	return newPrivateKey((*[32]byte)(), params44), nil
}

func ( []byte) (*PrivateKey, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	if len() != 32 {
		return nil, errInvalidSeedLength
	}
	return newPrivateKey((*[32]byte)(), params65), nil
}

func ( []byte) (*PrivateKey, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	if len() != 32 {
		return nil, errInvalidSeedLength
	}
	return newPrivateKey((*[32]byte)(), params87), nil
}

func newPrivateKey( *[32]byte,  parameters) *PrivateKey {
	,  := .k, .l

	 := &PrivateKey{pub: PublicKey{p: }}
	.seed = *

	 := sha3.NewShake256()
	.Write([:])
	.Write([]byte{byte(), byte()})
	,  := make([]byte, 32), make([]byte, 64)
	.Read()
	.Read()
	.Read(.k[:])

	 := .pub.a[:*]
	computeMatrixA(, , )

	 := .s1[:]
	for  := range  {
		[] = ntt(sampleBoundedPoly(, byte(), ))
	}
	 := .s2[:]
	for  := range  {
		[] = ntt(sampleBoundedPoly(, byte(+), ))
	}

	// ˆt = Â ∘ ŝ₁ + ŝ₂
	 := make([]nttElement, , maxK)
	for  := range  {
		[] = []
		for  := range  {
			[] = polyAdd([], nttMul([*+], []))
		}
	}
	// t = NTT⁻¹(ˆt)
	 := make([]ringElement, , maxK)
	for  := range  {
		[] = inverseNTT([])
	}
	// (t₁, _) = Power2Round(t)
	// (_, ˆt₀) = NTT(Power2Round(t))
	,  := make([][n]uint16, , maxK), .t0[:]
	for  := range  {
		var  ringElement
		for  := range [] {
			[][], [] = power2Round([][])
		}
		[] = ntt()
	}

	// The computations below (and their storage in the PrivateKey struct) are
	// not strictly necessary and could be deferred to PrivateKey.PublicKey().
	// That would require keeping or re-deriving ρ and t/t1, though.

	 := pkEncode(.pub.raw[:0], , , )
	.pub.tr = computePublicKeyHash()
	computeT1Hat(.pub.t1[:], ) // NTT(t₁ ⋅ 2ᵈ)

	return 
}

func computeMatrixA( []nttElement,  []byte,  parameters) {
	,  := .k, .l
	for  := range  {
		for  := range  {
			[*+] = sampleNTT(, byte(), byte())
		}
	}
}

func computePublicKeyHash( []byte) [64]byte {
	 := sha3.NewShake256()
	.Write()
	var  [64]byte
	.Read([:])
	return 
}

func computeT1Hat( []nttElement,  [][n]uint16) {
	for  := range  {
		var  ringElement
		for  := range [] {
			// t₁ <= 2¹⁰ - 1
			// t₁ ⋅ 2ᵈ <= 2ᵈ(2¹⁰ - 1) = 2²³ - 2¹³ < q = 2²³ - 2¹³ + 1
			,  := fieldToMontgomery(uint32([][]) << 13)
			[] = 
		}
		[] = ntt()
	}
}

func pkEncode( []byte,  []byte,  [][n]uint16,  parameters) []byte {
	 := append(, ...)
	for ,  := range [:.k] {
		// Encode four at a time into 4 * 10 bits = 5 bytes.
		for  := 0;  < n;  += 4 {
			 := []
			 := [+1]
			 := [+2]
			 := [+3]
			 := byte( >> 0)
			 := byte(( >> 8) | ( << 2))
			 := byte(( >> 6) | ( << 4))
			 := byte(( >> 4) | ( << 6))
			 := byte( >> 2)
			 = append(, , , , , )
		}
	}
	return 
}

func pkDecode( []byte,  [][n]uint16,  parameters) ( []byte,  error) {
	if len() != pubKeySize() {
		return nil, errInvalidPublicKeyLength
	}
	,  = [:32], [32:]
	for  := range  {
		// Decode four at a time from 4 * 10 bits = 5 bytes.
		for  := 0;  < n;  += 4 {
			, , , ,  := [0], [1], [2], [3], [4]
			[][+0] = uint16(>>0) | uint16(&0b0000_0011)<<8
			[][+1] = uint16(>>2) | uint16(&0b0000_1111)<<6
			[][+2] = uint16(>>4) | uint16(&0b0011_1111)<<4
			[][+3] = uint16(>>6) | uint16(&0b1111_1111)<<2
			 = [5:]
		}
	}
	return , nil
}

var errInvalidPublicKeyLength = errors.New("mldsa: invalid public key length")

func ( []byte) (*PublicKey, error) {
	return newPublicKey(, params44)
}

func ( []byte) (*PublicKey, error) {
	return newPublicKey(, params65)
}

func ( []byte) (*PublicKey, error) {
	return newPublicKey(, params87)
}

func newPublicKey( []byte,  parameters) (*PublicKey, error) {
	,  := .k, .l

	 := make([][n]uint16, , maxK)
	,  := pkDecode(, , )
	if  != nil {
		return nil, 
	}

	 := &PublicKey{p: }
	copy(.raw[:], )
	computeMatrixA(.a[:*], , )
	.tr = computePublicKeyHash()
	computeT1Hat(.t1[:], ) // NTT(t₁ ⋅ 2ᵈ)

	return , nil
}

var (
	errContextTooLong    = errors.New("mldsa: context too long")
	errMessageHashLength = errors.New("mldsa: invalid message hash length")
	errRandomLength      = errors.New("mldsa: invalid random length")
)

func ( *PrivateKey,  []byte,  string) ([]byte, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	var  [32]byte
	drbg.Read([:])
	,  := computeMessageHash(.pub.tr[:], , )
	if  != nil {
		return nil, 
	}
	return signInternal(, &, &), nil
}

func ( *PrivateKey,  []byte,  string) ([]byte, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	var  [32]byte
	,  := computeMessageHash(.pub.tr[:], , )
	if  != nil {
		return nil, 
	}
	return signInternal(, &, &), nil
}

func ( *PrivateKey,  []byte,  string,  []byte) ([]byte, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	,  := computeMessageHash(.pub.tr[:], , )
	if  != nil {
		return nil, 
	}
	if len() != 32 {
		return nil, errRandomLength
	}
	return signInternal(, &, (*[32]byte)()), nil
}

func ( *PrivateKey,  []byte) ([]byte, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	var  [32]byte
	drbg.Read([:])
	if len() != 64 {
		return nil, errMessageHashLength
	}
	return signInternal(, (*[64]byte)(), &), nil
}

func ( *PrivateKey,  []byte) ([]byte, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	var  [32]byte
	if len() != 64 {
		return nil, errMessageHashLength
	}
	return signInternal(, (*[64]byte)(), &), nil
}

func ( *PrivateKey,  []byte,  []byte) ([]byte, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	if len() != 64 {
		return nil, errMessageHashLength
	}
	if len() != 32 {
		return nil, errRandomLength
	}
	return signInternal(, (*[64]byte)(), (*[32]byte)()), nil
}

func computeMessageHash( []byte,  []byte,  string) ([64]byte, error) {
	if len() > 255 {
		return [64]byte{}, errContextTooLong
	}
	 := sha3.NewShake256()
	.Write()
	.Write([]byte{0}) // ML-DSA / HashML-DSA domain separator
	.Write([]byte{byte(len())})
	.Write([]byte())
	.Write()
	var  [64]byte
	.Read([:])
	return , nil
}

func signInternal( *PrivateKey,  *[64]byte,  *[32]byte) []byte {
	, ,  := .pub.p, .pub.p.k, .pub.p.l
	, , ,  := .pub.a[:*], .s1[:], .s2[:], .t0[:]

	 := .τ * .η
	 := uint32(1 << .γ1)
	 :=  - uint32()
	 := (q - 1) / uint32(.γ2)
	 :=  - uint32()

	 := sha3.NewShake256()
	.Write(.k[:])
	.Write([:])
	.Write([:])
	 := make([]byte, 64)
	.Read()

	 := 0
:
	for {
		// Main rejection sampling loop. Note that leaking rejected signatures
		// leaks information about the private key. However, as explained in
		// https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
		// Section 5.5, we are free to leak rejected ch values, as well as which
		// check causes the rejection and which coefficient failed the check
		// (but not the value or sign of the coefficient).

		 := make([]ringElement, , maxL)
		for  := range  {
			 := make([]byte, 2)
			byteorder.LEPutUint16(, uint16())
			++

			.Reset()
			.Write()
			.Write()
			 := make([]byte, (.γ1+1)*n/8, (maxγ1+1)*n/8)
			.Read()

			[] = bitUnpack(, )
		}

		// w = NTT⁻¹(Â ∘ NTT(y))
		 := make([]nttElement, , maxL)
		for  := range  {
			[] = ntt([])
		}
		 := make([]ringElement, , maxK)
		for  := range  {
			var  nttElement
			for  := range  {
				 = polyAdd(, nttMul([*+], []))
			}
			[] = inverseNTT()
		}

		.Reset()
		.Write([:])
		for  := range  {
			w1Encode(, highBits([], ), )
		}
		 := make([]byte, .λ/4, maxλ/4)
		.Read()

		// sampleInBall is not constant time, but see comment above about
		// leaking rejected ch values being acceptable.
		 := ntt(sampleInBall(, ))

		 := make([]ringElement, , maxL)
		for  := range  {
			[] = inverseNTT(nttMul(, []))
		}
		 := make([]ringElement, , maxK)
		for  := range  {
			[] = inverseNTT(nttMul(, []))
		}

		 := make([]ringElement, , maxL)
		for  := range  {
			[] = polyAdd([], [])

			// Reject if ||z||∞ ≥ γ1 − β
			if coefficientsExceedBound([], ) {
				if testingOnlyRejectionReason != nil {
					testingOnlyRejectionReason("z")
				}
				continue 
			}
		}

		for  := range  {
			 := polySub([], [])

			// Reject if ||LowBits(r0)||∞ ≥ γ2 − β
			if lowBitsExceedBound(, , ) {
				if testingOnlyRejectionReason != nil {
					testingOnlyRejectionReason("r0")
				}
				continue 
			}
		}

		 := make([]ringElement, , maxK)
		for  := range  {
			[] = inverseNTT(nttMul(, []))

			// Reject if ||ct0||∞ ≥ γ2
			if coefficientsExceedBound([], ) {
				if testingOnlyRejectionReason != nil {
					testingOnlyRejectionReason("ct0")
				}
				continue 
			}
		}

		 := 0
		 := make([][n]byte, , maxK)
		for  := range  {
			var  int
			[],  = makeHint([], [], [], )
			 += 
		}
		// Reject if number of hints > ω
		if  > .ω {
			if testingOnlyRejectionReason != nil {
				testingOnlyRejectionReason("h")
			}
			continue 
		}

		return sigEncode(, , , )
	}
}

// testingOnlyRejectionReason is set in tests, to ensure that all rejection
// paths are covered. If not nil, it is called with a string describing the
// reason for rejection: "z", "r0", "ct0", or "h".
var testingOnlyRejectionReason func(reason string)

// w1Encode implements w1Encode from FIPS 204, writing directly into H.
func w1Encode( *sha3.SHAKE,  [n]byte,  parameters) {
	switch .γ2 {
	case 32:
		// Coefficients are <= (q − 1)/(2γ2) − 1 = 15, four bits each.
		 := make([]byte, 4*n/8)
		for  := 0;  < n;  += 2 {
			 := []
			 := [+1]
			[/2] =  | <<4
		}
		.Write()
	case 88:
		// Coefficients are <= (q − 1)/(2γ2) − 1 = 43, six bits each.
		 := make([]byte, 6*n/8)
		for  := 0;  < n;  += 4 {
			 := []
			 := [+1]
			 := [+2]
			 := [+3]
			[3*/4+0] = ( >> 0) | ( << 6)
			[3*/4+1] = ( >> 2) | ( << 4)
			[3*/4+2] = ( >> 4) | ( << 2)
		}
		.Write()
	default:
		panic("mldsa: internal error: unsupported γ2")
	}
}

func coefficientsExceedBound( ringElement,  uint32) bool {
	// If this function appears in profiles, it might be possible to deduplicate
	// the work of fieldFromMontgomery inside fieldInfinityNorm with the
	// subsequent encoding of w.
	for  := range  {
		if fieldInfinityNorm([]) >=  {
			return true
		}
	}
	return false
}

func lowBitsExceedBound( ringElement,  uint32,  parameters) bool {
	switch .γ2 {
	case 32:
		for  := range  {
			,  := decompose32([])
			if constantTimeAbs() >=  {
				return true
			}
		}
	case 88:
		for  := range  {
			,  := decompose88([])
			if constantTimeAbs() >=  {
				return true
			}
		}
	default:
		panic("mldsa: internal error: unsupported γ2")
	}
	return false
}

var (
	errInvalidSignatureLength           = errors.New("mldsa: invalid signature length")
	errInvalidSignatureCoeffBounds      = errors.New("mldsa: invalid signature")
	errInvalidSignatureChallenge        = errors.New("mldsa: invalid signature")
	errInvalidSignatureHintLimits       = errors.New("mldsa: invalid signature encoding")
	errInvalidSignatureHintIndexOrder   = errors.New("mldsa: invalid signature encoding")
	errInvalidSignatureHintExtraIndices = errors.New("mldsa: invalid signature encoding")
)

func ( *PublicKey, ,  []byte,  string) error {
	fipsSelfTest()
	fips140.RecordApproved()
	,  := computeMessageHash(.tr[:], , )
	if  != nil {
		return 
	}
	return verifyInternal(, &, )
}

func ( *PublicKey,  []byte,  []byte) error {
	fipsSelfTest()
	fips140.RecordApproved()
	if len() != 64 {
		return errMessageHashLength
	}
	return verifyInternal(, (*[64]byte)(), )
}

func verifyInternal( *PublicKey,  *[64]byte,  []byte) error {
	, ,  := .p, .p.k, .p.l
	,  := .t1[:], .a[:*]

	 := .τ * .η
	 := uint32(1 << .γ1)
	 :=  - uint32()

	 := make([]ringElement, , maxL)
	 := make([][n]byte, , maxK)
	,  := sigDecode(, , , )
	if  != nil {
		return 
	}

	 := ntt(sampleInBall(, ))

	// w = Â ∘ NTT(z) − NTT(c) ∘ NTT(t₁ ⋅ 2ᵈ)
	 := make([]nttElement, , maxL)
	for  := range  {
		[] = ntt([])
	}
	 := make([]ringElement, , maxK)
	for  := range  {
		var  nttElement
		for  := range  {
			 = polyAdd(, nttMul([*+], []))
		}
		 = polySub(, nttMul(, []))
		[] = inverseNTT()
	}

	// Use hints h to compute w₁ from w(approx).
	 := make([][n]byte, , maxK)
	for  := range  {
		[] = useHint([], [], )
	}

	 := sha3.NewShake256()
	.Write([:])
	for  := range  {
		w1Encode(, [], )
	}
	 := make([]byte, .λ/4, maxλ/4)
	.Read()

	for  := range  {
		if coefficientsExceedBound([], ) {
			return errInvalidSignatureCoeffBounds
		}
	}

	if !bytes.Equal(, ) {
		return errInvalidSignatureChallenge
	}

	return nil
}

func sigEncode( []byte,  []ringElement,  [][n]byte,  parameters) []byte {
	 := make([]byte, 0, sigSize())
	 = append(, ...)
	for  := range  {
		 = bitPack(, [], )
	}
	 = hintEncode(, , )
	return 
}

func sigDecode( []byte,  []ringElement,  [][n]byte,  parameters) ( []byte,  error) {
	if len() != sigSize() {
		return nil, errInvalidSignatureLength
	}
	,  = [:.λ/4], [.λ/4:]
	for  := range  {
		 := (.γ1 + 1) * n / 8
		[] = bitUnpack([:], )
		 = [:]
	}
	if  := hintDecode(, , );  != nil {
		return nil, 
	}
	return , nil
}

func hintEncode( []byte,  [][n]byte,  parameters) []byte {
	,  := .ω, .k
	,  := sliceForAppend(, +)
	var  byte
	for  := range  {
		for  := range n {
			if [][] != 0 {
				[] = byte()
				++
			}
		}
		[+] = 
	}
	return 
}

func hintDecode( []byte,  [][n]byte,  parameters) error {
	,  := .ω, .k
	if len() != + {
		return errors.New("mldsa: internal error: invalid signature hint length")
	}
	var  byte
	for  := range  {
		 := [+]
		if  <  ||  > byte() {
			return errInvalidSignatureHintLimits
		}
		 := 
		for  <  {
			if  >  && [-1] >= [] {
				return errInvalidSignatureHintIndexOrder
			}
			[][[]] = 1
			++
		}
	}
	for  := ;  < byte(); ++ {
		if [] != 0 {
			return errInvalidSignatureHintExtraIndices
		}
	}
	return nil
}