// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package hpke

import (
	
	
	
	
	
	
	
	
	

	
)

// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)

type hkdfKDF struct {
	hash crypto.Hash
}

func ( *hkdfKDF) ( []byte,  []byte,  string,  []byte) []byte {
	 := make([]byte, 0, 7+len()+len()+len())
	 = append(, []byte("HPKE-v1")...)
	 = append(, ...)
	 = append(, ...)
	 = append(, ...)
	return hkdf.Extract(.hash.New, , )
}

func ( *hkdfKDF) ( []byte,  []byte,  string,  []byte,  uint16) []byte {
	 := make([]byte, 0, 2+7+len()+len()+len())
	 = byteorder.BEAppendUint16(, )
	 = append(, []byte("HPKE-v1")...)
	 = append(, ...)
	 = append(, ...)
	 = append(, ...)
	return hkdf.Expand(.hash.New, , string(), int())
}

// dhKEM implements the KEM specified in RFC 9180, Section 4.1.
type dhKEM struct {
	dh  ecdh.Curve
	kdf hkdfKDF

	suiteID []byte
	nSecret uint16
}

type KemID uint16

const DHKEM_X25519_HKDF_SHA256 = 0x0020

var SupportedKEMs = map[uint16]struct {
	curve   ecdh.Curve
	hash    crypto.Hash
	nSecret uint16
}{
	// RFC 9180 Section 7.1
	DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
}

func newDHKem( uint16) (*dhKEM, error) {
	,  := SupportedKEMs[]
	if ! {
		return nil, errors.New("unsupported suite ID")
	}
	return &dhKEM{
		dh:      .curve,
		kdf:     hkdfKDF{.hash},
		suiteID: byteorder.BEAppendUint16([]byte("KEM"), ),
		nSecret: .nSecret,
	}, nil
}

func ( *dhKEM) (,  []byte) []byte {
	 := .kdf.LabeledExtract(.suiteID[:], nil, "eae_prk", )
	return .kdf.LabeledExpand(.suiteID[:], , "shared_secret", , .nSecret)
}

func ( *dhKEM) ( *ecdh.PublicKey) ( []byte,  []byte,  error) {
	var  *ecdh.PrivateKey
	if testingOnlyGenerateKey != nil {
		,  = testingOnlyGenerateKey()
	} else {
		,  = .dh.GenerateKey(rand.Reader)
	}
	if  != nil {
		return nil, nil, 
	}
	,  := .ECDH()
	if  != nil {
		return nil, nil, 
	}
	 := .PublicKey().Bytes()

	 := .Bytes()
	 := append(, ...)

	return .ExtractAndExpand(, ), , nil
}

func ( *dhKEM) ( []byte,  *ecdh.PrivateKey) ([]byte, error) {
	,  := .dh.NewPublicKey()
	if  != nil {
		return nil, 
	}
	,  := .ECDH()
	if  != nil {
		return nil, 
	}
	 := append(, .PublicKey().Bytes()...)

	return .ExtractAndExpand(, ), nil
}

type context struct {
	aead cipher.AEAD

	sharedSecret []byte

	suiteID []byte

	key            []byte
	baseNonce      []byte
	exporterSecret []byte

	seqNum uint128
}

type Sender struct {
	*context
}

type Receipient struct {
	*context
}

var aesGCMNew = func( []byte) (cipher.AEAD, error) {
	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}
	return cipher.NewGCM()
}

type AEADID uint16

const (
	AEAD_AES_128_GCM      = 0x0001
	AEAD_AES_256_GCM      = 0x0002
	AEAD_ChaCha20Poly1305 = 0x0003
)

var SupportedAEADs = map[uint16]struct {
	keySize   int
	nonceSize int
	aead      func([]byte) (cipher.AEAD, error)
}{
	// RFC 9180, Section 7.3
	AEAD_AES_128_GCM:      {keySize: 16, nonceSize: 12, aead: aesGCMNew},
	AEAD_AES_256_GCM:      {keySize: 32, nonceSize: 12, aead: aesGCMNew},
	AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
}

type KDFID uint16

const KDF_HKDF_SHA256 = 0x0001

var SupportedKDFs = map[uint16]func() *hkdfKDF{
	// RFC 9180, Section 7.2
	KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
}

func newContext( []byte, , ,  uint16,  []byte) (*context, error) {
	 := suiteID(, , )

	,  := SupportedKDFs[]
	if ! {
		return nil, errors.New("unsupported KDF id")
	}
	 := ()

	,  := SupportedAEADs[]
	if ! {
		return nil, errors.New("unsupported AEAD id")
	}

	 := .LabeledExtract(, nil, "psk_id_hash", nil)
	 := .LabeledExtract(, nil, "info_hash", )
	 := append([]byte{0}, ...)
	 = append(, ...)

	 := .LabeledExtract(, , "secret", nil)

	 := .LabeledExpand(, , "key", , uint16(.keySize) /* Nk - key size for AEAD */)
	 := .LabeledExpand(, , "base_nonce", , uint16(.nonceSize) /* Nn - nonce size for AEAD */)
	 := .LabeledExpand(, , "exp", , uint16(.hash.Size()) /* Nh - hash output size of the kdf*/)

	,  := .aead()
	if  != nil {
		return nil, 
	}

	return &context{
		aead:           ,
		sharedSecret:   ,
		suiteID:        ,
		key:            ,
		baseNonce:      ,
		exporterSecret: ,
	}, nil
}

func (, ,  uint16,  *ecdh.PublicKey,  []byte) ([]byte, *Sender, error) {
	,  := newDHKem()
	if  != nil {
		return nil, nil, 
	}
	, ,  := .Encap()
	if  != nil {
		return nil, nil, 
	}

	,  := newContext(, , , , )
	if  != nil {
		return nil, nil, 
	}

	return , &Sender{}, nil
}

func (, ,  uint16,  *ecdh.PrivateKey, ,  []byte) (*Receipient, error) {
	,  := newDHKem()
	if  != nil {
		return nil, 
	}
	,  := .Decap(, )
	if  != nil {
		return nil, 
	}

	,  := newContext(, , , , )
	if  != nil {
		return nil, 
	}

	return &Receipient{}, nil
}

func ( *context) () []byte {
	 := .seqNum.bytes()[16-.aead.NonceSize():]
	for  := range .baseNonce {
		[] ^= .baseNonce[]
	}
	return 
}

func ( *context) () {
	// Message limit is, according to the RFC, 2^95+1, which
	// is somewhat confusing, but we do as we're told.
	if .seqNum.bitLen() >= (.aead.NonceSize()*8)-1 {
		panic("message limit reached")
	}
	.seqNum = .seqNum.addOne()
}

func ( *Sender) (,  []byte) ([]byte, error) {
	 := .aead.Seal(nil, .nextNonce(), , )
	.incrementNonce()
	return , nil
}

func ( *Receipient) (,  []byte) ([]byte, error) {
	,  := .aead.Open(nil, .nextNonce(), , )
	if  != nil {
		return nil, 
	}
	.incrementNonce()
	return , nil
}

func suiteID(, ,  uint16) []byte {
	 := make([]byte, 0, 4+2+2+2)
	 = append(, []byte("HPKE")...)
	 = byteorder.BEAppendUint16(, )
	 = byteorder.BEAppendUint16(, )
	 = byteorder.BEAppendUint16(, )
	return 
}

func ( uint16,  []byte) (*ecdh.PublicKey, error) {
	,  := SupportedKEMs[]
	if ! {
		return nil, errors.New("unsupported KEM id")
	}
	return .curve.NewPublicKey()
}

func ( uint16,  []byte) (*ecdh.PrivateKey, error) {
	,  := SupportedKEMs[]
	if ! {
		return nil, errors.New("unsupported KEM id")
	}
	return .curve.NewPrivateKey()
}

type uint128 struct {
	hi, lo uint64
}

func ( uint128) () uint128 {
	,  := bits.Add64(.lo, 1, 0)
	return uint128{.hi + , }
}

func ( uint128) () int {
	return bits.Len64(.hi) + bits.Len64(.lo)
}

func ( uint128) () []byte {
	 := make([]byte, 16)
	byteorder.BEPutUint64([0:], .hi)
	byteorder.BEPutUint64([8:], .lo)
	return 
}