As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Unverified Commit 37076c46 authored by Igor Drozdov's avatar Igor Drozdov
Browse files

Merge branch '648-gitlab-sshd-should-include-data-transfer-bytes-in-logs-3' into 'main'

Resolve "GitLab sshd should include data transfer bytes in logs"

Closes #648

See merge request https://gitlab.com/gitlab-org/gitlab-shell/-/merge_requests/831



Merged-by: default avatarIgor Drozdov <idrozdov@gitlab.com>
Approved-by: default avatarIgor Drozdov <idrozdov@gitlab.com>
Reviewed-by: default avatarIgor Drozdov <idrozdov@gitlab.com>
Reviewed-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
Co-authored-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
parents 7898d8e6 4bf9c833
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -23,7 +23,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{
Out: os.Stdout,
Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin,
ErrOut: os.Stderr,
}
Loading
Loading
Loading
Loading
@@ -24,7 +24,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{
Out: os.Stdout,
Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin,
ErrOut: os.Stderr,
}
Loading
Loading
Loading
Loading
@@ -24,7 +24,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{
Out: os.Stdout,
Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin,
ErrOut: os.Stderr,
}
Loading
Loading
Loading
Loading
@@ -32,7 +32,7 @@ func main() {
command.CheckForVersionFlag(os.Args, Version, BuildTime)
readWriter := &readwriter.ReadWriter{
Out: os.Stdout,
Out: &readwriter.CountingWriter{W: os.Stdout},
In: os.Stdin,
ErrOut: os.Stderr,
}
Loading
Loading
Loading
Loading
@@ -22,8 +22,9 @@ type LogMetadata struct {
}
type LogData struct {
Username string `json:"username"`
Meta LogMetadata `json:"meta"`
Username string `json:"username"`
WrittenBytes int64 `json:"written_bytes"`
Meta LogMetadata `json:"meta"`
}
func CheckForVersionFlag(osArgs []string, version, buildTime string) {
Loading
Loading
@@ -87,7 +88,8 @@ func NewLogData(project, username string) LogData {
}
return LogData{
Username: username,
Username: username,
WrittenBytes: 0,
Meta: LogMetadata{
Project: project,
RootNamespace: rootNameSpace,
Loading
Loading
package readwriter
import "io"
import (
"io"
)
type ReadWriter struct {
Out io.Writer
In io.Reader
ErrOut io.Writer
}
// CountingWriter wraps an io.Writer and counts all the writes. Accessing
// the count N is not thread-safe.
type CountingWriter struct {
W io.Writer
N int64
}
func (cw *CountingWriter) Write(p []byte) (int, error) {
n, err := cw.W.Write(p)
cw.N += int64(n)
return n, err
}
package readwriter
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestCountingWriter_Write(t *testing.T) {
testString := []byte("test string")
buffer := &bytes.Buffer{}
cw := &CountingWriter{
W: buffer,
}
n, err := cw.Write(testString)
require.NoError(t, err)
require.Equal(t, 11, n)
require.Equal(t, int64(11), cw.N)
cw.Write(testString)
require.Equal(t, int64(22), cw.N)
}
Loading
Loading
@@ -167,8 +167,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
NamespacePath: s.namespace,
}
countingWriter := &readwriter.CountingWriter{W: s.channel}
rw := &readwriter.ReadWriter{
Out: s.channel,
Out: countingWriter,
In: s.channel,
ErrOut: s.channel.Stderr(),
}
Loading
Loading
@@ -183,6 +185,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
} else {
cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
}
if err != nil {
if errors.Is(err, disallowedcommand.Error) {
s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
Loading
Loading
@@ -202,6 +205,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
ctxWithLogData, err := cmd.Execute(ctx)
logData := extractDataFromContext(ctxWithLogData)
logData.WrittenBytes = countingWriter.N
ctxWithLogData = context.WithValue(ctx, "logData", logData)
if err != nil {
grpcStatus := grpcstatus.Convert(err)
if grpcStatus.Code() != grpccodes.Internal {
Loading
Loading
Loading
Loading
@@ -18,6 +18,7 @@ import (
type fakeChannel struct {
stdErr io.ReadWriter
stdOut io.ReadWriter
sentRequestName string
sentRequestPayload []byte
}
Loading
Loading
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
}
func (f *fakeChannel) Write(data []byte) (int, error) {
return 0, nil
return f.stdOut.Write(data)
}
func (f *fakeChannel) Close() error {
Loading
Loading
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
},
}
for _, s := range sessions {
out := &bytes.Buffer{}
f := &fakeChannel{stdErr: out}
stdErr := &bytes.Buffer{}
stdOut := &bytes.Buffer{}
f := &fakeChannel{stdErr: stdErr, stdOut: stdOut}
r := &ssh.Request{Payload: tc.payload}
s.channel = f
Loading
Loading
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
func TestHandleShell(t *testing.T) {
testCases := []struct {
desc string
cmd string
errMsg string
gitlabKeyId string
expectedErrString string
expectedExitCode uint32
desc string
cmd string
errMsg string
gitlabKeyId string
expectedOutString string
expectedErrString string
expectedExitCode uint32
expectedWrittenBytes int64
}{
{
desc: "fails to parse command",
Loading
Loading
@@ -177,27 +181,32 @@ func TestHandleShell(t *testing.T) {
gitlabKeyId: "root",
expectedErrString: "Invalid SSH command: invalid command line string",
expectedExitCode: 128,
}, {
},
{
desc: "specified command is unknown",
cmd: "unknown-command",
errMsg: "ERROR: Unknown command: unknown-command\n",
gitlabKeyId: "root",
expectedErrString: "Disallowed command",
expectedExitCode: 128,
}, {
},
{
desc: "fails to parse command",
cmd: "discover",
gitlabKeyId: "",
errMsg: "ERROR: Failed to get username: who='' is invalid\n",
expectedErrString: "Failed to get username: who='' is invalid",
expectedExitCode: 1,
}, {
desc: "fails to parse command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
expectedErrString: "",
expectedExitCode: 0,
},
{
desc: "parses command",
cmd: "discover",
errMsg: "",
gitlabKeyId: "root",
expectedOutString: "Welcome to GitLab, @test-user!\n",
expectedErrString: "",
expectedExitCode: 0,
expectedWrittenBytes: 31,
},
}
Loading
Loading
@@ -205,29 +214,37 @@ func TestHandleShell(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
out := &bytes.Buffer{}
stdOut := &bytes.Buffer{}
stdErr := &bytes.Buffer{}
s := &session{
gitlabKeyId: tc.gitlabKeyId,
execCmd: tc.cmd,
channel: &fakeChannel{stdErr: out},
channel: &fakeChannel{stdErr: stdErr, stdOut: stdOut},
cfg: &config.Config{GitlabUrl: url},
}
r := &ssh.Request{}
_, exitCode, err := s.handleShell(context.Background(), r)
ctxWithLogData, exitCode, err := s.handleShell(context.Background(), r)
logData := extractDataFromContext(ctxWithLogData)
if tc.expectedOutString != "" {
require.Equal(t, tc.expectedOutString, stdOut.String())
}
if tc.expectedErrString != "" {
require.Equal(t, tc.expectedErrString, err.Error())
}
require.Equal(t, tc.expectedExitCode, exitCode)
require.Equal(t, tc.expectedWrittenBytes, logData.WrittenBytes)
formattedErr := &bytes.Buffer{}
if tc.errMsg != "" {
console.DisplayWarningMessage(tc.errMsg, formattedErr)
require.Equal(t, formattedErr.String(), out.String())
require.Equal(t, formattedErr.String(), stdErr.String())
} else {
require.Equal(t, tc.errMsg, out.String())
require.Equal(t, tc.errMsg, stdErr.String())
}
})
}
Loading
Loading
Loading
Loading
@@ -217,8 +217,9 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
logData := extractDataFromContext(ctxWithLogData)
ctxlog.WithFields(log.Fields{
"duration_s": time.Since(started).Seconds(),
"meta": logData.Meta,
"duration_s": time.Since(started).Seconds(),
"written_bytes": logData.WrittenBytes,
"meta": logData.Meta,
}).Info("access: finish")
}
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment