// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package mldsa

import (
	
	
	
)

// FIPS 204 defines a needless semi-expanded format for private keys. This is
// not a good format for key storage and exchange, because it is large and
// requires careful parsing to reject malformed keys. Seeds instead are just 32
// bytes, are always valid, and always expand to valid keys in memory. It is
// *also* a poor in-memory format, because it defers computing the NTT of s1,
// s2, and t0 and the expansion of A until signing time, which is inefficient.
// For a hot second, it looked like we could have all agreed to only use seeds,
// but unfortunately OpenSSL and BouncyCastle lobbied hard against that during
// the WGLC of the LAMPS IETF working group. Also, ACVP tests provide and expect
// semi-expanded keys, so we implement them here for testing purposes.

func semiExpandedPrivKeySize( parameters) int {
	,  := .k, .l
	 := bits.Len(uint(.η)) + 1
	// ρ + K + tr + l × n × η-bit coefficients of s₁ +
	// k × n × η-bit coefficients of s₂ + k × n × 13-bit coefficients of t₀
	return 32 + 32 + 64 + *n*/8 + *n*/8 + *n*13/8
}

// TestingOnlyNewPrivateKeyFromSemiExpanded creates a PrivateKey from a
// semi-expanded private key encoding, for testing purposes. It rejects
// inconsistent keys.
//
// [PrivateKey.Bytes] must NOT be called on the resulting key, as it will
// produce a random value.
func ( []byte) (*PrivateKey, error) {
	var  parameters
	switch len() {
	case semiExpandedPrivKeySize(params44):
		 = params44
	case semiExpandedPrivKeySize(params65):
		 = params65
	case semiExpandedPrivKeySize(params87):
		 = params87
	default:
		return nil, errors.New("mldsa: invalid semi-expanded private key size")
	}
	,  := .k, .l

	, , , , , ,  := skDecode(, )
	if  != nil {
		return nil, 
	}

	 := &PrivateKey{pub: PublicKey{p: }}
	.k = 
	.pub.tr = 
	 := .pub.a[:*]
	computeMatrixA(, [:], )
	for  := range  {
		.s1[] = ntt([])
	}
	for  := range  {
		.s2[] = ntt([])
	}
	for  := range  {
		.t0[] = ntt([])
	}

	// We need to put something in priv.seed, and putting random bytes feels
	// safer than putting anything predictable.
	drbg.Read(.seed[:])

	// Making this format *even more* annoying, we need to recompute t1 from ρ,
	// s1, and s2 if we want to generate the public key. This is essentially as
	// much work as regenerating everything from seed.
	//
	// You might also notice that the semi-expanded format also stores t0 and a
	// hash of the public key, though. How are we supposed to check they are
	// consistent without regenerating the public key? Do we even need to check?
	// Who knows! FIPS 204 says
	//
	//  > Note that there exist malformed inputs that can cause skDecode to
	//  > return values that are not in the correct range. Hence, skDecode
	//  > should only be run on inputs that come from trusted sources.
	//
	// so it sounds like it doesn't even want us to check the coefficients are
	// within bounds, but especially if using this format for key exchange, that
	// sounds like a bad idea. So we check everything.

	 := make([][n]uint16, , maxK)
	for  := range  {
		 := .s2[]
		for  := range  {
			 = polyAdd(, nttMul([*+], .s1[]))
		}
		 := inverseNTT()
		for  := range n {
			,  := power2Round([])
			[][] = 
			if  != [][] {
				return nil, errors.New("mldsa: semi-expanded private key inconsistent with t0")
			}
		}
	}

	 := pkEncode(.pub.raw[:0], [:], , )
	if computePublicKeyHash() !=  {
		return nil, errors.New("mldsa: semi-expanded private key inconsistent with public key hash")
	}
	computeT1Hat(.pub.t1[:], ) // NTT(t₁ ⋅ 2ᵈ)

	return , nil
}

func ( *PrivateKey) []byte {
	, ,  := .pub.p.k, .pub.p.l, .pub.p.η
	 := make([]byte, 0, semiExpandedPrivKeySize(.pub.p))
	 = append(, .pub.raw[:32]...) // ρ
	 = append(, .k[:]...)         // K
	 = append(, .pub.tr[:]...)    // tr
	for  := range  {
		 = bitPackSlow(, inverseNTT(.s1[]), , )
	}
	for  := range  {
		 = bitPackSlow(, inverseNTT(.s2[]), , )
	}
	const  = 1 << (13 - 1) // 2^(d-1)
	for  := range  {
		 = bitPackSlow(, inverseNTT(.t0[]), -1, )
	}
	return 
}

func skDecode( []byte,  parameters) (,  [32]byte,  [64]byte, , ,  []ringElement,  error) {
	, ,  := .k, .l, .η
	if len() != semiExpandedPrivKeySize() {
		 = errors.New("mldsa: invalid semi-expanded private key size")
		return
	}
	copy([:], [:32])
	 = [32:]
	copy([:], [:32])
	 = [32:]
	copy([:], [:64])
	 = [64:]

	 = make([]ringElement, )
	for  := range  {
		 := n * bits.Len(uint()*2) / 8
		[],  = bitUnpackSlow([:], , )
		if  != nil {
			return
		}
		 = [:]
	}

	 = make([]ringElement, )
	for  := range  {
		 := n * bits.Len(uint()*2) / 8
		[],  = bitUnpackSlow([:], , )
		if  != nil {
			return
		}
		 = [:]
	}

	const  = 1 << (13 - 1) // 2^(d-1)
	 = make([]ringElement, )
	for  := range  {
		 := n * 13 / 8
		[],  = bitUnpackSlow([:], -1, )
		if  != nil {
			return
		}
		 = [:]
	}

	return
}

func bitPackSlow( []byte,  ringElement, ,  int) []byte {
	 := bits.Len(uint( + ))
	if  <= 0 ||  > 16 {
		panic("mldsa: internal error: invalid bitlen")
	}
	,  := sliceForAppend(, n*/8)
	var  uint32
	var  uint
	for  := range  {
		 := int32() - fieldCenteredMod([])
		 |= uint32() << 
		 += uint()
		for  >= 8 {
			[0] = byte()
			 = [1:]
			 >>= 8
			 -= 8
		}
	}
	if  > 0 {
		[0] = byte()
	}
	return 
}

func bitUnpackSlow( []byte, ,  int) (ringElement, error) {
	 := bits.Len(uint( + ))
	if  <= 0 ||  > 16 {
		panic("mldsa: internal error: invalid bitlen")
	}
	if len() != n*/8 {
		return ringElement{}, errors.New("mldsa: invalid input length for bitUnpackSlow")
	}

	 := uint32((1 << ) - 1)
	 := uint32( + )

	var  ringElement
	var  uint32
	var  uint
	 := 0

	for  := range  {
		for  < uint() {
			if  < len() {
				 |= uint32([]) << 
				++
				 += 8
			}
		}
		 :=  & 
		if  >  {
			return ringElement{}, errors.New("mldsa: coefficient out of range")
		}
		[] = fieldSubToMontgomery(uint32(), )
		 >>= 
		 -= uint()
	}

	return , nil
}