As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Unverified Commit 5b6f4016 authored by Ash McKenzie's avatar Ash McKenzie Committed by GitLab
Browse files

Merge branch 'fix-lint-sshd-session' into 'main'

Resolve `make lint` (golangci-lint) issues for `internal/sshd/session.go` and `internal/sshd/session_test.go`

Closes #714

See merge request https://gitlab.com/gitlab-org/gitlab-shell/-/merge_requests/1032



Merged-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
Approved-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
Reviewed-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
Co-authored-by: default avatargaurav.marwal <gauravmarwal@gmail.com>
parents c28c003b 7ebf87cb
No related branches found
No related tags found
No related merge requests found
// Package sshd provides functionality for handling SSH sessions
package sshd
import (
Loading
Loading
@@ -26,7 +27,7 @@ type session struct {
// State set up by the connection
cfg *config.Config
channel ssh.Channel
gitlabKeyId string
gitlabKeyID string
gitlabKrb5Principal string
gitlabUsername string
namespace string
Loading
Loading
@@ -73,7 +74,8 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con
case "exec":
// The command has been executed as `ssh user@host command` or `exec` channel has been used
// in the app implementation
ctxWithLogData, shouldContinue, err = s.handleExec(ctx, req)
shouldContinue = false
ctxWithLogData, err = s.handleExec(ctx, req)
case "shell":
// The command has been entered into the shell or `shell` channel has been used
// in the app implementation
Loading
Loading
@@ -86,7 +88,7 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con
shouldContinue = true
if req.WantReply {
if err := req.Reply(false, []byte{}); err != nil {
if err = req.Reply(false, []byte{}); err != nil {
sessionLog.WithError(err).Debug("session: handle: Failed to reply")
}
}
Loading
Loading
@@ -95,7 +97,7 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con
sessionLog.WithField("should_continue", shouldContinue).Debug("session: handle: request processed")
if !shouldContinue {
s.channel.Close()
_ = s.channel.Close()
break
}
}
Loading
Loading
@@ -107,16 +109,16 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con
func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
var accepted bool
var envRequest envRequest
var envReq envRequest
if err := ssh.Unmarshal(req.Payload, &envRequest); err != nil {
if err := ssh.Unmarshal(req.Payload, &envReq); err != nil {
log.ContextLogger(ctx).WithError(err).Error("session: handleEnv: failed to unmarshal request")
return false, err
}
switch envRequest.Name {
switch envReq.Name {
case sshenv.GitProtocolEnv:
s.gitProtocolVersion = envRequest.Value
s.gitProtocolVersion = envReq.Value
accepted = true
default:
// Client requested a forbidden envvar, nothing to do
Loading
Loading
@@ -129,25 +131,25 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error)
}
log.WithContextFields(
ctx, log.Fields{"accepted": accepted, "env_request": envRequest},
ctx, log.Fields{"accepted": accepted, "env_request": envReq},
).Debug("session: handleEnv: processed")
return true, nil
}
func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Context, bool, error) {
var execRequest execRequest
func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Context, error) {
var execReq execRequest
if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
return ctx, false, err
if err := ssh.Unmarshal(req.Payload, &execReq); err != nil {
return ctx, err
}
s.execCmd = execRequest.Command
s.execCmd = execReq.Command
ctxWithLogData, status, err := s.handleShell(ctx, req)
s.exit(ctxWithLogData, status)
return ctxWithLogData, false, err
return ctxWithLogData, err
}
func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) {
Loading
Loading
@@ -175,25 +177,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
ErrOut: s.channel.Stderr(),
}
var cmd command.Command
var err error
if s.gitlabKrb5Principal != "" {
cmd, err = shellCmd.NewWithKrb5Principal(s.gitlabKrb5Principal, env, s.cfg, rw)
} else if s.gitlabUsername != "" {
cmd, err = shellCmd.NewWithUsername(s.gitlabUsername, env, s.cfg, rw)
} else {
cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
}
cmd, err := s.getCommand(env, rw)
if err != nil {
if errors.Is(err, disallowedcommand.Error) {
s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
} else {
s.toStderr(ctx, "ERROR: Failed to parse command: %v\n", err.Error())
}
return ctx, 128, err
return s.handleCommandError(ctx, err)
}
cmdName := reflect.TypeOf(cmd).String()
Loading
Loading
@@ -206,10 +193,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
ctxWithLogData, err := cmd.Execute(ctx)
logData := extractDataFromContext(ctxWithLogData)
logData := extractLogDataFromContext(ctxWithLogData)
logData.WrittenBytes = countingWriter.N
ctxWithLogData = context.WithValue(ctx, "logData", logData)
ctxWithLogData = context.WithValue(ctx, logInfo{}, logData)
if err != nil {
grpcStatus := grpcstatus.Convert(err)
Loading
Loading
@@ -225,6 +212,31 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
return ctxWithLogData, 0, nil
}
func (s *session) handleCommandError(ctx context.Context, err error) (context.Context, uint32, error) {
if errors.Is(err, disallowedcommand.Error) {
s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
} else {
s.toStderr(ctx, "ERROR: Failed to parse command: %v\n", err.Error())
}
return ctx, 128, err
}
func (s *session) getCommand(env sshenv.Env, rw *readwriter.ReadWriter) (command.Command, error) {
var cmd command.Command
var err error
switch {
case s.gitlabKrb5Principal != "":
cmd, err = shellCmd.NewWithKrb5Principal(s.gitlabKrb5Principal, env, s.cfg, rw)
case s.gitlabUsername != "":
cmd, err = shellCmd.NewWithUsername(s.gitlabUsername, env, s.cfg, rw)
default:
cmd, err = shellCmd.NewWithKey(s.gitlabKeyID, env, s.cfg, rw)
}
return cmd, err
}
func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
out := fmt.Sprintf(format, args...)
log.WithContextFields(ctx, log.Fields{"stderr": out}).Debug("session: toStderr: output")
Loading
Loading
@@ -235,6 +247,6 @@ func (s *session) exit(ctx context.Context, status uint32) {
log.WithContextFields(ctx, log.Fields{"exit_status": status}).Info("session: exit: exiting")
req := exitStatusReq{ExitStatus: status}
s.channel.CloseWrite()
s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
_ = s.channel.CloseWrite()
_, _ = s.channel.SendRequest("exit-status", false, ssh.Marshal(req))
}
Loading
Loading
@@ -23,7 +23,7 @@ type fakeChannel struct {
sentRequestPayload []byte
}
func (f *fakeChannel) Read(data []byte) (int, error) {
func (f *fakeChannel) Read(_ []byte) (int, error) {
return 0, nil
}
Loading
Loading
@@ -39,7 +39,7 @@ func (f *fakeChannel) CloseWrite() error {
return nil
}
func (f *fakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
func (f *fakeChannel) SendRequest(name string, _ bool, payload []byte) (bool, error) {
f.sentRequestName = name
f.sentRequestPayload = payload
Loading
Loading
@@ -53,7 +53,7 @@ func (f *fakeChannel) Stderr() io.ReadWriter {
var requests = []testserver.TestRequestHandler{
{
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
Handler: func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte(`{"id": 1000, "name": "Test User", "username": "test-user"}`))
},
},
Loading
Loading
@@ -133,7 +133,7 @@ func TestHandleExec(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
sessions := []*session{
{
gitlabKeyId: "id",
gitlabKeyID: "id",
cfg: &config.Config{GitlabUrl: url},
},
{
Loading
Loading
@@ -152,7 +152,8 @@ func TestHandleExec(t *testing.T) {
r := &ssh.Request{Payload: tc.payload}
s.channel = f
_, shouldContinue, err := s.handleExec(context.Background(), r)
shouldContinue := false
_, err := s.handleExec(context.Background(), r)
require.Equal(t, tc.expectedErr, err)
require.False(t, shouldContinue)
Loading
Loading
@@ -168,7 +169,7 @@ func TestHandleShell(t *testing.T) {
desc string
cmd string
errMsg string
gitlabKeyId string
gitlabKeyID string
expectedOutString string
expectedErrString string
expectedExitCode uint32
Loading
Loading
@@ -178,7 +179,7 @@ func TestHandleShell(t *testing.T) {
desc: "fails to parse command",
cmd: `\`,
errMsg: "ERROR: Failed to parse command: Invalid SSH command: invalid command line string\n",
gitlabKeyId: "root",
gitlabKeyID: "root",
expectedErrString: "Invalid SSH command: invalid command line string",
expectedExitCode: 128,
},
Loading
Loading
@@ -186,14 +187,14 @@ func TestHandleShell(t *testing.T) {
desc: "specified command is unknown",
cmd: "unknown-command",
errMsg: "ERROR: Unknown command: unknown-command\n",
gitlabKeyId: "root",
gitlabKeyID: "root",
expectedErrString: "Disallowed command",
expectedExitCode: 128,
},
{
desc: "fails to parse command",
cmd: "discover",
gitlabKeyId: "",
gitlabKeyID: "",
errMsg: "ERROR: Failed to get username: who='' is invalid\n",
expectedErrString: "Failed to get username: who='' is invalid",
expectedExitCode: 1,
Loading
Loading
@@ -202,7 +203,7 @@ func TestHandleShell(t *testing.T) {
desc: "parses command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
gitlabKeyID: "root",
expectedOutString: "Welcome to GitLab, @test-user!\n",
expectedErrString: "",
expectedExitCode: 0,
Loading
Loading
@@ -217,7 +218,7 @@ func TestHandleShell(t *testing.T) {
stdOut := &bytes.Buffer{}
stdErr := &bytes.Buffer{}
s := &session{
gitlabKeyId: tc.gitlabKeyId,
gitlabKeyID: tc.gitlabKeyID,
execCmd: tc.cmd,
channel: &fakeChannel{stdErr: stdErr, stdOut: stdOut},
cfg: &config.Config{GitlabUrl: url},
Loading
Loading
@@ -226,7 +227,7 @@ func TestHandleShell(t *testing.T) {
ctxWithLogData, exitCode, err := s.handleShell(context.Background(), r)
logData := extractDataFromContext(ctxWithLogData)
logInfo := extractLogDataFromContext(ctxWithLogData)
if tc.expectedOutString != "" {
require.Equal(t, tc.expectedOutString, stdOut.String())
Loading
Loading
@@ -237,7 +238,7 @@ func TestHandleShell(t *testing.T) {
}
require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.expectedWrittenBytes, logData.WrittenBytes)
require.Equal(t, tc.expectedWrittenBytes, logInfo.WrittenBytes)
formattedErr := &bytes.Buffer{}
if tc.errMsg != "" {
Loading
Loading
Loading
Loading
@@ -223,7 +223,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
session := &session{
cfg: s.Config,
channel: channel,
gitlabKeyId: sconn.Permissions.Extensions["key-id"],
gitlabKeyID: sconn.Permissions.Extensions["key-id"],
gitlabKrb5Principal: sconn.Permissions.Extensions["krb5principal"],
gitlabUsername: sconn.Permissions.Extensions["username"],
namespace: sconn.Permissions.Extensions["namespace"],
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment