// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tls

import (
	
	
	
	
	
	

	
)

// sortedSupportedAEADs is just a sorted version of hpke.SupportedAEADS.
// We need this so that when we insert them into ECHConfigs the ordering
// is stable.
var sortedSupportedAEADs []uint16

func init() {
	for  := range hpke.SupportedAEADs {
		sortedSupportedAEADs = append(sortedSupportedAEADs, )
	}
	slices.Sort(sortedSupportedAEADs)
}

type echCipher struct {
	KDFID  uint16
	AEADID uint16
}

type echExtension struct {
	Type uint16
	Data []byte
}

type echConfig struct {
	raw []byte

	Version uint16
	Length  uint16

	ConfigID             uint8
	KemID                uint16
	PublicKey            []byte
	SymmetricCipherSuite []echCipher

	MaxNameLength uint8
	PublicName    []byte
	Extensions    []echExtension
}

var errMalformedECHConfig = errors.New("tls: malformed ECHConfigList")

func parseECHConfig( []byte) ( bool,  echConfig,  error) {
	 := cryptobyte.String()
	.raw = []byte()
	if !.ReadUint16(&.Version) {
		return false, echConfig{}, errMalformedECHConfig
	}
	if !.ReadUint16(&.Length) {
		return false, echConfig{}, errMalformedECHConfig
	}
	if len(.raw) < int(.Length)+4 {
		return false, echConfig{}, errMalformedECHConfig
	}
	.raw = .raw[:.Length+4]
	if .Version != extensionEncryptedClientHello {
		.Skip(int(.Length))
		return true, echConfig{}, nil
	}
	if !.ReadUint8(&.ConfigID) {
		return false, echConfig{}, errMalformedECHConfig
	}
	if !.ReadUint16(&.KemID) {
		return false, echConfig{}, errMalformedECHConfig
	}
	if !readUint16LengthPrefixed(&, &.PublicKey) {
		return false, echConfig{}, errMalformedECHConfig
	}
	var  cryptobyte.String
	if !.ReadUint16LengthPrefixed(&) {
		return false, echConfig{}, errMalformedECHConfig
	}
	for !.Empty() {
		var  echCipher
		if !.ReadUint16(&.KDFID) {
			return false, echConfig{}, errMalformedECHConfig
		}
		if !.ReadUint16(&.AEADID) {
			return false, echConfig{}, errMalformedECHConfig
		}
		.SymmetricCipherSuite = append(.SymmetricCipherSuite, )
	}
	if !.ReadUint8(&.MaxNameLength) {
		return false, echConfig{}, errMalformedECHConfig
	}
	var  cryptobyte.String
	if !.ReadUint8LengthPrefixed(&) {
		return false, echConfig{}, errMalformedECHConfig
	}
	.PublicName = 
	var  cryptobyte.String
	if !.ReadUint16LengthPrefixed(&) {
		return false, echConfig{}, errMalformedECHConfig
	}
	for !.Empty() {
		var  echExtension
		if !.ReadUint16(&.Type) {
			return false, echConfig{}, errMalformedECHConfig
		}
		if !.ReadUint16LengthPrefixed((*cryptobyte.String)(&.Data)) {
			return false, echConfig{}, errMalformedECHConfig
		}
		.Extensions = append(.Extensions, )
	}

	return false, , nil
}

// parseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
// slice of parsed ECHConfigs, in the same order they were parsed, or an error
// if the list is malformed.
func parseECHConfigList( []byte) ([]echConfig, error) {
	 := cryptobyte.String()
	var  uint16
	if !.ReadUint16(&) {
		return nil, errMalformedECHConfig
	}
	if  != uint16(len()-2) {
		return nil, errMalformedECHConfig
	}
	var  []echConfig
	for len() > 0 {
		if len() < 4 {
			return nil, errors.New("tls: malformed ECHConfig")
		}
		 := uint16([2])<<8 | uint16([3])
		, ,  := parseECHConfig()
		if  != nil {
			return nil, 
		}
		 = [+4:]
		if ! {
			 = append(, )
		}
	}
	return , nil
}

func pickECHConfig( []echConfig) *echConfig {
	for ,  := range  {
		if ,  := hpke.SupportedKEMs[.KemID]; ! {
			continue
		}
		var  bool
		for ,  := range .SymmetricCipherSuite {
			if ,  := hpke.SupportedAEADs[.AEADID]; ! {
				continue
			}
			if ,  := hpke.SupportedKDFs[.KDFID]; ! {
				continue
			}
			 = true
			break
		}
		if ! {
			continue
		}
		if !validDNSName(string(.PublicName)) {
			continue
		}
		var  bool
		for ,  := range .Extensions {
			// If high order bit is set to 1 the extension is mandatory.
			// Since we don't support any extensions, if we see a mandatory
			// bit, we skip the config.
			if .Type&uint16(1<<15) != 0 {
				 = true
			}
		}
		if  {
			continue
		}
		return &
	}
	return nil
}

func pickECHCipherSuite( []echCipher) (echCipher, error) {
	for ,  := range  {
		// NOTE: all of the supported AEADs and KDFs are fine, rather than
		// imposing some sort of preference here, we just pick the first valid
		// suite.
		if ,  := hpke.SupportedAEADs[.AEADID]; ! {
			continue
		}
		if ,  := hpke.SupportedKDFs[.KDFID]; ! {
			continue
		}
		return , nil
	}
	return echCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH")
}

func encodeInnerClientHello( *clientHelloMsg,  int) ([]byte, error) {
	,  := .marshalMsg(true)
	if  != nil {
		return nil, 
	}
	 = [4:] // strip four byte prefix

	var  int
	if .serverName != "" {
		 = max(0, -len(.serverName))
	} else {
		 =  + 9
	}
	 = 31 - ((len() +  - 1) % 32)

	return append(, make([]byte, )...), nil
}

func skipUint8LengthPrefixed( *cryptobyte.String) bool {
	var  uint8
	if !.ReadUint8(&) {
		return false
	}
	return .Skip(int())
}

func skipUint16LengthPrefixed( *cryptobyte.String) bool {
	var  uint16
	if !.ReadUint16(&) {
		return false
	}
	return .Skip(int())
}

type rawExtension struct {
	extType uint16
	data    []byte
}

func extractRawExtensions( *clientHelloMsg) ([]rawExtension, error) {
	 := cryptobyte.String(.original)
	if !.Skip(4+2+32) || // header, version, random
		!skipUint8LengthPrefixed(&) || // session ID
		!skipUint16LengthPrefixed(&) || // cipher suites
		!skipUint8LengthPrefixed(&) { // compression methods
		return nil, errors.New("tls: malformed outer client hello")
	}
	var  []rawExtension
	var  cryptobyte.String
	if !.ReadUint16LengthPrefixed(&) {
		return nil, errors.New("tls: malformed outer client hello")
	}

	for !.Empty() {
		var  uint16
		var  cryptobyte.String
		if !.ReadUint16(&) ||
			!.ReadUint16LengthPrefixed(&) {
			return nil, errors.New("tls: invalid inner client hello")
		}
		 = append(, rawExtension{, })
	}
	return , nil
}

func decodeInnerClientHello( *clientHelloMsg,  []byte) (*clientHelloMsg, error) {
	// Reconstructing the inner client hello from its encoded form is somewhat
	// complicated. It is missing its header (message type and length), session
	// ID, and the extensions may be compressed. Since we need to put the
	// extensions back in the same order as they were in the raw outer hello,
	// and since we don't store the raw extensions, or the order we parsed them
	// in, we need to reparse the raw extensions from the outer hello in order
	// to properly insert them into the inner hello. This _should_ result in raw
	// bytes which match the hello as it was generated by the client.
	 := cryptobyte.String()
	var , , ,  []byte
	var  cryptobyte.String
	if !.ReadBytes(&, 2+32) ||
		!readUint8LengthPrefixed(&, &) ||
		len() != 0 ||
		!readUint16LengthPrefixed(&, &) ||
		!readUint8LengthPrefixed(&, &) ||
		!.ReadUint16LengthPrefixed(&) {
		return nil, errors.New("tls: invalid inner client hello")
	}

	// The specification says we must verify that the trailing padding is all
	// zeros. This is kind of weird for TLS messages, where we generally just
	// throw away any trailing garbage.
	for ,  := range  {
		if  != 0 {
			return nil, errors.New("tls: invalid inner client hello")
		}
	}

	,  := extractRawExtensions()
	if  != nil {
		return nil, 
	}

	 := cryptobyte.NewBuilder(nil)
	.AddUint8(typeClientHello)
	.AddUint24LengthPrefixed(func( *cryptobyte.Builder) {
		.AddBytes()
		.AddUint8LengthPrefixed(func( *cryptobyte.Builder) {
			.AddBytes(.sessionId)
		})
		.AddUint16LengthPrefixed(func( *cryptobyte.Builder) {
			.AddBytes()
		})
		.AddUint8LengthPrefixed(func( *cryptobyte.Builder) {
			.AddBytes()
		})
		.AddUint16LengthPrefixed(func( *cryptobyte.Builder) {
			for !.Empty() {
				var  uint16
				var  cryptobyte.String
				if !.ReadUint16(&) ||
					!.ReadUint16LengthPrefixed(&) {
					.SetError(errors.New("tls: invalid inner client hello"))
					return
				}
				if  == extensionECHOuterExtensions {
					if !.ReadUint8LengthPrefixed(&) {
						.SetError(errors.New("tls: invalid inner client hello"))
						return
					}
					var  int
					for !.Empty() {
						var  uint16
						if !.ReadUint16(&) {
							.SetError(errors.New("tls: invalid inner client hello"))
							return
						}
						if  == extensionEncryptedClientHello {
							.SetError(errors.New("tls: invalid outer extensions"))
							return
						}
						for ;  <= len(); ++ {
							if  == len() {
								.SetError(errors.New("tls: invalid outer extensions"))
								return
							}
							if [].extType ==  {
								break
							}
						}
						.AddUint16([].extType)
						.AddUint16LengthPrefixed(func( *cryptobyte.Builder) {
							.AddBytes([].data)
						})
					}
				} else {
					.AddUint16()
					.AddUint16LengthPrefixed(func( *cryptobyte.Builder) {
						.AddBytes()
					})
				}
			}
		})
	})

	,  := .Bytes()
	if  != nil {
		return nil, 
	}
	 := &clientHelloMsg{}
	if !.unmarshal() {
		return nil, errors.New("tls: invalid reconstructed inner client hello")
	}

	if !bytes.Equal(.encryptedClientHello, []byte{uint8(innerECHExt)}) {
		return nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
	}

	if len(.supportedVersions) != 1 || (len(.supportedVersions) >= 1 && .supportedVersions[0] != VersionTLS13) {
		return nil, errors.New("tls: client sent encrypted_client_hello extension and offered incompatible versions")
	}

	return , nil
}

func decryptECHPayload( *hpke.Receipient, ,  []byte) ([]byte, error) {
	 := bytes.Replace([4:], , make([]byte, len()), 1)
	return .Open(, )
}

func generateOuterECHExt( uint8, ,  uint16,  []byte,  []byte) ([]byte, error) {
	var  cryptobyte.Builder
	.AddUint8(0) // outer
	.AddUint16()
	.AddUint16()
	.AddUint8()
	.AddUint16LengthPrefixed(func( *cryptobyte.Builder) { .AddBytes() })
	.AddUint16LengthPrefixed(func( *cryptobyte.Builder) { .AddBytes() })
	return .Bytes()
}

func computeAndUpdateOuterECHExtension(,  *clientHelloMsg,  *echClientContext,  bool) error {
	var  []byte
	if  {
		 = .encapsulatedKey
	}
	,  := encodeInnerClientHello(, int(.config.MaxNameLength))
	if  != nil {
		return 
	}
	// NOTE: the tag lengths for all of the supported AEADs are the same (16
	// bytes), so we have hardcoded it here. If we add support for another AEAD
	// with a different tag length, we will need to change this.
	 := len() + 16 // AEAD tag length
	.encryptedClientHello,  = generateOuterECHExt(.config.ConfigID, .kdfID, .aeadID, , make([]byte, ))
	if  != nil {
		return 
	}
	,  := .marshal()
	if  != nil {
		return 
	}
	 = [4:] // strip the four byte prefix
	,  := .hpkeContext.Seal(, )
	if  != nil {
		return 
	}
	.encryptedClientHello,  = generateOuterECHExt(.config.ConfigID, .kdfID, .aeadID, , )
	if  != nil {
		return 
	}
	return nil
}

// validDNSName is a rather rudimentary check for the validity of a DNS name.
// This is used to check if the public_name in a ECHConfig is valid when we are
// picking a config. This can be somewhat lax because even if we pick a
// valid-looking name, the DNS layer will later reject it anyway.
func validDNSName( string) bool {
	if len() > 253 {
		return false
	}
	 := strings.Split(, ".")
	if len() <= 1 {
		return false
	}
	for ,  := range  {
		 := len()
		if  == 0 {
			return false
		}
		for ,  := range  {
			if  == '-' && ( == 0 ||  == -1) {
				return false
			}
			if ( < '0' ||  > '9') && ( < 'a' ||  > 'z') && ( < 'A' ||  > 'Z') &&  != '-' {
				return false
			}
		}
	}
	return true
}

// ECHRejectionError is the error type returned when ECH is rejected by a remote
// server. If the server offered a ECHConfigList to use for retries, the
// RetryConfigList field will contain this list.
//
// The client may treat an ECHRejectionError with an empty set of RetryConfigs
// as a secure signal from the server.
type ECHRejectionError struct {
	RetryConfigList []byte
}

func ( *ECHRejectionError) () string {
	return "tls: server rejected ECH"
}

var errMalformedECHExt = errors.New("tls: malformed encrypted_client_hello extension")

type echExtType uint8

const (
	innerECHExt echExtType = 1
	outerECHExt echExtType = 0
)

func parseECHExt( []byte) ( echExtType,  echCipher,  uint8,  []byte,  []byte,  error) {
	 := make([]byte, len())
	copy(, )
	 := cryptobyte.String()
	var  uint8
	if !.ReadUint8(&) {
		 = errMalformedECHExt
		return
	}
	 = echExtType()
	if  == innerECHExt {
		if !.Empty() {
			 = errMalformedECHExt
			return
		}
		return , , 0, nil, nil, nil
	}
	if  != outerECHExt {
		 = errMalformedECHExt
		return
	}
	if !.ReadUint16(&.KDFID) {
		 = errMalformedECHExt
		return
	}
	if !.ReadUint16(&.AEADID) {
		 = errMalformedECHExt
		return
	}
	if !.ReadUint8(&) {
		 = errMalformedECHExt
		return
	}
	if !readUint16LengthPrefixed(&, &) {
		 = errMalformedECHExt
		return
	}
	if !readUint16LengthPrefixed(&, &) {
		 = errMalformedECHExt
		return
	}

	// NOTE: clone encap and payload so that mutating them does not mutate the
	// raw extension bytes.
	return , , , bytes.Clone(), bytes.Clone(), nil
}

func marshalEncryptedClientHelloConfigList( []EncryptedClientHelloKey) ([]byte, error) {
	 := cryptobyte.NewBuilder(nil)
	.AddUint16LengthPrefixed(func( *cryptobyte.Builder) {
		for ,  := range  {
			.AddBytes(.Config)
		}
	})
	return .Bytes()
}

func ( *Conn) ( *clientHelloMsg) (*clientHelloMsg, *echServerContext, error) {
	, , , , ,  := parseECHExt(.encryptedClientHello)
	if  != nil {
		.sendAlert(alertDecodeError)
		return nil, nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
	}

	if  == innerECHExt {
		return , &echServerContext{inner: true}, nil
	}

	if len(.config.EncryptedClientHelloKeys) == 0 {
		return , nil, nil
	}

	for ,  := range .config.EncryptedClientHelloKeys {
		, ,  := parseECHConfig(.Config)
		if  != nil ||  {
			.sendAlert(alertInternalError)
			return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys Config: %s", )
		}
		if  {
			continue
		}
		,  := hpke.ParseHPKEPrivateKey(.KemID, .PrivateKey)
		if  != nil {
			.sendAlert(alertInternalError)
			return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys PrivateKey: %s", )
		}
		 := append([]byte("tls ech\x00"), .Config...)
		,  := hpke.SetupReceipient(hpke.DHKEM_X25519_HKDF_SHA256, .KDFID, .AEADID, , , )
		if  != nil {
			// attempt next trial decryption
			continue
		}

		,  := decryptECHPayload(, .original, )
		if  != nil {
			// attempt next trial decryption
			continue
		}

		// NOTE: we do not enforce that the sent server_name matches the ECH
		// configs PublicName, since this is not particularly important, and
		// the client already had to know what it was in order to properly
		// encrypt the payload. This is only a MAY in the spec, so we're not
		// doing anything revolutionary.

		,  := decodeInnerClientHello(, )
		if  != nil {
			.sendAlert(alertIllegalParameter)
			return nil, nil, errors.New("tls: client sent invalid encrypted_client_hello extension")
		}

		.echAccepted = true

		return , &echServerContext{
			hpkeContext: ,
			configID:    ,
			ciphersuite: ,
		}, nil
	}

	return , nil, nil
}

func buildRetryConfigList( []EncryptedClientHelloKey) ([]byte, error) {
	var  bool
	var  cryptobyte.Builder
	.AddUint16LengthPrefixed(func( *cryptobyte.Builder) {
		for ,  := range  {
			if !.SendAsRetry {
				continue
			}
			 = true
			.AddBytes(.Config)
		}
	})
	if ! {
		return nil, nil
	}
	return .Bytes()
}