rate.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. package limiter
  2. import (
  3. "ehang.io/nps/lib/enet"
  4. "ehang.io/nps/lib/rate"
  5. )
  6. // RateLimiter is used to limit the speed of transport
  7. type RateLimiter struct {
  8. baseLimiter
  9. RateLimit int64 `json:"rate_limit" required:"true" placeholder:"10(kb)" zh_name:"最大速度"`
  10. rate *rate.Rate
  11. }
  12. func (rl *RateLimiter) GetName() string {
  13. return "rate"
  14. }
  15. func (rl *RateLimiter) GetZhName() string {
  16. return "带宽限制"
  17. }
  18. // Init the rate controller
  19. func (rl *RateLimiter) Init() error {
  20. if rl.RateLimit > 0 && rl.rate == nil {
  21. rl.rate = rate.NewRate(rl.RateLimit)
  22. rl.rate.Start()
  23. }
  24. return nil
  25. }
  26. // DoLimit return limited Conn
  27. func (rl *RateLimiter) DoLimit(c enet.Conn) (enet.Conn, error) {
  28. return NewRateConn(c, rl.rate), nil
  29. }
  30. // rateConn is used to limiter the rate fo connection
  31. type rateConn struct {
  32. enet.Conn
  33. rate *rate.Rate
  34. }
  35. // NewRateConn return limited connection by rate interface
  36. func NewRateConn(rc enet.Conn, rate *rate.Rate) enet.Conn {
  37. return &rateConn{
  38. Conn: rc,
  39. rate: rate,
  40. }
  41. }
  42. // Read data and remove capacity from rate pool
  43. func (s *rateConn) Read(b []byte) (n int, err error) {
  44. n, err = s.Conn.Read(b)
  45. if s.rate != nil && err == nil {
  46. err = s.rate.Get(int64(n))
  47. }
  48. return
  49. }
  50. // Write data and remove capacity from rate pool
  51. func (s *rateConn) Write(b []byte) (n int, err error) {
  52. n, err = s.Conn.Write(b)
  53. if s.rate != nil && err == nil {
  54. err = s.rate.Get(int64(n))
  55. }
  56. return
  57. }