As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit 8e0c2360 authored by Igor Drozdov's avatar Igor Drozdov
Browse files

Refactor sshd.go and move the connection logic to connection.go

parent c8ba21bd
No related branches found
No related tags found
No related merge requests found
package sshd
import (
"net"
"context"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
Loading
Loading
@@ -13,19 +15,71 @@ import (
type connection struct {
concurrentSessions *semaphore.Weighted
remoteAddr string
nconn net.Conn
remoteAddr string
started time.Time
}
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
func newConnection(maxSessions int64, remoteAddr string) *connection {
func newConnection(maxSessions int64, nconn net.Conn) *connection {
return &connection{
concurrentSessions: semaphore.NewWeighted(maxSessions),
remoteAddr: remoteAddr,
nconn: nconn,
remoteAddr: nconn.RemoteAddr().String(),
started: time.Now(),
}
}
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
func (c *connection) handle(ctx context.Context, cfg *ssh.ServerConfig, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
// Prevent a panic in a single connection from taking out the whole server
defer func() {
if err := recover(); err != nil {
ctxlog.Warn("panic handling session")
}
metrics.SliSshdSessionsErrorsTotal.Inc()
}()
ctxlog.Info("server: handleConn: start")
metrics.SshdConnectionsInFlight.Inc()
defer func() {
metrics.SshdConnectionsInFlight.Dec()
metrics.SshdSessionDuration.Observe(time.Since(c.started).Seconds())
}()
// Initialize the connection with server
sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, cfg)
// Track the time required to establish a session
establishSessionDuration := time.Since(c.started).Seconds()
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
// Most of the times a connection failes due to the client's misconfiguration or when
// a client cancels a request, so we shouldn't treat them as an error
// Warnings will helps us to track the errors whether they happend on the server side
if err != nil {
ctxlog.WithError(err).WithFields(log.Fields{
"establish_session_duration_s": establishSessionDuration,
}).Warn("conn: init: failed to initialize SSH connection")
return
}
go ssh.DiscardRequests(reqs)
// Handle incoming requests
c.handleRequests(ctx, sconn, chans, handler)
ctxlog.WithFields(log.Fields{
"duration_s": time.Since(c.started).Seconds(),
"establish_session_duration_s": establishSessionDuration,
}).Info("server: handleConn: done")
}
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
Loading
Loading
@@ -55,10 +109,16 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
defer func() {
if err := recover(); err != nil {
ctxlog.WithField("recovered_error", err).Warn("panic handling session")
metrics.SliSshdSessionsErrorsTotal.Inc()
}
}()
handler(ctx, channel, requests)
err := handler(ctx, sconn, channel, requests)
if err != nil {
metrics.SliSshdSessionsErrorsTotal.Inc()
}
ctxlog.Info("connection: handle: done")
}()
}
Loading
Loading
Loading
Loading
@@ -5,6 +5,8 @@ import (
"errors"
"testing"
"golang.org/x/sync/semaphore"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
Loading
Loading
@@ -48,7 +50,9 @@ func (f *fakeNewChannel) ExtraData() []byte {
}
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
conn := newConnection(sessionsNum, "127.0.0.1:50000")
conn := &connection{
concurrentSessions: semaphore.NewWeighted(sessionsNum),
}
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
Loading
Loading
@@ -62,10 +66,11 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
return nil
})
})
Loading
Loading
@@ -80,7 +85,7 @@ func TestUnknownChannelType(t *testing.T) {
conn, chans := setup(1, newChannel)
go func() {
conn.handle(context.Background(), chans, nil)
conn.handleRequests(context.Background(), nil, chans, nil)
}()
rejectionData := <-rejectCh
Loading
Loading
@@ -100,8 +105,9 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return nil
})
}()
Loading
Loading
@@ -114,9 +120,10 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
return nil
})
require.True(t, channelHandled)
Loading
Loading
@@ -132,8 +139,9 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
return nil
})
}()
Loading
Loading
Loading
Loading
@@ -22,7 +22,6 @@ type session struct {
channel ssh.Channel
gitlabKeyId string
remoteAddr string
success bool
// State managed by the session
execCmd string
Loading
Loading
@@ -42,11 +41,12 @@ type exitStatusReq struct {
ExitStatus uint32
}
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("session: handle: entering request loop")
var err error
for req := range requests {
sessionLog := ctxlog.WithFields(log.Fields{
"bytesize": len(req.Payload),
Loading
Loading
@@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
var shouldContinue bool
switch req.Type {
case "env":
shouldContinue = s.handleEnv(ctx, req)
shouldContinue, err = s.handleEnv(ctx, req)
case "exec":
shouldContinue = s.handleExec(ctx, req)
shouldContinue, err = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
s.exit(ctx, s.handleShell(ctx, req))
var status uint32
status, err = s.handleShell(ctx, req)
s.exit(ctx, status)
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
Loading
Loading
@@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
ctxlog.Debug("session: handle: exiting request loop")
return err
}
func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
return false
return false, err
}
switch envRequest.Name {
Loading
Loading
@@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
).Debug("session: handleEnv: processed")
return true
return true, nil
}
func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) {
var execRequest execRequest
if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
return false
return false, err
}
s.execCmd = execRequest.Command
s.exit(ctx, s.handleShell(ctx, req))
status, err := s.handleShell(ctx, req)
s.exit(ctx, status)
return false
return false, err
}
func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) {
ctxlog := log.ContextLogger(ctx)
if req.WantReply {
Loading
Loading
@@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
}
s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
return 128
return 128, err
}
cmdName := reflect.TypeOf(cmd).String()
Loading
Loading
@@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
if err := cmd.Execute(ctx); err != nil {
s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
return 1
return 1, err
}
ctxlog.Info("session: handleShell: command executed successfully")
return 0
return 0, nil
}
func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
Loading
Loading
@@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) {
log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
s.success = status == 0
s.channel.CloseWrite()
s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}
Loading
Loading
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"testing"
"errors"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
Loading
Loading
@@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) {
testCases := []struct {
desc string
payload []byte
expectedErr error
expectedProtocolVersion string
expectedResult bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"),
expectedProtocolVersion: "1",
expectedResult: false,
}, {
desc: "valid payload",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
expectedErr: nil,
expectedProtocolVersion: "2",
expectedResult: true,
}, {
desc: "valid payload with forbidden env var",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
expectedErr: nil,
expectedProtocolVersion: "1",
expectedResult: true,
},
Loading
Loading
@@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) {
s := &session{gitProtocolVersion: "1"}
r := &ssh.Request{Payload: tc.payload}
require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
shouldContinue, err := s.handleEnv(context.Background(), r)
require.Equal(t, tc.expectedErr, err)
require.Equal(t, tc.expectedResult, shouldContinue)
require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion)
})
}
}
Loading
Loading
@@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) {
testCases := []struct {
desc string
payload []byte
expectedErr error
expectedExecCmd string
sentRequestName string
sentRequestPayload []byte
success bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"),
expectedExecCmd: "",
sentRequestName: "",
}, {
desc: "valid payload",
payload: ssh.Marshal(execRequest{Command: "discover"}),
expectedErr: nil,
expectedExecCmd: "discover",
sentRequestName: "exit-status",
sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
success: true,
},
}
Loading
Loading
@@ -129,10 +138,12 @@ func TestHandleExec(t *testing.T) {
}
r := &ssh.Request{Payload: tc.payload}
require.Equal(t, false, s.handleExec(context.Background(), r))
shouldContinue, err := s.handleExec(context.Background(), r)
require.Equal(t, tc.expectedErr, err)
require.Equal(t, false, shouldContinue)
require.Equal(t, tc.sentRequestName, f.sentRequestName)
require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
require.Equal(t, tc.success, s.success)
})
}
}
Loading
Loading
@@ -143,32 +154,36 @@ func TestHandleShell(t *testing.T) {
cmd string
errMsg string
gitlabKeyId string
expectedErrString string
expectedExitCode uint32
success bool
}{
{
desc: "fails to parse command",
cmd: `\`,
errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
gitlabKeyId: "root",
expectedErrString: "Invalid SSH command: invalid command line string",
expectedExitCode: 128,
}, {
desc: "specified command is unknown",
cmd: "unknown-command",
errMsg: "Unknown command: unknown-command\n",
gitlabKeyId: "root",
expectedErrString: "Disallowed command",
expectedExitCode: 128,
}, {
desc: "fails to parse command",
cmd: "discover",
gitlabKeyId: "",
errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
expectedErrString: "Failed to get username: who='' is invalid",
expectedExitCode: 1,
}, {
desc: "fails to parse command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
expectedErrString: "",
expectedExitCode: 0,
},
}
Loading
Loading
@@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) {
}
r := &ssh.Request{}
require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
exitCode, err := s.handleShell(context.Background(), r)
if tc.expectedErrString != "" {
require.Equal(t, tc.expectedErrString, err.Error())
}
require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.errMsg, out.String())
})
}
Loading
Loading
Loading
Loading
@@ -12,7 +12,6 @@ import (
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
Loading
Loading
@@ -146,68 +145,23 @@ func (s *Server) getStatus() status {
}
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
success := false
metrics.SshdConnectionsInFlight.Inc()
started := time.Now()
defer func() {
metrics.SshdConnectionsInFlight.Dec()
metrics.SshdSessionDuration.Observe(time.Since(started).Seconds())
metrics.SliSshdSessionsTotal.Inc()
if !success {
metrics.SliSshdSessionsErrorsTotal.Inc()
}
}()
remoteAddr := nconn.RemoteAddr().String()
defer s.wg.Done()
defer nconn.Close()
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
// Prevent a panic in a single connection from taking out the whole server
defer func() {
if err := recover(); err != nil {
ctxlog.Warn("panic handling session")
}
}()
ctxlog.Info("server: handleConn: start")
sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
if err != nil {
ctxlog.WithError(err).Error("server: handleConn: failed to initialize SSH connection")
return
}
go ssh.DiscardRequests(reqs)
var establishSessionDuration float64
conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
establishSessionDuration = time.Since(started).Seconds()
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, nconn)
conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
session := &session{
cfg: s.Config,
channel: channel,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
remoteAddr: remoteAddr,
remoteAddr: nconn.RemoteAddr().String(),
}
session.handle(ctx, requests)
success = session.success
return session.handle(ctx, requests)
})
ctxlog.WithFields(log.Fields{
"duration_s": time.Since(started).Seconds(),
"establish_session_duration_s": establishSessionDuration,
}).Info("server: handleConn: done")
}
func unconditionalRequirePolicy(_ net.Addr) (proxyproto.Policy, error) {
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment