This commit is contained in:
knight0zh
2022-08-17 17:08:08 +08:00
parent 59e6f5d6c1
commit d18f22e8f8
7 changed files with 199 additions and 177 deletions

View File

@@ -1,2 +1,2 @@
bench:
go test -bench=. -benchmem -count=3
go test -bench=. -benchmem

View File

@@ -3,4 +3,5 @@
#### 目前实现:
- 九宫格
- 四叉树
- 四叉树

24
aoi.go
View File

@@ -1,12 +1,28 @@
package aoi
import "sync"
type AOI interface {
Add(entity *Entity) // 添加实体
Delete(entity *Entity) // 移除实体
Search(entity *Entity) (result []*Entity) // 范围查询
Add(x, y float64, name string) // 添加实体
Delete(x, y float64, name string) // 移除实体
Search(x, y float64) (result []string) // 范围查询
}
type Entity struct {
X, Y float64
Name string
Key string
}
var (
resultPool sync.Pool
entityPool sync.Pool
)
func init() {
resultPool.New = func() interface{} {
return make([]string, 0, 500)
}
entityPool.New = func() interface{} {
return &Entity{}
}
}

90
grid.go
View File

@@ -2,6 +2,12 @@ package aoi
import "sync"
var (
// 分别将这8个方向的方向向量按顺序写入x, y的分量数组
dx = []int{-1, -1, -1, 0, 0, 1, 1, 1}
dy = []int{-1, 0, 1, -1, 1, -1, 0, 1}
)
// Grid 格子
type Grid struct {
GID int //格子ID
@@ -15,6 +21,7 @@ type GridManger struct {
AreaWidth int // 格子宽度(长=宽)
GridCount int // 格子数量
grids map[int]*Grid
pool sync.Pool
}
func NewGrid(gid int) *Grid {
@@ -31,6 +38,9 @@ func NewGridManger(startX, startY, areaWidth, gridCount int) AOI {
GridCount: gridCount,
grids: make(map[int]*Grid),
}
manager.pool.New = func() interface{} {
return make([]*Grid, 0, 9)
}
for y := 0; y < gridCount; y++ {
for x := 0; x < gridCount; x++ {
@@ -47,66 +57,76 @@ func (g *GridManger) gridWidth() int {
return g.AreaWidth / g.GridCount
}
// GetGIDByPos 通过横纵坐标获取对应的格子ID
func (g *GridManger) GetGIDByPos(entity *Entity) int {
gx := (int(entity.X) - g.StartX) / g.gridWidth()
gy := (int(entity.Y) - g.StartY) / g.gridWidth()
// getGIDByPos 通过横纵坐标获取对应的格子ID
func (g *GridManger) getGIDByPos(x, y float64) int {
gx := (int(x) - g.StartX) / g.gridWidth()
gy := (int(y) - g.StartY) / g.gridWidth()
return gy*g.GridCount + gx
}
// GetSurroundGrids 根据格子的gID得到当前周边的九宫格信息
func (g *GridManger) GetSurroundGrids(gID int) (grids []*Grid) {
// getSurroundGrids 根据格子的gID得到当前周边的九宫格信息
func (g *GridManger) getSurroundGrids(gID int) []*Grid {
grids := g.pool.Get().([]*Grid)
defer func() {
grids = grids[:0]
g.pool.Put(grids)
}()
if _, ok := g.grids[gID]; !ok {
return
return grids
}
grids = append(grids, g.grids[gID])
// 根据gID, 得到格子所在的坐标
x, y := gID%g.GridCount, gID/g.GridCount
// 分别将这8个方向的方向向量按顺序写入x, y的分量数组
dx := []int{-1, -1, -1, 0, 0, 1, 1, 1}
dy := []int{-1, 0, 1, -1, 1, -1, 0, 1}
surroundGID := make([]int, 0)
for i := 0; i < 8; i++ {
newX := x + dx[i]
newY := y + dy[i]
if newX >= 0 && newX < g.GridCount && newY >= 0 && newY < g.GridCount {
surroundGID = append(surroundGID, newY*g.GridCount+newX)
grids = append(grids, g.grids[newY*g.GridCount+newX])
}
}
for _, gID := range surroundGID {
grids = append(grids, g.grids[gID])
return grids
}
func (g *GridManger) Add(x, y float64, key string) {
entity := entityPool.Get().(*Entity)
entity.X = x
entity.Y = y
entity.Key = key
ID := g.getGIDByPos(x, y)
grid := g.grids[ID]
grid.Entities.Store(key, entity)
}
func (g *GridManger) Delete(x, y float64, key string) {
ID := g.getGIDByPos(x, y)
grid := g.grids[ID]
if entity, ok := grid.Entities.Load(key); ok {
grid.Entities.Delete(key)
entityPool.Put(entity)
}
return
}
func (g *GridManger) Add(entity *Entity) {
ID := g.GetGIDByPos(entity)
grid := g.grids[ID]
grid.Entities.Store(entity.Name, entity)
}
func (g *GridManger) Delete(entity *Entity) {
ID := g.GetGIDByPos(entity)
grid := g.grids[ID]
grid.Entities.Delete(entity.Name)
}
func (g *GridManger) Search(entity *Entity) (result []*Entity) {
ID := g.GetGIDByPos(entity)
grids := g.GetSurroundGrids(ID)
func (g *GridManger) Search(x, y float64) []string {
result := resultPool.Get().([]string)
defer func() {
result = result[:0]
resultPool.Put(result)
}()
ID := g.getGIDByPos(x, y)
grids := g.getSurroundGrids(ID)
for _, grid := range grids {
grid.Entities.Range(func(_, value interface{}) bool {
result = append(result, value.(*Entity))
result = append(result, value.(*Entity).Key)
return true
})
}
return
return result
}

View File

@@ -24,44 +24,34 @@ func TestGridManger_GetSurroundGrids(t *testing.T) {
aol := NewGridManger(0, 0, 250, 5)
manger := aol.(*GridManger)
tests := []struct {
entity *Entity
want []int
x, y float64
want []int
}{
{
entity: &Entity{
X: 0, Y: 0,
},
x: 0, y: 0,
want: []int{0, 1, 5, 6},
},
{
entity: &Entity{
X: 150, Y: 0,
},
x: 150, y: 0,
want: []int{2, 3, 4, 7, 8, 9},
},
{
entity: &Entity{
X: 50, Y: 50,
},
x: 50, y: 50,
want: []int{0, 1, 2, 5, 6, 7, 10, 11, 12},
},
{
entity: &Entity{
X: 200, Y: 100,
},
x: 200, y: 100,
want: []int{8, 9, 13, 14, 18, 19},
},
{
entity: &Entity{
X: 200, Y: 200,
},
x: 200, y: 200,
want: []int{18, 19, 23, 24},
},
}
for _, tt := range tests {
ID := manger.GetGIDByPos(tt.entity)
grids := manger.GetSurroundGrids(ID)
ID := manger.getGIDByPos(tt.x, tt.y)
grids := manger.getSurroundGrids(ID)
gID := make([]int, 0)
for _, grid := range grids {
gID = append(gID, grid.GID)
@@ -76,51 +66,51 @@ func TestNewGridManger(t *testing.T) {
manger := aol.(*GridManger)
entities := []*Entity{
{
X: 0, Y: 0, Name: "a",
X: 0, Y: 0, Key: "a",
},
{
X: 50, Y: 0, Name: "b",
X: 50, Y: 0, Key: "b",
},
{
X: 100, Y: 0, Name: "c",
X: 100, Y: 0, Key: "c",
},
{
X: 50, Y: 0, Name: "d",
X: 50, Y: 0, Key: "d",
},
{
X: 50, Y: 50, Name: "e",
X: 50, Y: 50, Key: "e",
},
{
X: 50, Y: 100, Name: "f",
X: 50, Y: 100, Key: "f",
},
{
X: 100, Y: 0, Name: "g",
X: 100, Y: 0, Key: "g",
},
{
X: 100, Y: 50, Name: "h",
X: 100, Y: 50, Key: "h",
},
{
X: 100, Y: 100, Name: "i",
X: 100, Y: 100, Key: "i",
},
}
for _, entity := range entities {
manger.Add(entity)
manger.Add(entity.X, entity.Y, entity.Key)
}
search := manger.Search(&Entity{X: 50, Y: 50})
search := manger.Search(50, 50)
result := make([]string, 0)
for _, entity := range search {
result = append(result, entity.Name)
result = append(result, entity)
}
sort.Strings(result)
assert.Equal(t, []string{"a", "b", "c", "d", "e", "f", "g", "h", "i"}, result)
manger.Delete(&Entity{X: 100, Y: 100, Name: "i"})
search2 := manger.Search(&Entity{X: 50, Y: 50})
manger.Delete(100, 100, "i")
search2 := manger.Search(50, 50)
result2 := make([]string, 0)
for _, entity := range search2 {
result2 = append(result2, entity.Name)
result2 = append(result2, entity)
}
sort.Strings(result2)
assert.Equal(t, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, result2)
@@ -128,7 +118,7 @@ func TestNewGridManger(t *testing.T) {
func BenchmarkGridManger(b *testing.B) {
var wg sync.WaitGroup
aol := NewGridManger(0, 0, 256, 16)
aol := NewGridManger(0, 0, 1024, 16)
manger := aol.(*GridManger)
rand.Seed(time.Now().UnixNano())
@@ -136,28 +126,28 @@ func BenchmarkGridManger(b *testing.B) {
wg.Add(30000)
for j := 0; j < 10000; j++ {
go func() {
manger.Add(&Entity{
X: float64(rand.Intn(5) * 10),
Y: float64(rand.Intn(5) * 10),
Name: fmt.Sprintf("player%d", rand.Intn(50)),
})
manger.Add(
float64(rand.Intn(10)*10+rand.Intn(10)),
float64(rand.Intn(10)*10+rand.Intn(10)),
fmt.Sprintf("player%d", rand.Intn(100)),
)
wg.Done()
}()
go func() {
manger.Delete(&Entity{
X: float64(rand.Intn(5) * 10),
Y: float64(rand.Intn(5) * 10),
Name: fmt.Sprintf("player%d", rand.Intn(50)),
})
manger.Delete(
float64(rand.Intn(10)*10+rand.Intn(10)),
float64(rand.Intn(10)*10+rand.Intn(10)),
fmt.Sprintf("player%d", rand.Intn(100)),
)
wg.Done()
}()
go func() {
manger.Search(&Entity{
X: float64(rand.Intn(5) * 10),
Y: float64(rand.Intn(5) * 10),
})
manger.Search(
float64(rand.Intn(10)*10+rand.Intn(10)),
float64(rand.Intn(10)*10+rand.Intn(10)),
)
wg.Done()
}()
}

View File

@@ -9,19 +9,16 @@ const (
rightDown
maxCap = 500 // 节点最大容量
maxDeep = 3 // 节点最大深度
maxDeep = 4 // 节点最大深度
radius = 16 // 视野半径
)
type QuadOption func(*QuadTree)
type Node struct {
Leaf bool // 是否为叶子节点
Deep int // 深度
AreaWidth float64 // 格子宽度(长=宽)
XStart float64 // 起始范围
YStart float64 // 起始范围
Parent *Node // 父节点
Tree *QuadTree // 树指针
Child [4]*Node // 子节点
Entities *sync.Map // 实体
@@ -30,6 +27,7 @@ type Node struct {
type QuadTree struct {
maxCap, maxDeep int
radius float64
mPool sync.Pool
*Node
}
@@ -40,9 +38,8 @@ func NewSonNode(xStart, yStart float64, parent *Node) *Node {
AreaWidth: parent.AreaWidth / 2,
XStart: xStart,
YStart: yStart,
Parent: parent,
Tree: parent.Tree,
Entities: &sync.Map{},
Entities: parent.Tree.mPool.Get().(*sync.Map),
}
return son
@@ -99,21 +96,22 @@ func (n *Node) cutNode() {
n.Child[rightDown] = NewSonNode(n.XStart+half, n.YStart+half, n)
// 将实体迁移到对应子节点
n.Entities.Range(func(_, v interface{}) bool {
n.Entities.Range(func(k, v interface{}) bool {
entity := v.(*Entity)
for _, node := range n.Child {
if node.intersects(entity.X, entity.Y) {
node.Entities.Store(entity.Name, entity)
node.Entities.Store(entity.Key, entity)
}
}
n.Entities.Delete(k)
return true
})
// 清空容器
n.Tree.mPool.Put(n.Entities)
n.Entities = nil
}
func NewQuadTree(xStart, yStart, width float64, opts ...QuadOption) AOI {
func NewQuadTree(xStart, yStart, width float64) AOI {
basicNode := &Node{
Leaf: true,
Deep: 1,
@@ -121,7 +119,6 @@ func NewQuadTree(xStart, yStart, width float64, opts ...QuadOption) AOI {
XStart: xStart,
YStart: yStart,
Child: [4]*Node{},
Entities: &sync.Map{},
}
tree := &QuadTree{
maxDeep: maxDeep,
@@ -129,12 +126,15 @@ func NewQuadTree(xStart, yStart, width float64, opts ...QuadOption) AOI {
radius: radius,
Node: basicNode,
}
tree.mPool.New = func() interface{} {
return &sync.Map{}
}
basicNode.Tree = tree
basicNode.Entities = tree.mPool.Get().(*sync.Map)
return tree
}
func (n *Node) Add(entity *Entity) {
func (n *Node) Add(x, y float64, name string) {
// 判断是否需要分割
if n.Leaf && n.needCut() {
n.cutNode()
@@ -142,39 +142,57 @@ func (n *Node) Add(entity *Entity) {
// 非叶子节点往下递归
if !n.Leaf {
n.Child[n.findSonQuadrant(entity.X, entity.Y)].Add(entity)
n.Child[n.findSonQuadrant(x, y)].Add(x, y, name)
return
}
entity := entityPool.Get().(*Entity)
entity.X = x
entity.Y = y
entity.Key = name
// 叶子节点进行存储
n.Entities.Store(entity.Name, entity)
n.Entities.Store(entity.Key, entity)
}
func (n *Node) Delete(entity *Entity) {
func (n *Node) Delete(x, y float64, name string) {
if !n.Leaf {
n.Child[n.findSonQuadrant(entity.X, entity.Y)].Delete(entity)
n.Child[n.findSonQuadrant(x, y)].Delete(x, y, name)
return
}
n.Entities.Delete(entity.Name)
if entity, ok := n.Entities.Load(name); ok {
n.Entities.Delete(name)
entityPool.Put(entity)
}
}
func (n *Node) Search(entity *Entity) (result []*Entity) {
func (n *Node) Search(x, y float64) []string {
result := resultPool.Get().([]string)
defer func() {
result = result[:0]
resultPool.Put(result)
}()
n.search(x, y, &result)
return result
}
func (n *Node) search(x, y float64, result *[]string) {
if !n.Leaf {
minX, maxX := entity.X-n.Tree.radius, entity.X+n.Tree.radius
minY, maxY := entity.Y-n.Tree.radius, entity.Y+n.Tree.radius
minX, maxX := x-n.Tree.radius, x+n.Tree.radius
minY, maxY := y-n.Tree.radius, y+n.Tree.radius
for _, son := range n.Child {
if son.intersects(minX, minY) || son.intersects(maxX, minY) ||
son.intersects(minX, maxY) || son.intersects(maxX, maxY) {
result = append(result, son.Search(entity)...)
son.search(x, y, result)
}
}
return
}
n.Entities.Range(func(key, value interface{}) bool {
result = append(result, value.(*Entity))
*result = append(*result, value.(*Entity).Key)
return true
})
return

View File

@@ -68,20 +68,10 @@ func Test_NeedCut(t *testing.T) {
tree.maxCap = 2 // 超过两人节点分裂
assert.Equal(t, false, tree.needCut())
player1 := &Entity{
X: 60.9,
Y: 24.9,
Name: "player1",
}
tree.Add(player1)
tree.Add(60.9, 24.9, "player1")
assert.Equal(t, false, tree.needCut())
player2 := &Entity{
X: 25,
Y: 25,
Name: "player2",
}
tree.Add(player2)
tree.Add(25, 25, "player2")
assert.Equal(t, true, tree.needCut())
}
@@ -91,63 +81,50 @@ func TestNode_Search(t *testing.T) {
tree := aoi.(*QuadTree)
tree.maxCap = 2 // 超过两人节点分裂
tree.radius = 5
player1 := &Entity{
X: 60.9,
Y: 24.9,
Name: "player1",
}
tree.Add(player1)
player2 := &Entity{
X: 25,
Y: 25,
Name: "player2",
}
tree.Add(player2)
entities := tree.Search(player1)
tree.Add(60.9, 24.9, "player1")
tree.Add(25, 25, "player2")
// 查询player1附近
entities := tree.Search(60.9, 24.9)
assert.Equal(t, 2, len(entities), "player1 player2")
// 当出现第三个玩家超过节点最大容量产生分裂
player3 := &Entity{
X: 99,
Y: 24,
Name: "player3",
}
tree.Add(player3)
entities = tree.Search(player1)
tree.Add(99, 24, "player3")
// 查询player1附近
entities = tree.Search(60.9, 24.9)
assert.Equal(t, 2, len(entities), "player1 player3")
// 添加第四个玩家
player4 := &Entity{
X: 72,
Y: 23,
Name: "player4",
}
tree.Add(player4)
entities = tree.Search(player1)
tree.Add(72, 23, "player4")
// 查询player1附近
entities = tree.Search(60.9, 24.9)
assert.Equal(t, 2, len(entities), "player1 player4")
entities = tree.Search(player2)
// 查询player2附近
entities = tree.Search(25, 25)
assert.Equal(t, 1, len(entities), "player2")
// 添加第五个玩家
player5 := &Entity{
X: 49.9,
Y: 49.9,
Name: "player5",
}
tree.Add(player5)
entities = tree.Search(player2)
tree.Add(49.9, 49.9, "player5")
// 查询player2附近
entities = tree.Search(25, 25)
assert.Equal(t, 2, len(entities), "player2 player5")
tree.Delete(player5)
entities = tree.Search(player2)
// 移除player5
tree.Delete(49.9, 49.9, "player5")
// 查询player2附近
entities = tree.Search(25, 25)
assert.Equal(t, 1, len(entities), "player2")
}
func BenchmarkQuadtree(b *testing.B) {
var wg sync.WaitGroup
aoi := NewQuadTree(0, 0, 100)
aoi := NewQuadTree(0, 0, 1024)
tree := aoi.(*QuadTree)
rand.Seed(time.Now().UnixNano())
@@ -155,28 +132,28 @@ func BenchmarkQuadtree(b *testing.B) {
wg.Add(30000)
for j := 0; j < 10000; j++ {
go func() {
tree.Add(&Entity{
X: float64(rand.Intn(5) * 10),
Y: float64(rand.Intn(5) * 10),
Name: fmt.Sprintf("player%d", rand.Intn(50)),
})
tree.Add(
float64(rand.Intn(10)*10+rand.Intn(10)),
float64(rand.Intn(10)*10+rand.Intn(10)),
fmt.Sprintf("player%d", rand.Intn(100)),
)
wg.Done()
}()
go func() {
tree.Delete(&Entity{
X: float64(rand.Intn(5) * 10),
Y: float64(rand.Intn(5) * 10),
Name: fmt.Sprintf("player%d", rand.Intn(50)),
})
tree.Delete(
float64(rand.Intn(10)*10+rand.Intn(10)),
float64(rand.Intn(10)*10+rand.Intn(10)),
fmt.Sprintf("player%d", rand.Intn(100)),
)
wg.Done()
}()
go func() {
tree.Search(&Entity{
X: float64(rand.Intn(5) * 10),
Y: float64(rand.Intn(5) * 10),
})
tree.Search(
float64(rand.Intn(10)*10+rand.Intn(10)),
float64(rand.Intn(10)*10+rand.Intn(10)),
)
wg.Done()
}()
}