// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ecdsa

import (
	
	
	
	
	
	
	
	
)

// PrivateKey and PublicKey are not generic to make it possible to use them
// in other types without instantiating them with a specific point type.
// They are tied to one of the Curve types below through the curveID field.

type PrivateKey struct {
	pub PublicKey
	d   []byte // bigmod.(*Nat).Bytes output (fixed length)
}

func ( *PrivateKey) () []byte {
	return .d
}

func ( *PrivateKey) () *PublicKey {
	return &.pub
}

type PublicKey struct {
	curve curveID
	q     []byte // uncompressed nistec Point.Bytes output
}

func ( *PublicKey) () []byte {
	return .q
}

type curveID string

const (
	p224 curveID = "P-224"
	p256 curveID = "P-256"
	p384 curveID = "P-384"
	p521 curveID = "P-521"
)

type Curve[ Point[]] struct {
	curve      curveID
	newPoint   func() 
	ordInverse func([]byte) ([]byte, error)
	N          *bigmod.Modulus
	nMinus2    []byte
}

// Point is a generic constraint for the [nistec] Point types.
type Point[ any] interface {
	*nistec.P224Point | *nistec.P256Point | *nistec.P384Point | *nistec.P521Point
	Bytes() []byte
	BytesX() ([]byte, error)
	SetBytes([]byte) (, error)
	ScalarMult(, []byte) (, error)
	ScalarBaseMult([]byte) (, error)
	Add(p1, p2 ) 
}

func precomputeParams[ Point[]]( *Curve[],  []byte) {
	var  error
	.N,  = bigmod.NewModulus()
	if  != nil {
		panic()
	}
	,  := bigmod.NewNat().SetBytes([]byte{2}, .N)
	.nMinus2 = bigmod.NewNat().ExpandFor(.N).Sub(, .N).Bytes(.N)
}

func () *Curve[*nistec.P224Point] { return _P224() }

var _P224 = sync.OnceValue(func() *Curve[*nistec.P224Point] {
	 := &Curve[*nistec.P224Point]{
		curve:    p224,
		newPoint: nistec.NewP224Point,
	}
	precomputeParams(, p224Order)
	return 
})

var p224Order = []byte{
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x16, 0xa2,
	0xe0, 0xb8, 0xf0, 0x3e, 0x13, 0xdd, 0x29, 0x45,
	0x5c, 0x5c, 0x2a, 0x3d,
}

func () *Curve[*nistec.P256Point] { return _P256() }

var _P256 = sync.OnceValue(func() *Curve[*nistec.P256Point] {
	 := &Curve[*nistec.P256Point]{
		curve:      p256,
		newPoint:   nistec.NewP256Point,
		ordInverse: nistec.P256OrdInverse,
	}
	precomputeParams(, p256Order)
	return 
})

var p256Order = []byte{
	0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
	0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51}

func () *Curve[*nistec.P384Point] { return _P384() }

var _P384 = sync.OnceValue(func() *Curve[*nistec.P384Point] {
	 := &Curve[*nistec.P384Point]{
		curve:    p384,
		newPoint: nistec.NewP384Point,
	}
	precomputeParams(, p384Order)
	return 
})

var p384Order = []byte{
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
	0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
	0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73}

func () *Curve[*nistec.P521Point] { return _P521() }

var _P521 = sync.OnceValue(func() *Curve[*nistec.P521Point] {
	 := &Curve[*nistec.P521Point]{
		curve:    p521,
		newPoint: nistec.NewP521Point,
	}
	precomputeParams(, p521Order)
	return 
})

var p521Order = []byte{0x01, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfa,
	0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f, 0x96, 0x6b,
	0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09, 0xa5, 0xd0,
	0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c, 0x47, 0xae,
	0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38, 0x64, 0x09}

func [ Point[]]( *Curve[], ,  []byte) (*PrivateKey, error) {
	fips140.RecordApproved()
	,  := NewPublicKey(, )
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewNat().SetBytes(, .N)
	if  != nil {
		return nil, 
	}
	 := &PrivateKey{pub: *, d: .Bytes(.N)}
	if  := fipsPCT(, );  != nil {
		// This can happen if the application went out of its way to make an
		// ecdsa.PrivateKey with a mismatching PublicKey.
		return nil, 
	}
	return , nil
}

func [ Point[]]( *Curve[],  []byte) (*PublicKey, error) {
	// SetBytes checks that Q is a valid point on the curve, and that its
	// coordinates are reduced modulo p, fulfilling the requirements of SP
	// 800-89, Section 5.3.2.
	,  := .newPoint().SetBytes()
	if  != nil {
		return nil, 
	}
	return &PublicKey{curve: .curve, q: }, nil
}

// GenerateKey generates a new ECDSA private key pair for the specified curve.
func [ Point[]]( *Curve[],  io.Reader) (*PrivateKey, error) {
	fips140.RecordApproved()

	, ,  := randomPoint(, func( []byte) error {
		return drbg.ReadWithReader(, )
	})
	if  != nil {
		return nil, 
	}

	 := &PrivateKey{
		pub: PublicKey{
			curve: .curve,
			q:     .Bytes(),
		},
		d: .Bytes(.N),
	}
	if  := fipsPCT(, );  != nil {
		// This clearly can't happen, but FIPS 140-3 mandates that we check it.
		panic()
	}
	return , nil
}

// randomPoint returns a random scalar and the corresponding point using a
// procedure equivalent to FIPS 186-5, Appendix A.2.2 (ECDSA Key Pair Generation
// by Rejection Sampling) and to Appendix A.3.2 (Per-Message Secret Number
// Generation of Private Keys by Rejection Sampling) or Appendix A.3.3
// (Per-Message Secret Number Generation for Deterministic ECDSA) followed by
// Step 5 of Section 6.4.1.
func randomPoint[ Point[]]( *Curve[],  func([]byte) error) ( *bigmod.Nat,  ,  error) {
	for {
		 := make([]byte, .N.Size())
		if  := ();  != nil {
			return nil, nil, 
		}

		// Take only the leftmost bits of the generated random value. This is
		// both necessary to increase the chance of the random value being in
		// the correct range and to match the specification. It's unfortunate
		// that we need to do a shift instead of a mask, but see the comment on
		// rightShift.
		//
		// These are the most dangerous lines in the package and maybe in the
		// library: a single bit of bias in the selection of nonces would likely
		// lead to key recovery, but no tests would fail. Look but DO NOT TOUCH.
		if  := len()*8 - .N.BitLen();  > 0 {
			// Just to be safe, assert that this only happens for the one curve that
			// doesn't have a round number of bits.
			if .curve != p521 {
				panic("ecdsa: internal error: unexpectedly masking off bits")
			}
			 = rightShift(, )
		}

		// FIPS 186-5, Appendix A.4.2 makes us check x <= N - 2 and then return
		// x + 1. Note that it follows that 0 < x + 1 < N. Instead, SetBytes
		// checks that k < N, and we explicitly check 0 != k. Since k can't be
		// negative, this is strictly equivalent. None of this matters anyway
		// because the chance of selecting zero is cryptographically negligible.
		if ,  := bigmod.NewNat().SetBytes(, .N);  == nil && .IsZero() == 0 {
			,  := .newPoint().ScalarBaseMult(.Bytes(.N))
			return , , 
		}

		if testingOnlyRejectionSamplingLooped != nil {
			testingOnlyRejectionSamplingLooped()
		}
	}
}

// testingOnlyRejectionSamplingLooped is called when rejection sampling in
// randomPoint rejects a candidate for being higher than the modulus.
var testingOnlyRejectionSamplingLooped func()

// Signature is an ECDSA signature, where r and s are represented as big-endian
// fixed-length byte slices.
type Signature struct {
	R, S []byte
}

// Sign signs a hash (which shall be the result of hashing a larger message with
// the hash function H) using the private key, priv. If the hash is longer than
// the bit-length of the private key's curve order, the hash will be truncated
// to that length.
func [ Point[],  fips140.Hash]( *Curve[],  func() ,  *PrivateKey,  io.Reader,  []byte) (*Signature, error) {
	if .pub.curve != .curve {
		return nil, errors.New("ecdsa: private key does not match curve")
	}
	fips140.RecordApproved()
	fipsSelfTest()

	// Random ECDSA is dangerous, because a failure of the RNG would immediately
	// leak the private key. Instead, we use a "hedged" approach, as specified
	// in draft-irtf-cfrg-det-sigs-with-noise-04, Section 4. This has also the
	// advantage of closely resembling Deterministic ECDSA.

	 := make([]byte, len(.d))
	if  := drbg.ReadWithReader(, );  != nil {
		return nil, 
	}

	// See https://github.com/cfrg/draft-irtf-cfrg-det-sigs-with-noise/issues/6
	// for the FIPS compliance of this method. In short Z is entropy from the
	// main DRBG, of length 3/2 of security_strength, so the nonce is optional
	// per SP 800-90Ar1, Section 8.6.7, and the rest is a personalization
	// string, which per SP 800-90Ar1, Section 8.7.1 may contain secret
	// information.
	 := newDRBG(, , nil, blockAlignedPersonalizationString{.d, bits2octets(, )})

	return sign(, , , )
}

// SignDeterministic signs a hash (which shall be the result of hashing a
// larger message with the hash function H) using the private key, priv. If the
// hash is longer than the bit-length of the private key's curve order, the hash
// will be truncated to that length. This applies Deterministic ECDSA as
// specified in FIPS 186-5 and RFC 6979.
func [ Point[],  fips140.Hash]( *Curve[],  func() ,  *PrivateKey,  []byte) (*Signature, error) {
	if .pub.curve != .curve {
		return nil, errors.New("ecdsa: private key does not match curve")
	}
	fips140.RecordApproved()
	fipsSelfTestDeterministic()
	 := newDRBG(, .d, bits2octets(, ), nil) // RFC 6979, Section 3.3
	return sign(, , , )
}

// bits2octets as specified in FIPS 186-5, Appendix B.2.4 or RFC 6979,
// Section 2.3.4. See RFC 6979, Section 3.5 for the rationale.
func bits2octets[ Point[]]( *Curve[],  []byte) []byte {
	 := bigmod.NewNat()
	hashToNat(, , )
	return .Bytes(.N)
}

func signGeneric[ Point[]]( *Curve[],  *PrivateKey,  *hmacDRBG,  []byte) (*Signature, error) {
	// FIPS 186-5, Section 6.4.1

	, ,  := randomPoint(, func( []byte) error {
		.Generate()
		return nil
	})
	if  != nil {
		return nil, 
	}

	// kInv = k⁻¹
	 := bigmod.NewNat()
	inverse(, , )

	,  := .BytesX()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewNat().SetOverflowingBytes(, .N)
	if  != nil {
		return nil, 
	}

	// The spec wants us to retry here, but the chance of hitting this condition
	// on a large prime-order group like the NIST curves we support is
	// cryptographically negligible. If we hit it, something is awfully wrong.
	if .IsZero() == 1 {
		return nil, errors.New("ecdsa: internal error: r is zero")
	}

	 := bigmod.NewNat()
	hashToNat(, , )

	,  := bigmod.NewNat().SetBytes(.d, .N)
	if  != nil {
		return nil, 
	}
	.Mul(, .N)
	.Add(, .N)
	.Mul(, .N)

	// Again, the chance of this happening is cryptographically negligible.
	if .IsZero() == 1 {
		return nil, errors.New("ecdsa: internal error: s is zero")
	}

	return &Signature{.Bytes(.N), .Bytes(.N)}, nil
}

// inverse sets kInv to the inverse of k modulo the order of the curve.
func inverse[ Point[]]( *Curve[], ,  *bigmod.Nat) {
	if .ordInverse != nil {
		,  := .ordInverse(.Bytes(.N))
		// Some platforms don't implement ordInverse, and always return an error.
		if  == nil {
			,  := .SetBytes(, .N)
			if  != nil {
				panic("ecdsa: internal error: ordInverse produced an invalid value")
			}
			return
		}
	}

	// Calculate the inverse of s in GF(N) using Fermat's method
	// (exponentiation modulo P - 2, per Euler's theorem)
	.Exp(, .nMinus2, .N)
}

// hashToNat sets e to the left-most bits of hash, according to
// FIPS 186-5, Section 6.4.1, point 2 and Section 6.4.2, point 3.
func hashToNat[ Point[]]( *Curve[],  *bigmod.Nat,  []byte) {
	// ECDSA asks us to take the left-most log2(N) bits of hash, and use them as
	// an integer modulo N. This is the absolute worst of all worlds: we still
	// have to reduce, because the result might still overflow N, but to take
	// the left-most bits for P-521 we have to do a right shift.
	if  := .N.Size(); len() >=  {
		 = [:]
		if  := len()*8 - .N.BitLen();  > 0 {
			 = rightShift(, )
		}
	}
	,  := .SetOverflowingBytes(, .N)
	if  != nil {
		panic("ecdsa: internal error: truncated hash is too long")
	}
}

// rightShift implements the right shift necessary for bits2int, which takes the
// leftmost bits of either the hash or HMAC_DRBG output.
//
// Note how taking the rightmost bits would have been as easy as masking the
// first byte, but we can't have nice things.
func rightShift( []byte,  int) []byte {
	if  <= 0 ||  >= 8 {
		panic("ecdsa: internal error: shift can only be by 1 to 7 bits")
	}
	 = bytes.Clone()
	for  := len() - 1;  >= 0; -- {
		[] >>= 
		if  > 0 {
			[] |= [-1] << (8 - )
		}
	}
	return 
}

// Verify verifies the signature, sig, of hash (which should be the result of
// hashing a larger message) using the public key, pub. If the hash is longer
// than the bit-length of the private key's curve order, the hash will be
// truncated to that length.
//
// The inputs are not considered confidential, and may leak through timing side
// channels, or if an attacker has control of part of the inputs.
func [ Point[]]( *Curve[],  *PublicKey,  []byte,  *Signature) error {
	if .curve != .curve {
		return errors.New("ecdsa: public key does not match curve")
	}
	fips140.RecordApproved()
	fipsSelfTest()
	return verify(, , , )
}

func verifyGeneric[ Point[]]( *Curve[],  *PublicKey,  []byte,  *Signature) error {
	// FIPS 186-5, Section 6.4.2

	,  := .newPoint().SetBytes(.q)
	if  != nil {
		return 
	}

	,  := bigmod.NewNat().SetBytes(.R, .N)
	if  != nil {
		return 
	}
	if .IsZero() == 1 {
		return errors.New("ecdsa: invalid signature: r is zero")
	}
	,  := bigmod.NewNat().SetBytes(.S, .N)
	if  != nil {
		return 
	}
	if .IsZero() == 1 {
		return errors.New("ecdsa: invalid signature: s is zero")
	}

	 := bigmod.NewNat()
	hashToNat(, , )

	// w = s⁻¹
	 := bigmod.NewNat()
	inverse(, , )

	// p₁ = [e * s⁻¹]G
	,  := .newPoint().ScalarBaseMult(.Mul(, .N).Bytes(.N))
	if  != nil {
		return 
	}
	// p₂ = [r * s⁻¹]Q
	,  := .ScalarMult(, .Mul(, .N).Bytes(.N))
	if  != nil {
		return 
	}
	// BytesX returns an error for the point at infinity.
	,  := .Add(, ).BytesX()
	if  != nil {
		return 
	}

	,  := bigmod.NewNat().SetOverflowingBytes(, .N)
	if  != nil {
		return 
	}

	if .Equal() != 1 {
		return errors.New("ecdsa: signature did not verify")
	}
	return nil
}