// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package math

import 

func zero( uint64) uint64 {
	if  == 0 {
		return 1
	}
	return 0
	// branchless:
	// return ((x>>1 | x&1) - 1) >> 63
}

func nonzero( uint64) uint64 {
	if  != 0 {
		return 1
	}
	return 0
	// branchless:
	// return 1 - ((x>>1|x&1)-1)>>63
}

func shl(,  uint64,  uint) (,  uint64) {
	 = << | >>(64-) | <<(-64)
	 =  << 
	return
}

func shr(,  uint64,  uint) (,  uint64) {
	 = >> | <<(64-) | >>(-64)
	 =  >> 
	return
}

// shrcompress compresses the bottom n+1 bits of the two-word
// value into a single bit. the result is equal to the value
// shifted to the right by n, except the result's 0th bit is
// set to the bitwise OR of the bottom n+1 bits.
func shrcompress(,  uint64,  uint) (,  uint64) {
	// TODO: Performance here is really sensitive to the
	// order/placement of these branches. n == 0 is common
	// enough to be in the fast path. Perhaps more measurement
	// needs to be done to find the optimal order/placement?
	switch {
	case  == 0:
		return , 
	case  == 64:
		return 0,  | nonzero()
	case  >= 128:
		return 0, nonzero( | )
	case  < 64:
		,  = shr(, , )
		 |= nonzero( & (1<< - 1))
	case  < 128:
		,  = shr(, , )
		 |= nonzero(&(1<<(-64)-1) | )
	}
	return
}

func lz(,  uint64) ( int32) {
	 = int32(bits.LeadingZeros64())
	if  == 64 {
		 += int32(bits.LeadingZeros64())
	}
	return 
}

// split splits b into sign, biased exponent, and mantissa.
// It adds the implicit 1 bit to the mantissa for normal values,
// and normalizes subnormal values.
func split( uint64) ( uint32,  int32,  uint64) {
	 = uint32( >> 63)
	 = int32(>>52) & mask
	 =  & fracMask

	if  == 0 {
		// Normalize value if subnormal.
		 := uint(bits.LeadingZeros64() - 11)
		 <<= 
		 = 1 - int32()
	} else {
		// Add implicit 1 bit
		 |= 1 << 52
	}
	return
}

// FMA returns x * y + z, computed with only one rounding.
// (That is, FMA returns the fused multiply-add of x, y, and z.)
func (, ,  float64) float64 {
	, ,  := Float64bits(), Float64bits(), Float64bits()

	// Inf or NaN or zero involved. At most one rounding will occur.
	if  == 0.0 ||  == 0.0 ||  == 0.0 || &uvinf == uvinf || &uvinf == uvinf {
		return * + 
	}
	// Handle non-finite z separately. Evaluating x*y+z where
	// x and y are finite, but z is infinite, should always result in z.
	if &uvinf == uvinf {
		return 
	}

	// Inputs are (sub)normal.
	// Split x, y, z into sign, exponent, mantissa.
	, ,  := split()
	, ,  := split()
	, ,  := split()

	// Compute product p = x*y as sign, exponent, two-word mantissa.
	// Start with exponent. "is normal" bit isn't subtracted yet.
	 :=  +  - bias + 1

	// pm1:pm2 is the double-word mantissa for the product p.
	// Shift left to leave top bit in product. Effectively
	// shifts the 106-bit product to the left by 21.
	,  := bits.Mul64(<<10, <<11)
	,  := <<10, uint64(0)
	 :=  ^  // product sign

	// normalize to 62nd bit
	 := uint((^ >> 62) & 1)
	,  = shl(, , )
	 -= int32()

	// Swap addition operands so |p| >= |z|
	if  <  ||  ==  &&  <  {
		, , , , , , ,  = , , , , , , , 
	}

	// Special case: if p == -z the result is always +0 since neither operand is zero.
	if  !=  &&  ==  &&  ==  &&  ==  {
		return 0
	}

	// Align significands
	,  = shrcompress(, , uint(-))

	// Compute resulting significands, normalizing if necessary.
	var ,  uint64
	if  ==  {
		// Adding (pm1:pm2) + (zm1:zm2)
		,  = bits.Add64(, , 0)
		, _ = bits.Add64(, , )
		 -= int32(^ >> 63)
		,  = shrcompress(, , uint(64+>>63))
	} else {
		// Subtracting (pm1:pm2) - (zm1:zm2)
		// TODO: should we special-case cancellation?
		,  = bits.Sub64(, , 0)
		, _ = bits.Sub64(, , )
		 := lz(, )
		 -= 
		,  = shl(, , uint(-1))
		 |= nonzero()
	}

	// Round and break ties to even
	if  > 1022+bias ||  == 1022+bias && (+1<<9)>>63 == 1 {
		// rounded value overflows exponent range
		return Float64frombits(uint64()<<63 | uvinf)
	}
	if  < 0 {
		 := uint(-)
		 = >> | nonzero(&(1<<-1))
		 = 0
	}
	 = (( + 1<<9) >> 10) & ^zero((&(1<<10-1))^1<<9)
	 &= -int32(nonzero())
	return Float64frombits(uint64()<<63 + uint64()<<52 + )
}