ip_conn_num.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package limiter
  2. import (
  3. "ehang.io/nps/lib/enet"
  4. "github.com/pkg/errors"
  5. "net"
  6. "sync"
  7. )
  8. // ipNumMap is used to store the connection num of a ip address
  9. type ipNumMap struct {
  10. m map[string]int32
  11. sync.Mutex
  12. }
  13. // AddOrSet is used to add connection num of a ip address
  14. func (i *ipNumMap) AddOrSet(key string) {
  15. i.Lock()
  16. if v, ok := i.m[key]; ok {
  17. i.m[key] = v + 1
  18. } else {
  19. i.m[key] = 1
  20. }
  21. i.Unlock()
  22. }
  23. // SubOrDel is used to decrease connection of a ip address
  24. func (i *ipNumMap) SubOrDel(key string) {
  25. i.Lock()
  26. if v, ok := i.m[key]; ok {
  27. i.m[key] = v - 1
  28. if i.m[key] == 0 {
  29. delete(i.m, key)
  30. }
  31. }
  32. i.Unlock()
  33. }
  34. // Get return the connection num of a ip
  35. func (i *ipNumMap) Get(key string) int32 {
  36. return i.m[key]
  37. }
  38. // IpConnNumLimiter is used to limit the connection num of a service at the same time of same ip
  39. type IpConnNumLimiter struct {
  40. m *ipNumMap
  41. MaxNum int32 `json:"max_num" required:"true" placeholder:"10" zh_name:"单ip最大连接数"`
  42. sync.Mutex
  43. }
  44. func (cl *IpConnNumLimiter) GetName() string {
  45. return "ip_conn_num"
  46. }
  47. func (cl *IpConnNumLimiter) GetZhName() string {
  48. return "单ip连接数限制"
  49. }
  50. // Init the ipNumMap
  51. func (cl *IpConnNumLimiter) Init() error {
  52. cl.m = &ipNumMap{m: make(map[string]int32)}
  53. return nil
  54. }
  55. // DoLimit reports whether the connection num of the ip exceed the maximum number
  56. // If true, return error
  57. func (cl *IpConnNumLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
  58. ip, _, err := net.SplitHostPort(c.RemoteAddr().String())
  59. if err != nil {
  60. return c, errors.Wrap(err, "split ip addr")
  61. }
  62. if cl.m.Get(ip) >= cl.MaxNum {
  63. return c, errors.Errorf("the ip(%s) exceed the maximum number(%d)", ip, cl.MaxNum)
  64. }
  65. return NewNumConn(c, ip, cl.m), nil
  66. }
  67. // numConn is an implement of enet.Conn
  68. type numConn struct {
  69. key string
  70. m *ipNumMap
  71. enet.Conn
  72. }
  73. // NewNumConn return a numConn
  74. func NewNumConn(c enet.Conn, key string, m *ipNumMap) *numConn {
  75. m.AddOrSet(key)
  76. return &numConn{
  77. m: m,
  78. key: key,
  79. Conn: c,
  80. }
  81. }
  82. // Close is used to decrease the connection num of a ip when connection closing
  83. func (n *numConn) Close() error {
  84. n.m.SubOrDel(n.key)
  85. return n.Conn.Close()
  86. }