As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Unverified Commit 37076c46 authored by Igor Drozdov's avatar Igor Drozdov
Browse files

Merge branch '648-gitlab-sshd-should-include-data-transfer-bytes-in-logs-3' into 'main'

Resolve "GitLab sshd should include data transfer bytes in logs"

Closes #648

See merge request https://gitlab.com/gitlab-org/gitlab-shell/-/merge_requests/831



Merged-by: default avatarIgor Drozdov <idrozdov@gitlab.com>
Approved-by: default avatarIgor Drozdov <idrozdov@gitlab.com>
Reviewed-by: default avatarIgor Drozdov <idrozdov@gitlab.com>
Reviewed-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
Co-authored-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
parents 7898d8e6 4bf9c833
No related branches found
No related tags found
No related merge requests found
Loading
@@ -23,7 +23,7 @@ func main() {
Loading
@@ -23,7 +23,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime) command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{ readWriter := &readwriter.ReadWriter{
Out: os.Stdout, Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin, In: os.Stdin,
ErrOut: os.Stderr, ErrOut: os.Stderr,
} }
Loading
Loading
Loading
@@ -24,7 +24,7 @@ func main() {
Loading
@@ -24,7 +24,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime) command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{ readWriter := &readwriter.ReadWriter{
Out: os.Stdout, Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin, In: os.Stdin,
ErrOut: os.Stderr, ErrOut: os.Stderr,
} }
Loading
Loading
Loading
@@ -24,7 +24,7 @@ func main() {
Loading
@@ -24,7 +24,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime) command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{ readWriter := &readwriter.ReadWriter{
Out: os.Stdout, Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin, In: os.Stdin,
ErrOut: os.Stderr, ErrOut: os.Stderr,
} }
Loading
Loading
Loading
@@ -32,7 +32,7 @@ func main() {
Loading
@@ -32,7 +32,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime) command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{ readWriter := &readwriter.ReadWriter{
Out: os.Stdout, Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin, In: os.Stdin,
ErrOut: os.Stderr, ErrOut: os.Stderr,
} }
Loading
Loading
Loading
@@ -22,8 +22,9 @@ type LogMetadata struct {
Loading
@@ -22,8 +22,9 @@ type LogMetadata struct {
} }
type LogData struct { type LogData struct {
Username string `json:"username"` Username string `json:"username"`
Meta LogMetadata `json:"meta"` WrittenBytes int64 `json:"written_bytes"`
Meta LogMetadata `json:"meta"`
} }
func CheckForVersionFlag(osArgs []string, version, buildTime string) { func CheckForVersionFlag(osArgs []string, version, buildTime string) {
Loading
@@ -87,7 +88,8 @@ func NewLogData(project, username string) LogData {
Loading
@@ -87,7 +88,8 @@ func NewLogData(project, username string) LogData {
} }
return LogData{ return LogData{
Username: username, Username: username,
WrittenBytes: 0,
Meta: LogMetadata{ Meta: LogMetadata{
Project: project, Project: project,
RootNamespace: rootNameSpace, RootNamespace: rootNameSpace,
Loading
Loading
package readwriter package readwriter
import "io" import (
"io"
)
type ReadWriter struct { type ReadWriter struct {
Out io.Writer Out io.Writer
In io.Reader In io.Reader
ErrOut io.Writer ErrOut io.Writer
} }
// CountingWriter wraps an io.Writer and counts all the writes. Accessing
// the count N is not thread-safe.
type CountingWriter struct {
W io.Writer
N int64
}
func (cw *CountingWriter) Write(p []byte) (int, error) {
n, err := cw.W.Write(p)
cw.N += int64(n)
return n, err
}
package readwriter
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestCountingWriter_Write(t *testing.T) {
testString := []byte("test string")
buffer := &bytes.Buffer{}
cw := &CountingWriter{
W: buffer,
}
n, err := cw.Write(testString)
require.NoError(t, err)
require.Equal(t, 11, n)
require.Equal(t, int64(11), cw.N)
cw.Write(testString)
require.Equal(t, int64(22), cw.N)
}
Loading
@@ -167,8 +167,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
Loading
@@ -167,8 +167,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
NamespacePath: s.namespace, NamespacePath: s.namespace,
} }
countingWriter := &readwriter.CountingWriter{W: s.channel}
rw := &readwriter.ReadWriter{ rw := &readwriter.ReadWriter{
Out: s.channel, Out: countingWriter,
In: s.channel, In: s.channel,
ErrOut: s.channel.Stderr(), ErrOut: s.channel.Stderr(),
} }
Loading
@@ -183,6 +185,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
Loading
@@ -183,6 +185,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
} else { } else {
cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw) cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
} }
if err != nil { if err != nil {
if errors.Is(err, disallowedcommand.Error) { if errors.Is(err, disallowedcommand.Error) {
s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd) s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
Loading
@@ -202,6 +205,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
Loading
@@ -202,6 +205,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
ctxWithLogData, err := cmd.Execute(ctx) ctxWithLogData, err := cmd.Execute(ctx)
logData := extractDataFromContext(ctxWithLogData)
logData.WrittenBytes = countingWriter.N
ctxWithLogData = context.WithValue(ctx, "logData", logData)
if err != nil { if err != nil {
grpcStatus := grpcstatus.Convert(err) grpcStatus := grpcstatus.Convert(err)
if grpcStatus.Code() != grpccodes.Internal { if grpcStatus.Code() != grpccodes.Internal {
Loading
Loading
Loading
@@ -18,6 +18,7 @@ import (
Loading
@@ -18,6 +18,7 @@ import (
type fakeChannel struct { type fakeChannel struct {
stdErr io.ReadWriter stdErr io.ReadWriter
stdOut io.ReadWriter
sentRequestName string sentRequestName string
sentRequestPayload []byte sentRequestPayload []byte
} }
Loading
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
Loading
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
} }
func (f *fakeChannel) Write(data []byte) (int, error) { func (f *fakeChannel) Write(data []byte) (int, error) {
return 0, nil return f.stdOut.Write(data)
} }
func (f *fakeChannel) Close() error { func (f *fakeChannel) Close() error {
Loading
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
Loading
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
}, },
} }
for _, s := range sessions { for _, s := range sessions {
out := &bytes.Buffer{} stdErr := &bytes.Buffer{}
f := &fakeChannel{stdErr: out} stdOut := &bytes.Buffer{}
f := &fakeChannel{stdErr: stdErr, stdOut: stdOut}
r := &ssh.Request{Payload: tc.payload} r := &ssh.Request{Payload: tc.payload}
s.channel = f s.channel = f
Loading
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
Loading
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
func TestHandleShell(t *testing.T) { func TestHandleShell(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
cmd string cmd string
errMsg string errMsg string
gitlabKeyId string gitlabKeyId string
expectedErrString string expectedOutString string
expectedExitCode uint32 expectedErrString string
expectedExitCode uint32
expectedWrittenBytes int64
}{ }{
{ {
desc: "fails to parse command", desc: "fails to parse command",
Loading
@@ -177,27 +181,32 @@ func TestHandleShell(t *testing.T) {
Loading
@@ -177,27 +181,32 @@ func TestHandleShell(t *testing.T) {
gitlabKeyId: "root", gitlabKeyId: "root",
expectedErrString: "Invalid SSH command: invalid command line string", expectedErrString: "Invalid SSH command: invalid command line string",
expectedExitCode: 128, expectedExitCode: 128,
}, { },
{
desc: "specified command is unknown", desc: "specified command is unknown",
cmd: "unknown-command", cmd: "unknown-command",
errMsg: "ERROR: Unknown command: unknown-command\n", errMsg: "ERROR: Unknown command: unknown-command\n",
gitlabKeyId: "root", gitlabKeyId: "root",
expectedErrString: "Disallowed command", expectedErrString: "Disallowed command",
expectedExitCode: 128, expectedExitCode: 128,
}, { },
{
desc: "fails to parse command", desc: "fails to parse command",
cmd: "discover", cmd: "discover",
gitlabKeyId: "", gitlabKeyId: "",
errMsg: "ERROR: Failed to get username: who='' is invalid\n", errMsg: "ERROR: Failed to get username: who='' is invalid\n",
expectedErrString: "Failed to get username: who='' is invalid", expectedErrString: "Failed to get username: who='' is invalid",
expectedExitCode: 1, expectedExitCode: 1,
}, { },
desc: "fails to parse command", {
cmd: "discover", desc: "parses command",
errMsg: "", cmd: "discover",
gitlabKeyId: "root", errMsg: "",
expectedErrString: "", gitlabKeyId: "root",
expectedExitCode: 0, expectedOutString: "Welcome to GitLab, @test-user!\n",
expectedErrString: "",
expectedExitCode: 0,
expectedWrittenBytes: 31,
}, },
} }
Loading
@@ -205,29 +214,37 @@ func TestHandleShell(t *testing.T) {
Loading
@@ -205,29 +214,37 @@ func TestHandleShell(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
out := &bytes.Buffer{} stdOut := &bytes.Buffer{}
stdErr := &bytes.Buffer{}
s := &session{ s := &session{
gitlabKeyId: tc.gitlabKeyId, gitlabKeyId: tc.gitlabKeyId,
execCmd: tc.cmd, execCmd: tc.cmd,
channel: &fakeChannel{stdErr: out}, channel: &fakeChannel{stdErr: stdErr, stdOut: stdOut},
cfg: &config.Config{GitlabUrl: url}, cfg: &config.Config{GitlabUrl: url},
} }
r := &ssh.Request{} r := &ssh.Request{}
_, exitCode, err := s.handleShell(context.Background(), r) ctxWithLogData, exitCode, err := s.handleShell(context.Background(), r)
logData := extractDataFromContext(ctxWithLogData)
if tc.expectedOutString != "" {
require.Equal(t, tc.expectedOutString, stdOut.String())
}
if tc.expectedErrString != "" { if tc.expectedErrString != "" {
require.Equal(t, tc.expectedErrString, err.Error()) require.Equal(t, tc.expectedErrString, err.Error())
} }
require.Equal(t, tc.expectedExitCode, exitCode) require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.expectedWrittenBytes, logData.WrittenBytes)
formattedErr := &bytes.Buffer{} formattedErr := &bytes.Buffer{}
if tc.errMsg != "" { if tc.errMsg != "" {
console.DisplayWarningMessage(tc.errMsg, formattedErr) console.DisplayWarningMessage(tc.errMsg, formattedErr)
require.Equal(t, formattedErr.String(), out.String()) require.Equal(t, formattedErr.String(), stdErr.String())
} else { } else {
require.Equal(t, tc.errMsg, out.String()) require.Equal(t, tc.errMsg, stdErr.String())
} }
}) })
} }
Loading
Loading
Loading
@@ -217,8 +217,9 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
Loading
@@ -217,8 +217,9 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
logData := extractDataFromContext(ctxWithLogData) logData := extractDataFromContext(ctxWithLogData)
ctxlog.WithFields(log.Fields{ ctxlog.WithFields(log.Fields{
"duration_s": time.Since(started).Seconds(), "duration_s": time.Since(started).Seconds(),
"meta": logData.Meta, "written_bytes": logData.WrittenBytes,
"meta": logData.Meta,
}).Info("access: finish") }).Info("access: finish")
} }
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment