diff --git a/quadtree.go b/quadtree.go index cb30b8a..c9647f4 100644 --- a/quadtree.go +++ b/quadtree.go @@ -7,31 +7,175 @@ const ( rightUp leftDown rightDown + + maxCap = 500 // 节点最大容量 + maxDeep = 5 // 节点最大深度 + radius = 20 // 视野半径 ) +type QuadOption func(*QuadTree) + type Node struct { - AreaWidth int // 格子宽度(长=宽) - XStart int // 起始范围 - YStart int // 起始范围 - Deep int // 深度 - Leaf bool // 是否为叶子节点 - Parent *Node // 父节点 - Child [4]*Node // 子节点 - Entities sync.Map // 实体 + Leaf bool // 是否为叶子节点 + Deep int // 深度 + AreaWidth float64 // 格子宽度(长=宽) + XStart float64 // 起始范围 + YStart float64 // 起始范围 + Parent *Node // 父节点 + Tree *QuadTree // 树指针 + Child [4]*Node // 子节点 + Entities *sync.Map // 实体 } type QuadTree struct { - Root *Node + maxCap, maxDeep int + radius float64 + *Node } -func (q QuadTree) Add(entity *Entity) { - panic("implement me") +func NewSonNode(xStart, yStart float64, parent *Node) *Node { + son := &Node{ + Leaf: true, + Deep: parent.Deep + 1, + AreaWidth: parent.AreaWidth / 2, + XStart: xStart, + YStart: yStart, + Parent: parent, + Tree: parent.Tree, + Entities: &sync.Map{}, + } + + return son } -func (q QuadTree) Delete(entity *Entity) { - panic("implement me") +// canCut 检查节点是否可以分割 +func (n *Node) canCut() bool { + if n.XStart+n.AreaWidth/2 > 0 && n.YStart+n.AreaWidth/2 > 0 { + return true + } + return false } -func (q QuadTree) Search(entity *Entity) (result []*Entity) { - panic("implement me") +// needCut 检查节点是否需要分割 +func (n *Node) needCut() bool { + lens := 0 + n.Entities.Range(func(key, value interface{}) bool { + lens++ + return true + }) + return lens+1 > n.Tree.maxCap && n.Deep+1 <= n.Tree.maxDeep && n.canCut() +} + +// intersects 检查坐标是否在节点范围内 +func (n *Node) intersects(x, y float64) bool { + if n.XStart <= x && x < n.XStart+n.AreaWidth && n.YStart <= y && y < n.YStart+n.AreaWidth { + return true + } + return false +} + +// findSonQuadrant 根据坐标寻找子节点的方位 +func (n *Node) findSonQuadrant(x, y float64) int { + if x < n.Child[rightDown].XStart { + if y < n.Child[rightDown].YStart { + return leftUp + } + return leftDown + } + if y < n.Child[rightDown].YStart { + return rightUp + } + return rightDown +} + +// cutNode 分割节点 +func (n *Node) cutNode() { + n.Leaf = false + half := n.AreaWidth / 2 + + n.Child[leftUp] = NewSonNode(n.XStart, n.YStart, n) + n.Child[rightUp] = NewSonNode(n.XStart+half, n.YStart, n) + n.Child[leftDown] = NewSonNode(n.XStart, n.YStart+half, n) + n.Child[rightDown] = NewSonNode(n.XStart+half, n.YStart+half, n) + + // 将实体迁移到对应子节点 + n.Entities.Range(func(_, v interface{}) bool { + entity := v.(*Entity) + for _, node := range n.Child { + if node.intersects(entity.X, entity.Y) { + node.Entities.Store(entity.Name, entity) + } + } + return true + }) + + // 清空容器 + n.Entities = nil +} + +func NewQuadTree(xStart, yStart, width float64, opts ...QuadOption) *QuadTree { + basicNode := &Node{ + Leaf: true, + Deep: 1, + AreaWidth: width, + XStart: xStart, + YStart: yStart, + Child: [4]*Node{}, + Entities: &sync.Map{}, + } + tree := &QuadTree{ + maxDeep: maxDeep, + maxCap: maxCap, + radius: radius, + Node: basicNode, + } + basicNode.Tree = tree + + return tree +} + +func (n *Node) Add(entity *Entity) { + // 判断是否需要分割 + if n.Leaf && n.needCut() { + n.cutNode() + } + + // 非叶子节点往下递归 + if !n.Leaf { + n.Child[n.findSonQuadrant(entity.X, entity.Y)].Add(entity) + return + } + + // 叶子节点进行存储 + n.Entities.Store(entity.Name, entity) +} + +func (n *Node) Delete(entity *Entity) { + if !n.Leaf { + n.Child[n.findSonQuadrant(entity.X, entity.Y)].Delete(entity) + return + } + + n.Entities.Delete(entity.Name) +} + +func (n *Node) Search(entity *Entity) (result []*Entity) { + if !n.Leaf { + minX, maxX := entity.X-n.Tree.radius, entity.X+n.Tree.radius + minY, maxY := entity.Y-n.Tree.radius, entity.Y+n.Tree.radius + + for _, son := range n.Child { + if son.intersects(minX, minY) || son.intersects(maxX, minY) || + son.intersects(minX, maxY) || son.intersects(maxX, maxY) { + result = append(result, son.Search(entity)...) + } + } + return + } + + n.Entities.Range(func(key, value interface{}) bool { + result = append(result, value.(*Entity)) + return true + }) + return } diff --git a/quadtree_test.go b/quadtree_test.go index 5d87797..a994f42 100644 --- a/quadtree_test.go +++ b/quadtree_test.go @@ -1 +1,112 @@ package aoi + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFindQuadrant(t *testing.T) { + tree := NewQuadTree(0, 0, 100) + tree.cutNode() + + tests := []struct { + x, y float64 + want int + }{ + { + x: 49.9, y: 49.9, want: leftUp, + }, + { + x: 50, y: 50, want: rightDown, + }, + { + x: 49.9, y: 50, want: leftDown, + }, + { + x: 50, y: 49.9, want: rightUp, + }, + } + + for _, tt := range tests { + d := tree.findSonQuadrant(tt.x, tt.y) + assert.Equal(t, tt.want, d) + } + + // 再次分割 + tree.Child[rightUp].cutNode() + tests2 := []struct { + x, y float64 + want int + }{ + { + x: 74.9, y: 24.9, want: leftUp, + }, + { + x: 75, y: 25, want: rightDown, + }, + { + x: 74.9, y: 25, want: leftDown, + }, + { + x: 75, y: 24.9, want: rightUp, + }, + } + for _, tt := range tests2 { + d := tree.Child[rightUp].findSonQuadrant(tt.x, tt.y) + assert.Equal(t, tt.want, d) + } +} + +func TestNode_Search(t *testing.T) { + tree := NewQuadTree(0, 0, 100) + tree.maxCap = 2 // 超过两人节点分裂 + tree.radius = 5 + player1 := &Entity{ + X: 60.9, + Y: 24.9, + Name: "player1", + } + tree.Add(player1) + + player2 := &Entity{ + X: 25, + Y: 25, + Name: "player2", + } + tree.Add(player2) + entities := tree.Search(player1) + assert.Equal(t, 2, len(entities), "player1 player2") + + // 当出现第三个玩家超过节点最大容量产生分裂 + player3 := &Entity{ + X: 99, + Y: 24, + Name: "player3", + } + tree.Add(player3) + entities = tree.Search(player1) + assert.Equal(t, 2, len(entities), "player1 player3") + + // 添加第四个玩家 + player4 := &Entity{ + X: 72, + Y: 23, + Name: "player4", + } + tree.Add(player4) + entities = tree.Search(player1) + assert.Equal(t, 2, len(entities), "player1 player4") + + entities = tree.Search(player2) + assert.Equal(t, 1, len(entities), "player2") + + // 添加第五个玩家 + player5 := &Entity{ + X: 49.9, + Y: 49.9, + Name: "player5", + } + tree.Add(player5) + entities = tree.Search(player2) + assert.Equal(t, 2, len(entities), "player2 player5") +}