As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit 9cb22b2f authored by Patrick Bajao's avatar Patrick Bajao
Browse files

Merge branch 'id-wait-until-gitaly-execution' into 'main'

Wait until all Gitaly sessions are executed

See merge request gitlab-org/gitlab-shell!624
parents 7cde0770 509e04b6
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -39,7 +39,7 @@ func TestCustomPrometheusMetrics(t *testing.T) {
require.NoError(t, err)
var actualNames []string
for _, m := range ms[0:9] {
for _, m := range ms[0:10] {
actualNames = append(actualNames, m.GetName())
}
Loading
Loading
@@ -47,6 +47,7 @@ func TestCustomPrometheusMetrics(t *testing.T) {
"gitlab_shell_http_in_flight_requests",
"gitlab_shell_http_request_duration_seconds",
"gitlab_shell_http_requests_total",
"gitlab_shell_sshd_canceled_sessions",
"gitlab_shell_sshd_concurrent_limited_sessions_total",
"gitlab_shell_sshd_in_flight_connections",
"gitlab_shell_sshd_session_duration_seconds",
Loading
Loading
Loading
Loading
@@ -22,6 +22,7 @@ const (
sshdHitMaxSessionsName = "concurrent_limited_sessions_total"
sshdSessionDurationSecondsName = "session_duration_seconds"
sshdSessionEstablishedDurationSecondsName = "session_established_duration_seconds"
sshdCanceledSessionsName = "canceled_sessions"
sliSshdSessionsTotalName = "gitlab_sli:shell_sshd_sessions:total"
sliSshdSessionsErrorsTotalName = "gitlab_sli:shell_sshd_sessions:errors_total"
Loading
Loading
@@ -76,6 +77,15 @@ var (
},
)
SshdCanceledSessions = promauto.NewCounter(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: sshdSubsystem,
Name: sshdCanceledSessionsName,
Help: "The number of canceled gitlab-sshd sessions.",
},
)
SliSshdSessionsTotal = promauto.NewCounter(
prometheus.CounterOpts{
Name: sliSshdSessionsTotalName,
Loading
Loading
Loading
Loading
@@ -6,6 +6,8 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/sync/semaphore"
grpccodes "google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
Loading
Loading
@@ -15,19 +17,25 @@ import (
const KeepAliveMsg = "keepalive@openssh.com"
var EOFTimeout = 10 * time.Second
type connection struct {
cfg *config.Config
concurrentSessions *semaphore.Weighted
remoteAddr string
sconn *ssh.ServerConn
maxSessions int64
}
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request) error
func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
maxSessions := cfg.Server.ConcurrentSessionsLimit
return &connection{
cfg: cfg,
concurrentSessions: semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit),
maxSessions: maxSessions,
concurrentSessions: semaphore.NewWeighted(maxSessions),
remoteAddr: remoteAddr,
sconn: sconn,
}
Loading
Loading
@@ -76,10 +84,27 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
}
}()
handler(ctx, channel, requests)
metrics.SliSshdSessionsTotal.Inc()
err := handler(ctx, channel, requests)
if err != nil {
if grpcstatus.Convert(err).Code() == grpccodes.Canceled {
metrics.SshdCanceledSessions.Inc()
} else {
metrics.SliSshdSessionsErrorsTotal.Inc()
}
}
ctxlog.Info("connection: handle: done")
}()
}
// When a connection has been prematurely closed we block execution until all concurrent sessions are released
// in order to allow Gitaly complete the operations and close all the channels gracefully.
// If it didn't happen within timeout, we unblock the execution
// Related issue: https://gitlab.com/gitlab-org/gitlab-shell/-/issues/563
ctx, cancel := context.WithTimeout(ctx, EOFTimeout)
defer cancel()
c.concurrentSessions.Acquire(ctx, c.maxSessions)
}
func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
Loading
Loading
Loading
Loading
@@ -7,10 +7,14 @@ import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
grpccodes "google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
)
type rejectCall struct {
Loading
Loading
@@ -90,7 +94,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
Loading
Loading
@@ -128,8 +132,9 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return nil
})
}()
Loading
Loading
@@ -142,9 +147,10 @@ func TestAcceptSessionSucceeds(t *testing.T) {
conn, chans := setup(1, newChannel)
channelHandled := false
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
return nil
})
require.True(t, channelHandled)
Loading
Loading
@@ -160,8 +166,9 @@ func TestAcceptSessionFails(t *testing.T) {
channelHandled := false
go func() {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) {
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
return nil
})
}()
Loading
Loading
@@ -186,3 +193,33 @@ func TestClientAliveInterval(t *testing.T) {
require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond)
}
func TestSessionsMetrics(t *testing.T) {
// Unfortunately, there is no working way to reset Counter (not CounterVec)
// https://pkg.go.dev/github.com/prometheus/client_golang/prometheus#pkg-index
initialSessionsTotal := testutil.ToFloat64(metrics.SliSshdSessionsTotal)
initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal)
initialCanceledSessions := testutil.ToFloat64(metrics.SshdCanceledSessions)
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return errors.New("custom error")
})
require.InDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
require.InDelta(t, initialCanceledSessions, testutil.ToFloat64(metrics.SshdCanceledSessions), 0.1)
conn, chans = setup(1, newChannel)
conn.handle(context.Background(), chans, func(context.Context, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return grpcstatus.Error(grpccodes.Canceled, "error")
})
require.InDelta(t, initialSessionsTotal+2, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
require.InDelta(t, initialSessionsErrorTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal), 0.1)
require.InDelta(t, initialCanceledSessions+1, testutil.ToFloat64(metrics.SshdCanceledSessions), 0.1)
}
Loading
Loading
@@ -22,7 +22,6 @@ type session struct {
channel ssh.Channel
gitlabKeyId string
remoteAddr string
success bool
// State managed by the session
execCmd string
Loading
Loading
@@ -42,11 +41,12 @@ type exitStatusReq struct {
ExitStatus uint32
}
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("session: handle: entering request loop")
var err error
for req := range requests {
sessionLog := ctxlog.WithFields(log.Fields{
"bytesize": len(req.Payload),
Loading
Loading
@@ -58,12 +58,14 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
var shouldContinue bool
switch req.Type {
case "env":
shouldContinue = s.handleEnv(ctx, req)
shouldContinue, err = s.handleEnv(ctx, req)
case "exec":
shouldContinue = s.handleExec(ctx, req)
shouldContinue, err = s.handleExec(ctx, req)
case "shell":
shouldContinue = false
s.exit(ctx, s.handleShell(ctx, req))
var status uint32
status, err = s.handleShell(ctx, req)
s.exit(ctx, status)
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
Loading
Loading
@@ -84,15 +86,17 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
}
ctxlog.Debug("session: handle: exiting request loop")
return err
}
func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
var accepted bool
var envRequest envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
return false
return false, err
}
switch envRequest.Name {
Loading
Loading
@@ -113,23 +117,24 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
).Debug("session: handleEnv: processed")
return true
return true, nil
}
func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) {
var execRequest execRequest
if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
return false
return false, err
}
s.execCmd = execRequest.Command
s.exit(ctx, s.handleShell(ctx, req))
status, err := s.handleShell(ctx, req)
s.exit(ctx, status)
return false
return false, err
}
func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) {
ctxlog := log.ContextLogger(ctx)
if req.WantReply {
Loading
Loading
@@ -157,7 +162,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
s.toStderr(ctx, "Failed to parse command: %v\n", err.Error())
}
s.toStderr(ctx, "Unknown command: %v\n", s.execCmd)
return 128
return 128, err
}
cmdName := reflect.TypeOf(cmd).String()
Loading
Loading
@@ -165,12 +170,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
if err := cmd.Execute(ctx); err != nil {
s.toStderr(ctx, "remote: ERROR: %v\n", err.Error())
return 1
return 1, err
}
ctxlog.Info("session: handleShell: command executed successfully")
return 0
return 0, nil
}
func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
Loading
Loading
@@ -183,8 +188,6 @@ func (s *session) exit(ctx context.Context, status uint32) {
log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
s.success = status == 0
s.channel.CloseWrite()
s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}
Loading
Loading
@@ -3,6 +3,7 @@ package sshd
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"testing"
Loading
Loading
@@ -60,22 +61,26 @@ func TestHandleEnv(t *testing.T) {
testCases := []struct {
desc string
payload []byte
expectedErr error
expectedProtocolVersion string
expectedResult bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
expectedErr: errors.New("ssh: unmarshal error for field Name of type envRequest"),
expectedProtocolVersion: "1",
expectedResult: false,
}, {
desc: "valid payload",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL", Value: "2"}),
expectedErr: nil,
expectedProtocolVersion: "2",
expectedResult: true,
}, {
desc: "valid payload with forbidden env var",
payload: ssh.Marshal(envRequest{Name: "GIT_PROTOCOL_ENV", Value: "2"}),
expectedErr: nil,
expectedProtocolVersion: "1",
expectedResult: true,
},
Loading
Loading
@@ -86,8 +91,11 @@ func TestHandleEnv(t *testing.T) {
s := &session{gitProtocolVersion: "1"}
r := &ssh.Request{Payload: tc.payload}
require.Equal(t, s.handleEnv(context.Background(), r), tc.expectedResult)
require.Equal(t, s.gitProtocolVersion, tc.expectedProtocolVersion)
shouldContinue, err := s.handleEnv(context.Background(), r)
require.Equal(t, tc.expectedErr, err)
require.Equal(t, tc.expectedResult, shouldContinue)
require.Equal(t, tc.expectedProtocolVersion, s.gitProtocolVersion)
})
}
}
Loading
Loading
@@ -96,23 +104,24 @@ func TestHandleExec(t *testing.T) {
testCases := []struct {
desc string
payload []byte
expectedErr error
expectedExecCmd string
sentRequestName string
sentRequestPayload []byte
success bool
}{
{
desc: "invalid payload",
payload: []byte("invalid"),
expectedErr: errors.New("ssh: unmarshal error for field Command of type execRequest"),
expectedExecCmd: "",
sentRequestName: "",
}, {
desc: "valid payload",
payload: ssh.Marshal(execRequest{Command: "discover"}),
expectedErr: nil,
expectedExecCmd: "discover",
sentRequestName: "exit-status",
sentRequestPayload: ssh.Marshal(exitStatusReq{ExitStatus: 0}),
success: true,
},
}
Loading
Loading
@@ -129,47 +138,53 @@ func TestHandleExec(t *testing.T) {
}
r := &ssh.Request{Payload: tc.payload}
require.Equal(t, false, s.handleExec(context.Background(), r))
shouldContinue, err := s.handleExec(context.Background(), r)
require.Equal(t, tc.expectedErr, err)
require.Equal(t, false, shouldContinue)
require.Equal(t, tc.sentRequestName, f.sentRequestName)
require.Equal(t, tc.sentRequestPayload, f.sentRequestPayload)
require.Equal(t, tc.success, s.success)
})
}
}
func TestHandleShell(t *testing.T) {
testCases := []struct {
desc string
cmd string
errMsg string
gitlabKeyId string
expectedExitCode uint32
success bool
desc string
cmd string
errMsg string
gitlabKeyId string
expectedErrString string
expectedExitCode uint32
}{
{
desc: "fails to parse command",
cmd: `\`,
errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
gitlabKeyId: "root",
expectedExitCode: 128,
desc: "fails to parse command",
cmd: `\`,
errMsg: "Failed to parse command: Invalid SSH command: invalid command line string\nUnknown command: \\\n",
gitlabKeyId: "root",
expectedErrString: "Invalid SSH command: invalid command line string",
expectedExitCode: 128,
}, {
desc: "specified command is unknown",
cmd: "unknown-command",
errMsg: "Unknown command: unknown-command\n",
gitlabKeyId: "root",
expectedExitCode: 128,
desc: "specified command is unknown",
cmd: "unknown-command",
errMsg: "Unknown command: unknown-command\n",
gitlabKeyId: "root",
expectedErrString: "Disallowed command",
expectedExitCode: 128,
}, {
desc: "fails to parse command",
cmd: "discover",
gitlabKeyId: "",
errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
expectedExitCode: 1,
desc: "fails to parse command",
cmd: "discover",
gitlabKeyId: "",
errMsg: "remote: ERROR: Failed to get username: who='' is invalid\n",
expectedErrString: "Failed to get username: who='' is invalid",
expectedExitCode: 1,
}, {
desc: "fails to parse command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
expectedExitCode: 0,
desc: "fails to parse command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
expectedErrString: "",
expectedExitCode: 0,
},
}
Loading
Loading
@@ -186,7 +201,13 @@ func TestHandleShell(t *testing.T) {
}
r := &ssh.Request{}
require.Equal(t, tc.expectedExitCode, s.handleShell(context.Background(), r))
exitCode, err := s.handleShell(context.Background(), r)
if tc.expectedErrString != "" {
require.Equal(t, tc.expectedErrString, err.Error())
}
require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.errMsg, out.String())
})
}
Loading
Loading
Loading
Loading
@@ -181,7 +181,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started := time.Now()
var establishSessionDuration float64
conn := newConnection(s.Config, remoteAddr, sconn)
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) error {
establishSessionDuration = time.Since(started).Seconds()
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
Loading
Loading
@@ -192,11 +192,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
remoteAddr: remoteAddr,
}
metrics.SliSshdSessionsTotal.Inc()
session.handle(ctx, requests)
if !session.success {
metrics.SliSshdSessionsErrorsTotal.Inc()
}
return session.handle(ctx, requests)
})
reason := sconn.Wait()
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment