socks5_read_access_handle.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package socks5
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/cnlh/nps/core"
  6. "io"
  7. "net"
  8. )
  9. const (
  10. UserPassAuth = uint8(2)
  11. userAuthVersion = uint8(1)
  12. authSuccess = uint8(0)
  13. authFailure = uint8(1)
  14. UserNoAuth = uint8(0)
  15. )
  16. type Access struct {
  17. clientConn net.Conn
  18. }
  19. func (access *Access) GetConfigName() *core.NpsConfigs {
  20. return core.NewNpsConfigs("socks5_check_access_check", "need check the permission simply")
  21. }
  22. func (access *Access) GetStage() core.Stage {
  23. return core.STAGE_RUN
  24. }
  25. func (access *Access) Start(ctx context.Context, config map[string]string) error {
  26. return nil
  27. }
  28. func (access *Access) End(ctx context.Context, config map[string]string) error {
  29. return nil
  30. }
  31. func (access *Access) Run(ctx context.Context, config map[string]string) error {
  32. clientCtxConn := ctx.Value(core.CLIENT_CONNECTION)
  33. if clientCtxConn == nil {
  34. return core.CLIENT_CONNECTION_NOT_EXIST
  35. }
  36. access.clientConn = clientCtxConn.(net.Conn)
  37. if config["socks5_check_access"] != "true" {
  38. return access.sendAccessMsgToClient(UserNoAuth)
  39. }
  40. // need auth
  41. if err := access.sendAccessMsgToClient(UserPassAuth); err != nil {
  42. return err
  43. }
  44. // send auth reply to client ,and get the auth information
  45. username, password, err := access.getAuthInfoFromClient()
  46. if err != nil {
  47. return err
  48. }
  49. context.WithValue(ctx, "socks_client_username", username)
  50. context.WithValue(ctx, "socks_client_password", password)
  51. // check
  52. return nil
  53. }
  54. func (access *Access) sendAccessMsgToClient(auth uint8) error {
  55. buf := make([]byte, 2)
  56. buf[0] = 5
  57. buf[1] = auth
  58. n, err := access.clientConn.Write(buf)
  59. if err != nil || n != 2 {
  60. return errors.New("write access message to client error " + err.Error())
  61. }
  62. return nil
  63. }
  64. func (access *Access) getAuthInfoFromClient() (username string, password string, err error) {
  65. header := []byte{0, 0}
  66. if _, err = io.ReadAtLeast(access.clientConn, header, 2); err != nil {
  67. return
  68. }
  69. if header[0] != userAuthVersion {
  70. err = errors.New("authentication method is not supported")
  71. return
  72. }
  73. userLen := int(header[1])
  74. user := make([]byte, userLen)
  75. if _, err = io.ReadAtLeast(access.clientConn, user, userLen); err != nil {
  76. return
  77. }
  78. if _, err := access.clientConn.Read(header[:1]); err != nil {
  79. err = errors.New("get password length error" + err.Error())
  80. return
  81. }
  82. passLen := int(header[0])
  83. pass := make([]byte, passLen)
  84. if _, err := io.ReadAtLeast(access.clientConn, pass, passLen); err != nil {
  85. err = errors.New("get password error" + err.Error())
  86. return
  87. }
  88. username = string(user)
  89. password = string(pass)
  90. return
  91. }