// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package concurrent

import (
	
	
	
	
	
	
)

// HashTrieMap is an implementation of a concurrent hash-trie. The implementation
// is designed around frequent loads, but offers decent performance for stores
// and deletes as well, especially if the map is larger. It's primary use-case is
// the unique package, but can be used elsewhere as well.
type HashTrieMap[,  comparable] struct {
	root     *indirect[, ]
	keyHash  hashFunc
	keyEqual equalFunc
	valEqual equalFunc
	seed     uintptr
}

// NewHashTrieMap creates a new HashTrieMap for the provided key and value.
func [,  comparable]() *HashTrieMap[, ] {
	var  map[]
	 := abi.TypeOf().MapType()
	 := &HashTrieMap[, ]{
		root:     newIndirectNode[, ](nil),
		keyHash:  .Hasher,
		keyEqual: .Key.Equal,
		valEqual: .Elem.Equal,
		seed:     uintptr(rand.Uint64()),
	}
	return 
}

type hashFunc func(unsafe.Pointer, uintptr) uintptr
type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool

// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func ( *HashTrieMap[, ]) ( ) ( ,  bool) {
	 := .keyHash(abi.NoEscape(unsafe.Pointer(&)), .seed)

	 := .root
	 := 8 * goarch.PtrSize
	for  != 0 {
		 -= nChildrenLog2

		 := .children[(>>)&nChildrenMask].Load()
		if  == nil {
			return *new(), false
		}
		if .isEntry {
			return .entry().lookup(, .keyEqual)
		}
		 = .indirect()
	}
	panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func ( *HashTrieMap[, ]) ( ,  ) ( ,  bool) {
	 := .keyHash(abi.NoEscape(unsafe.Pointer(&)), .seed)
	var  *indirect[, ]
	var  uint
	var  *atomic.Pointer[node[, ]]
	var  *node[, ]
	for {
		// Find the key or a candidate location for insertion.
		 = .root
		 = 8 * goarch.PtrSize
		 := false
		for  != 0 {
			 -= nChildrenLog2

			 = &.children[(>>)&nChildrenMask]
			 = .Load()
			if  == nil {
				// We found a nil slot which is a candidate for insertion.
				 = true
				break
			}
			if .isEntry {
				// We found an existing entry, which is as far as we can go.
				// If it stays this way, we'll have to replace it with an
				// indirect node.
				if ,  := .entry().lookup(, .keyEqual);  {
					return , true
				}
				 = true
				break
			}
			 = .indirect()
		}
		if ! {
			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
		}

		// Grab the lock and double-check what we saw.
		.mu.Lock()
		 = .Load()
		if ( == nil || .isEntry) && !.dead.Load() {
			// What we saw is still true, so we can continue with the insert.
			break
		}
		// We have to start over.
		.mu.Unlock()
	}
	// N.B. This lock is held from when we broke out of the outer loop above.
	// We specifically break this out so that we can use defer here safely.
	// One option is to break this out into a new function instead, but
	// there's so much local iteration state used below that this turns out
	// to be cleaner.
	defer .mu.Unlock()

	var  *entry[, ]
	if  != nil {
		 = .entry()
		if ,  := .lookup(, .keyEqual);  {
			// Easy case: by loading again, it turns out exactly what we wanted is here!
			return , true
		}
	}
	 := newEntryNode(, )
	if  == nil {
		// Easy case: create a new entry and store it.
		.Store(&.node)
	} else {
		// We possibly need to expand the entry already there into one or more new nodes.
		//
		// Publish the node last, which will make both oldEntry and newEntry visible. We
		// don't want readers to be able to observe that oldEntry isn't in the tree.
		.Store(.expand(, , , , ))
	}
	return , false
}

// expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and
// produces a subtree of indirect nodes to hold the two new entries.
func ( *HashTrieMap[, ]) (,  *entry[, ],  uintptr,  uint,  *indirect[, ]) *node[, ] {
	// Check for a hash collision.
	 := .keyHash(unsafe.Pointer(&.key), .seed)
	if  ==  {
		// Store the old entry in the new entry's overflow list, then store
		// the new entry.
		.overflow.Store()
		return &.node
	}
	// We have to add an indirect node. Worse still, we may need to add more than one.
	 := newIndirectNode()
	 := 
	for {
		if  == 0 {
			panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting")
		}
		 -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper.
		 := ( >> ) & nChildrenMask
		 := ( >> ) & nChildrenMask
		if  !=  {
			.children[].Store(&.node)
			.children[].Store(&.node)
			break
		}
		 := newIndirectNode()
		.children[].Store(&.node)
		 = 
	}
	return &.node
}

// CompareAndDelete deletes the entry for key if its value is equal to old.
//
// If there is no current value for key in the map, CompareAndDelete returns false
// (even if the old value is the nil interface value).
func ( *HashTrieMap[, ]) ( ,  ) ( bool) {
	 := .keyHash(abi.NoEscape(unsafe.Pointer(&)), .seed)
	var  *indirect[, ]
	var  uint
	var  *atomic.Pointer[node[, ]]
	var  *node[, ]
	for {
		// Find the key or return when there's nothing to delete.
		 = .root
		 = 8 * goarch.PtrSize
		 := false
		for  != 0 {
			 -= nChildrenLog2

			 = &.children[(>>)&nChildrenMask]
			 = .Load()
			if  == nil {
				// Nothing to delete. Give up.
				return
			}
			if .isEntry {
				// We found an entry. Check if it matches.
				if ,  := .entry().lookup(, .keyEqual); ! {
					// No match, nothing to delete.
					return
				}
				// We've got something to delete.
				 = true
				break
			}
			 = .indirect()
		}
		if ! {
			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
		}

		// Grab the lock and double-check what we saw.
		.mu.Lock()
		 = .Load()
		if !.dead.Load() {
			if  == nil {
				// Valid node that doesn't contain what we need. Nothing to delete.
				.mu.Unlock()
				return
			}
			if .isEntry {
				// What we saw is still true, so we can continue with the delete.
				break
			}
		}
		// We have to start over.
		.mu.Unlock()
	}
	// Try to delete the entry.
	,  := .entry().compareAndDelete(, , .keyEqual, .valEqual)
	if ! {
		// Nothing was actually deleted, which means the node is no longer there.
		.mu.Unlock()
		return false
	}
	if  != nil {
		// We didn't actually delete the whole entry, just one entry in the chain.
		// Nothing else to do, since the parent is definitely not empty.
		.Store(&.node)
		.mu.Unlock()
		return true
	}
	// Delete the entry.
	.Store(nil)

	// Check if the node is now empty (and isn't the root), and delete it if able.
	for .parent != nil && .empty() {
		if  == 8*goarch.PtrSize {
			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
		}
		 += nChildrenLog2

		// Delete the current node in the parent.
		 := .parent
		.mu.Lock()
		.dead.Store(true)
		.children[(>>)&nChildrenMask].Store(nil)
		.mu.Unlock()
		 = 
	}
	.mu.Unlock()
	return true
}

// All returns an iter.Seq2 that produces all key-value pairs in the map.
// The enumeration does not represent any consistent snapshot of the map,
// but is guaranteed to visit each unique key-value pair only once. It is
// safe to operate on the tree during iteration. No particular enumeration
// order is guaranteed.
func ( *HashTrieMap[, ]) () func( func(, ) bool) {
	return func( func( ,  ) bool) {
		.iter(.root, )
	}
}

func ( *HashTrieMap[, ]) ( *indirect[, ],  func( ,  ) bool) bool {
	for  := range .children {
		 := .children[].Load()
		if  == nil {
			continue
		}
		if !.isEntry {
			if !.(.indirect(), ) {
				return false
			}
			continue
		}
		 := .entry()
		for  != nil {
			if !(.key, .value) {
				return false
			}
			 = .overflow.Load()
		}
	}
	return true
}

const (
	// 16 children. This seems to be the sweet spot for
	// load performance: any smaller and we lose out on
	// 50% or more in CPU performance. Any larger and the
	// returns are minuscule (~1% improvement for 32 children).
	nChildrenLog2 = 4
	nChildren     = 1 << nChildrenLog2
	nChildrenMask = nChildren - 1
)

// indirect is an internal node in the hash-trie.
type indirect[,  comparable] struct {
	node[, ]
	dead     atomic.Bool
	mu       sync.Mutex // Protects mutation to children and any children that are entry nodes.
	parent   *indirect[, ]
	children [nChildren]atomic.Pointer[node[, ]]
}

func newIndirectNode[,  comparable]( *indirect[, ]) *indirect[, ] {
	return &indirect[, ]{node: node[, ]{isEntry: false}, parent: }
}

func ( *indirect[, ]) () bool {
	 := 0
	for  := range .children {
		if .children[].Load() != nil {
			++
		}
	}
	return  == 0
}

// entry is a leaf node in the hash-trie.
type entry[,  comparable] struct {
	node[, ]
	overflow atomic.Pointer[entry[, ]] // Overflow for hash collisions.
	key      
	value    
}

func newEntryNode[,  comparable]( ,  ) *entry[, ] {
	return &entry[, ]{
		node:  node[, ]{isEntry: true},
		key:   ,
		value: ,
	}
}

func ( *entry[, ]) ( ,  equalFunc) (, bool) {
	for  != nil {
		if (unsafe.Pointer(&.key), abi.NoEscape(unsafe.Pointer(&))) {
			return .value, true
		}
		 = .overflow.Load()
	}
	return *new(), false
}

// compareAndDelete deletes an entry in the overflow chain if both the key and value compare
// equal. Returns the new entry chain and whether or not anything was deleted.
//
// compareAndDelete must be called under the mutex of the indirect node which e is a child of.
func ( *entry[, ]) ( ,  , ,  equalFunc) (*entry[, ], bool) {
	if (unsafe.Pointer(&.key), abi.NoEscape(unsafe.Pointer(&))) &&
		(unsafe.Pointer(&.value), abi.NoEscape(unsafe.Pointer(&))) {
		// Drop the head of the list.
		return .overflow.Load(), true
	}
	 := &.overflow
	 := .Load()
	for  != nil {
		if (unsafe.Pointer(&.key), abi.NoEscape(unsafe.Pointer(&))) &&
			(unsafe.Pointer(&.value), abi.NoEscape(unsafe.Pointer(&))) {
			.Store(.overflow.Load())
			return , true
		}
		 = &.overflow
		 = .overflow.Load()
	}
	return , false
}

// node is the header for a node. It's polymorphic and
// is actually either an entry or an indirect.
type node[,  comparable] struct {
	isEntry bool
}

func ( *node[, ]) () *entry[, ] {
	if !.isEntry {
		panic("called entry on non-entry node")
	}
	return (*entry[, ])(unsafe.Pointer())
}

func ( *node[, ]) () *indirect[, ] {
	if .isEntry {
		panic("called indirect on entry node")
	}
	return (*indirect[, ])(unsafe.Pointer())
}