// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package rsa

import (
	
	
	
	
)

type PublicKey struct {
	N *bigmod.Modulus
	E int
}

// Size returns the modulus size in bytes. Raw signatures and ciphertexts
// for or by this public key will have the same size.
func ( *PublicKey) () int {
	return (.N.BitLen() + 7) / 8
}

type PrivateKey struct {
	// pub has already been checked with checkPublicKey.
	pub PublicKey
	d   *bigmod.Nat
	// The following values are not set for deprecated multi-prime keys.
	//
	// Since they are always set for keys in FIPS mode, for SP 800-56B Rev. 2
	// purposes we always use the Chinese Remainder Theorem (CRT) format.
	p, q *bigmod.Modulus // p × q = n
	// dP and dQ are used as exponents, so we store them as big-endian byte
	// slices to be passed to [bigmod.Nat.Exp].
	dP   []byte      // d mod (p - 1)
	dQ   []byte      // d mod (q - 1)
	qInv *bigmod.Nat // qInv = q⁻¹ mod p
	// fipsApproved is false if this key does not comply with FIPS 186-5 or
	// SP 800-56B Rev. 2.
	fipsApproved bool
}

func ( *PrivateKey) () *PublicKey {
	return &.pub
}

// NewPrivateKey creates a new RSA private key from the given parameters.
//
// All values are in big-endian byte slice format, and may have leading zeros
// or be shorter if leading zeroes were trimmed.
func ( []byte,  int, , ,  []byte) (*PrivateKey, error) {
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewNat().SetBytes(, )
	if  != nil {
		return nil, 
	}
	return newPrivateKey(, , , , )
}

func newPrivateKey( *bigmod.Modulus,  int,  *bigmod.Nat, ,  *bigmod.Modulus) (*PrivateKey, error) {
	 := .Nat().SubOne()
	,  := bigmod.NewModulus(.Bytes())
	if  != nil {
		return nil, 
	}
	 := bigmod.NewNat().Mod(, ).Bytes()

	 := .Nat().SubOne()
	,  := bigmod.NewModulus(.Bytes())
	if  != nil {
		return nil, 
	}
	 := bigmod.NewNat().Mod(, ).Bytes()

	// Constant-time modular inversion with prime modulus by Fermat's Little
	// Theorem: qInv = q⁻¹ mod p = q^(p-2) mod p.
	if .Nat().IsOdd() == 0 {
		// [bigmod.Nat.Exp] requires an odd modulus.
		return nil, errors.New("crypto/rsa: p is even")
	}
	 := .Nat().SubOne().SubOne().Bytes()
	 := bigmod.NewNat().Mod(.Nat(), )
	.Exp(, , )

	 := &PrivateKey{
		pub: PublicKey{
			N: , E: ,
		},
		d: , p: , q: ,
		dP: , dQ: , qInv: ,
	}
	if  := checkPrivateKey();  != nil {
		return nil, 
	}
	return , nil
}

// NewPrivateKeyWithPrecomputation creates a new RSA private key from the given
// parameters, which include precomputed CRT values.
func ( []byte,  int, , , , , ,  []byte) (*PrivateKey, error) {
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewNat().SetBytes(, )
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewNat().SetBytes(, )
	if  != nil {
		return nil, 
	}

	 := &PrivateKey{
		pub: PublicKey{
			N: , E: ,
		},
		d: , p: , q: ,
		dP: , dQ: , qInv: ,
	}
	if  := checkPrivateKey();  != nil {
		return nil, 
	}
	return , nil
}

// NewPrivateKeyWithoutCRT creates a new RSA private key from the given parameters.
//
// This is meant for deprecated multi-prime keys, and is not FIPS 140 compliant.
func ( []byte,  int,  []byte) (*PrivateKey, error) {
	,  := bigmod.NewModulus()
	if  != nil {
		return nil, 
	}
	,  := bigmod.NewNat().SetBytes(, )
	if  != nil {
		return nil, 
	}
	 := &PrivateKey{
		pub: PublicKey{
			N: , E: ,
		},
		d: ,
	}
	if  := checkPrivateKey();  != nil {
		return nil, 
	}
	return , nil
}

// Export returns the key parameters in big-endian byte slice format.
//
// P, Q, dP, dQ, and qInv may be nil if the key was created with
// NewPrivateKeyWithoutCRT.
func ( *PrivateKey) () ( []byte,  int, , , , , ,  []byte) {
	 = .pub.N.Nat().Bytes(.pub.N)
	 = .pub.E
	 = .d.Bytes(.pub.N)
	if .dP == nil {
		return
	}
	 = .p.Nat().Bytes(.p)
	 = .q.Nat().Bytes(.q)
	 = bytes.Clone(.dP)
	 = bytes.Clone(.dQ)
	 = .qInv.Bytes(.p)
	return
}

// checkPrivateKey is called by the NewPrivateKey and GenerateKey functions, and
// is allowed to modify priv.fipsApproved.
func checkPrivateKey( *PrivateKey) error {
	.fipsApproved = true

	if ,  := checkPublicKey(&.pub);  != nil {
		return 
	} else if ! {
		.fipsApproved = false
	}

	if .dP == nil {
		// Legacy and deprecated multi-prime keys.
		.fipsApproved = false
		return nil
	}

	 := .pub.N
	 := .p
	 := .q

	// FIPS 186-5, Section 5.1 requires "that p and q be of the same bit length."
	if .BitLen() != .BitLen() {
		.fipsApproved = false
	}

	// Check that pq ≡ 1 mod N (and that p < N and q < N).
	 := bigmod.NewNat().ExpandFor()
	if ,  := .SetBytes(.Nat().Bytes(), );  != nil {
		return errors.New("crypto/rsa: invalid prime")
	}
	 := bigmod.NewNat().ExpandFor()
	if ,  := .SetBytes(.Nat().Bytes(), );  != nil {
		return errors.New("crypto/rsa: invalid prime")
	}
	if .Mul(, ).IsZero() != 1 {
		return errors.New("crypto/rsa: p * q != n")
	}

	// Check that de ≡ 1 mod p-1, and de ≡ 1 mod q-1.
	//
	// This implies that e is coprime to each p-1 as e has a multiplicative
	// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) = exponent(ℤ/nℤ).
	// It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1 mod p. Thus a^de ≡ a
	// mod n for all a coprime to n, as required.
	//
	// This checks dP, dQ, and e. We don't check d because it is not actually
	// used in the RSA private key operation.
	,  := bigmod.NewModulus(.Nat().SubOne().Bytes())
	if  != nil {
		return errors.New("crypto/rsa: invalid prime")
	}
	,  := bigmod.NewNat().SetBytes(.dP, )
	if  != nil {
		return errors.New("crypto/rsa: invalid CRT exponent")
	}
	 := bigmod.NewNat()
	.SetUint(uint(.pub.E)).ExpandFor()
	.Mul(, )
	if .IsOne() != 1 {
		return errors.New("crypto/rsa: invalid CRT exponent")
	}

	,  := bigmod.NewModulus(.Nat().SubOne().Bytes())
	if  != nil {
		return errors.New("crypto/rsa: invalid prime")
	}
	,  := bigmod.NewNat().SetBytes(.dQ, )
	if  != nil {
		return errors.New("crypto/rsa: invalid CRT exponent")
	}
	.SetUint(uint(.pub.E)).ExpandFor()
	.Mul(, )
	if .IsOne() != 1 {
		return errors.New("crypto/rsa: invalid CRT exponent")
	}

	// Check that qInv * q ≡ 1 mod p.
	,  := bigmod.NewNat().SetOverflowingBytes(.Nat().Bytes(), )
	if  != nil {
		// q >= 2^⌈log2(p)⌉
		 = bigmod.NewNat().Mod(.Nat(), )
	}
	if .Mul(.qInv, ).IsOne() != 1 {
		return errors.New("crypto/rsa: invalid CRT coefficient")
	}

	// Check that |p - q| > 2^(nlen/2 - 100).
	//
	// If p and q are very close to each other, then N=pq can be trivially
	// factored using Fermat's factorization method. Broken RSA implementations
	// do generate such keys. See Hanno Böck, Fermat Factorization in the Wild,
	// https://eprint.iacr.org/2023/026.pdf.
	 := bigmod.NewNat()
	if ,  := bigmod.NewNat().SetBytes(.Nat().Bytes(), );  != nil {
		// q > p
		,  := bigmod.NewNat().SetBytes(.Nat().Bytes(), )
		if  != nil {
			return errors.New("crypto/rsa: p == q")
		}
		// diff = 0 - p mod q = q - p
		.ExpandFor().Sub(, )
	} else {
		// p > q
		// diff = 0 - q mod p = p - q
		.ExpandFor().Sub(, )
	}
	// A tiny bit of leakage is acceptable because it's not adaptive, an
	// attacker only learns the magnitude of p - q.
	if .BitLenVarTime() <= .BitLen()/2-100 {
		return errors.New("crypto/rsa: |p - q| too small")
	}

	// Check that d > 2^(nlen/2).
	//
	// See section 3 of https://crypto.stanford.edu/~dabo/papers/RSA-survey.pdf
	// for more details about attacks on small d values.
	//
	// Likewise, the leakage of the magnitude of d is not adaptive.
	if .d.BitLenVarTime() <= .BitLen()/2 {
		return errors.New("crypto/rsa: d too small")
	}

	// If the key is still in scope for FIPS mode, perform a Pairwise
	// Consistency Test.
	if .fipsApproved {
		if  := fips140.PCT("RSA sign and verify PCT", func() error {
			 := []byte{
				0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
				0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
				0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
				0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
			}
			,  := signPKCS1v15(, "SHA-256", )
			if  != nil {
				return 
			}
			return verifyPKCS1v15(.PublicKey(), "SHA-256", , )
		});  != nil {
			return 
		}
	}

	return nil
}

func checkPublicKey( *PublicKey) ( bool,  error) {
	 = true
	if .N == nil {
		return false, errors.New("crypto/rsa: missing public modulus")
	}
	if .N.Nat().IsOdd() == 0 {
		return false, errors.New("crypto/rsa: public modulus is even")
	}
	// FIPS 186-5, Section 5.1: "This standard specifies the use of a modulus
	// whose bit length is an even integer and greater than or equal to 2048
	// bits."
	if .N.BitLen() < 2048 {
		 = false
	}
	if .N.BitLen()%2 == 1 {
		 = false
	}
	if .E < 2 {
		return false, errors.New("crypto/rsa: public exponent too small or negative")
	}
	// e needs to be coprime with p-1 and q-1, since it must be invertible
	// modulo λ(pq). Since p and q are prime, this means e needs to be odd.
	if .E&1 == 0 {
		return false, errors.New("crypto/rsa: public exponent is even")
	}
	// FIPS 186-5, Section 5.5(e): "The exponent e shall be an odd, positive
	// integer such that 2¹⁶ < e < 2²⁵⁶."
	if .E <= 1<<16 {
		 = false
	}
	// We require pub.E to fit into a 32-bit integer so that we
	// do not have different behavior depending on whether
	// int is 32 or 64 bits. See also
	// https://www.imperialviolet.org/2012/03/16/rsae.html.
	if .E > 1<<31-1 {
		return false, errors.New("crypto/rsa: public exponent too large")
	}
	return , nil
}

// Encrypt performs the RSA public key operation.
func ( *PublicKey,  []byte) ([]byte, error) {
	fips140.RecordNonApproved()
	if ,  := checkPublicKey();  != nil {
		return nil, 
	}
	return encrypt(, )
}

func encrypt( *PublicKey,  []byte) ([]byte, error) {
	,  := bigmod.NewNat().SetBytes(, .N)
	if  != nil {
		return nil, 
	}
	return bigmod.NewNat().ExpShortVarTime(, uint(.E), .N).Bytes(.N), nil
}

var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA key size")
var ErrDecryption = errors.New("crypto/rsa: decryption error")
var ErrVerification = errors.New("crypto/rsa: verification error")

const withCheck = true
const noCheck = false

// DecryptWithoutCheck performs the RSA private key operation.
func ( *PrivateKey,  []byte) ([]byte, error) {
	fips140.RecordNonApproved()
	return decrypt(, , noCheck)
}

// DecryptWithCheck performs the RSA private key operation and checks the
// result to defend against errors in the CRT computation.
func ( *PrivateKey,  []byte) ([]byte, error) {
	fips140.RecordNonApproved()
	return decrypt(, , withCheck)
}

// decrypt performs an RSA decryption of ciphertext into out. If check is true,
// m^e is calculated and compared with ciphertext, in order to defend against
// errors in the CRT computation.
func decrypt( *PrivateKey,  []byte,  bool) ([]byte, error) {
	if !.fipsApproved {
		fips140.RecordNonApproved()
	}

	var  *bigmod.Nat
	,  := .pub.N, .pub.E

	,  := bigmod.NewNat().SetBytes(, )
	if  != nil {
		return nil, ErrDecryption
	}

	if .dP == nil {
		// Legacy codepath for deprecated multi-prime keys.
		fips140.RecordNonApproved()
		 = bigmod.NewNat().Exp(, .d.Bytes(), )

	} else {
		,  := .p, .q
		 := bigmod.NewNat()
		// m = c ^ Dp mod p
		 = bigmod.NewNat().Exp(.Mod(, ), .dP, )
		// m2 = c ^ Dq mod q
		 := bigmod.NewNat().Exp(.Mod(, ), .dQ, )
		// m = m - m2 mod p
		.Sub(.Mod(, ), )
		// m = m * Qinv mod p
		.Mul(.qInv, )
		// m = m * q mod N
		.ExpandFor().Mul(.Mod(.Nat(), ), )
		// m = m + m2 mod N
		.Add(.ExpandFor(), )
	}

	if  {
		 := bigmod.NewNat().ExpShortVarTime(, uint(), )
		if .Equal() != 1 {
			return nil, ErrDecryption
		}
	}

	return .Bytes(), nil
}