123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- package socks5
- import (
- "context"
- "errors"
- "github.com/cnlh/nps/core"
- "io"
- "net"
- )
- const (
- UserPassAuth = uint8(2)
- userAuthVersion = uint8(1)
- authSuccess = uint8(0)
- authFailure = uint8(1)
- UserNoAuth = uint8(0)
- )
- type Access struct {
- clientConn net.Conn
- }
- func (access *Access) GetConfigName() *core.NpsConfigs {
- return core.NewNpsConfigs("socks5_check_access_check", "need check the permission simply")
- }
- func (access *Access) GetStage() core.Stage {
- return core.STAGE_RUN
- }
- func (access *Access) Start(ctx context.Context, config map[string]string) error {
- return nil
- }
- func (access *Access) End(ctx context.Context, config map[string]string) error {
- return nil
- }
- func (access *Access) Run(ctx context.Context, config map[string]string) error {
- clientCtxConn := ctx.Value(core.CLIENT_CONNECTION)
- if clientCtxConn == nil {
- return core.CLIENT_CONNECTION_NOT_EXIST
- }
- access.clientConn = clientCtxConn.(net.Conn)
- if config["socks5_check_access"] != "true" {
- return access.sendAccessMsgToClient(UserNoAuth)
- }
- // need auth
- if err := access.sendAccessMsgToClient(UserPassAuth); err != nil {
- return err
- }
- // send auth reply to client ,and get the auth information
- username, password, err := access.getAuthInfoFromClient()
- if err != nil {
- return err
- }
- context.WithValue(ctx, "socks_client_username", username)
- context.WithValue(ctx, "socks_client_password", password)
- // check
- return nil
- }
- func (access *Access) sendAccessMsgToClient(auth uint8) error {
- buf := make([]byte, 2)
- buf[0] = 5
- buf[1] = auth
- n, err := access.clientConn.Write(buf)
- if err != nil || n != 2 {
- return errors.New("write access message to client error " + err.Error())
- }
- return nil
- }
- func (access *Access) getAuthInfoFromClient() (username string, password string, err error) {
- header := []byte{0, 0}
- if _, err = io.ReadAtLeast(access.clientConn, header, 2); err != nil {
- return
- }
- if header[0] != userAuthVersion {
- err = errors.New("authentication method is not supported")
- return
- }
- userLen := int(header[1])
- user := make([]byte, userLen)
- if _, err = io.ReadAtLeast(access.clientConn, user, userLen); err != nil {
- return
- }
- if _, err := access.clientConn.Read(header[:1]); err != nil {
- err = errors.New("get password length error" + err.Error())
- return
- }
- passLen := int(header[0])
- pass := make([]byte, passLen)
- if _, err := io.ReadAtLeast(access.clientConn, pass, passLen); err != nil {
- err = errors.New("get password error" + err.Error())
- return
- }
- username = string(user)
- password = string(pass)
- return
- }
|