// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package hpke

import (
	
	
	
	
	
	
	
	
)

// The KDF is one of the three components of an HPKE ciphersuite, implementing
// key derivation.
type KDF interface {
	ID() uint16
	oneStage() bool
	size() int // Nh
	labeledDerive(suiteID, inputKey []byte, label string, context []byte, length uint16) ([]byte, error)
	labeledExtract(suiteID, salt []byte, label string, inputKey []byte) ([]byte, error)
	labeledExpand(suiteID, randomKey []byte, label string, info []byte, length uint16) ([]byte, error)
}

// NewKDF returns the KDF implementation for the given KDF ID.
//
// Applications are encouraged to use specific implementations like [HKDFSHA256]
// instead, unless runtime agility is required.
func ( uint16) (KDF, error) {
	switch  {
	case 0x0001: // HKDF-SHA256
		return HKDFSHA256(), nil
	case 0x0002: // HKDF-SHA384
		return HKDFSHA384(), nil
	case 0x0003: // HKDF-SHA512
		return HKDFSHA512(), nil
	case 0x0010: // SHAKE128
		return SHAKE128(), nil
	case 0x0011: // SHAKE256
		return SHAKE256(), nil
	default:
		return nil, fmt.Errorf("unsupported KDF %04x", )
	}
}

// HKDFSHA256 returns an HKDF-SHA256 KDF implementation.
func () KDF { return hkdfSHA256 }

// HKDFSHA384 returns an HKDF-SHA384 KDF implementation.
func () KDF { return hkdfSHA384 }

// HKDFSHA512 returns an HKDF-SHA512 KDF implementation.
func () KDF { return hkdfSHA512 }

type hkdfKDF struct {
	hash func() hash.Hash
	id   uint16
	nH   int
}

var hkdfSHA256 = &hkdfKDF{hash: sha256.New, id: 0x0001, nH: sha256.Size}
var hkdfSHA384 = &hkdfKDF{hash: sha512.New384, id: 0x0002, nH: sha512.Size384}
var hkdfSHA512 = &hkdfKDF{hash: sha512.New, id: 0x0003, nH: sha512.Size}

func ( *hkdfKDF) () uint16 {
	return .id
}

func ( *hkdfKDF) () int {
	return .nH
}

func ( *hkdfKDF) () bool {
	return false
}

func ( *hkdfKDF) (,  []byte,  string,  []byte,  uint16) ([]byte, error) {
	return nil, errors.New("hpke: internal error: labeledDerive called on two-stage KDF")
}

func ( *hkdfKDF) ( []byte,  []byte,  string,  []byte) ([]byte, error) {
	 := make([]byte, 0, 7+len()+len()+len())
	 = append(, []byte("HPKE-v1")...)
	 = append(, ...)
	 = append(, ...)
	 = append(, ...)
	return hkdf.Extract(.hash, , )
}

func ( *hkdfKDF) ( []byte,  []byte,  string,  []byte,  uint16) ([]byte, error) {
	 := make([]byte, 0, 2+7+len()+len()+len())
	 = byteorder.BEAppendUint16(, )
	 = append(, []byte("HPKE-v1")...)
	 = append(, ...)
	 = append(, ...)
	 = append(, ...)
	return hkdf.Expand(.hash, , string(), int())
}

// SHAKE128 returns a SHAKE128 KDF implementation.
func () KDF {
	return shake128KDF
}

// SHAKE256 returns a SHAKE256 KDF implementation.
func () KDF {
	return shake256KDF
}

type shakeKDF struct {
	hash func() *sha3.SHAKE
	id   uint16
	nH   int
}

var shake128KDF = &shakeKDF{hash: sha3.NewSHAKE128, id: 0x0010, nH: 32}
var shake256KDF = &shakeKDF{hash: sha3.NewSHAKE256, id: 0x0011, nH: 64}

func ( *shakeKDF) () uint16 {
	return .id
}

func ( *shakeKDF) () int {
	return .nH
}

func ( *shakeKDF) () bool {
	return true
}

func ( *shakeKDF) (,  []byte,  string,  []byte,  uint16) ([]byte, error) {
	 := .hash()
	.Write()
	.Write([]byte("HPKE-v1"))
	.Write()
	.Write([]byte{byte(len() >> 8), byte(len())})
	.Write([]byte())
	.Write([]byte{byte( >> 8), byte()})
	.Write()
	 := make([]byte, )
	.Read()
	return , nil
}

func ( *shakeKDF) (,  []byte,  string,  []byte) ([]byte, error) {
	return nil, errors.New("hpke: internal error: labeledExtract called on one-stage KDF")
}

func ( *shakeKDF) (,  []byte,  string,  []byte,  uint16) ([]byte, error) {
	return nil, errors.New("hpke: internal error: labeledExpand called on one-stage KDF")
}