As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit c40ad688 authored by Stan Hu's avatar Stan Hu
Browse files

Merge branch 'id-login-grace-time' into 'main'

Close the connection when context is canceled

See merge request gitlab-org/gitlab-shell!646
parents 4d2459f3 0110b9ea
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -3,6 +3,8 @@ package sshd
import (
"context"
"errors"
"net"
"strings"
"time"
"golang.org/x/crypto/ssh"
Loading
Loading
@@ -22,52 +24,91 @@ const KeepAliveMsg = "keepalive@openssh.com"
var EOFTimeout = 10 * time.Second
type connection struct {
cfg *config.Config
concurrentSessions *semaphore.Weighted
remoteAddr string
sconn *ssh.ServerConn
maxSessions int64
cfg *config.Config
concurrentSessions *semaphore.Weighted
nconn net.Conn
maxSessions int64
remoteAddr string
started time.Time
establishSessionDuration float64
}
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error
type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
func newConnection(cfg *config.Config, nconn net.Conn) *connection {
maxSessions := cfg.Server.ConcurrentSessionsLimit
return &connection{
cfg: cfg,
maxSessions: maxSessions,
concurrentSessions: semaphore.NewWeighted(maxSessions),
remoteAddr: remoteAddr,
sconn: sconn,
nconn: nconn,
remoteAddr: nconn.RemoteAddr().String(),
started: time.Now(),
}
}
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) {
sconn, chans, err := c.initServerConn(ctx, srvCfg)
if err != nil {
return
}
if c.cfg.Server.ClientAliveInterval > 0 {
ticker := time.NewTicker(time.Duration(c.cfg.Server.ClientAliveInterval))
defer ticker.Stop()
go c.sendKeepAliveMsg(ctx, ticker)
go c.sendKeepAliveMsg(ctx, sconn, ticker)
}
c.handleRequests(ctx, sconn, chans, handler)
reason := sconn.Wait()
log.WithContextFields(ctx, log.Fields{
"duration_s": time.Since(c.started).Seconds(),
"establish_session_duration_s": c.establishSessionDuration,
"reason": reason,
}).Info("server: handleConn: done")
}
func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) {
sconn, chans, reqs, err := ssh.NewServerConn(c.nconn, srvCfg)
if err != nil {
msg := "connection: initServerConn: failed to initialize SSH connection"
logger := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}).WithError(err)
if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" {
logger.Debug(msg)
} else {
logger.Warn(msg)
}
return nil, nil, err
}
go ssh.DiscardRequests(reqs)
return sconn, chans, err
}
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
ctxlog.Info("connection: handle: unknown channel type")
ctxlog.Info("connection: handleRequests: unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !c.concurrentSessions.TryAcquire(1) {
ctxlog.Info("connection: handle: too many concurrent sessions")
ctxlog.Info("connection: handleRequests: too many concurrent sessions")
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
metrics.SshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
ctxlog.WithError(err).Error("connection: handle: accepting channel failed")
ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed")
c.concurrentSessions.Release(1)
continue
}
Loading
Loading
@@ -76,6 +117,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
defer func(started time.Time) {
metrics.SshdSessionDuration.Observe(time.Since(started).Seconds())
}(time.Now())
c.establishSessionDuration = time.Since(c.started).Seconds()
defer c.concurrentSessions.Release(1)
Loading
Loading
@@ -87,12 +129,12 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
}()
metrics.SliSshdSessionsTotal.Inc()
err := handler(ctx, channel, requests)
err := handler(sconn, channel, requests)
if err != nil {
c.trackError(err)
}
ctxlog.Info("connection: handle: done")
ctxlog.Info("connection: handleRequests: done")
}()
}
Loading
Loading
@@ -105,7 +147,7 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
c.concurrentSessions.Acquire(ctx, c.maxSessions)
}
func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for {
Loading
Loading
@@ -113,9 +155,9 @@ func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker)
case <-ctx.Done():
return
case <-ticker.C:
ctxlog.Debug("session: handleShell: send keepalive message to a client")
ctxlog.Debug("connection: sendKeepAliveMsg: send keepalive message to a client")
c.sconn.SendRequest(KeepAliveMsg, true, nil)
sconn.SendRequest(KeepAliveMsg, true, nil)
}
}
}
Loading
Loading
Loading
Loading
@@ -10,6 +10,7 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
grpccodes "google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
Loading
Loading
@@ -81,7 +82,7 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}}
conn := newConnection(cfg, "127.0.0.1:50000", &ssh.ServerConn{&fakeConn{}, nil})
conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)}
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
Loading
Loading
@@ -95,7 +96,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
Loading
Loading
@@ -113,7 +114,7 @@ func TestUnknownChannelType(t *testing.T) {
conn, chans := setup(1, newChannel)
go func() {
conn.handle(context.Background(), chans, nil)
conn.handleRequests(context.Background(), nil, chans, nil)
}()
rejectionData := <-rejectCh
Loading
Loading
@@ -133,7 +134,7 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return nil
})
Loading
Loading
@@ -148,7 +149,7 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
return nil
Loading
Loading
@@ -167,7 +168,7 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
return nil
})
Loading
Loading
@@ -185,12 +186,11 @@ func TestAcceptSessionFails(t *testing.T) {
func TestClientAliveInterval(t *testing.T) {
f := &fakeConn{}
conn := newConnection(&config.Config{}, "127.0.0.1:50000", &ssh.ServerConn{f, nil})
ticker := time.NewTicker(time.Millisecond)
defer ticker.Stop()
go conn.sendKeepAliveMsg(context.Background(), ticker)
conn := &connection{}
go conn.sendKeepAliveMsg(context.Background(), &ssh.ServerConn{f, nil}, ticker)
require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond)
}
Loading
Loading
@@ -204,7 +204,7 @@ func TestSessionsMetrics(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return errors.New("custom error")
})
Loading
Loading
@@ -213,7 +213,7 @@ func TestSessionsMetrics(t *testing.T) {
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
conn, chans = setup(1, newChannel)
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return grpcstatus.Error(grpccodes.Canceled, "canceled")
})
Loading
Loading
@@ -222,7 +222,7 @@ func TestSessionsMetrics(t *testing.T) {
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
conn, chans = setup(1, newChannel)
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return &client.ApiError{"api error"}
})
Loading
Loading
@@ -231,7 +231,7 @@ func TestSessionsMetrics(t *testing.T) {
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
conn, chans = setup(1, newChannel)
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return grpcstatus.Error(grpccodes.Unavailable, "unavailable")
})
Loading
Loading
Loading
Loading
@@ -10,7 +10,6 @@ import (
"time"
"github.com/pires/go-proxyproto"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
Loading
Loading
@@ -39,18 +38,6 @@ type Server struct {
serverConfig *serverConfig
}
func logSSHInitError(ctxlog *logrus.Entry, err error) {
msg := "server: handleConn: failed to initialize SSH connection"
logger := ctxlog.WithError(err)
if strings.Contains(err.Error(), "no common algorithm for host key") || err.Error() == "EOF" {
logger.Debug(msg)
} else {
logger.Warn(msg)
}
}
func NewServer(cfg *config.Config) (*Server, error) {
serverConfig, err := newServerConfig(cfg)
if err != nil {
Loading
Loading
@@ -159,18 +146,21 @@ func (s *Server) getStatus() status {
}
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
defer s.wg.Done()
metrics.SshdConnectionsInFlight.Inc()
defer metrics.SshdConnectionsInFlight.Dec()
remoteAddr := nconn.RemoteAddr().String()
defer s.wg.Done()
defer nconn.Close()
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
defer cancel()
go func() {
<-ctx.Done()
nconn.Close() // Close the connection when context is cancelled
}()
remoteAddr := nconn.RemoteAddr().String()
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": remoteAddr})
ctxlog.Debug("server: handleConn: start")
// Prevent a panic in a single connection from taking out the whole server
defer func() {
Loading
Loading
@@ -181,22 +171,8 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
}
}()
ctxlog.Debug("server: handleConn: start")
sconn, chans, reqs, err := ssh.NewServerConn(nconn, s.serverConfig.get(ctx))
if err != nil {
logSSHInitError(ctxlog, err)
return
}
go ssh.DiscardRequests(reqs)
started := time.Now()
var establishSessionDuration float64
conn := newConnection(s.Config, remoteAddr, sconn)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error {
establishSessionDuration = time.Since(started).Seconds()
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
conn := newConnection(s.Config, nconn)
conn.handle(ctx, s.serverConfig.get(ctx), func(sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
session := &session{
cfg: s.Config,
channel: channel,
Loading
Loading
@@ -206,13 +182,6 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
return session.handle(ctx, requests)
})
reason := sconn.Wait()
ctxlog.WithFields(log.Fields{
"duration_s": time.Since(started).Seconds(),
"establish_session_duration_s": establishSessionDuration,
"reason": reason,
}).Info("server: handleConn: done")
}
func (s *Server) requirePolicy(_ net.Addr) (proxyproto.Policy, error) {
Loading
Loading
Loading
Loading
@@ -222,6 +222,35 @@ func TestInvalidServerConfig(t *testing.T) {
require.Nil(t, s.Shutdown())
}
func TestClosingHangedConnections(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := setupServerWithContext(t, nil, ctx)
unauthenticatedRequestStatus := make(chan string)
completed := make(chan bool)
clientCfg := clientConfig(t)
clientCfg.HostKeyCallback = func(_ string, _ net.Addr, _ ssh.PublicKey) error {
unauthenticatedRequestStatus <- "authentication-started"
<-completed // Wait infinitely
return nil
}
go func() {
// Start an SSH connection that never ends
ssh.Dial("tcp", serverUrl, clientCfg)
}()
require.Equal(t, "authentication-started", <-unauthenticatedRequestStatus)
require.NoError(t, s.Shutdown())
cancel()
verifyStatus(t, s, StatusClosed)
}
func setupServer(t *testing.T) *Server {
t.Helper()
Loading
Loading
@@ -231,6 +260,12 @@ func setupServer(t *testing.T) *Server {
func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
t.Helper()
return setupServerWithContext(t, cfg, context.Background())
}
func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Context) *Server {
t.Helper()
requests := []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/authorized_keys",
Loading
Loading
@@ -270,7 +305,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
s, err := NewServer(cfg)
require.NoError(t, err)
go func() { require.NoError(t, s.ListenAndServe(context.Background())) }()
go func() { require.NoError(t, s.ListenAndServe(ctx)) }()
t.Cleanup(func() { s.Shutdown() })
verifyStatus(t, s, StatusReady)
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment