socks5_read_access_handle.go 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package socks5
  2. import (
  3. "context"
  4. "errors"
  5. "github.com/cnlh/nps/core"
  6. "io"
  7. "net"
  8. )
  9. const (
  10. UserPassAuth = uint8(2)
  11. userAuthVersion = uint8(1)
  12. authSuccess = uint8(0)
  13. authFailure = uint8(1)
  14. UserNoAuth = uint8(0)
  15. )
  16. type Access struct {
  17. core.NpsPlugin
  18. clientConn net.Conn
  19. }
  20. func (access *Access) GetConfigName() *core.NpsConfigs {
  21. return core.NewNpsConfigs("socks5_check_access_check", "need check the permission simply")
  22. }
  23. func (access *Access) Run(ctx context.Context, config map[string]string) (context.Context, error) {
  24. access.clientConn = access.GetClientConn(ctx)
  25. if config["socks5_check_access"] != "true" {
  26. return ctx, access.sendAccessMsgToClient(UserNoAuth)
  27. }
  28. // need auth
  29. if err := access.sendAccessMsgToClient(UserPassAuth); err != nil {
  30. return ctx, err
  31. }
  32. // send auth reply to client ,and get the auth information
  33. username, password, err := access.getAuthInfoFromClient()
  34. if err != nil {
  35. return ctx, err
  36. }
  37. ctx = context.WithValue(ctx, "socks_client_username", username)
  38. ctx = context.WithValue(ctx, "socks_client_password", password)
  39. // check
  40. return ctx, nil
  41. }
  42. func (access *Access) sendAccessMsgToClient(auth uint8) error {
  43. buf := make([]byte, 2)
  44. buf[0] = 5
  45. buf[1] = auth
  46. n, err := access.clientConn.Write(buf)
  47. if err != nil || n != 2 {
  48. return errors.New("write access message to client error " + err.Error())
  49. }
  50. return nil
  51. }
  52. func (access *Access) getAuthInfoFromClient() (username string, password string, err error) {
  53. header := []byte{0, 0}
  54. if _, err = io.ReadAtLeast(access.clientConn, header, 2); err != nil {
  55. return
  56. }
  57. if header[0] != userAuthVersion {
  58. err = errors.New("authentication method is not supported")
  59. return
  60. }
  61. userLen := int(header[1])
  62. user := make([]byte, userLen)
  63. if _, err = io.ReadAtLeast(access.clientConn, user, userLen); err != nil {
  64. return
  65. }
  66. if _, err := access.clientConn.Read(header[:1]); err != nil {
  67. err = errors.New("get password length error" + err.Error())
  68. return
  69. }
  70. passLen := int(header[0])
  71. pass := make([]byte, passLen)
  72. if _, err := io.ReadAtLeast(access.clientConn, pass, passLen); err != nil {
  73. err = errors.New("get password error" + err.Error())
  74. return
  75. }
  76. username = string(user)
  77. password = string(pass)
  78. return
  79. }