As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Unverified Commit ce561187 authored by Igor Drozdov's avatar Igor Drozdov
Browse files

Return metadata context without using channels

On production environment a race condition causes a panic
because a message is being sent to a closed channel
parent 37b54f61
No related branches found
No related tags found
Loading
Loading
Loading
@@ -36,7 +36,7 @@ type connection struct {
remoteAddr string
}
type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error)
type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
func newConnection(cfg *config.Config, nconn net.Conn) *connection {
maxSessions := cfg.Server.ConcurrentSessionsLimit
Loading
Loading
@@ -50,12 +50,12 @@ func newConnection(cfg *config.Config, nconn net.Conn) *connection {
}
}
func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) context.Context {
func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) {
log.WithContextFields(ctx, log.Fields{}).Info("server: handleConn: start")
sconn, chans, err := c.initServerConn(ctx, srvCfg)
if err != nil {
return ctx
return
}
if c.cfg.Server.ClientAliveInterval > 0 {
Loading
Loading
@@ -64,17 +64,10 @@ func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handl
go c.sendKeepAliveMsg(ctx, sconn, ticker)
}
ctxWithLogMetadataChan := make(chan context.Context)
defer close(ctxWithLogMetadataChan)
go c.handleRequests(ctx, sconn, chans, ctxWithLogMetadataChan, handler)
ctxWithLogMetadata := <-ctxWithLogMetadataChan
c.handleRequests(ctx, sconn, chans, handler)
reason := sconn.Wait()
log.WithContextFields(ctx, log.Fields{"reason": reason}).Info("server: handleConn: done")
return ctxWithLogMetadata
}
func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) {
Loading
Loading
@@ -101,7 +94,7 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi
return sconn, chans, err
}
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, ctxWithLogMetadataChan chan<- context.Context, handler channelHandler) {
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
Loading
Loading
@@ -144,12 +137,10 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn,
}()
metrics.SliSshdSessionsTotal.Inc()
ctxWithLogMetadata, err := handler(ctx, sconn, channel, requests)
err := handler(ctx, sconn, channel, requests)
if err != nil {
c.trackError(ctxlog, err)
}
ctxWithLogMetadataChan <- ctxWithLogMetadata
}()
}
Loading
Loading
Loading
Loading
@@ -81,25 +81,23 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo
return true, nil, nil
}
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel, chan<- context.Context) {
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}}
conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)}
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
ctxWithLogMetadataChan := make(chan context.Context)
return conn, chans, ctxWithLogMetadataChan
return conn, chans
}
func TestPanicDuringSessionIsRecovered(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
numSessions := 0
require.NotPanics(t, func() {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
Loading
Loading
@@ -114,10 +112,10 @@ func TestUnknownChannelType(t *testing.T) {
defer close(rejectCh)
newChannel := &fakeNewChannel{channelType: "unknown session", rejectCh: rejectCh}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
go func() {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, nil)
conn.handleRequests(context.Background(), nil, chans, nil)
}()
rejectionData := <-rejectCh
Loading
Loading
@@ -131,15 +129,15 @@ func TestTooManySessions(t *testing.T) {
defer close(rejectCh)
newChannel := &fakeNewChannel{channelType: "session", rejectCh: rejectCh}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return ctx, nil
return nil
})
}()
Loading
Loading
@@ -149,14 +147,14 @@ func TestTooManySessions(t *testing.T) {
func TestAcceptSessionSucceeds(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
return ctx, nil
return nil
})
require.True(t, channelHandled)
Loading
Loading
@@ -168,14 +166,14 @@ func TestAcceptSessionFails(t *testing.T) {
acceptErr := errors.New("some failure")
newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
go func() {
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
return ctx, nil
return nil
})
}()
Loading
Loading
@@ -207,12 +205,12 @@ func TestSessionsMetrics(t *testing.T) {
initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal)
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return ctx, errors.New("custom error")
return errors.New("custom error")
})
eventuallyInDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
Loading
Loading
@@ -229,13 +227,13 @@ func TestSessionsMetrics(t *testing.T) {
{"not our ref", grpcstatus.Error(grpccodes.Internal, `rpc error: code = Internal desc = cmd wait: exit status 128, stderr: "fatal: git upload-pack: not our ref 9106d18f6a1b8022f6517f479696f3e3ea5e68c1"`)},
} {
t.Run(ignoredError.desc, func(t *testing.T) {
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ignored := ignoredError.err
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return ctx, ignored
return ignored
})
eventuallyInDelta(t, initialSessionsTotal+2+float64(i), testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
Loading
Loading
Loading
Loading
@@ -194,7 +194,8 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started := time.Now()
conn := newConnection(s.Config, nconn)
ctxWithLogMetadata := conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) {
var ctxWithLogMetadata context.Context
conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
session := &session{
cfg: s.Config,
channel: channel,
Loading
Loading
@@ -204,7 +205,10 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started: started,
}
return session.handle(ctx, requests)
var err error
ctxWithLogMetadata, err = session.handle(ctx, requests)
return err
})
ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds(), "meta": extractMetaDataFromContext(ctxWithLogMetadata)}).Info("access: finish")
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment