// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package asmgen

// mulAddVWW generates mulAddVWW, which does z, c = x*m + a.
func mulAddVWW( *Asm) {
	 := .Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)")

	if .AltCarry().Valid() {
		addMulVirtualCarry(, 0)
		return
	}
	addMul(, "", "x", 0)
}

// addMulVVWW generates addMulVVWW which does z, c = x + y*m + a.
// (A more pedantic name would be addMulAddVVWW.)
func addMulVVWW( *Asm) {
	 := .Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)")

	// If the architecture has virtual carries, emit that version unconditionally.
	if .AltCarry().Valid() {
		addMulVirtualCarry(, 1)
		return
	}

	// If the architecture optionally has two carries, test and emit both versions.
	if .JmpEnable(OptionAltCarry, "altcarry") {
		 := .RegsUsed()
		addMul(, "x", "y", 1)
		.Label("altcarry")
		.SetOption(OptionAltCarry, true)
		.SetRegsUsed()
		addMulAlt()
		.SetOption(OptionAltCarry, false)
		return
	}

	// Otherwise emit the one-carry form.
	addMul(, "x", "y", 1)
}

// Computing z = addsrc + m*mulsrc + a, we need:
//
//	for i := range z {
//		lo, hi := m * mulsrc[i]
//		lo, carry = bits.Add(lo, a, 0)
//		lo, carryAlt = bits.Add(lo, addsrc[i], 0)
//		z[i] = lo
//		a = hi + carry + carryAlt  // cannot overflow
//	}
//
// The final addition cannot overflow because after processing N words,
// the maximum possible value is (for a 64-bit system):
//
//	  (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1)
//	= (2**64)*(2**64N - 1) + (2**64 - 1)
//	= 2**64(N+1) - 1,
//
// which fits in N+1 words (the high order one being the new value of a).
//
// (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.)
//
// If we unroll the loop a bit, then we can chain the carries in two passes.
// Consider:
//
//	lo0, hi0 := m * mulsrc[i]
//	lo0, carry = bits.Add(lo0, a, 0)
//	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
//	z[i] = lo0
//	a = hi + carry + carryAlt // cannot overflow
//
//	lo1, hi1 := m * mulsrc[i]
//	lo1, carry = bits.Add(lo1, a, 0)
//	lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
//	z[i] = lo1
//	a = hi + carry + carryAlt // cannot overflow
//
//	lo2, hi2 := m * mulsrc[i]
//	lo2, carry = bits.Add(lo2, a, 0)
//	lo2, carryAlt = bits.Add(lo2, addsrc[i], 0)
//	z[i] = lo2
//	a = hi + carry + carryAlt // cannot overflow
//
//	lo3, hi3 := m * mulsrc[i]
//	lo3, carry = bits.Add(lo3, a, 0)
//	lo3, carryAlt = bits.Add(lo3, addsrc[i], 0)
//	z[i] = lo3
//	a = hi + carry + carryAlt // cannot overflow
//
// There are three ways we can optimize this sequence.
//
// (1) Reordering, we can chain carries so that we can use one hardware carry flag
// but amortize the cost of saving and restoring it across multiple instructions:
//
//	// multiply
//	lo0, hi0 := m * mulsrc[i]
//	lo1, hi1 := m * mulsrc[i+1]
//	lo2, hi2 := m * mulsrc[i+2]
//	lo3, hi3 := m * mulsrc[i+3]
//
//	lo0, carry = bits.Add(lo0, a, 0)
//	lo1, carry = bits.Add(lo1, hi0, carry)
//	lo2, carry = bits.Add(lo2, hi1, carry)
//	lo3, carry = bits.Add(lo3, hi2, carry)
//	a = hi3 + carry // cannot overflow
//
//	// add
//	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
//	lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
//	lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
//	lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
//	a = a + carryAlt // cannot overflow
//
//	z[i] = lo0
//	z[i+1] = lo1
//	z[i+2] = lo2
//	z[i+3] = lo3
//
// addMul takes this approach, using the hardware carry flag
// first for carry and then for carryAlt.
//
// (2) addMulAlt assumes there are two hardware carry flags available.
// It dedicates one each to carry and carryAlt, so that a multi-block
// unrolling can keep the flags in hardware across all the blocks.
// So even if the block size is 1, the code can do:
//
//	// multiply and add
//	lo0, hi0 := m * mulsrc[i]
//	lo0, carry = bits.Add(lo0, a, 0)
//	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
//	z[i] = lo0
//
//	lo1, hi1 := m * mulsrc[i+1]
//	lo1, carry = bits.Add(lo1, hi0, carry)
//	lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
//	z[i+1] = lo1
//
//	lo2, hi2 := m * mulsrc[i+2]
//	lo2, carry = bits.Add(lo2, hi1, carry)
//	lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
//	z[i+2] = lo2
//
//	lo3, hi3 := m * mulsrc[i+3]
//	lo3, carry = bits.Add(lo3, hi2, carry)
//	lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
//	z[i+3] = lo2
//
//	a = hi3 + carry + carryAlt // cannot overflow
//
// (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits
// (loong64, mips, riscv64), cutting the number of actual instructions almost by half.
// Look again at the original word-at-a-time version:
//
//	lo1, hi1 := m * mulsrc[i]
//	lo1, carry = bits.Add(lo1, a, 0)
//	lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
//	z[i] = lo1
//	a = hi + carry + carryAlt // cannot overflow
//
// Although it uses four adds per word, those are cheap adds: the two bits.Add adds
// use two instructions each (ADD+SLTU) and the final + adds only use one ADD each,
// for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use
// only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to
// five instruction each, for a total of 10 instructions per word. So the word-at-a-time
// loop is actually better. And we can reorder things slightly to use only a single carry bit:
//
//	lo1, hi1 := m * mulsrc[i]
//	lo1, carry = bits.Add(lo1, a, 0)
//	a = hi + carry
//	lo1, carry = bits.Add(lo1, addsrc[i], 0)
//	a = a + carry
//	z[i] = lo1
func addMul( *Func, ,  string,  int) {
	 := .Asm
	 := HintNone
	if .Arch == Arch386 &&  != "" {
		 = HintMemOK // too few registers otherwise
	}
	 := .ArgHint("m", )
	 := .Arg("a")
	 := .Arg("z_len")

	 := .Pipe()
	if  != "" {
		.SetHint(, HintMemOK)
	}
	.SetHint(, HintMulSrc)
	 := []int{1, 4}
	switch .Arch {
	case Arch386:
		 = []int{1} // too few registers
	case ArchARM:
		.SetMaxColumns(2) // too few registers (but more than 386)
	case ArchARM64:
		 = []int{1, 8} // 5% speedup on c4as16
	}

	// See the large comment above for an explanation of the code being generated.
	// This is optimization strategy 1.
	.Start(, ...)
	.Loop(func(,  [][]Reg) {
		.Comment("multiply")
		 := 
		 := SetCarry
		for ,  := range [] {
			 := .RegHint(HintMulHi)
			.MulWide(, , , )
			.Add(, , , )
			 = UseCarry | SetCarry
			if  !=  {
				.Free()
			}
			[0][] = 
			 = 
		}
		.Add(.Imm(0), , , UseCarry|SmashCarry)
		if  != "" {
			.Comment("add")
			 := SetCarry
			for ,  := range [0] {
				.Add(, [0][], [0][], )
				 = UseCarry | SetCarry
			}
			.Add(.Imm(0), , , UseCarry|SmashCarry)
		}
		.StoreN()
	})

	.StoreArg(, "c")
	.Ret()
}

func addMulAlt( *Func) {
	 := .Asm
	 := .ArgHint("m", HintMulSrc)
	 := .Arg("a")
	 := .Arg("z_len")

	// On amd64, we need a non-immediate for the AtUnrollEnd adds.
	 := .ZR()
	if !.Valid() {
		 = .Reg()
		.Mov(.Imm(0), )
	}

	 := .Pipe()
	.SetLabel("alt")
	.SetHint("x", HintMemOK)
	.SetHint("y", HintMemOK)
	if .Arch == ArchAMD64 {
		.SetMaxColumns(2)
	}

	// See the large comment above for an explanation of the code being generated.
	// This is optimization strategy (2).
	var  Reg
	 := 
	.Start(, 1, 8)
	.AtUnrollStart(func() {
		.Comment("multiply and add")
		.ClearCarry(AddCarry | AltCarry)
		.ClearCarry(AddCarry)
		 = .Reg()
	})
	.AtUnrollEnd(func() {
		.Add(, , , UseCarry|SmashCarry)
		.Add(, , , UseCarry|SmashCarry|AltCarry)
		 = 
	})
	.Loop(func(,  [][]Reg) {
		for ,  := range [1] {
			 := [0][]
			 := 
			if .IsMem() {
				 = .Reg()
			}
			.MulWide(, , , )
			.Add(, , , UseCarry|SetCarry)
			.Add(, , , UseCarry|SetCarry|AltCarry)
			[0][] = 
			,  = , 
		}
		.StoreN()
	})

	.StoreArg(, "c")
	.Ret()
}

func addMulVirtualCarry( *Func,  int) {
	 := .Asm
	 := .Arg("m")
	 := .Arg("a")
	 := .Arg("z_len")

	// See the large comment above for an explanation of the code being generated.
	// This is optimization strategy (3).
	 := .Pipe()
	.Start(, 1, 4)
	.Loop(func(,  [][]Reg) {
		.Comment("synthetic carry, one column at a time")
		,  := .Reg(), .Reg()
		for ,  := range [] {
			.MulWide(, , , )
			if  == 1 {
				.Add([0][], , , SetCarry)
				.Add(.Imm(0), , , UseCarry|SmashCarry)
			}
			.Add(, , , SetCarry)
			.Add(.Imm(0), , , UseCarry|SmashCarry)
			[0][] = 
		}
		.StoreN()
	})
	.StoreArg(, "c")
	.Ret()
}