// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Multiplication.

package big

// Operands that are shorter than karatsubaThreshold are multiplied using
// "grade school" multiplication; for longer operands the Karatsuba algorithm
// is used.
var karatsubaThreshold = 40 // see calibrate_test.go

// mul sets z = x*y, using stk for temporary storage.
// The caller may pass stk == nil to request that mul obtain and release one itself.
func ( nat) ( *stack, ,  nat) nat {
	 := len()
	 := len()

	switch {
	case  < :
		return .(, , )
	case  == 0 ||  == 0:
		return [:0]
	case  == 1:
		return .mulAddWW(, [0], 0)
	}
	// m >= n > 1

	// determine if z can be reused
	if alias(, ) || alias(, ) {
		 = nil // z is an alias for x or y - cannot reuse
	}
	 = .make( + )

	// use basic multiplication if the numbers are small
	if  < karatsubaThreshold {
		basicMul(, , )
		return .norm()
	}

	if  == nil {
		 = getStack()
		defer .free()
	}

	// Let x = x1:x0 where x0 is the same length as y.
	// Compute z = x0*y and then add in x1*y in sections
	// if needed.
	karatsuba(, [:2*], [:], )

	if  <  {
		clear([2*:])
		defer .restore(.save())
		 := .nat(2 * )
		for  := ;  < ;  +=  {
			 = .(, [:min(+, len())], )
			addTo([:], )
		}
	}

	return .norm()
}

// Operands that are shorter than basicSqrThreshold are squared using
// "grade school" multiplication; for operands longer than karatsubaSqrThreshold
// we use the Karatsuba algorithm optimized for x == y.
var basicSqrThreshold = 12     // see calibrate_test.go
var karatsubaSqrThreshold = 80 // see calibrate_test.go

// sqr sets z = x*x, using stk for temporary storage.
// The caller may pass stk == nil to request that sqr obtain and release one itself.
func ( nat) ( *stack,  nat) nat {
	 := len()
	switch {
	case  == 0:
		return [:0]
	case  == 1:
		 := [0]
		 = .make(2)
		[1], [0] = mulWW(, )
		return .norm()
	}

	if alias(, ) {
		 = nil // z is an alias for x - cannot reuse
	}
	 = .make(2 * )

	if  < basicSqrThreshold &&  < karatsubaSqrThreshold {
		basicMul(, , )
		return .norm()
	}

	if  == nil {
		 = getStack()
		defer .free()
	}

	if  < karatsubaSqrThreshold {
		basicSqr(, , )
		return .norm()
	}

	karatsubaSqr(, , )
	return .norm()
}

// basicSqr sets z = x*x and is asymptotically faster than basicMul
// by about a factor of 2, but slower for small arguments due to overhead.
// Requirements: len(x) > 0, len(z) == 2*len(x)
// The (non-normalized) result is placed in z.
func basicSqr( *stack, ,  nat) {
	 := len()
	if  < basicSqrThreshold {
		basicMul(, , )
		return
	}

	defer .restore(.save())
	 := .nat(2 * )
	clear()
	[1], [0] = mulWW([0], [0]) // the initial square
	for  := 1;  < ; ++ {
		 := []
		// z collects the squares x[i] * x[i]
		[2*+1], [2*] = mulWW(, )
		// t collects the products x[i] * x[j] where j < i
		[2*] = addMulVVWW([:2*], [:2*], [0:], , 0)
	}
	[2*-1] = lshVU([1:2*-1], [1:2*-1], 1) // double the j < i products
	addVV(, , )                              // combine the result
}

// mulAddWW returns z = x*y + r.
func ( nat) ( nat, ,  Word) nat {
	 := len()
	if  == 0 ||  == 0 {
		return .setWord() // result is r
	}
	// m > 0

	 = .make( + 1)
	[] = mulAddVWW([0:], , , )

	return .norm()
}

// basicMul multiplies x and y and leaves the result in z.
// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
func basicMul(, ,  nat) {
	clear([0 : len()+len()]) // initialize z
	for ,  := range  {
		if  != 0 {
			[len()+] = addMulVVWW([:+len()], [:+len()], , , 0)
		}
	}
}

// karatsuba multiplies x and y,
// writing the (non-normalized) result to z.
// x and y must have the same length n,
// and z must have length twice that.
func karatsuba( *stack, , ,  nat) {
	 := len()
	if len() !=  || len() != 2* {
		panic("bad karatsuba length")
	}

	// Fall back to basic algorithm if small enough.
	if  < karatsubaThreshold ||  < 2 {
		basicMul(, , )
		return
	}

	// Let the notation x1:x0 denote the nat (x1<<N)+x0 for some N,
	// and similarly z2:z1:z0 = (z2<<2N)+(z1<<N)+z0.
	//
	// (Note that z0, z1, z2 might be ≥ 2**N, in which case the high
	// bits of, say, z0 are being added to the low bits of z1 in this notation.)
	//
	// Karatsuba multiplication is based on the observation that
	//
	//	x1:x0 * y1:y0 = x1*y1:(x0*y1+y0*x1):x0*y0
	//	              = x1*y1:((x0-x1)*(y1-y0)+x1*y1+x0*y0):x0*y0
	//
	// The second form uses only three half-width multiplications
	// instead of the four that the straightforward first form does.
	//
	// We call the three pieces z0, z1, z2:
	//
	//	z0 = x0*y0
	//	z2 = x1*y1
	//	z1 = (x0-x1)*(y1-y0) + z0 + z2

	 := ( + 1) / 2
	,  := &Int{abs: [:].norm()}, &Int{abs: [:].norm()}
	,  := &Int{abs: [:].norm()}, &Int{abs: [:].norm()}
	 := &Int{abs: [0 : 2*]}
	 := &Int{abs: [2*:]}

	// Allocate temporary storage for z1; repurpose z0 to hold tx and ty.
	defer .restore(.save())
	 := &Int{abs: .nat(2* + 1)}
	 := &Int{abs: [0:]}
	 := &Int{abs: [ : 2*]}

	.Sub(, )
	.Sub(, )
	.mul(, , )

	clear()
	.mul(, , )
	.mul(, , )
	.Add(, )
	.Add(, )
	addTo([:], .abs)

	// Debug mode: double-check answer and print trace on failure.
	const  = false
	if  {
		 := make(nat, len())
		basicMul(, , )
		if .cmp() != 0 {
			// All the temps were aliased to z and gone. Recompute.
			 = new(Int)
			.mul(, , )
			 = new(Int).Sub(, )
			 = new(Int).Sub(, )
			 = new(Int)
			.mul(, , )
			print("karatsuba wrong\n")
			trace("x ", &Int{abs: })
			trace("y ", &Int{abs: })
			trace("z ", &Int{abs: })
			trace("zz", &Int{abs: })
			trace("x0", )
			trace("x1", )
			trace("y0", )
			trace("y1", )
			trace("tx", )
			trace("ty", )
			trace("z0", )
			trace("z1", )
			trace("z2", )
			panic("karatsuba")
		}
	}

}

// karatsubaSqr squares x,
// writing the (non-normalized) result to z.
// z must have length 2*len(x).
// It is analogous to [karatsuba] but can run faster
// knowing both multiplicands are the same value.
func karatsubaSqr( *stack, ,  nat) {
	 := len()
	if len() != 2* {
		panic("bad karatsubaSqr length")
	}

	if  < karatsubaSqrThreshold ||  < 2 {
		basicSqr(, , )
		return
	}

	// Recall that for karatsuba we want to compute:
	//
	//	x1:x0 * y1:y0 = x1y1:(x0y1+y0x1):x0y0
	//                = x1y1:((x0-x1)*(y1-y0)+x1y1+x0y0):x0y0
	//	              = z2:z1:z0
	// where:
	//
	//	z0 = x0y0
	//	z2 = x1y1
	//	z1 = (x0-x1)*(y1-y0) + z0 + z2
	//
	// When x = y, these simplify to:
	//
	//	z0 = x0²
	//	z2 = x1²
	//	z1 = z0 + z2 - (x0-x1)²

	 := ( + 1) / 2
	,  := &Int{abs: [:].norm()}, &Int{abs: [:].norm()}
	 := &Int{abs: [0 : 2*]}
	 := &Int{abs: [2*:]}

	// Allocate temporary storage for z1; repurpose z0 to hold tx.
	defer .restore(.save())
	 := &Int{abs: .nat(2* + 1)}
	 := &Int{abs: [0:]}

	.Sub(, )
	.abs = .abs.sqr(, .abs)
	.neg = true

	clear()
	.abs = .abs.sqr(, .abs)
	.abs = .abs.sqr(, .abs)
	.Add(, )
	.Add(, )
	addTo([:], .abs)

	// Debug mode: double-check answer and print trace on failure.
	const  = false
	if  {
		 := make(nat, len())
		basicSqr(, , )
		if .cmp() != 0 {
			// All the temps were aliased to z and gone. Recompute.
			 = new(Int).Sub(, )
			 = new(Int).Mul(, )
			 = new(Int).Mul(, )
			 = new(Int).Mul(, )
			.Neg()
			.Add(, )
			.Add(, )
			print("karatsubaSqr wrong\n")
			trace("x ", &Int{abs: })
			trace("z ", &Int{abs: })
			trace("zz", &Int{abs: })
			trace("x0", )
			trace("x1", )
			trace("z0", )
			trace("z1", )
			trace("z2", )
			panic("karatsubaSqr")
		}
	}
}

// ifmt returns the debug formatting of the Int x: 0xHEX.
func ifmt( *Int) string {
	, ,  := "", .Text(16), ""
	if  == "" { // happens for denormalized zero
		 = "0x0"
	}
	if [0] == '-' {
		,  = "-", [1:]
	}

	// Add _ between words.
	const  = _W / 4 // digits per chunk
	for len() >  {
		,  = [:len()-], [len()-:]+"_"+
	}
	return  +  + 
}

// trace prints a single debug value.
func trace( string,  *Int) {
	print(, "=", ifmt(), "\n")
}