// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tls

import (
	
	
	
	
	
	
	
	
)

// This file contains the functions necessary to compute the TLS 1.3 key
// schedule. See RFC 8446, Section 7.

// nextTrafficSecret generates the next traffic secret, given the current one,
// according to RFC 8446, Section 7.2.
func ( *cipherSuiteTLS13) ( []byte) []byte {
	return tls13.ExpandLabel(.hash.New, , "traffic upd", nil, .hash.Size())
}

// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
func ( *cipherSuiteTLS13) ( []byte) (,  []byte) {
	 = tls13.ExpandLabel(.hash.New, , "key", nil, .keyLen)
	 = tls13.ExpandLabel(.hash.New, , "iv", nil, aeadNonceLength)
	return
}

// finishedHash generates the Finished verify_data or PskBinderEntry according
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
// selection.
func ( *cipherSuiteTLS13) ( []byte,  hash.Hash) []byte {
	 := tls13.ExpandLabel(.hash.New, , "finished", nil, .hash.Size())
	 := hmac.New(.hash.New, )
	.Write(.Sum(nil))
	return .Sum(nil)
}

// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
// RFC 8446, Section 7.5.
func ( *cipherSuiteTLS13) ( *tls13.MasterSecret,  hash.Hash) func(string, []byte, int) ([]byte, error) {
	 := .ExporterMasterSecret()
	return func( string,  []byte,  int) ([]byte, error) {
		return .Exporter(, , ), nil
	}
}

type keySharePrivateKeys struct {
	ecdhe *ecdh.PrivateKey
	mlkem crypto.Decapsulator
}

// A keyExchange implements a TLS 1.3 KEM.
type keyExchange interface {
	// keyShares generates one or two key shares.
	//
	// The first one will match the id, the second (if present) reuses the
	// traditional component of the requested hybrid, as allowed by
	// draft-ietf-tls-hybrid-design-09, Section 3.2.
	keyShares(rand io.Reader) (*keySharePrivateKeys, []keyShare, error)

	// serverSharedSecret computes the shared secret and the server's key share.
	serverSharedSecret(rand io.Reader, clientKeyShare []byte) ([]byte, keyShare, error)

	// clientSharedSecret computes the shared secret given the server's key
	// share and the keys generated by keyShares.
	clientSharedSecret(priv *keySharePrivateKeys, serverKeyShare []byte) ([]byte, error)
}

func keyExchangeForCurveID( CurveID) (keyExchange, error) {
	 := func( []byte) (crypto.Decapsulator, error) {
		return mlkem.NewDecapsulationKey768()
	}
	 := func( []byte) (crypto.Decapsulator, error) {
		return mlkem.NewDecapsulationKey1024()
	}
	 := func( []byte) (crypto.Encapsulator, error) {
		return mlkem.NewEncapsulationKey768()
	}
	 := func( []byte) (crypto.Encapsulator, error) {
		return mlkem.NewEncapsulationKey1024()
	}
	switch  {
	case X25519:
		return &ecdhKeyExchange{, ecdh.X25519()}, nil
	case CurveP256:
		return &ecdhKeyExchange{, ecdh.P256()}, nil
	case CurveP384:
		return &ecdhKeyExchange{, ecdh.P384()}, nil
	case CurveP521:
		return &ecdhKeyExchange{, ecdh.P521()}, nil
	case X25519MLKEM768:
		return &hybridKeyExchange{, ecdhKeyExchange{X25519, ecdh.X25519()},
			32, mlkem.EncapsulationKeySize768, mlkem.CiphertextSize768,
			, }, nil
	case SecP256r1MLKEM768:
		return &hybridKeyExchange{, ecdhKeyExchange{CurveP256, ecdh.P256()},
			65, mlkem.EncapsulationKeySize768, mlkem.CiphertextSize768,
			, }, nil
	case SecP384r1MLKEM1024:
		return &hybridKeyExchange{, ecdhKeyExchange{CurveP384, ecdh.P384()},
			97, mlkem.EncapsulationKeySize1024, mlkem.CiphertextSize1024,
			, }, nil
	default:
		return nil, errors.New("tls: unsupported key exchange")
	}
}

type ecdhKeyExchange struct {
	id    CurveID
	curve ecdh.Curve
}

func ( *ecdhKeyExchange) ( io.Reader) (*keySharePrivateKeys, []keyShare, error) {
	,  := .curve.GenerateKey()
	if  != nil {
		return nil, nil, 
	}
	return &keySharePrivateKeys{ecdhe: }, []keyShare{{.id, .PublicKey().Bytes()}}, nil
}

func ( *ecdhKeyExchange) ( io.Reader,  []byte) ([]byte, keyShare, error) {
	,  := .curve.GenerateKey()
	if  != nil {
		return nil, keyShare{}, 
	}
	,  := .curve.NewPublicKey()
	if  != nil {
		return nil, keyShare{}, 
	}
	,  := .ECDH()
	if  != nil {
		return nil, keyShare{}, 
	}
	return , keyShare{.id, .PublicKey().Bytes()}, nil
}

func ( *ecdhKeyExchange) ( *keySharePrivateKeys,  []byte) ([]byte, error) {
	,  := .curve.NewPublicKey()
	if  != nil {
		return nil, 
	}
	,  := .ecdhe.ECDH()
	if  != nil {
		return nil, 
	}
	return , nil
}

type hybridKeyExchange struct {
	id   CurveID
	ecdh ecdhKeyExchange

	ecdhElementSize     int
	mlkemPublicKeySize  int
	mlkemCiphertextSize int

	newMLKEMPrivateKey func([]byte) (crypto.Decapsulator, error)
	newMLKEMPublicKey  func([]byte) (crypto.Encapsulator, error)
}

func ( *hybridKeyExchange) ( io.Reader) (*keySharePrivateKeys, []keyShare, error) {
	, ,  := .ecdh.keyShares()
	if  != nil {
		return nil, nil, 
	}
	 := make([]byte, mlkem.SeedSize)
	if ,  := io.ReadFull(, );  != nil {
		return nil, nil, 
	}
	.mlkem,  = .newMLKEMPrivateKey()
	if  != nil {
		return nil, nil, 
	}
	var  []byte
	// For X25519MLKEM768, the ML-KEM-768 encapsulation key comes first.
	// For SecP256r1MLKEM768 and SecP384r1MLKEM1024, the ECDH share comes first.
	// See draft-ietf-tls-ecdhe-mlkem-02, Section 4.1.
	if .id == X25519MLKEM768 {
		 = append(.mlkem.Encapsulator().Bytes(), [0].data...)
	} else {
		 = append([0].data, .mlkem.Encapsulator().Bytes()...)
	}
	return , []keyShare{{.id, }, [0]}, nil
}

func ( *hybridKeyExchange) ( io.Reader,  []byte) ([]byte, keyShare, error) {
	if len() != .ecdhElementSize+.mlkemPublicKeySize {
		return nil, keyShare{}, errors.New("tls: invalid client key share length for hybrid key exchange")
	}
	var ,  []byte
	if .id == X25519MLKEM768 {
		 = [:.mlkemPublicKeySize]
		 = [.mlkemPublicKeySize:]
	} else {
		 = [:.ecdhElementSize]
		 = [.ecdhElementSize:]
	}
	, ,  := .ecdh.serverSharedSecret(, )
	if  != nil {
		return nil, keyShare{}, 
	}
	,  := .newMLKEMPublicKey()
	if  != nil {
		return nil, keyShare{}, 
	}
	,  := .Encapsulate()
	var  []byte
	if .id == X25519MLKEM768 {
		 = append(, ...)
		.data = append(, .data...)
	} else {
		 = append(, ...)
		.data = append(.data, ...)
	}
	.group = .id
	return , , nil
}

func ( *hybridKeyExchange) ( *keySharePrivateKeys,  []byte) ([]byte, error) {
	if len() != .ecdhElementSize+.mlkemCiphertextSize {
		return nil, errors.New("tls: invalid server key share length for hybrid key exchange")
	}
	var ,  []byte
	if .id == X25519MLKEM768 {
		 = [:.mlkemCiphertextSize]
		 = [.mlkemCiphertextSize:]
	} else {
		 = [:.ecdhElementSize]
		 = [.ecdhElementSize:]
	}
	,  := .ecdh.clientSharedSecret(, )
	if  != nil {
		return nil, 
	}
	,  := .mlkem.Decapsulate()
	if  != nil {
		return nil, 
	}
	var  []byte
	if .id == X25519MLKEM768 {
		 = append(, ...)
	} else {
		 = append(, ...)
	}
	return , nil
}