// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package hpke

import (
	
	
	
	
	
)

// A KEM is a Key Encapsulation Mechanism, one of the three components of an
// HPKE ciphersuite.
type KEM interface {
	// ID returns the HPKE KEM identifier.
	ID() uint16

	// GenerateKey generates a new key pair.
	GenerateKey() (PrivateKey, error)

	// NewPublicKey deserializes a public key from bytes.
	//
	// It implements DeserializePublicKey, as defined in RFC 9180.
	NewPublicKey([]byte) (PublicKey, error)

	// NewPrivateKey deserializes a private key from bytes.
	//
	// It implements DeserializePrivateKey, as defined in RFC 9180.
	NewPrivateKey([]byte) (PrivateKey, error)

	// DeriveKeyPair derives a key pair from the given input keying material.
	//
	// It implements DeriveKeyPair, as defined in RFC 9180.
	DeriveKeyPair(ikm []byte) (PrivateKey, error)

	encSize() int
}

// NewKEM returns the KEM implementation for the given KEM ID.
//
// Applications are encouraged to use specific implementations like [DHKEM] or
// [MLKEM768X25519] instead, unless runtime agility is required.
func ( uint16) (KEM, error) {
	switch  {
	case 0x0010: // DHKEM(P-256, HKDF-SHA256)
		return DHKEM(ecdh.P256()), nil
	case 0x0011: // DHKEM(P-384, HKDF-SHA384)
		return DHKEM(ecdh.P384()), nil
	case 0x0012: // DHKEM(P-521, HKDF-SHA512)
		return DHKEM(ecdh.P521()), nil
	case 0x0020: // DHKEM(X25519, HKDF-SHA256)
		return DHKEM(ecdh.X25519()), nil
	case 0x0041: // ML-KEM-768
		return MLKEM768(), nil
	case 0x0042: // ML-KEM-1024
		return MLKEM1024(), nil
	case 0x647a: // MLKEM768-X25519
		return MLKEM768X25519(), nil
	case 0x0050: // MLKEM768-P256
		return MLKEM768P256(), nil
	case 0x0051: // MLKEM1024-P384
		return MLKEM1024P384(), nil
	default:
		return nil, errors.New("unsupported KEM")
	}
}

// A PublicKey is an instantiation of a KEM (one of the three components of an
// HPKE ciphersuite) with an encapsulation key (i.e. the public key).
//
// A PublicKey is usually obtained from a method of the corresponding [KEM] or
// [PrivateKey], such as [KEM.NewPublicKey] or [PrivateKey.PublicKey].
type PublicKey interface {
	// KEM returns the instantiated KEM.
	KEM() KEM

	// Bytes returns the public key as the output of SerializePublicKey.
	Bytes() []byte

	encap() (sharedSecret, enc []byte, err error)
}

// A PrivateKey is an instantiation of a KEM (one of the three components of
// an HPKE ciphersuite) with a decapsulation key (i.e. the secret key).
//
// A PrivateKey is usually obtained from a method of the corresponding [KEM],
// such as [KEM.GenerateKey] or [KEM.NewPrivateKey].
type PrivateKey interface {
	// KEM returns the instantiated KEM.
	KEM() KEM

	// Bytes returns the private key as the output of SerializePrivateKey, as
	// defined in RFC 9180.
	//
	// Note that for X25519 this might not match the input to NewPrivateKey.
	// This is a requirement of RFC 9180, Section 7.1.2.
	Bytes() ([]byte, error)

	// PublicKey returns the corresponding PublicKey.
	PublicKey() PublicKey

	decap(enc []byte) (sharedSecret []byte, err error)
}

type dhKEM struct {
	kdf     KDF
	id      uint16
	curve   ecdh.Curve
	Nsecret uint16
	Nsk     uint16
	Nenc    int
}

func ( *dhKEM) (,  []byte) ([]byte, error) {
	 := byteorder.BEAppendUint16([]byte("KEM"), .id)
	,  := .kdf.labeledExtract(, nil, "eae_prk", )
	if  != nil {
		return nil, 
	}
	return .kdf.labeledExpand(, , "shared_secret", , .Nsecret)
}

func ( *dhKEM) () uint16 {
	return .id
}

func ( *dhKEM) () int {
	return .Nenc
}

var dhKEMP256 = &dhKEM{HKDFSHA256(), 0x0010, ecdh.P256(), 32, 32, 65}
var dhKEMP384 = &dhKEM{HKDFSHA384(), 0x0011, ecdh.P384(), 48, 48, 97}
var dhKEMP521 = &dhKEM{HKDFSHA512(), 0x0012, ecdh.P521(), 64, 66, 133}
var dhKEMX25519 = &dhKEM{HKDFSHA256(), 0x0020, ecdh.X25519(), 32, 32, 32}

// DHKEM returns a KEM implementing one of
//
//   - DHKEM(P-256, HKDF-SHA256)
//   - DHKEM(P-384, HKDF-SHA384)
//   - DHKEM(P-521, HKDF-SHA512)
//   - DHKEM(X25519, HKDF-SHA256)
//
// depending on curve.
func ( ecdh.Curve) KEM {
	switch  {
	case ecdh.P256():
		return dhKEMP256
	case ecdh.P384():
		return dhKEMP384
	case ecdh.P521():
		return dhKEMP521
	case ecdh.X25519():
		return dhKEMX25519
	default:
		// The set of ecdh.Curve implementations is closed, because the
		// interface has unexported methods. Therefore, this default case is
		// only hit if a new curve is added that DHKEM doesn't support.
		return unsupportedCurveKEM{}
	}
}

type unsupportedCurveKEM struct{}

func (unsupportedCurveKEM) () uint16 {
	return 0
}
func (unsupportedCurveKEM) () (PrivateKey, error) {
	return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) ([]byte) (PublicKey, error) {
	return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) ([]byte) (PrivateKey, error) {
	return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) ([]byte) (PrivateKey, error) {
	return nil, errors.New("unsupported curve")
}
func (unsupportedCurveKEM) () int {
	return 0
}

type dhKEMPublicKey struct {
	kem *dhKEM
	pub *ecdh.PublicKey
}

// NewDHKEMPublicKey returns a PublicKey implementing
//
//   - DHKEM(P-256, HKDF-SHA256)
//   - DHKEM(P-384, HKDF-SHA384)
//   - DHKEM(P-521, HKDF-SHA512)
//   - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of pub ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of [DHKEM].
func ( *ecdh.PublicKey) (PublicKey, error) {
	,  := DHKEM(.Curve()).(*dhKEM)
	if ! {
		return nil, errors.New("unsupported curve")
	}
	return &dhKEMPublicKey{
		kem: ,
		pub: ,
	}, nil
}

func ( *dhKEM) ( []byte) (PublicKey, error) {
	,  := .curve.NewPublicKey()
	if  != nil {
		return nil, 
	}
	return NewDHKEMPublicKey()
}

func ( *dhKEMPublicKey) () KEM {
	return .kem
}

func ( *dhKEMPublicKey) () []byte {
	return .pub.Bytes()
}

// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() *ecdh.PrivateKey

func ( *dhKEMPublicKey) () ( []byte,  []byte,  error) {
	,  := .pub.Curve().GenerateKey(rand.Reader)
	if  != nil {
		return nil, nil, 
	}
	if testingOnlyGenerateKey != nil {
		 = testingOnlyGenerateKey()
	}
	,  := .ECDH(.pub)
	if  != nil {
		return nil, nil, 
	}
	 := .PublicKey().Bytes()

	 := .pub.Bytes()
	 := append(, ...)
	,  = .kem.extractAndExpand(, )
	if  != nil {
		return nil, nil, 
	}
	return , , nil
}

type dhKEMPrivateKey struct {
	kem  *dhKEM
	priv ecdh.KeyExchanger
}

// NewDHKEMPrivateKey returns a PrivateKey implementing
//
//   - DHKEM(P-256, HKDF-SHA256)
//   - DHKEM(P-384, HKDF-SHA384)
//   - DHKEM(P-521, HKDF-SHA512)
//   - DHKEM(X25519, HKDF-SHA256)
//
// depending on the underlying curve of priv ([ecdh.X25519], [ecdh.P256],
// [ecdh.P384], or [ecdh.P521]).
//
// This function is meant for applications that already have an instantiated
// crypto/ecdh private key, or another implementation of a [ecdh.KeyExchanger]
// (e.g. a hardware key). Otherwise, applications should use the
// [KEM.NewPrivateKey] method of [DHKEM].
func ( ecdh.KeyExchanger) (PrivateKey, error) {
	,  := DHKEM(.Curve()).(*dhKEM)
	if ! {
		return nil, errors.New("unsupported curve")
	}
	return &dhKEMPrivateKey{
		kem:  ,
		priv: ,
	}, nil
}

func ( *dhKEM) () (PrivateKey, error) {
	,  := .curve.GenerateKey(rand.Reader)
	if  != nil {
		return nil, 
	}
	return NewDHKEMPrivateKey()
}

func ( *dhKEM) ( []byte) (PrivateKey, error) {
	,  := .curve.NewPrivateKey()
	if  != nil {
		return nil, 
	}
	return NewDHKEMPrivateKey()
}

func ( *dhKEM) ( []byte) (PrivateKey, error) {
	// DeriveKeyPair from RFC 9180 Section 7.1.3.
	 := byteorder.BEAppendUint16([]byte("KEM"), .id)
	,  := .kdf.labeledExtract(, nil, "dkp_prk", )
	if  != nil {
		return nil, 
	}
	if  == dhKEMX25519 {
		,  := .kdf.labeledExpand(, , "sk", nil, .Nsk)
		if  != nil {
			return nil, 
		}
		return .NewPrivateKey()
	}
	var  uint8
	for  < 4 {
		,  := .kdf.labeledExpand(, , "candidate", []byte{}, .Nsk)
		if  != nil {
			return nil, 
		}
		if  == dhKEMP521 {
			[0] &= 0x01
		}
		,  := .NewPrivateKey()
		if  != nil {
			++
			continue
		}
		return , nil
	}
	panic("chance of four rejections is < 2^-128")
}

func ( *dhKEMPrivateKey) () KEM {
	return .kem
}

func ( *dhKEMPrivateKey) () ([]byte, error) {
	// Bizarrely, RFC 9180, Section 7.1.2 says SerializePrivateKey MUST clamp
	// the output, which I thought we all agreed to instead do as part of the DH
	// function, letting private keys be random bytes.
	//
	// At the same time, it says DeserializePrivateKey MUST also clamp, implying
	// that the input doesn't have to be clamped, so Bytes by spec doesn't
	// necessarily match the NewPrivateKey input.
	//
	// I'm sure this will not lead to any unexpected behavior or interop issue.
	,  := .priv.(*ecdh.PrivateKey)
	if ! {
		return nil, errors.New("ecdh: private key does not support Bytes")
	}
	if .kem == dhKEMX25519 {
		 := .Bytes()
		[0] &= 248
		[31] &= 127
		[31] |= 64
		return , nil
	}
	return .Bytes(), nil
}

func ( *dhKEMPrivateKey) () PublicKey {
	return &dhKEMPublicKey{
		kem: .kem,
		pub: .priv.PublicKey(),
	}
}

func ( *dhKEMPrivateKey) ( []byte) ([]byte, error) {
	,  := .priv.Curve().NewPublicKey()
	if  != nil {
		return nil, 
	}
	,  := .priv.ECDH()
	if  != nil {
		return nil, 
	}
	 := append(slices.Clip(), .priv.PublicKey().Bytes()...)
	return .kem.extractAndExpand(, )
}