As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit 4bf9c833 authored by Ash McKenzie's avatar Ash McKenzie
Browse files

Set logData.WrittenBytes from CountingWriter.N

parent a509a44a
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -167,8 +167,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
NamespacePath: s.namespace,
}
countingWriter := &readwriter.CountingWriter{W: s.channel}
rw := &readwriter.ReadWriter{
Out: &readwriter.CountingWriter{W: s.channel},
Out: countingWriter,
In: s.channel,
ErrOut: s.channel.Stderr(),
}
Loading
Loading
@@ -183,6 +185,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
} else {
cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
}
if err != nil {
if errors.Is(err, disallowedcommand.Error) {
s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
Loading
Loading
@@ -202,6 +205,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
ctxWithLogData, err := cmd.Execute(ctx)
logData := extractDataFromContext(ctxWithLogData)
logData.WrittenBytes = countingWriter.N
ctxWithLogData = context.WithValue(ctx, "logData", logData)
if err != nil {
grpcStatus := grpcstatus.Convert(err)
if grpcStatus.Code() != grpccodes.Internal {
Loading
Loading
Loading
Loading
@@ -18,6 +18,7 @@ import (
type fakeChannel struct {
stdErr io.ReadWriter
stdOut io.ReadWriter
sentRequestName string
sentRequestPayload []byte
}
Loading
Loading
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
}
func (f *fakeChannel) Write(data []byte) (int, error) {
return 0, nil
return f.stdOut.Write(data)
}
func (f *fakeChannel) Close() error {
Loading
Loading
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
},
}
for _, s := range sessions {
out := &bytes.Buffer{}
f := &fakeChannel{stdErr: out}
stdErr := &bytes.Buffer{}
stdOut := &bytes.Buffer{}
f := &fakeChannel{stdErr: stdErr, stdOut: stdOut}
r := &ssh.Request{Payload: tc.payload}
s.channel = f
Loading
Loading
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
func TestHandleShell(t *testing.T) {
testCases := []struct {
desc string
cmd string
errMsg string
gitlabKeyId string
expectedErrString string
expectedExitCode uint32
desc string
cmd string
errMsg string
gitlabKeyId string
expectedOutString string
expectedErrString string
expectedExitCode uint32
expectedWrittenBytes int64
}{
{
desc: "fails to parse command",
Loading
Loading
@@ -177,27 +181,32 @@ func TestHandleShell(t *testing.T) {
gitlabKeyId: "root",
expectedErrString: "Invalid SSH command: invalid command line string",
expectedExitCode: 128,
}, {
},
{
desc: "specified command is unknown",
cmd: "unknown-command",
errMsg: "ERROR: Unknown command: unknown-command\n",
gitlabKeyId: "root",
expectedErrString: "Disallowed command",
expectedExitCode: 128,
}, {
},
{
desc: "fails to parse command",
cmd: "discover",
gitlabKeyId: "",
errMsg: "ERROR: Failed to get username: who='' is invalid\n",
expectedErrString: "Failed to get username: who='' is invalid",
expectedExitCode: 1,
}, {
desc: "fails to parse command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
expectedErrString: "",
expectedExitCode: 0,
},
{
desc: "parses command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
expectedOutString: "Welcome to GitLab, @test-user!\n",
expectedErrString: "",
expectedExitCode: 0,
expectedWrittenBytes: 31,
},
}
Loading
Loading
@@ -205,29 +214,37 @@ func TestHandleShell(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
out := &bytes.Buffer{}
stdOut := &bytes.Buffer{}
stdErr := &bytes.Buffer{}
s := &session{
gitlabKeyId: tc.gitlabKeyId,
execCmd: tc.cmd,
channel: &fakeChannel{stdErr: out},
channel: &fakeChannel{stdErr: stdErr, stdOut: stdOut},
cfg: &config.Config{GitlabUrl: url},
}
r := &ssh.Request{}
_, exitCode, err := s.handleShell(context.Background(), r)
ctxWithLogData, exitCode, err := s.handleShell(context.Background(), r)
logData := extractDataFromContext(ctxWithLogData)
if tc.expectedOutString != "" {
require.Equal(t, tc.expectedOutString, stdOut.String())
}
if tc.expectedErrString != "" {
require.Equal(t, tc.expectedErrString, err.Error())
}
require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.expectedWrittenBytes, logData.WrittenBytes)
formattedErr := &bytes.Buffer{}
if tc.errMsg != "" {
console.DisplayWarningMessage(tc.errMsg, formattedErr)
require.Equal(t, formattedErr.String(), out.String())
require.Equal(t, formattedErr.String(), stdErr.String())
} else {
require.Equal(t, tc.errMsg, out.String())
require.Equal(t, tc.errMsg, stdErr.String())
}
})
}
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment