// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ecdh

import (
	
	
	
	
	
	
	
	
)

// PrivateKey and PublicKey are not generic to make it possible to use them
// in other types without instantiating them with a specific point type.
// They are tied to one of the Curve types below through the curveID field.

// All this is duplicated from crypto/internal/fips/ecdsa, but the standards are
// different and FIPS 140 does not allow reusing keys across them.

type PrivateKey struct {
	pub PublicKey
	d   []byte // bigmod.(*Nat).Bytes output (fixed length)
}

func ( *PrivateKey) () []byte {
	return .d
}

func ( *PrivateKey) () *PublicKey {
	return &.pub
}

type PublicKey struct {
	curve curveID
	q     []byte // uncompressed nistec Point.Bytes output
}

func ( *PublicKey) () []byte {
	return .q
}

type curveID string

const (
	p224 curveID = "P-224"
	p256 curveID = "P-256"
	p384 curveID = "P-384"
	p521 curveID = "P-521"
)

type Curve[ Point[]] struct {
	curve    curveID
	newPoint func() 
	N        []byte
}

// Point is a generic constraint for the [nistec] Point types.
type Point[ any] interface {
	*nistec.P224Point | *nistec.P256Point | *nistec.P384Point | *nistec.P521Point
	Bytes() []byte
	BytesX() ([]byte, error)
	SetBytes([]byte) (, error)
	ScalarMult(, []byte) (, error)
	ScalarBaseMult([]byte) (, error)
}

func () *Curve[*nistec.P224Point] {
	return &Curve[*nistec.P224Point]{
		curve:    p224,
		newPoint: nistec.NewP224Point,
		N:        p224Order,
	}
}

var p224Order = []byte{
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x16, 0xa2,
	0xe0, 0xb8, 0xf0, 0x3e, 0x13, 0xdd, 0x29, 0x45,
	0x5c, 0x5c, 0x2a, 0x3d,
}

func () *Curve[*nistec.P256Point] {
	return &Curve[*nistec.P256Point]{
		curve:    p256,
		newPoint: nistec.NewP256Point,
		N:        p256Order,
	}
}

var p256Order = []byte{
	0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
	0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51,
}

func () *Curve[*nistec.P384Point] {
	return &Curve[*nistec.P384Point]{
		curve:    p384,
		newPoint: nistec.NewP384Point,
		N:        p384Order,
	}
}

var p384Order = []byte{
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
	0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
	0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73,
}

func () *Curve[*nistec.P521Point] {
	return &Curve[*nistec.P521Point]{
		curve:    p521,
		newPoint: nistec.NewP521Point,
		N:        p521Order,
	}
}

var p521Order = []byte{0x01, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfa,
	0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f, 0x96, 0x6b,
	0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09, 0xa5, 0xd0,
	0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c, 0x47, 0xae,
	0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38, 0x64, 0x09,
}

// GenerateKey generates a new ECDSA private key pair for the specified curve.
func [ Point[]]( *Curve[],  io.Reader) (*PrivateKey, error) {
	fips140.RecordApproved()
	// This procedure is equivalent to Key Pair Generation by Testing
	// Candidates, specified in NIST SP 800-56A Rev. 3, Section 5.6.1.2.2.

	for {
		 := make([]byte, len(.N))
		if  := drbg.ReadWithReader(, );  != nil {
			return nil, 
		}
		// In tests, rand will return all zeros and NewPrivateKey will reject
		// the zero key as it generates the identity as a public key. This also
		// makes this function consistent with crypto/elliptic.GenerateKey.
		[1] ^= 0x42

		// Mask off any excess bits if the size of the underlying field is not a
		// whole number of bytes, which is only the case for P-521.
		if .curve == p521 && .N[0]&0b1111_1110 == 0 {
			[0] &= 0b0000_0001
		}

		,  := NewPrivateKey(, )
		if  != nil {
			continue
		}
		return , nil
	}
}

func [ Point[]]( *Curve[],  []byte) (*PrivateKey, error) {
	// SP 800-56A Rev. 3, Section 5.6.1.2.2 checks that c <= n – 2 and then
	// returns d = c + 1. Note that it follows that 0 < d < n. Equivalently,
	// we check that 0 < d < n, and return d.
	if len() != len(.N) || isZero() || !isLess(, .N) {
		return nil, errors.New("crypto/ecdh: invalid private key")
	}

	,  := .newPoint().ScalarBaseMult()
	if  != nil {
		// This is unreachable because the only error condition of
		// ScalarBaseMult is if the input is not the right size.
		panic("crypto/ecdh: internal error: nistec ScalarBaseMult failed for a fixed-size input")
	}

	 := .Bytes()
	if len() == 1 {
		// The encoding of the identity is a single 0x00 byte. This is
		// unreachable because the only scalar that generates the identity is
		// zero, which is rejected above.
		panic("crypto/ecdh: internal error: public key is the identity element")
	}

	// A "Pairwise Consistency Test" makes no sense if we just generated the
	// public key from an ephemeral private key. Moreover, there is no way to
	// check it aside from redoing the exact same computation again. SP 800-56A
	// Rev. 3, Section 5.6.2.1.4 acknowledges that, and doesn't require it.
	// However, ISO 19790:2012, Section 7.10.3.3 has a blanket requirement for a
	// PCT for all generated keys (AS10.35) and FIPS 140-3 IG 10.3.A, Additional
	// Comment 1 goes out of its way to say that "the PCT shall be performed
	// consistent [...], even if the underlying standard does not require a
	// PCT". So we do it. And make ECDH nearly 50% slower (only) in FIPS mode.
	if  := fips140.PCT("ECDH PCT", func() error {
		,  := .newPoint().ScalarBaseMult()
		if  != nil {
			return 
		}
		if !bytes.Equal(.Bytes(), ) {
			return errors.New("crypto/ecdh: public key does not match private key")
		}
		return nil
	});  != nil {
		panic()
	}

	 := &PrivateKey{d: bytes.Clone(), pub: PublicKey{curve: .curve, q: }}
	return , nil
}

func [ Point[]]( *Curve[],  []byte) (*PublicKey, error) {
	// Reject the point at infinity and compressed encodings.
	if len() == 0 || [0] != 4 {
		return nil, errors.New("crypto/ecdh: invalid public key")
	}

	// SetBytes checks that x and y are in the interval [0, p - 1], and that
	// the point is on the curve. Along with the rejection of the point at
	// infinity (the identity element) above, this fulfills the requirements
	// of NIST SP 800-56A Rev. 3, Section 5.6.2.3.4.
	if ,  := .newPoint().SetBytes();  != nil {
		return nil, 
	}

	return &PublicKey{curve: .curve, q: bytes.Clone()}, nil
}

func [ Point[]]( *Curve[],  *PrivateKey,  *PublicKey) ([]byte, error) {
	fipsSelfTest()
	fips140.RecordApproved()
	return ecdh(, , )
}

func ecdh[ Point[]]( *Curve[],  *PrivateKey,  *PublicKey) ([]byte, error) {
	if .curve != .pub.curve {
		return nil, errors.New("crypto/ecdh: mismatched curves")
	}
	if .pub.curve != .curve {
		return nil, errors.New("crypto/ecdh: mismatched curves")
	}

	// This applies the Shared Secret Computation of the Ephemeral Unified Model
	// scheme specified in NIST SP 800-56A Rev. 3, Section 6.1.2.2.

	// Per Section 5.6.2.3.4, Step 1, reject the identity element (0x00).
	if len(.pub.q) == 1 {
		return nil, errors.New("crypto/ecdh: public key is the identity element")
	}

	// SetBytes checks that (x, y) are reduced modulo p, and that they are on
	// the curve, performing Steps 2-3 of Section 5.6.2.3.4.
	,  := .newPoint().SetBytes(.q)
	if  != nil {
		return nil, 
	}

	// Compute P according to Section 5.7.1.2.
	if ,  := .ScalarMult(, .d);  != nil {
		return nil, 
	}

	// BytesX checks that the result is not the identity element, and returns the
	// x-coordinate of the result, performing Steps 2-5 of Section 5.7.1.2.
	return .BytesX()
}

// isZero reports whether x is all zeroes in constant time.
func isZero( []byte) bool {
	var  byte
	for ,  := range  {
		 |= 
	}
	return  == 0
}

// isLess reports whether a < b, where a and b are big-endian buffers of the
// same length and shorter than 72 bytes.
func isLess(,  []byte) bool {
	if len() != len() {
		panic("crypto/ecdh: internal error: mismatched isLess inputs")
	}

	// Copy the values into a fixed-size preallocated little-endian buffer.
	// 72 bytes is enough for every scalar in this package, and having a fixed
	// size lets us avoid heap allocations.
	if len() > 72 {
		panic("crypto/ecdh: internal error: isLess input too large")
	}
	,  := make([]byte, 72), make([]byte, 72)
	for  := range  {
		[], [] = [len()--1], [len()--1]
	}

	// Perform a subtraction with borrow.
	var  uint64
	for  := 0;  < len();  += 8 {
		,  := byteorder.LEUint64([:]), byteorder.LEUint64([:])
		_,  = bits.Sub64(, , )
	}

	// If there is a borrow at the end of the operation, then a < b.
	return  == 1
}