// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package tls

import (
	
	
	
	
	
	
	
	
	
)

// A keyAgreement implements the client and server side of a TLS 1.0–1.2 key
// agreement protocol by generating and processing key exchange messages.
type keyAgreement interface {
	// On the server side, the first two methods are called in order.

	// In the case that the key agreement protocol doesn't use a
	// ServerKeyExchange message, generateServerKeyExchange can return nil,
	// nil.
	generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
	processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)

	// On the client side, the next two methods are called in order.

	// This method may not be called if the server doesn't send a
	// ServerKeyExchange message.
	processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
	generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
}

var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")

// rsaKeyAgreement implements the standard TLS key agreement where the client
// encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{}

func ( rsaKeyAgreement) ( *Config,  *Certificate,  *clientHelloMsg,  *serverHelloMsg) (*serverKeyExchangeMsg, error) {
	return nil, nil
}

func ( rsaKeyAgreement) ( *Config,  *Certificate,  *clientKeyExchangeMsg,  uint16) ([]byte, error) {
	if len(.ciphertext) < 2 {
		return nil, errClientKeyExchange
	}
	 := int(.ciphertext[0])<<8 | int(.ciphertext[1])
	if  != len(.ciphertext)-2 {
		return nil, errClientKeyExchange
	}
	 := .ciphertext[2:]

	,  := .PrivateKey.(crypto.Decrypter)
	if ! {
		return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
	}
	// Perform constant time RSA PKCS #1 v1.5 decryption
	,  := .Decrypt(.rand(), , &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
	if  != nil {
		return nil, 
	}
	// We don't check the version number in the premaster secret. For one,
	// by checking it, we would leak information about the validity of the
	// encrypted pre-master secret. Secondly, it provides only a small
	// benefit against a downgrade attack and some implementations send the
	// wrong version anyway. See the discussion at the end of section
	// 7.4.7.1 of RFC 4346.
	return , nil
}

func ( rsaKeyAgreement) ( *Config,  *clientHelloMsg,  *serverHelloMsg,  *x509.Certificate,  *serverKeyExchangeMsg) error {
	return errors.New("tls: unexpected ServerKeyExchange")
}

func ( rsaKeyAgreement) ( *Config,  *clientHelloMsg,  *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
	 := make([]byte, 48)
	[0] = byte(.vers >> 8)
	[1] = byte(.vers)
	,  := io.ReadFull(.rand(), [2:])
	if  != nil {
		return nil, nil, 
	}

	,  := .PublicKey.(*rsa.PublicKey)
	if ! {
		return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
	}
	,  := rsa.EncryptPKCS1v15(.rand(), , )
	if  != nil {
		return nil, nil, 
	}
	 := new(clientKeyExchangeMsg)
	.ciphertext = make([]byte, len()+2)
	.ciphertext[0] = byte(len() >> 8)
	.ciphertext[1] = byte(len())
	copy(.ciphertext[2:], )
	return , , nil
}

// sha1Hash calculates a SHA1 hash over the given byte slices.
func sha1Hash( [][]byte) []byte {
	 := sha1.New()
	for ,  := range  {
		.Write()
	}
	return .Sum(nil)
}

// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
// concatenation of an MD5 and SHA1 hash.
func md5SHA1Hash( [][]byte) []byte {
	 := make([]byte, md5.Size+sha1.Size)
	 := md5.New()
	for ,  := range  {
		.Write()
	}
	copy(, .Sum(nil))
	copy([md5.Size:], sha1Hash())
	return 
}

// hashForServerKeyExchange hashes the given slices and returns their digest
// using the given hash function (for TLS 1.2) or using a default based on
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
// do pre-hashing, it returns the concatenation of the slices.
func hashForServerKeyExchange( uint8,  crypto.Hash,  uint16,  ...[]byte) []byte {
	if  == signatureEd25519 {
		var  []byte
		for ,  := range  {
			 = append(, ...)
		}
		return 
	}
	if  >= VersionTLS12 {
		 := .New()
		for ,  := range  {
			.Write()
		}
		 := .Sum(nil)
		return 
	}
	if  == signatureECDSA {
		return sha1Hash()
	}
	return md5SHA1Hash()
}

// ecdheKeyAgreement implements a TLS key agreement where the server
// generates an ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH. The signature may
// be ECDSA, Ed25519 or RSA.
type ecdheKeyAgreement struct {
	version uint16
	isRSA   bool
	key     *ecdh.PrivateKey

	// ckx and preMasterSecret are generated in processServerKeyExchange
	// and returned in generateClientKeyExchange.
	ckx             *clientKeyExchangeMsg
	preMasterSecret []byte
}

func ( *ecdheKeyAgreement) ( *Config,  *Certificate,  *clientHelloMsg,  *serverHelloMsg) (*serverKeyExchangeMsg, error) {
	var  CurveID
	for ,  := range .supportedCurves {
		if .supportsCurve(.version, ) {
			 = 
			break
		}
	}

	if  == 0 {
		return nil, errors.New("tls: no supported elliptic curves offered")
	}
	if ,  := curveForCurveID(); ! {
		return nil, errors.New("tls: CurvePreferences includes unsupported curve")
	}

	,  := generateECDHEKey(.rand(), )
	if  != nil {
		return nil, 
	}
	.key = 

	// See RFC 4492, Section 5.4.
	 := .PublicKey().Bytes()
	 := make([]byte, 1+2+1+len())
	[0] = 3 // named curve
	[1] = byte( >> 8)
	[2] = byte()
	[3] = byte(len())
	copy([4:], )

	,  := .PrivateKey.(crypto.Signer)
	if ! {
		return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", .PrivateKey)
	}

	var  SignatureScheme
	var  uint8
	var  crypto.Hash
	if .version >= VersionTLS12 {
		,  = selectSignatureScheme(.version, , .supportedSignatureAlgorithms)
		if  != nil {
			return nil, 
		}
		, ,  = typeAndHashFromSignatureScheme()
		if  != nil {
			return nil, 
		}
	} else {
		, ,  = legacyTypeAndHashFromPublicKey(.Public())
		if  != nil {
			return nil, 
		}
	}
	if ( == signaturePKCS1v15 ||  == signatureRSAPSS) != .isRSA {
		return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
	}

	 := hashForServerKeyExchange(, , .version, .random, .random, )

	 := crypto.SignerOpts()
	if  == signatureRSAPSS {
		 = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: }
	}
	,  := .Sign(.rand(), , )
	if  != nil {
		return nil, errors.New("tls: failed to sign ECDHE parameters: " + .Error())
	}

	 := new(serverKeyExchangeMsg)
	 := 0
	if .version >= VersionTLS12 {
		 = 2
	}
	.key = make([]byte, len()++2+len())
	copy(.key, )
	 := .key[len():]
	if .version >= VersionTLS12 {
		[0] = byte( >> 8)
		[1] = byte()
		 = [2:]
	}
	[0] = byte(len() >> 8)
	[1] = byte(len())
	copy([2:], )

	return , nil
}

func ( *ecdheKeyAgreement) ( *Config,  *Certificate,  *clientKeyExchangeMsg,  uint16) ([]byte, error) {
	if len(.ciphertext) == 0 || int(.ciphertext[0]) != len(.ciphertext)-1 {
		return nil, errClientKeyExchange
	}

	,  := .key.Curve().NewPublicKey(.ciphertext[1:])
	if  != nil {
		return nil, errClientKeyExchange
	}
	,  := .key.ECDH()
	if  != nil {
		return nil, errClientKeyExchange
	}

	return , nil
}

func ( *ecdheKeyAgreement) ( *Config,  *clientHelloMsg,  *serverHelloMsg,  *x509.Certificate,  *serverKeyExchangeMsg) error {
	if len(.key) < 4 {
		return errServerKeyExchange
	}
	if .key[0] != 3 { // named curve
		return errors.New("tls: server selected unsupported curve")
	}
	 := CurveID(.key[1])<<8 | CurveID(.key[2])

	 := int(.key[3])
	if +4 > len(.key) {
		return errServerKeyExchange
	}
	 := .key[:4+]
	 := [4:]

	 := .key[4+:]
	if len() < 2 {
		return errServerKeyExchange
	}

	if ,  := curveForCurveID(); ! {
		return errors.New("tls: server selected unsupported curve")
	}

	,  := generateECDHEKey(.rand(), )
	if  != nil {
		return 
	}
	.key = 

	,  := .Curve().NewPublicKey()
	if  != nil {
		return errServerKeyExchange
	}
	.preMasterSecret,  = .ECDH()
	if  != nil {
		return errServerKeyExchange
	}

	 := .PublicKey().Bytes()
	.ckx = new(clientKeyExchangeMsg)
	.ckx.ciphertext = make([]byte, 1+len())
	.ckx.ciphertext[0] = byte(len())
	copy(.ckx.ciphertext[1:], )

	var  uint8
	var  crypto.Hash
	if .version >= VersionTLS12 {
		 := SignatureScheme([0])<<8 | SignatureScheme([1])
		 = [2:]
		if len() < 2 {
			return errServerKeyExchange
		}

		if !isSupportedSignatureAlgorithm(, .supportedSignatureAlgorithms) {
			return errors.New("tls: certificate used with invalid signature algorithm")
		}
		, ,  = typeAndHashFromSignatureScheme()
		if  != nil {
			return 
		}
	} else {
		, ,  = legacyTypeAndHashFromPublicKey(.PublicKey)
		if  != nil {
			return 
		}
	}
	if ( == signaturePKCS1v15 ||  == signatureRSAPSS) != .isRSA {
		return errServerKeyExchange
	}

	 := int([0])<<8 | int([1])
	if +2 != len() {
		return errServerKeyExchange
	}
	 = [2:]

	 := hashForServerKeyExchange(, , .version, .random, .random, )
	if  := verifyHandshakeSignature(, .PublicKey, , , );  != nil {
		return errors.New("tls: invalid signature by the server certificate: " + .Error())
	}
	return nil
}

func ( *ecdheKeyAgreement) ( *Config,  *clientHelloMsg,  *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
	if .ckx == nil {
		return nil, nil, errors.New("tls: missing ServerKeyExchange message")
	}

	return .preMasterSecret, .ckx, nil
}