rate.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package rate
  2. import (
  3. "errors"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. )
  8. // Rate is an implementation of the token bucket added regularly
  9. type Rate struct {
  10. bucketSize int64
  11. bucketSurplusSize int64
  12. bucketAddSize int64
  13. stopChan chan bool
  14. nowRate int64
  15. cond *sync.Cond
  16. hasStop bool
  17. hasStart bool
  18. }
  19. // NewRate return token bucket with specified rate
  20. func NewRate(addSize int64) *Rate {
  21. r := &Rate{
  22. bucketSize: addSize * 2,
  23. bucketSurplusSize: 0,
  24. bucketAddSize: addSize,
  25. stopChan: make(chan bool),
  26. cond: sync.NewCond(new(sync.Mutex)),
  27. }
  28. return r
  29. }
  30. // Start is used to add token regularly
  31. func (r *Rate) Start() {
  32. if !r.hasStart {
  33. r.hasStart = true
  34. go r.session()
  35. }
  36. }
  37. func (r *Rate) add(size int64) {
  38. if res := r.bucketSize - r.bucketSurplusSize; res < r.bucketAddSize {
  39. atomic.AddInt64(&r.bucketSurplusSize, res)
  40. return
  41. }
  42. atomic.AddInt64(&r.bucketSurplusSize, size)
  43. }
  44. // Write is called when add token to bucket
  45. func (r *Rate) Write(size int64) {
  46. r.add(size)
  47. }
  48. // Stop is called when not use the rate bucket
  49. func (r *Rate) Stop() {
  50. if r.hasStart {
  51. r.stopChan <- true
  52. r.hasStop = true
  53. r.cond.Broadcast()
  54. }
  55. }
  56. // Get is called when get token from bucket
  57. func (r *Rate) Get(size int64) error {
  58. if r.hasStop {
  59. return errors.New("the rate has closed")
  60. }
  61. if r.bucketSurplusSize >= size {
  62. atomic.AddInt64(&r.bucketSurplusSize, -size)
  63. return nil
  64. }
  65. for {
  66. r.cond.L.Lock()
  67. r.cond.Wait()
  68. if r.bucketSurplusSize >= size {
  69. r.cond.L.Unlock()
  70. atomic.AddInt64(&r.bucketSurplusSize, -size)
  71. return nil
  72. }
  73. if r.hasStop {
  74. return errors.New("the rate has closed")
  75. }
  76. r.cond.L.Unlock()
  77. }
  78. }
  79. // GetNowRate returns the current rate
  80. // Just a rough number
  81. func (r *Rate) GetNowRate() int64 {
  82. return r.nowRate
  83. }
  84. func (r *Rate) session() {
  85. ticker := time.NewTicker(time.Second * 1)
  86. for {
  87. select {
  88. case <-ticker.C:
  89. if rs := r.bucketAddSize - r.bucketSurplusSize; rs > 0 {
  90. r.nowRate = rs
  91. } else {
  92. r.nowRate = r.bucketSize - r.bucketSurplusSize
  93. }
  94. r.add(r.bucketAddSize)
  95. r.cond.Broadcast()
  96. case <-r.stopChan:
  97. ticker.Stop()
  98. return
  99. }
  100. }
  101. }