// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package hpke

import (
	
	
	
	
	
	
	
	

	
	
)

// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)

type hkdfKDF struct {
	hash crypto.Hash
}

func ( *hkdfKDF) ( []byte,  []byte,  string,  []byte) []byte {
	 := make([]byte, 0, 7+len()+len()+len())
	 = append(, []byte("HPKE-v1")...)
	 = append(, ...)
	 = append(, ...)
	 = append(, ...)
	return hkdf.Extract(.hash.New, , )
}

func ( *hkdfKDF) ( []byte,  []byte,  string,  []byte,  uint16) []byte {
	 := make([]byte, 0, 2+7+len()+len()+len())
	 = binary.BigEndian.AppendUint16(, )
	 = append(, []byte("HPKE-v1")...)
	 = append(, ...)
	 = append(, ...)
	 = append(, ...)
	 := make([]byte, )
	,  := hkdf.Expand(.hash.New, , ).Read()
	if  != nil ||  != int() {
		panic("hpke: LabeledExpand failed unexpectedly")
	}
	return 
}

// dhKEM implements the KEM specified in RFC 9180, Section 4.1.
type dhKEM struct {
	dh  ecdh.Curve
	kdf hkdfKDF

	suiteID []byte
	nSecret uint16
}

var SupportedKEMs = map[uint16]struct {
	curve   ecdh.Curve
	hash    crypto.Hash
	nSecret uint16
}{
	// RFC 9180 Section 7.1
	0x0020: {ecdh.X25519(), crypto.SHA256, 32},
}

func newDHKem( uint16) (*dhKEM, error) {
	,  := SupportedKEMs[]
	if ! {
		return nil, errors.New("unsupported suite ID")
	}
	return &dhKEM{
		dh:      .curve,
		kdf:     hkdfKDF{.hash},
		suiteID: binary.BigEndian.AppendUint16([]byte("KEM"), ),
		nSecret: .nSecret,
	}, nil
}

func ( *dhKEM) (,  []byte) []byte {
	 := .kdf.LabeledExtract(.suiteID[:], nil, "eae_prk", )
	return .kdf.LabeledExpand(.suiteID[:], , "shared_secret", , .nSecret)
}

func ( *dhKEM) ( *ecdh.PublicKey) ( []byte,  []byte,  error) {
	var  *ecdh.PrivateKey
	if testingOnlyGenerateKey != nil {
		,  = testingOnlyGenerateKey()
	} else {
		,  = .dh.GenerateKey(rand.Reader)
	}
	if  != nil {
		return nil, nil, 
	}
	,  := .ECDH()
	if  != nil {
		return nil, nil, 
	}
	 := .PublicKey().Bytes()

	 := .Bytes()
	 := append(, ...)

	return .ExtractAndExpand(, ), , nil
}

type Sender struct {
	aead cipher.AEAD
	kem  *dhKEM

	sharedSecret []byte

	suiteID []byte

	key            []byte
	baseNonce      []byte
	exporterSecret []byte

	seqNum uint128
}

var aesGCMNew = func( []byte) (cipher.AEAD, error) {
	,  := aes.NewCipher()
	if  != nil {
		return nil, 
	}
	return cipher.NewGCM()
}

var SupportedAEADs = map[uint16]struct {
	keySize   int
	nonceSize int
	aead      func([]byte) (cipher.AEAD, error)
}{
	// RFC 9180, Section 7.3
	0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
	0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
	0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
}

var SupportedKDFs = map[uint16]func() *hkdfKDF{
	// RFC 9180, Section 7.2
	0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
}

func (, ,  uint16,  crypto.PublicKey,  []byte) ([]byte, *Sender, error) {
	 := SuiteID(, , )

	,  := newDHKem()
	if  != nil {
		return nil, nil, 
	}
	,  := .(*ecdh.PublicKey)
	if ! {
		return nil, nil, errors.New("incorrect public key type")
	}
	, ,  := .Encap()
	if  != nil {
		return nil, nil, 
	}

	,  := SupportedKDFs[]
	if ! {
		return nil, nil, errors.New("unsupported KDF id")
	}
	 := ()

	,  := SupportedAEADs[]
	if ! {
		return nil, nil, errors.New("unsupported AEAD id")
	}

	 := .LabeledExtract(, nil, "psk_id_hash", nil)
	 := .LabeledExtract(, nil, "info_hash", )
	 := append([]byte{0}, ...)
	 = append(, ...)

	 := .LabeledExtract(, , "secret", nil)

	 := .LabeledExpand(, , "key", , uint16(.keySize) /* Nk - key size for AEAD */)
	 := .LabeledExpand(, , "base_nonce", , uint16(.nonceSize) /* Nn - nonce size for AEAD */)
	 := .LabeledExpand(, , "exp", , uint16(.hash.Size()) /* Nh - hash output size of the kdf*/)

	,  := .aead()
	if  != nil {
		return nil, nil, 
	}

	return , &Sender{
		kem:            ,
		aead:           ,
		sharedSecret:   ,
		suiteID:        ,
		key:            ,
		baseNonce:      ,
		exporterSecret: ,
	}, nil
}

func ( *Sender) () []byte {
	 := .seqNum.bytes()[16-.aead.NonceSize():]
	for  := range .baseNonce {
		[] ^= .baseNonce[]
	}
	// Message limit is, according to the RFC, 2^95+1, which
	// is somewhat confusing, but we do as we're told.
	if .seqNum.bitLen() >= (.aead.NonceSize()*8)-1 {
		panic("message limit reached")
	}
	.seqNum = .seqNum.addOne()
	return 
}

func ( *Sender) (,  []byte) ([]byte, error) {

	 := .aead.Seal(nil, .nextNonce(), , )
	return , nil
}

func (, ,  uint16) []byte {
	 := make([]byte, 0, 4+2+2+2)
	 = append(, []byte("HPKE")...)
	 = binary.BigEndian.AppendUint16(, )
	 = binary.BigEndian.AppendUint16(, )
	 = binary.BigEndian.AppendUint16(, )
	return 
}

func ( uint16,  []byte) (*ecdh.PublicKey, error) {
	,  := SupportedKEMs[]
	if ! {
		return nil, errors.New("unsupported KEM id")
	}
	return .curve.NewPublicKey()
}

type uint128 struct {
	hi, lo uint64
}

func ( uint128) () uint128 {
	,  := bits.Add64(.lo, 1, 0)
	return uint128{.hi + , }
}

func ( uint128) () int {
	return bits.Len64(.hi) + bits.Len64(.lo)
}

func ( uint128) () []byte {
	 := make([]byte, 16)
	binary.BigEndian.PutUint64([0:], .hi)
	binary.BigEndian.PutUint64([8:], .lo)
	return 
}