diff --git a/pam/integration-tests/cli_test.go b/pam/integration-tests/cli_test.go index 7f5c10f321..cc4d9151cc 100644 --- a/pam/integration-tests/cli_test.go +++ b/pam/integration-tests/cli_test.go @@ -38,6 +38,7 @@ func TestCLIAuthenticate(t *testing.T) { socketPath string currentUserNotRoot bool wantLocalGroups bool + skipRunnerCheck bool oldDB string stopDaemonAfter time.Duration }{ @@ -228,6 +229,14 @@ func TestCLIAuthenticate(t *testing.T) { tape: "authd_stopped", stopDaemonAfter: sleepDuration(defaultSleepValues[authdSleepLong] * 5), }, + "Exit_the_pam_client_if_parent_pam_application_is_stopped": { + tape: "pam_app_killed", + tapeVariables: map[string]string{ + "AUTHD_TEST_TAPE_AUTHD_PAM_BINARY_NAME": authdPamBinaryName, + vhsCommandFinalAuthWaitVariable: "", + }, + skipRunnerCheck: true, + }, "Error_if_cannot_connect_to_authd": { tape: "connection_error", @@ -284,7 +293,9 @@ func TestCLIAuthenticate(t *testing.T) { localgroupstestutils.RequireGroupFile(t, groupFileOutput, golden.Path(t)) - requireRunnerResultForUser(t, authd.SessionMode_LOGIN, tc.clientOptions.PamUser, got) + if !tc.skipRunnerCheck { + requireRunnerResultForUser(t, authd.SessionMode_LOGIN, tc.clientOptions.PamUser, got) + } }) } } diff --git a/pam/integration-tests/cmd/exec-client/client.go b/pam/integration-tests/cmd/exec-client/client.go index 062262deda..7174fb97de 100644 --- a/pam/integration-tests/cmd/exec-client/client.go +++ b/pam/integration-tests/cmd/exec-client/client.go @@ -23,10 +23,9 @@ import ( ) var ( - pamFlags = flag.Int64("flags", 0, "pam flags") - serverAddress = flag.String("server-address", "", "the dbus connection address to use to communicate with module") - logFile = flag.String("client-log", "", "the file where to save logs") - argsFile = flag.String("client-args-file", "", "the file where arguments are saved") + pamFlags = flag.Int64("flags", 0, "pam flags") + logFile = flag.String("client-log", "", "the file where to save logs") + argsFile = flag.String("client-args-file", "", "the file where arguments are saved") ) func main() { @@ -52,20 +51,29 @@ func mainFunc() error { return fmt.Errorf("%w: not enough arguments", pam_test.ErrInvalid) } - serverAddressEnv := os.Getenv("AUTHD_PAM_SERVER_ADDRESS") - if serverAddressEnv != "" { - *serverAddress = serverAddressEnv - } - - if serverAddress == nil { + serverAddress := os.Getenv("AUTHD_PAM_SERVER_ADDRESS") + if serverAddress == "" { return fmt.Errorf("%w: no connection provided", pam_test.ErrInvalid) } - mTx, closeFunc, err := newModuleWrapper(*serverAddress) + mTx, closeFunc, err := newModuleWrapper(serverAddress) if err != nil { return fmt.Errorf("%w: can't connect to server: %w", pam_test.ErrInvalid, err) } - defer closeFunc() + + clientReturned := make(chan struct{}) + defer func() { + close(clientReturned) + closeFunc() + }() + + go func() { + select { + case <-clientReturned: + case <-mTx.Context().Done(): + panic(fmt.Sprintf("D-Bus Connection lost: %v", mTx.Context().Err())) + } + }() action, args := args[0], args[1:] diff --git a/pam/integration-tests/cmd/exec-client/modulewrapper.go b/pam/integration-tests/cmd/exec-client/modulewrapper.go index cfe54374a4..1095dca6ce 100644 --- a/pam/integration-tests/cmd/exec-client/modulewrapper.go +++ b/pam/integration-tests/cmd/exec-client/modulewrapper.go @@ -18,33 +18,35 @@ import ( ) type moduleWrapper struct { - pam.ModuleTransaction + dbusmodule.Transaction } -func newModuleWrapper(serverAddress string) (pam.ModuleTransaction, func(), error) { +// Statically Ensure that [moduleWrapper] implements [pam.ModuleTransaction]. +var _ pam.ModuleTransaction = &moduleWrapper{} + +func newModuleWrapper(serverAddress string) (moduleWrapper, func(), error) { mTx, closeFunc, err := dbusmodule.NewTransaction(context.TODO(), serverAddress) - return &moduleWrapper{mTx}, closeFunc, err + return moduleWrapper{mTx}, closeFunc, err } // SimulateClientPanic forces the client to panic with the provided text. -func (m *moduleWrapper) CallUnhandledMethod() error { +func (m moduleWrapper) CallUnhandledMethod() error { method := "com.ubuntu.authd.pam.UnhandledMethod" - tx, _ := m.ModuleTransaction.(*dbusmodule.Transaction) - return tx.BusObject().Call(method, dbus.FlagNoAutoStart).Err + return m.BusObject().Call(method, dbus.FlagNoAutoStart).Err } // SimulateClientPanic forces the client to panic with the provided text. -func (m *moduleWrapper) SimulateClientPanic(text string) { +func (m moduleWrapper) SimulateClientPanic(text string) { panic(text) } // SimulateClientError forces the client to return a new Go error with no PAM type. -func (m *moduleWrapper) SimulateClientError(errorMsg string) error { +func (m moduleWrapper) SimulateClientError(errorMsg string) error { return errors.New(errorMsg) } // SimulateClientSignal sends a signal to the child process. -func (m *moduleWrapper) SimulateClientSignal(sig syscall.Signal, shouldExit bool) { +func (m moduleWrapper) SimulateClientSignal(sig syscall.Signal, shouldExit bool) { pid := os.Getpid() log.Debugf(context.Background(), "Sending signal %v to self pid (%v)", sig, pid) diff --git a/pam/integration-tests/helpers_test.go b/pam/integration-tests/helpers_test.go index 2f22577f6b..3f17706bd4 100644 --- a/pam/integration-tests/helpers_test.go +++ b/pam/integration-tests/helpers_test.go @@ -30,6 +30,8 @@ import ( "gorbe.io/go/osrelease" ) +const authdPamBinaryName = "authd-pam" + var ( authdTestSessionTime = time.Now() authdArtifactsDir string @@ -220,7 +222,7 @@ func buildPAMExecChild(t *testing.T) string { cmd.Args = append(cmd.Args, "-tags=pam_debug") cmd.Env = append(os.Environ(), `CGO_CFLAGS=-O0 -g3`) - authdPam := filepath.Join(t.TempDir(), "authd-pam") + authdPam := filepath.Join(t.TempDir(), authdPamBinaryName) t.Logf("Compiling Exec child at %s", authdPam) t.Log(strings.Join(cmd.Args, " ")) diff --git a/pam/integration-tests/native_test.go b/pam/integration-tests/native_test.go index a0b42953a0..697622c604 100644 --- a/pam/integration-tests/native_test.go +++ b/pam/integration-tests/native_test.go @@ -386,6 +386,15 @@ func TestNativeAuthenticate(t *testing.T) { tape: "authd_stopped", wantSeparateDaemon: true, }, + "Exit_the_pam_client_if_parent_pam_application_is_stopped": { + tape: "pam_app_killed", + tapeVariables: map[string]string{ + "AUTHD_TEST_TAPE_AUTHD_PAM_BINARY_NAME": authdPamBinaryName, + vhsCommandFinalAuthWaitVariable: "", + }, + userSelection: true, + skipRunnerCheck: true, + }, "Error_if_cannot_connect_to_authd": { tape: "connection_error", diff --git a/pam/integration-tests/ssh_test.go b/pam/integration-tests/ssh_test.go index 90d99bd922..b7781a7f11 100644 --- a/pam/integration-tests/ssh_test.go +++ b/pam/integration-tests/ssh_test.go @@ -536,7 +536,7 @@ func createSshdServiceFile(t *testing.T, module, execChild, mkHomeModule, socket "socket=" + socketPath, fmt.Sprintf("connection_timeout=%d", defaultConnectionTimeout), "debug=true", - "logfile=" + os.Stderr.Name(), + "logfile=/dev/stderr", "--exec-debug", } diff --git a/pam/integration-tests/testdata/golden/TestCLIAuthenticate/Exit_the_pam_client_if_parent_pam_application_is_stopped b/pam/integration-tests/testdata/golden/TestCLIAuthenticate/Exit_the_pam_client_if_parent_pam_application_is_stopped new file mode 100644 index 0000000000..12c050edd0 --- /dev/null +++ b/pam/integration-tests/testdata/golden/TestCLIAuthenticate/Exit_the_pam_client_if_parent_pam_application_is_stopped @@ -0,0 +1,15 @@ +> +──────────────────────────────────────────────────────────────────────────────── +> +──────────────────────────────────────────────────────────────────────────────── +> +──────────────────────────────────────────────────────────────────────────────── +./pam_authd login socket=${AUTHD_TEST_TAPE_SOCKET} +Terminateduser name +> +──────────────────────────────────────────────────────────────────────────────── +> +──────────────────────────────────────────────────────────────────────────────── +Parent Process killed +> +──────────────────────────────────────────────────────────────────────────────── diff --git a/pam/integration-tests/testdata/golden/TestNativeAuthenticate/Exit_the_pam_client_if_parent_pam_application_is_stopped b/pam/integration-tests/testdata/golden/TestNativeAuthenticate/Exit_the_pam_client_if_parent_pam_application_is_stopped new file mode 100644 index 0000000000..7007a7d43e --- /dev/null +++ b/pam/integration-tests/testdata/golden/TestNativeAuthenticate/Exit_the_pam_client_if_parent_pam_application_is_stopped @@ -0,0 +1,15 @@ +> +──────────────────────────────────────────────────────────────────────────────── +> +──────────────────────────────────────────────────────────────────────────────── +> +──────────────────────────────────────────────────────────────────────────────── +./pam_authd login socket=${AUTHD_TEST_TAPE_SOCKET} force_native_client=true +Terminated +> +──────────────────────────────────────────────────────────────────────────────── +> +──────────────────────────────────────────────────────────────────────────────── +Parent Process killed +> +──────────────────────────────────────────────────────────────────────────────── diff --git a/pam/integration-tests/testdata/tapes/cli/pam_app_killed.tape b/pam/integration-tests/testdata/tapes/cli/pam_app_killed.tape new file mode 100644 index 0000000000..e708a38605 --- /dev/null +++ b/pam/integration-tests/testdata/tapes/cli/pam_app_killed.tape @@ -0,0 +1,34 @@ +Hide +TypeInPrompt+Shell "${AUTHD_TEST_TAPE_COMMAND} &" +Enter +Wait + +Type "clear" +Enter +Wait + +# Find out the PID of the authd PAM binary process, so that we can kill it later. +Type `while true; do sleep 0.2 && child_pid=$(pgrep -f "${AUTHD_TEST_TAPE_AUTHD_PAM_BINARY_NAME} .* socket=${AUTHD_TEST_TAPE_SOCKET}"); [ -n "${child_pid}" ] && break; done` +Enter +Wait + +Type "clear" +Enter +Wait + +# Get back to the PAM application to foreground, while killing it in the background. +Type `"$0" -c 'sleep 5 && pkill -f "${AUTHD_TEST_TAPE_COMMAND}" &' && clear && fg` +Enter +Wait /Username: user name\n/ +Wait +Show + +ClearTerminal + +Hide +# Ensure that the child PID has been killed when destroying the parent application. +Type `if kill -0 "${child_pid}" &>/dev/null; then clear && echo "Parent process still alive"; else clear && echo Parent Process killed; fi` +Enter +Wait /Parent Process killed\n/ +Wait ${AUTHD_TEST_TAPE_COMMAND_AUTH_FINAL_WAIT} +Show diff --git a/pam/integration-tests/testdata/tapes/native/pam_app_killed.tape b/pam/integration-tests/testdata/tapes/native/pam_app_killed.tape new file mode 100644 index 0000000000..ca372715a5 --- /dev/null +++ b/pam/integration-tests/testdata/tapes/native/pam_app_killed.tape @@ -0,0 +1,33 @@ +Hide +TypeInPrompt+Shell "${AUTHD_TEST_TAPE_COMMAND} &" +Enter +Wait + +Type "clear" +Enter +Wait + +# Find out the PID of the authd PAM binary process, so that we can kill it later. +Type `while true; do sleep 0.2 && child_pid=$(pgrep -f "${AUTHD_TEST_TAPE_AUTHD_PAM_BINARY_NAME} .* socket=${AUTHD_TEST_TAPE_SOCKET}"); [ -n "${child_pid}" ] && break; done` +Enter +Wait + +Type "clear" +Enter +Wait + +# Get back to the PAM application to foreground, while killing it in the background. +Type `"$0" -c 'sleep 5 && pkill -f "${AUTHD_TEST_TAPE_COMMAND}" &' && clear && fg` +Enter +Wait +Show + +ClearTerminal + +Hide +# Ensure that the child PID has been killed when destroying the parent application. +Type `if kill -0 "${child_pid}" &>/dev/null; then clear && echo "Parent process still alive"; else clear && echo Parent Process killed; fi` +Enter +Wait /Parent Process killed\n/ +Wait ${AUTHD_TEST_TAPE_COMMAND_AUTH_FINAL_WAIT} +Show diff --git a/pam/internal/dbusmodule/transaction.go b/pam/internal/dbusmodule/transaction.go index 7e0830910b..fcbb7bb90c 100644 --- a/pam/internal/dbusmodule/transaction.go +++ b/pam/internal/dbusmodule/transaction.go @@ -15,7 +15,8 @@ import ( // Transaction is a [pam.Transaction] with dbus support. type Transaction struct { - obj dbus.BusObject + conn *dbus.Conn + obj dbus.BusObject } type options struct { @@ -38,10 +39,13 @@ const objectPath = "/com/ubuntu/authd/pam" // FIXME: dbus.Variant does not support maybe types, so we're using a variant string instead. const variantNothing = "<@mv nothing>" +// Statically Ensure that [Transaction] implements [pam.ModuleTransaction]. +var _ pam.ModuleTransaction = &Transaction{} + // NewTransaction creates a new [dbusmodule.Transaction] with the provided connection. // A [pam.ModuleTransaction] implementation is returned together with a cleanup function that // should be called to release the connection. -func NewTransaction(ctx context.Context, address string, o ...TransactionOptions) (tx pam.ModuleTransaction, cleanup func(), err error) { +func NewTransaction(ctx context.Context, address string, o ...TransactionOptions) (tx Transaction, cleanup func(), err error) { opts := options{} for _, f := range o { f(&opts) @@ -50,31 +54,38 @@ func NewTransaction(ctx context.Context, address string, o ...TransactionOptions log.Debugf(context.TODO(), "Connecting to %s", address) conn, err := dbus.Dial(address, dbus.WithContext(ctx)) if err != nil { - return nil, nil, err + return Transaction{}, nil, err } cleanup = func() { conn.Close() } if err = conn.Auth(nil); err != nil { cleanup() - return nil, nil, err + return Transaction{}, nil, err } if opts.isSharedConnection { if err = conn.Hello(); err != nil { cleanup() - return nil, nil, err + return Transaction{}, nil, err } } obj := conn.Object(ifaceName, objectPath) - return &Transaction{obj: obj}, cleanup, nil + + return Transaction{conn: conn, obj: obj}, cleanup, nil } // BusObject gets the DBus object. -func (tx *Transaction) BusObject() dbus.BusObject { +func (tx Transaction) BusObject() dbus.BusObject { return tx.obj } +// Context returns the context associated with the connection. The +// context will be cancelled when the connection is closed. +func (tx Transaction) Context() context.Context { + return tx.conn.Context() +} + // SetData allows to save any value in the module data that is preserved // during the whole time the module is loaded. -func (tx *Transaction) SetData(key string, data any) error { +func (tx Transaction) SetData(key string, data any) error { if data == nil { return dbusUnsetter(tx.obj, "UnsetData", key) } @@ -83,7 +94,7 @@ func (tx *Transaction) SetData(key string, data any) error { // GetData allows to get any value from the module data saved using SetData // that is preserved across the whole time the module is loaded. -func (tx *Transaction) GetData(key string) (any, error) { +func (tx Transaction) GetData(key string) (any, error) { // See the FIXME on variantNothing, all this should be managed by variant. data, err := dbusGetter[any](tx.obj, "GetData", key) if data == variantNothing { @@ -93,12 +104,12 @@ func (tx *Transaction) GetData(key string) (any, error) { } // SetItem sets a PAM item. -func (tx *Transaction) SetItem(item pam.Item, value string) error { +func (tx Transaction) SetItem(item pam.Item, value string) error { return dbusSetter(tx.obj, "SetItem", item, value) } // GetItem retrieves a PAM item. -func (tx *Transaction) GetItem(item pam.Item) (string, error) { +func (tx Transaction) GetItem(item pam.Item) (string, error) { return dbusGetter[string](tx.obj, "GetItem", item) } @@ -107,7 +118,7 @@ func (tx *Transaction) GetItem(item pam.Item) (string, error) { // NAME=value will set a variable to a value. // NAME= will set a variable to an empty value. // NAME (without an "=") will delete a variable. -func (tx *Transaction) PutEnv(nameVal string) error { +func (tx Transaction) PutEnv(nameVal string) error { if !strings.Contains(nameVal, "=") { return dbusUnsetter(tx.obj, "UnsetEnv", nameVal) } @@ -116,7 +127,7 @@ func (tx *Transaction) PutEnv(nameVal string) error { } // GetEnv is used to retrieve a PAM environment variable. -func (tx *Transaction) GetEnv(name string) string { +func (tx Transaction) GetEnv(name string) string { env, err := dbusGetter[string](tx.obj, "GetEnv", name) if err != nil { return "" @@ -125,7 +136,7 @@ func (tx *Transaction) GetEnv(name string) string { } // GetEnvList returns a copy of the PAM environment as a map. -func (tx *Transaction) GetEnvList() (map[string]string, error) { +func (tx Transaction) GetEnvList() (map[string]string, error) { var r int var envMap map[string]string method := fmt.Sprintf("%s.GetEnvList", ifaceName) @@ -143,7 +154,7 @@ func (tx *Transaction) GetEnvList() (map[string]string, error) { // GetUser is similar to GetItem(User), but it would start a conversation if // no user is currently set in PAM. -func (tx *Transaction) GetUser(prompt string) (string, error) { +func (tx Transaction) GetUser(prompt string) (string, error) { user, err := tx.GetItem(pam.User) if err != nil { return "", err @@ -162,7 +173,7 @@ func (tx *Transaction) GetUser(prompt string) (string, error) { // StartStringConv starts a text-based conversation using the provided style // and prompt. -func (tx *Transaction) StartStringConv(style pam.Style, prompt string) ( +func (tx Transaction) StartStringConv(style pam.Style, prompt string) ( pam.StringConvResponse, error) { res, err := tx.StartConv(pam.NewStringConvRequest(style, prompt)) if err != nil { @@ -177,19 +188,19 @@ func (tx *Transaction) StartStringConv(style pam.Style, prompt string) ( } // StartStringConvf allows to start string conversation with formatting support. -func (tx *Transaction) StartStringConvf(style pam.Style, format string, args ...interface{}) ( +func (tx Transaction) StartStringConvf(style pam.Style, format string, args ...interface{}) ( pam.StringConvResponse, error) { return tx.StartStringConv(style, fmt.Sprintf(format, args...)) } // StartBinaryConv starts a binary conversation using the provided bytes. -func (tx *Transaction) StartBinaryConv(bytes []byte) ( +func (tx Transaction) StartBinaryConv(bytes []byte) ( pam.BinaryConvResponse, error) { return nil, fmt.Errorf("%w: binary conversations are not supported", pam.ErrConv) } // StartConv initiates a PAM conversation using the provided ConvRequest. -func (tx *Transaction) StartConv(req pam.ConvRequest) ( +func (tx Transaction) StartConv(req pam.ConvRequest) ( pam.ConvResponse, error) { resp, err := tx.StartConvMulti([]pam.ConvRequest{req}) if err != nil { @@ -201,7 +212,7 @@ func (tx *Transaction) StartConv(req pam.ConvRequest) ( return resp[0], nil } -func (tx *Transaction) handleStringRequest(req pam.StringConvRequest) (pam.StringConvResponse, error) { +func (tx Transaction) handleStringRequest(req pam.StringConvRequest) (pam.StringConvResponse, error) { if req.Style() == pam.BinaryPrompt { return nil, fmt.Errorf("%w: binary style is not supported", pam.ErrConv) } @@ -225,7 +236,7 @@ func (tx *Transaction) handleStringRequest(req pam.StringConvRequest) (pam.Strin } // StartConvMulti initiates a PAM conversation with multiple ConvRequest's. -func (tx *Transaction) StartConvMulti(requests []pam.ConvRequest) ( +func (tx Transaction) StartConvMulti(requests []pam.ConvRequest) ( responses []pam.ConvResponse, err error) { defer decorate.OnError(&err, "%v", err) @@ -251,7 +262,7 @@ func (tx *Transaction) StartConvMulti(requests []pam.ConvRequest) ( } // InvokeHandler is called by the C code to invoke the proper handler. -func (tx *Transaction) InvokeHandler(handler pam.ModuleHandlerFunc, +func (tx Transaction) InvokeHandler(handler pam.ModuleHandlerFunc, flags pam.Flags, args []string) error { return pam.ErrAbort } diff --git a/pam/internal/dbusmodule/transaction_test.go b/pam/internal/dbusmodule/transaction_test.go index 34e21d345c..555ae69182 100644 --- a/pam/internal/dbusmodule/transaction_test.go +++ b/pam/internal/dbusmodule/transaction_test.go @@ -22,7 +22,7 @@ func TestTransactionConnectionError(t *testing.T) { t.Parallel() tx, cleanup, err := dbusmodule.NewTransaction(context.TODO(), "invalid-address") - require.Nil(t, tx, "Transaction must be unset") + require.Zero(t, tx, "Transaction must be unset") require.Nil(t, cleanup, "Cleanup func must be unset") require.NotNil(t, err, "Error must be set") } @@ -31,9 +31,7 @@ func TestTransactionHandler(t *testing.T) { t.Parallel() tx, _ := prepareTransaction(t, nil) - dbusTx, ok := tx.(*dbusmodule.Transaction) - require.True(t, ok, "Transaction should be a dbus module Transaction") - require.ErrorIs(t, dbusTx.InvokeHandler(nil, 0, nil), pam.ErrAbort) + require.ErrorIs(t, tx.InvokeHandler(nil, 0, nil), pam.ErrAbort) } func TestTransactionSetEnv(t *testing.T) { @@ -604,6 +602,19 @@ func TestStartBinaryConv(t *testing.T) { } } +func TestDisconnectionHandler(t *testing.T) { + address, _, cleanup := prepareTestServerWithCleanup(t, nil) + tx, txCleanup, err := dbusmodule.NewTransaction(context.TODO(), address, + dbusmodule.WithSharedConnection(true)) + require.NoError(t, err, "Setup: Can't connect to %s", address) + t.Cleanup(txCleanup) + + require.NoError(t, tx.Context().Err(), "Context must not be cancelled") + cleanup() + <-tx.Context().Done() + require.ErrorIs(t, tx.Context().Err(), context.Canceled, "Context must be cancelled") +} + type methodCallExpectations struct { methodReturns []methodReturn wantMethodCalls []methodCall @@ -629,13 +640,14 @@ func requireDbusErrorIs(t *testing.T, err error, wantError error) { } } -func prepareTransaction(t *testing.T, expectedReturns []methodReturn) (pam.ModuleTransaction, *testServer) { +func prepareTransaction(t *testing.T, expectedReturns []methodReturn) (dbusmodule.Transaction, *testServer) { t.Helper() address, obj := prepareTestServer(t, expectedReturns) tx, cleanup, err := dbusmodule.NewTransaction(context.TODO(), address, dbusmodule.WithSharedConnection(true)) require.NoError(t, err, "Setup: Can't connect to %s", address) + t.Cleanup(func() { <-tx.Context().Done() }) t.Cleanup(cleanup) t.Logf("Using bus at address %s", address) @@ -646,6 +658,13 @@ func prepareTransaction(t *testing.T, expectedReturns []methodReturn) (pam.Modul func prepareTestServer(t *testing.T, expectedReturns []methodReturn) (string, *testServer) { t.Helper() + obj, conn, _ := prepareTestServerWithCleanup(t, expectedReturns) + return obj, conn +} + +func prepareTestServerWithCleanup(t *testing.T, expectedReturns []methodReturn) (address string, obj *testServer, cleanup func()) { + t.Helper() + address, cleanup, err := testutils.StartBusMock() require.NoError(t, err, "Setup: Creating mock bus failed") t.Cleanup(cleanup) @@ -658,7 +677,7 @@ func prepareTestServer(t *testing.T, expectedReturns []methodReturn) (string, *t } }) - obj := &testServer{t: t, mu: &sync.Mutex{}, returns: expectedReturns} + obj = &testServer{t: t, mu: &sync.Mutex{}, returns: expectedReturns} err = conn.Export(obj, objectPath, ifaceName) require.NoError(t, err, "Setup: Exporting test server object to bus failed") @@ -667,7 +686,7 @@ func prepareTestServer(t *testing.T, expectedReturns []methodReturn) (string, *t require.Equal(t, reply, dbus.RequestNameReplyPrimaryOwner, "Setup: can't get dbus name") - return address, obj + return address, obj, cleanup } func TestMain(m *testing.M) { diff --git a/pam/main-exec.go b/pam/main-exec.go index 5720c1586c..3504229b4d 100644 --- a/pam/main-exec.go +++ b/pam/main-exec.go @@ -7,7 +7,6 @@ import ( "fmt" "os" "runtime" - "time" "github.com/msteinert/pam/v2" "github.com/ubuntu/authd/log" @@ -15,9 +14,7 @@ import ( ) var ( - pamFlags = flag.Int64("flags", 0, "pam flags") - serverAddress = flag.String("server-address", "", "the dbus connection to use to communicate with module") - timeout = flag.Int64("timeout", 120, "timeout for the server connection (in seconds)") + pamFlags = flag.Int64("flags", 0, "pam flags") ) func init() { @@ -37,23 +34,33 @@ func mainFunc() error { return errors.New("not enough arguments") } - serverAddressEnv := os.Getenv("AUTHD_PAM_SERVER_ADDRESS") - if serverAddressEnv != "" { - *serverAddress = serverAddressEnv - } - - if serverAddress == nil { + serverAddress := os.Getenv("AUTHD_PAM_SERVER_ADDRESS") + if serverAddress == "" { return fmt.Errorf("%w: no connection provided", pam.ErrSystem) } - ctx, cancel := context.WithTimeout(context.TODO(), time.Duration(*timeout)*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mTx, closeFunc, err := dbusmodule.NewTransaction(ctx, *serverAddress) + + mTx, closeFunc, err := dbusmodule.NewTransaction(ctx, serverAddress) if err != nil { return fmt.Errorf("%w: can't connect to server: %w", pam.ErrSystem, err) } defer closeFunc() + actionDone := make(chan struct{}) + defer close(actionDone) + + go func() { + select { + case <-actionDone: + case <-mTx.Context().Done(): + log.Warningf(context.Background(), "D-Bus Connection closed: %v", + mTx.Context().Err()) + os.Exit(255) + } + }() + action, args := args[0], args[1:] flags := pam.Flags(0) @@ -61,6 +68,8 @@ func mainFunc() error { flags = pam.Flags(*pamFlags) } + log.Debugf(context.Background(), "Starting action %q (%v)", action, flags) + switch action { case "authenticate": return module.Authenticate(mTx, flags, args)