// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package hpke

import (
	
	
	
	
	
	
	
	
	
	
)

var mlkem768X25519 = &hybridKEM{
	id: 0x647a,
	label: /**/ `\./` +
		/*   */ `/^\`,
	curve: ecdh.X25519(),

	curveSeedSize:    32,
	curvePointSize:   32,
	pqEncapsKeySize:  mlkem.EncapsulationKeySize768,
	pqCiphertextSize: mlkem.CiphertextSize768,

	pqNewPublicKey: func( []byte) (crypto.Encapsulator, error) {
		return mlkem.NewEncapsulationKey768()
	},
	pqNewPrivateKey: func( []byte) (crypto.Decapsulator, error) {
		return mlkem.NewDecapsulationKey768()
	},
}

// MLKEM768X25519 returns a KEM implementing MLKEM768-X25519 (a.k.a. X-Wing)
// from draft-ietf-hpke-pq.
func () KEM {
	return mlkem768X25519
}

var mlkem768P256 = &hybridKEM{
	id:    0x0050,
	label: "MLKEM768-P256",
	curve: ecdh.P256(),

	curveSeedSize:    32,
	curvePointSize:   65,
	pqEncapsKeySize:  mlkem.EncapsulationKeySize768,
	pqCiphertextSize: mlkem.CiphertextSize768,

	pqNewPublicKey: func( []byte) (crypto.Encapsulator, error) {
		return mlkem.NewEncapsulationKey768()
	},
	pqNewPrivateKey: func( []byte) (crypto.Decapsulator, error) {
		return mlkem.NewDecapsulationKey768()
	},
}

// MLKEM768P256 returns a KEM implementing MLKEM768-P256 from draft-ietf-hpke-pq.
func () KEM {
	return mlkem768P256
}

var mlkem1024P384 = &hybridKEM{
	id:    0x0051,
	label: "MLKEM1024-P384",
	curve: ecdh.P384(),

	curveSeedSize:    48,
	curvePointSize:   97,
	pqEncapsKeySize:  mlkem.EncapsulationKeySize1024,
	pqCiphertextSize: mlkem.CiphertextSize1024,

	pqNewPublicKey: func( []byte) (crypto.Encapsulator, error) {
		return mlkem.NewEncapsulationKey1024()
	},
	pqNewPrivateKey: func( []byte) (crypto.Decapsulator, error) {
		return mlkem.NewDecapsulationKey1024()
	},
}

// MLKEM1024P384 returns a KEM implementing MLKEM1024-P384 from draft-ietf-hpke-pq.
func () KEM {
	return mlkem1024P384
}

type hybridKEM struct {
	id    uint16
	label string
	curve ecdh.Curve

	curveSeedSize    int
	curvePointSize   int
	pqEncapsKeySize  int
	pqCiphertextSize int

	pqNewPublicKey  func(data []byte) (crypto.Encapsulator, error)
	pqNewPrivateKey func(data []byte) (crypto.Decapsulator, error)
}

func ( *hybridKEM) () uint16 {
	return .id
}

func ( *hybridKEM) () int {
	return .pqCiphertextSize + .curvePointSize
}

func ( *hybridKEM) (, , ,  []byte) []byte {
	 := sha3.New256()
	.Write()
	.Write()
	.Write()
	.Write()
	.Write([]byte(.label))
	return .Sum(nil)
}

type hybridPublicKey struct {
	kem *hybridKEM
	t   *ecdh.PublicKey
	pq  crypto.Encapsulator
}

// NewHybridPublicKey returns a PublicKey implementing one of
//
//   - MLKEM768-X25519 (a.k.a. X-Wing)
//   - MLKEM768-P256
//   - MLKEM1024-P384
//
// from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq (either
// *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem public keys. Otherwise, applications should use
// the [KEM.NewPublicKey] method of e.g. [MLKEM768X25519].
func ( crypto.Encapsulator,  *ecdh.PublicKey) (PublicKey, error) {
	switch .Curve() {
	case ecdh.X25519():
		if ,  := .(*mlkem.EncapsulationKey768); ! {
			return nil, errors.New("invalid PQ KEM for X25519 hybrid")
		}
		return &hybridPublicKey{mlkem768X25519, , }, nil
	case ecdh.P256():
		if ,  := .(*mlkem.EncapsulationKey768); ! {
			return nil, errors.New("invalid PQ KEM for P-256 hybrid")
		}
		return &hybridPublicKey{mlkem768P256, , }, nil
	case ecdh.P384():
		if ,  := .(*mlkem.EncapsulationKey1024); ! {
			return nil, errors.New("invalid PQ KEM for P-384 hybrid")
		}
		return &hybridPublicKey{mlkem1024P384, , }, nil
	default:
		return nil, errors.New("unsupported curve")
	}
}

func ( *hybridKEM) ( []byte) (PublicKey, error) {
	if len() != .pqEncapsKeySize+.curvePointSize {
		return nil, errors.New("invalid public key size")
	}
	,  := .pqNewPublicKey([:.pqEncapsKeySize])
	if  != nil {
		return nil, 
	}
	var  *ecdh.PublicKey
	fips140.WithoutEnforcement(func() { // Hybrid of ML-KEM, which is Approved.
		,  = .curve.NewPublicKey([.pqEncapsKeySize:])
	})
	if  != nil {
		return nil, 
	}
	return NewHybridPublicKey(, )
}

func ( *hybridPublicKey) () KEM {
	return .kem
}

func ( *hybridPublicKey) () []byte {
	return append(.pq.Bytes(), .t.Bytes()...)
}

var testingOnlyEncapsulate func() (ss, ct []byte)

func ( *hybridPublicKey) () ( []byte,  []byte,  error) {
	var  *ecdh.PrivateKey
	fips140.WithoutEnforcement(func() { // Hybrid of ML-KEM, which is Approved.
		,  = .t.Curve().GenerateKey(rand.Reader)
	})
	if  != nil {
		return nil, nil, 
	}
	if testingOnlyGenerateKey != nil {
		 = testingOnlyGenerateKey()
	}
	var  []byte
	fips140.WithoutEnforcement(func() {
		,  = .ECDH(.t)
	})
	if  != nil {
		return nil, nil, 
	}
	 := .PublicKey().Bytes()

	,  := .pq.Encapsulate()
	if testingOnlyEncapsulate != nil {
		,  = testingOnlyEncapsulate()
	}

	 := .kem.sharedSecret(, , , .t.Bytes())
	 := append(, ...)
	return , , nil
}

type hybridPrivateKey struct {
	kem  *hybridKEM
	seed []byte // can be nil
	t    ecdh.KeyExchanger
	pq   crypto.Decapsulator
}

// NewHybridPrivateKey returns a PrivateKey implementing
//
//   - MLKEM768-X25519 (a.k.a. X-Wing)
//   - MLKEM768-P256
//   - MLKEM1024-P384
//
// from draft-ietf-hpke-pq, depending on the underlying curve of t
// ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have instantiated
// crypto/ecdh and crypto/mlkem private keys, or another implementation of a
// [ecdh.KeyExchanger] and [crypto.Decapsulator] (e.g. a hardware key).
// Otherwise, applications should use the [KEM.NewPrivateKey] method of e.g.
// [MLKEM768X25519].
func ( crypto.Decapsulator,  ecdh.KeyExchanger) (PrivateKey, error) {
	return newHybridPrivateKey(, , nil)
}

func ( *hybridKEM) () (PrivateKey, error) {
	 := make([]byte, 32)
	drbg.Read()
	return .NewPrivateKey()
}

func ( *hybridKEM) ( []byte) (PrivateKey, error) {
	if len() != 32 {
		return nil, errors.New("hpke: invalid hybrid KEM secret length")
	}

	 := sha3.NewSHAKE256()
	.Write()

	 := make([]byte, mlkem.SeedSize)
	.Read()
	,  := .pqNewPrivateKey()
	if  != nil {
		return nil, 
	}

	 := make([]byte, .curveSeedSize)
	for {
		.Read()
		var  ecdh.KeyExchanger
		fips140.WithoutEnforcement(func() { // Hybrid of ML-KEM, which is Approved.
			,  = .curve.NewPrivateKey()
		})
		if  != nil {
			continue
		}
		return newHybridPrivateKey(, , )
	}
}

func newHybridPrivateKey( crypto.Decapsulator,  ecdh.KeyExchanger,  []byte) (PrivateKey, error) {
	switch .Curve() {
	case ecdh.X25519():
		if ,  := .Encapsulator().(*mlkem.EncapsulationKey768); ! {
			return nil, errors.New("invalid PQ KEM for X25519 hybrid")
		}
		return &hybridPrivateKey{mlkem768X25519, bytes.Clone(), , }, nil
	case ecdh.P256():
		if ,  := .Encapsulator().(*mlkem.EncapsulationKey768); ! {
			return nil, errors.New("invalid PQ KEM for P-256 hybrid")
		}
		return &hybridPrivateKey{mlkem768P256, bytes.Clone(), , }, nil
	case ecdh.P384():
		if ,  := .Encapsulator().(*mlkem.EncapsulationKey1024); ! {
			return nil, errors.New("invalid PQ KEM for P-384 hybrid")
		}
		return &hybridPrivateKey{mlkem1024P384, bytes.Clone(), , }, nil
	default:
		return nil, errors.New("unsupported curve")
	}
}

func ( *hybridKEM) ( []byte) (PrivateKey, error) {
	 := byteorder.BEAppendUint16([]byte("KEM"), .id)
	,  := SHAKE256().labeledDerive(, , "DeriveKeyPair", nil, 32)
	if  != nil {
		return nil, 
	}
	return .NewPrivateKey()
}

func ( *hybridPrivateKey) () KEM {
	return .kem
}

func ( *hybridPrivateKey) () ([]byte, error) {
	if .seed == nil {
		return nil, errors.New("private key seed not available")
	}
	return .seed, nil
}

func ( *hybridPrivateKey) () PublicKey {
	return &hybridPublicKey{
		kem: .kem,
		t:   .t.PublicKey(),
		pq:  .pq.Encapsulator(),
	}
}

func ( *hybridPrivateKey) ( []byte) ([]byte, error) {
	if len() != .kem.pqCiphertextSize+.kem.curvePointSize {
		return nil, errors.New("invalid encapsulated key size")
	}
	,  := [:.kem.pqCiphertextSize], [.kem.pqCiphertextSize:]
	,  := .pq.Decapsulate()
	if  != nil {
		return nil, 
	}
	var  *ecdh.PublicKey
	fips140.WithoutEnforcement(func() { // Hybrid of ML-KEM, which is Approved.
		,  = .t.Curve().NewPublicKey()
	})
	if  != nil {
		return nil, 
	}
	var  []byte
	fips140.WithoutEnforcement(func() {
		,  = .t.ECDH()
	})
	if  != nil {
		return nil, 
	}
	 := .kem.sharedSecret(, , , .t.PublicKey().Bytes())
	return , nil
}

var mlkem768 = &mlkemKEM{
	id:             0x0041,
	ciphertextSize: mlkem.CiphertextSize768,
	newPublicKey: func( []byte) (crypto.Encapsulator, error) {
		return mlkem.NewEncapsulationKey768()
	},
	newPrivateKey: func( []byte) (crypto.Decapsulator, error) {
		return mlkem.NewDecapsulationKey768()
	},
	generateKey: func() (crypto.Decapsulator, error) {
		return mlkem.GenerateKey768()
	},
}

// MLKEM768 returns a KEM implementing ML-KEM-768 from draft-ietf-hpke-pq.
func () KEM {
	return mlkem768
}

var mlkem1024 = &mlkemKEM{
	id:             0x0042,
	ciphertextSize: mlkem.CiphertextSize1024,
	newPublicKey: func( []byte) (crypto.Encapsulator, error) {
		return mlkem.NewEncapsulationKey1024()
	},
	newPrivateKey: func( []byte) (crypto.Decapsulator, error) {
		return mlkem.NewDecapsulationKey1024()
	},
	generateKey: func() (crypto.Decapsulator, error) {
		return mlkem.GenerateKey1024()
	},
}

// MLKEM1024 returns a KEM implementing ML-KEM-1024 from draft-ietf-hpke-pq.
func () KEM {
	return mlkem1024
}

type mlkemKEM struct {
	id             uint16
	ciphertextSize int
	newPublicKey   func(data []byte) (crypto.Encapsulator, error)
	newPrivateKey  func(data []byte) (crypto.Decapsulator, error)
	generateKey    func() (crypto.Decapsulator, error)
}

func ( *mlkemKEM) () uint16 {
	return .id
}

func ( *mlkemKEM) () int {
	return .ciphertextSize
}

type mlkemPublicKey struct {
	kem *mlkemKEM
	pq  crypto.Encapsulator
}

// NewMLKEMPublicKey returns a KEMPublicKey implementing
//
//   - ML-KEM-768
//   - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of pub
// (*[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem public key. Otherwise, applications should use the
// [KEM.NewPublicKey] method of e.g. [MLKEM768].
func ( crypto.Encapsulator) (PublicKey, error) {
	switch .(type) {
	case *mlkem.EncapsulationKey768:
		return &mlkemPublicKey{mlkem768, }, nil
	case *mlkem.EncapsulationKey1024:
		return &mlkemPublicKey{mlkem1024, }, nil
	default:
		return nil, errors.New("unsupported public key type")
	}
}

func ( *mlkemKEM) ( []byte) (PublicKey, error) {
	,  := .newPublicKey()
	if  != nil {
		return nil, 
	}
	return NewMLKEMPublicKey()
}

func ( *mlkemPublicKey) () KEM {
	return .kem
}

func ( *mlkemPublicKey) () []byte {
	return .pq.Bytes()
}

func ( *mlkemPublicKey) () ( []byte,  []byte,  error) {
	,  := .pq.Encapsulate()
	if testingOnlyEncapsulate != nil {
		,  = testingOnlyEncapsulate()
	}
	return , , nil
}

type mlkemPrivateKey struct {
	kem *mlkemKEM
	pq  crypto.Decapsulator
}

// NewMLKEMPrivateKey returns a KEMPrivateKey implementing
//
//   - ML-KEM-768
//   - ML-KEM-1024
//
// from draft-ietf-hpke-pq, depending on the type of priv.Encapsulator()
// (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
//
// This function is meant for applications that already have an instantiated
// crypto/mlkem private key. Otherwise, applications should use the
// [KEM.NewPrivateKey] method of e.g. [MLKEM768].
func ( crypto.Decapsulator) (PrivateKey, error) {
	switch .Encapsulator().(type) {
	case *mlkem.EncapsulationKey768:
		return &mlkemPrivateKey{mlkem768, }, nil
	case *mlkem.EncapsulationKey1024:
		return &mlkemPrivateKey{mlkem1024, }, nil
	default:
		return nil, errors.New("unsupported public key type")
	}
}

func ( *mlkemKEM) () (PrivateKey, error) {
	,  := .generateKey()
	if  != nil {
		return nil, 
	}
	return NewMLKEMPrivateKey()
}

func ( *mlkemKEM) ( []byte) (PrivateKey, error) {
	,  := .newPrivateKey()
	if  != nil {
		return nil, 
	}
	return NewMLKEMPrivateKey()
}

func ( *mlkemKEM) ( []byte) (PrivateKey, error) {
	 := byteorder.BEAppendUint16([]byte("KEM"), .id)
	,  := SHAKE256().labeledDerive(, , "DeriveKeyPair", nil, 64)
	if  != nil {
		return nil, 
	}
	return .NewPrivateKey()
}

func ( *mlkemPrivateKey) () KEM {
	return .kem
}

func ( *mlkemPrivateKey) () ([]byte, error) {
	,  := .pq.(interface {
		() []byte
	})
	if ! {
		return nil, errors.New("private key seed not available")
	}
	return .(), nil
}

func ( *mlkemPrivateKey) () PublicKey {
	return &mlkemPublicKey{
		kem: .kem,
		pq:  .pq.Encapsulator(),
	}
}

func ( *mlkemPrivateKey) ( []byte) ([]byte, error) {
	return .pq.Decapsulate()
}