From 4b982c872840fd04e79827fdb8532bddc392fbef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 25 Sep 2023 18:46:03 +0200 Subject: [PATCH 01/24] transaction: Add a transaction base type to define more transaction kinds A pam handler can be used both by a module and by an Application, go-pam is meant to be used in the application side right now, but it can be easily changed to also create modules. This is the prerequisite work to support this. --- transaction.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/transaction.go b/transaction.go index dc2d378..161cbaa 100644 --- a/transaction.go +++ b/transaction.go @@ -124,12 +124,20 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { return C.CString(r), success } -// Transaction is the application's handle for a PAM transaction. -type Transaction struct { +// transactionBase is a handler for a PAM transaction that can be used to +// group the operations that can be performed both by the application and the +// module side +type transactionBase struct { handle *C.pam_handle_t - conv *C.struct_pam_conv lastStatus atomic.Int32 - c cgo.Handle +} + +// Transaction is the application's handle for a PAM transaction. +type Transaction struct { + transactionBase + + conv *C.struct_pam_conv + c cgo.Handle } // End cleans up the PAM handle and deletes the callback function. @@ -146,7 +154,7 @@ func (t *Transaction) End() error { } // Allows to call pam functions managing return status -func (t *Transaction) handlePamStatus(cStatus C.int) error { +func (t *transactionBase) handlePamStatus(cStatus C.int) error { t.lastStatus.Store(int32(cStatus)) if status := Error(cStatus); status != success { return status @@ -268,14 +276,14 @@ const ( ) // SetItem sets a PAM information item. -func (t *Transaction) SetItem(i Item, item string) error { +func (t *transactionBase) SetItem(i Item, item string) error { cs := unsafe.Pointer(C.CString(item)) defer C.free(cs) return t.handlePamStatus(C.pam_set_item(t.handle, C.int(i), cs)) } // GetItem retrieves a PAM information item. -func (t *Transaction) GetItem(i Item) (string, error) { +func (t *transactionBase) GetItem(i Item) (string, error) { var s unsafe.Pointer err := t.handlePamStatus(C.pam_get_item(t.handle, C.int(i), &s)) if err != nil { @@ -360,14 +368,14 @@ func (t *Transaction) CloseSession(f Flags) error { // NAME=value will set a variable to a value. // NAME= will set a variable to an empty value. // NAME (without an "=") will delete a variable. -func (t *Transaction) PutEnv(nameval string) error { +func (t *transactionBase) PutEnv(nameval string) error { cs := C.CString(nameval) defer C.free(unsafe.Pointer(cs)) return t.handlePamStatus(C.pam_putenv(t.handle, cs)) } // GetEnv is used to retrieve a PAM environment variable. -func (t *Transaction) GetEnv(name string) string { +func (t *transactionBase) GetEnv(name string) string { cs := C.CString(name) defer C.free(unsafe.Pointer(cs)) value := C.pam_getenv(t.handle, cs) @@ -382,7 +390,7 @@ func next(p **C.char) **C.char { } // GetEnvList returns a copy of the PAM environment as a map. -func (t *Transaction) GetEnvList() (map[string]string, error) { +func (t *transactionBase) GetEnvList() (map[string]string, error) { env := make(map[string]string) p := C.pam_getenvlist(t.handle) if p == nil { From 8cf5c51c31bb9773a69d60dca3b65df955cfe458 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 25 Sep 2023 18:52:56 +0200 Subject: [PATCH 02/24] transaction: Add ModuleTransaction type and ModuleHandler interface This allows to easily define go-handlers for module operations. We need to expose few more types externally so that it's possible to create the module transaction handler and return specific transaction errors --- module-transaction.go | 29 ++++++++ module-transaction_test.go | 131 +++++++++++++++++++++++++++++++++++++ transaction.go | 6 +- 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 module-transaction.go create mode 100644 module-transaction_test.go diff --git a/module-transaction.go b/module-transaction.go new file mode 100644 index 0000000..9a3aae6 --- /dev/null +++ b/module-transaction.go @@ -0,0 +1,29 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +// ModuleTransaction is an interface that a pam module transaction +// should implement. +type ModuleTransaction interface { + SetItem(Item, string) error + GetItem(Item) (string, error) + PutEnv(nameVal string) error + GetEnv(name string) string + GetEnvList() (map[string]string, error) +} + +// ModuleHandlerFunc is a function type used by the ModuleHandler. +type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error + +// ModuleTransaction is the module-side handle for a PAM transaction. +type moduleTransaction = transactionBase + +// ModuleHandler is an interface for objects that can be used to create +// PAM modules from go. +type ModuleHandler interface { + AcctMgmt(ModuleTransaction, Flags, []string) error + Authenticate(ModuleTransaction, Flags, []string) error + ChangeAuthTok(ModuleTransaction, Flags, []string) error + CloseSession(ModuleTransaction, Flags, []string) error + OpenSession(ModuleTransaction, Flags, []string) error + SetCred(ModuleTransaction, Flags, []string) error +} diff --git a/module-transaction_test.go b/module-transaction_test.go new file mode 100644 index 0000000..d5c7533 --- /dev/null +++ b/module-transaction_test.go @@ -0,0 +1,131 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +import ( + "errors" + "reflect" + "testing" +) + +func Test_NewNullModuleTransaction(t *testing.T) { + t.Parallel() + mt := moduleTransaction{} + + if mt.handle != nil { + t.Fatalf("unexpected handle value: %v", mt.handle) + } + + if s := Error(mt.lastStatus.Load()); s != success { + t.Fatalf("unexpected status: %v", s) + } + + tests := map[string]struct { + testFunc func(t *testing.T) (any, error) + expectedError error + ignoreError bool + }{ + "GetItem": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetItem(Service) + }, + }, + "SetItem": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetItem(Service, "foo") + }, + }, + "GetEnv": { + ignoreError: true, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetEnv("foo"), nil + }, + }, + "PutEnv": { + expectedError: ErrAbort, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.PutEnv("foo=bar") + }, + }, + "GetEnvList": { + expectedError: ErrBuf, + testFunc: func(t *testing.T) (any, error) { + t.Helper() + list, err := mt.GetEnvList() + if len(list) > 0 { + t.Fatalf("unexpected list: %v", list) + } + return nil, err + }, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name+"-error-check", func(t *testing.T) { + t.Parallel() + data, err := tc.testFunc(t) + + switch d := data.(type) { + case string: + if d != "" { + t.Fatalf("empty value was expected, got %s", d) + } + case interface{}: + if !reflect.ValueOf(d).IsNil() { + t.Fatalf("nil value was expected, got %v", d) + } + default: + if d != nil { + t.Fatalf("nil value was expected, got %v", d) + } + } + + if tc.ignoreError { + return + } + if err == nil { + t.Fatal("error was expected, but got none") + } + + var expectedError error = ErrSystem + if tc.expectedError != nil { + expectedError = tc.expectedError + } + + if !errors.Is(err, expectedError) { + t.Fatalf("status %v was expected, but got %v", + expectedError, err) + } + }) + } + + for name, tc := range tests { + // These can't be parallel - we test a private value that is not thread safe + t.Run(name+"-lastStatus-check", func(t *testing.T) { + mt.lastStatus.Store(99999) + _, err := tc.testFunc(t) + status := Error(mt.lastStatus.Load()) + + if tc.ignoreError { + return + } + if err == nil { + t.Fatal("error was expected, but got none") + } + + expectedStatus := ErrSystem + if tc.expectedError != nil { + errors.As(err, &expectedStatus) + } + + if status != expectedStatus { + t.Fatalf("status %v was expected, but got %d", + expectedStatus, status) + } + }) + } +} diff --git a/transaction.go b/transaction.go index 161cbaa..3ca2771 100644 --- a/transaction.go +++ b/transaction.go @@ -124,11 +124,15 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { return C.CString(r), success } +// NativeHandle is the type of the native PAM handle for a transaction so that +// it can be exported +type NativeHandle = *C.pam_handle_t + // transactionBase is a handler for a PAM transaction that can be used to // group the operations that can be performed both by the application and the // module side type transactionBase struct { - handle *C.pam_handle_t + handle NativeHandle lastStatus atomic.Int32 } From 22b9e813d58f8d00d67d34c9a6e916c966e32d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Tue, 10 Oct 2023 06:04:49 +0200 Subject: [PATCH 03/24] transaction: Properly handle nil bytes in binary transactions If returned binaries are nil, we should pass them as nil and not as an empty bytes array. --- transaction.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transaction.go b/transaction.go index 3ca2771..10c1a55 100644 --- a/transaction.go +++ b/transaction.go @@ -105,6 +105,9 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { if err != nil { return nil, C.int(ErrConv) } + if bytes == nil { + return nil, success + } return (*C.char)(C.CBytes(bytes)), success } handler = cb From 959de379760aab19867f3aadcd3032f07fb68409 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 25 Sep 2023 19:31:46 +0200 Subject: [PATCH 04/24] pam-moduler: Add first implementation of a Go PAM Module generator A PAM module can be generated using pam-moduler and implemented fully in go without having to manually deal with the C setup. Module can be compiled using go generate, so go:generate directives can be used to make this process automatic, with a single go generate call as shown in the example. --- .codecov.yml | 3 + .github/workflows/test.yaml | 6 + .gitignore | 2 + README.md | 118 ++++++++++++++ cmd/pam-moduler/moduler.go | 305 +++++++++++++++++++++++++++++++++++ example-module/module.go | 50 ++++++ example-module/pam_module.go | 96 +++++++++++ module-transaction.go | 10 +- 8 files changed, 589 insertions(+), 1 deletion(-) create mode 100644 .codecov.yml create mode 100644 cmd/pam-moduler/moduler.go create mode 100644 example-module/module.go create mode 100644 example-module/pam_module.go diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..5066aeb --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,3 @@ +ignore: + # Ignore pam-moduler generated files + - "**/pam_module.go" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index ff083ab..a4007dc 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,6 +29,12 @@ jobs: run: sudo useradd -d /tmp/test -p '$1$Qd8H95T5$RYSZQeoFbEB.gS19zS99A0' -s /bin/false test - name: Checkout code uses: actions/checkout@v4 + - name: Generate example module + run: | + rm -f example-module/pam_go.so + go generate -C example-module -v + test -e example-module/pam_go.so + git diff --exit-code example-module - name: Test run: sudo go test -v -cover -coverprofile=coverage.out ./... - name: Upload coverage reports to Codecov diff --git a/.gitignore b/.gitignore index 2d83068..a2f238d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ coverage.out +example-module/*.so +example-module/*.h diff --git a/README.md b/README.md index fab308e..8738e00 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,124 @@ This is a Go wrapper for the PAM application API. +## Module support + +Go PAM can also used to create PAM modules in a simple way, using the go. + +The code can be generated using [pam-moduler](cmd/pam-moduler/moduler.go) and +an example how to use it using `go generate` create them is available as an +[example module](example-module/module.go). + +### Modules and PAM applications + +The modules generated with go can be used by any PAM application, however there +are some caveats, in fact a Go shared library could misbehave when loaded +improperly. In particular if a Go shared library is loaded and then the program +`fork`s, the library will have an undefined behavior. + +This is the case of SSHd that loads a pam library before forking, making any +go PAM library to make it hang. + +To solve this case, we can use a little workaround: to ensure that the go +library is loaded only after the program has forked, we can just `dload` it once +a PAM library is called, in this way go code will be loaded only after that the +PAM application has `fork`'ed. + +To do this, we can use a very simple wrapper written in C: + +```c +#include +#include +#include +#include + +typedef int (*PamHandler)(pam_handle_t *, + int flags, + int argc, + const char **argv); + +static void +on_go_module_removed (pam_handle_t *pamh, + void *go_module, + int error_status) +{ + dlclose (go_module); +} + +static void * +load_module (pam_handle_t *pamh, + const char *module_path) +{ + void *go_module; + + if (pam_get_data (pamh, "go-module", (const void **) &go_module) == PAM_SUCCESS) + return go_module; + + go_module = dlopen (module_path, RTLD_LAZY); + if (!go_module) + return NULL; + + pam_set_data (pamh, "go-module", go_module, on_go_module_removed); + + return go_module; +} + +static inline int +call_pam_function (pam_handle_t *pamh, + const char *function, + int flags, + int argc, + const char **argv) +{ + char module_path[PATH_MAX] = {0}; + const char *sub_module; + PamHandler func; + void *go_module; + + if (argc < 1) + { + pam_error (pamh, "%s: no module provided", function); + return PAM_MODULE_UNKNOWN; + } + + sub_module = argv[0]; + argc -= 1; + argv = (argc == 0) ? NULL : &argv[1]; + + strncpy (module_path, sub_module, PATH_MAX - 1); + + go_module = load_module (pamh, module_path); + if (!go_module) + { + pam_error (pamh, "Impossible to load module %s", module_path); + return PAM_OPEN_ERR; + } + + *(void **) (&func) = dlsym (go_module, function); + if (!func) + { + pam_error (pamh, "Symbol %s not found in %s", function, module_path); + return PAM_OPEN_ERR; + } + + return func (pamh, flags, argc, argv); +} + +#define DEFINE_PAM_WRAPPER(name) \ + PAM_EXTERN int \ + (pam_sm_ ## name) (pam_handle_t * pamh, int flags, int argc, const char **argv) \ + { \ + return call_pam_function (pamh, "pam_sm_" #name, flags, argc, argv); \ + } + +DEFINE_PAM_WRAPPER (acct_mgmt) +DEFINE_PAM_WRAPPER (authenticate) +DEFINE_PAM_WRAPPER (chauthtok) +DEFINE_PAM_WRAPPER (close_session) +DEFINE_PAM_WRAPPER (open_session) +DEFINE_PAM_WRAPPER (setcred) +``` + ## Testing To run the full suite, the tests must be run as the root user. To setup your diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go new file mode 100644 index 0000000..b95195d --- /dev/null +++ b/cmd/pam-moduler/moduler.go @@ -0,0 +1,305 @@ +// pam-moduler is a tool to automate the creation of PAM Modules in go +// +// The file is created in the same package and directory as the package that +// creates the module +// +// The module implementation should define a pamModuleHandler object that +// implements the pam.ModuleHandler type and that will be used for each callback +// +// Otherwise it's possible to provide a typename from command line that will +// be used for this purpose +// +// For example: +// +// //go:generate go run github.com/msteinert/pam/pam-moduler +// //go:generate go generate --skip="pam_module" +// package main +// +// import "github.com/msteinert/pam/v2" +// +// type ExampleHandler struct{} +// var pamModuleHandler pam.ModuleHandler = &ExampleHandler{} +// +// func (h *ExampleHandler) AcctMgmt(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) Authenticate(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) ChangeAuthTok(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) OpenSession(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) CloseSession(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } +// +// func (h *ExampleHandler) SetCred(pam.ModuleTransaction, pam.Flags, []string) error { +// return nil +// } + +// Package main provides the module shared library. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "log" + "os" + "path/filepath" + "strings" +) + +const toolName = "pam-moduler" + +var ( + output = flag.String("output", "", "output file name; default srcdir/pam_module.go") + libName = flag.String("libname", "", "output library name; default pam_go.so") + typeName = flag.String("type", "", "type name to be used as pam.ModuleHandler") + buildTags = flag.String("tags", "", "build tags expression to append to use in the go:build directive") + skipGenerator = flag.Bool("no-generator", false, "whether to add go:generator directives to the generated source") + moduleBuildFlags = flag.String("build-flags", "", "comma-separated list of go build flags to use when generating the module") + moduleBuildTags = flag.String("build-tags", "", "comma-separated list of build tags to use when generating the module") + noMain = flag.Bool("no-main", false, "whether to add an empty main to generated file") +) + +// Usage is a replacement usage function for the flags package. +func Usage() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", toolName) + fmt.Fprintf(os.Stderr, "\t%s [flags] [-output O] [-libname pam_go] [-type N]\n", toolName) + flag.PrintDefaults() +} + +func main() { + log.SetFlags(0) + log.SetPrefix(toolName + ": ") + flag.Usage = Usage + flag.Parse() + + if *skipGenerator { + if *libName != "" { + fmt.Fprintf(os.Stderr, + "Generator directives disabled, libname will have no effect\n") + } + if *moduleBuildTags != "" { + fmt.Fprintf(os.Stderr, + "Generator directives disabled, build-tags will have no effect\n") + } + if *moduleBuildFlags != "" { + fmt.Fprintf(os.Stderr, + "Generator directives disabled, build-flags will have no effect\n") + } + } + + lib := *libName + if lib == "" { + lib = "pam_go" + } else { + lib, _ = strings.CutSuffix(lib, ".so") + lib, _ = strings.CutPrefix(lib, "lib") + } + + outputName, _ := strings.CutSuffix(*output, ".go") + if outputName == "" { + baseName := "pam_module" + outputName = filepath.Join(".", strings.ToLower(baseName)) + } + outputName = outputName + ".go" + + var tags string + if *buildTags != "" { + tags = *buildTags + } + + var generateTags []string + if len(*moduleBuildTags) > 0 { + generateTags = strings.Split(*moduleBuildTags, ",") + } + + var buildFlags []string + if *moduleBuildFlags != "" { + buildFlags = strings.Split(*moduleBuildFlags, ",") + } + + g := Generator{ + outputName: outputName, + libName: lib, + tags: tags, + buildFlags: buildFlags, + generateTags: generateTags, + noMain: *noMain, + typeName: *typeName, + } + + // Print the header and package clause. + g.printf("// Code generated by \"%s %s\"; DO NOT EDIT.\n", + toolName, strings.Join(os.Args[1:], " ")) + g.printf("\n") + + // Generate the code + g.generate() + + // Format the output. + src := g.format() + + // Write to file. + err := os.WriteFile(outputName, src, 0600) + if err != nil { + log.Fatalf("writing output: %s", err) + } +} + +// Generator holds the state of the analysis. Primarily used to buffer +// the output for format.Source. +type Generator struct { + buf bytes.Buffer // Accumulated output. + + libName string + outputName string + typeName string + tags string + generateTags []string + buildFlags []string + noMain bool +} + +func (g *Generator) printf(format string, args ...interface{}) { + fmt.Fprintf(&g.buf, format, args...) +} + +// generate produces the String method for the named type. +func (g *Generator) generate() { + if g.tags != "" { + g.printf("//go:build %s\n", g.tags) + } + + var buildTagsArg string + if len(g.generateTags) > 0 { + buildTagsArg = fmt.Sprintf("-tags %s", strings.Join(g.generateTags, ",")) + } + + // We use a slice since we want to keep order, for reproducible builds. + vFuncs := []struct { + cName string + goName string + }{ + {"authenticate", "Authenticate"}, + {"setcred", "SetCred"}, + {"acct_mgmt", "AcctMgmt"}, + {"open_session", "OpenSession"}, + {"close_session", "CloseSession"}, + {"chauthtok", "ChangeAuthTok"}, + } + + g.printf(`//go:generate go build "-ldflags=-extldflags -Wl,-soname,%[2]s.so" `+ + `-buildmode=c-shared -o %[2]s.so %[3]s %[4]s +`, + g.outputName, g.libName, buildTagsArg, strings.Join(g.buildFlags, " ")) + + g.printf(` +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "os" + "unsafe" + "github.com/msteinert/pam/v2" +) +`) + + if g.typeName != "" { + g.printf(` +var pamModuleHandler pam.ModuleHandler = &%[1]s{} +`, g.typeName) + } else { + g.printf(` +// Do a typecheck at compile time +var _ pam.ModuleHandler = pamModuleHandler; +`) + } + + g.printf(` +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)), + pam.Flags(flags), sliceFromArgv(argc, argv)) + + if err == nil { + return 0; + } + + if (pam.Flags(flags) & pam.Silent) == 0 { + fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} +`) + + for _, f := range vFuncs { + g.printf(` +//export pam_sm_%[1]s +func pam_sm_%[1]s(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.%[2]s) +} +`, f.cName, f.goName) + } + + if !g.noMain { + g.printf("\nfunc main() {}\n") + } +} + +// format returns the gofmt-ed contents of the Generator's buffer. +func (g *Generator) format() []byte { + src, err := format.Source(g.buf.Bytes()) + if err != nil { + // Should never happen, but can arise when developing this code. + // The user can compile the output to see the error. + log.Printf("warning: internal error: invalid Go generated: %s", err) + log.Printf("warning: compile the package to analyze the error") + return g.buf.Bytes() + } + return src +} diff --git a/example-module/module.go b/example-module/module.go new file mode 100644 index 0000000..634e3ac --- /dev/null +++ b/example-module/module.go @@ -0,0 +1,50 @@ +// These go:generate directive allow to generate the module by just using +// `go generate` once in the module directory. +// This is not strictly needed + +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler +//go:generate go generate --skip="pam_module.go" + +// Package main provides the module shared library. +package main + +import ( + "fmt" + + "github.com/msteinert/pam/v2" +) + +type exampleHandler struct{} + +var pamModuleHandler pam.ModuleHandler = &exampleHandler{} +var _ = pamModuleHandler + +// AcctMgmt is the module handle function for account management. +func (h *exampleHandler) AcctMgmt(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("AcctMgmt not implemented: %w", pam.ErrIgnore) +} + +// Authenticate is the module handle function for authentication. +func (h *exampleHandler) Authenticate(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return pam.ErrAuthinfoUnavail +} + +// ChangeAuthTok is the module handle function for changing authentication token. +func (h *exampleHandler) ChangeAuthTok(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("ChangeAuthTok not implemented: %w", pam.ErrIgnore) +} + +// OpenSession is the module handle function for open session. +func (h *exampleHandler) OpenSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("OpenSession not implemented: %w", pam.ErrIgnore) +} + +// CloseSession is the module handle function for close session. +func (h *exampleHandler) CloseSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("CloseSession not implemented: %w", pam.ErrIgnore) +} + +// SetCred is the module handle function for set credentials. +func (h *exampleHandler) SetCred(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return fmt.Errorf("SetCred not implemented: %w", pam.ErrIgnore) +} diff --git a/example-module/pam_module.go b/example-module/pam_module.go new file mode 100644 index 0000000..b3bfb08 --- /dev/null +++ b/example-module/pam_module.go @@ -0,0 +1,96 @@ +// Code generated by "pam-moduler "; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam/v2" + "os" + "unsafe" +) + +// Do a typecheck at compile time +var _ pam.ModuleHandler = pamModuleHandler + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)), + pam.Flags(flags), sliceFromArgv(argc, argv)) + + if err == nil { + return 0 + } + + if (pam.Flags(flags) & pam.Silent) == 0 { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/module-transaction.go b/module-transaction.go index 9a3aae6..12b3a40 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -15,7 +15,9 @@ type ModuleTransaction interface { type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error // ModuleTransaction is the module-side handle for a PAM transaction. -type moduleTransaction = transactionBase +type moduleTransaction struct { + transactionBase +} // ModuleHandler is an interface for objects that can be used to create // PAM modules from go. @@ -27,3 +29,9 @@ type ModuleHandler interface { OpenSession(ModuleTransaction, Flags, []string) error SetCred(ModuleTransaction, Flags, []string) error } + +// NewModuleTransaction allows initializing a transaction invoker from +// the module side. +func NewModuleTransaction(handle NativeHandle) ModuleTransaction { + return &moduleTransaction{transactionBase{handle: handle}} +} From 75278e8a4fef941b63ab4dfbb0dec40ae017ba1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 20 Nov 2023 21:06:58 +0100 Subject: [PATCH 05/24] transaction: Define C functions as unexported static inlines This will make it easier to avoid exporting unexpected symbols to the generated PAM libraries. Also it makes less messy handling C code inside go files. --- transaction.go | 41 ++++++++++++++-------------------- transaction.c => transaction.h | 24 +++++++++++++------- 2 files changed, 33 insertions(+), 32 deletions(-) rename transaction.c => transaction.h (62%) diff --git a/transaction.go b/transaction.go index 10c1a55..376f5ff 100644 --- a/transaction.go +++ b/transaction.go @@ -4,21 +4,7 @@ package pam //#cgo CFLAGS: -Wall -Wno-unused-variable -std=c99 //#cgo LDFLAGS: -lpam // -//#include -//#include -//#include -// -//#ifdef PAM_BINARY_PROMPT -//#define BINARY_PROMPT_IS_SUPPORTED 1 -//#else -//#include -//#define PAM_BINARY_PROMPT INT_MAX -//#define BINARY_PROMPT_IS_SUPPORTED 0 -//#endif -// -//void init_pam_conv(struct pam_conv *conv, uintptr_t); -//int pam_start_confdir(const char *service_name, const char *user, const struct pam_conv *pam_conversation, const char *confdir, pam_handle_t **pamh) __attribute__ ((weak)); -//int check_pam_start_confdir(void); +//#include "transaction.h" import "C" import ( @@ -89,16 +75,24 @@ func (f ConversationFunc) RespondPAM(s Style, msg string) (string, error) { return f(s, msg) } -// cbPAMConv is a wrapper for the conversation callback function. +// _go_pam_conv_handler is a C wrapper for the conversation callback function. // -//export cbPAMConv -func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { +//export _go_pam_conv_handler +func _go_pam_conv_handler(msg *C.struct_pam_message, c C.uintptr_t, outMsg **C.char) C.int { + convHandler, ok := cgo.Handle(c).Value().(ConversationHandler) + if !ok || convHandler == nil { + return C.int(ErrConv) + } + replyMsg, r := pamConvHandler(Style(msg.msg_style), msg.msg, convHandler) + *outMsg = replyMsg + return r +} + +// pamConvHandler is a Go wrapper for the conversation callback function. +func pamConvHandler(style Style, msg *C.char, handler ConversationHandler) (*C.char, C.int) { var r string var err error - v := cgo.Handle(c).Value() - style := Style(s) - var handler ConversationHandler - switch cb := v.(type) { + switch cb := handler.(type) { case BinaryConversationHandler: if style == BinaryPrompt { bytes, err := cb.RespondPAMBinary(BinaryPointer(msg)) @@ -116,8 +110,7 @@ func cbPAMConv(s C.int, msg *C.char, c C.uintptr_t) (*C.char, C.int) { return nil, C.int(ErrConv) } handler = cb - } - if handler == nil { + default: return nil, C.int(ErrConv) } r, err = handler.RespondPAM(style, C.GoString(msg)) diff --git a/transaction.c b/transaction.h similarity index 62% rename from transaction.c rename to transaction.h index 8abed03..88d2766 100644 --- a/transaction.c +++ b/transaction.h @@ -1,15 +1,25 @@ -#include "_cgo_export.h" #include #include +#include #include +#ifdef PAM_BINARY_PROMPT +#define BINARY_PROMPT_IS_SUPPORTED 1 +#else +#include +#define PAM_BINARY_PROMPT INT_MAX +#define BINARY_PROMPT_IS_SUPPORTED 0 +#endif + #ifdef __sun #define PAM_CONST #else #define PAM_CONST const #endif -int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_response **resp, void *appdata_ptr) +extern int _go_pam_conv_handler(struct pam_message *, uintptr_t, char **reply); + +static inline int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_response **resp, void *appdata_ptr) { if (num_msg <= 0 || num_msg > PAM_MAX_NUM_MSG) return PAM_CONV_ERR; @@ -19,11 +29,9 @@ int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_resp return PAM_BUF_ERR; for (size_t i = 0; i < num_msg; ++i) { - struct cbPAMConv_return result = cbPAMConv(msg[i]->msg_style, (char *)msg[i]->msg, (uintptr_t)appdata_ptr); - if (result.r1 != PAM_SUCCESS) + int result = _go_pam_conv_handler((struct pam_message *)msg[i], (uintptr_t)appdata_ptr, &(*resp)[i].resp); + if (result != PAM_SUCCESS) goto error; - - (*resp)[i].resp = result.r0; } return PAM_SUCCESS; @@ -41,7 +49,7 @@ int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_resp return PAM_CONV_ERR; } -void init_pam_conv(struct pam_conv *conv, uintptr_t appdata) +static inline void init_pam_conv(struct pam_conv *conv, uintptr_t appdata) { conv->conv = cb_pam_conv; conv->appdata_ptr = (void *)appdata; @@ -52,7 +60,7 @@ void init_pam_conv(struct pam_conv *conv, uintptr_t appdata) int pam_start_confdir(const char *service_name, const char *user, const struct pam_conv *pam_conversation, const char *confdir, pam_handle_t **pamh) __attribute__((weak)); -int check_pam_start_confdir(void) +static inline int check_pam_start_confdir(void) { if (pam_start_confdir == NULL) return 1; From 4b6191083955194a46d38c6b29346baa5badcf37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 20 Nov 2023 21:09:43 +0100 Subject: [PATCH 06/24] transaction, moduler: Do not export PAM conv handler function to modules This function is only needed when using go PAM for creating applications so it's not something we expect to have exported to library modules. To prevent this use an `asPamModule` tag to prevent compilation of application-only features. --- app-transaction.go | 24 ++++++++++++++++++++++++ cmd/pam-moduler/moduler.go | 4 ++-- example-module/pam_module.go | 2 +- transaction.go | 13 ------------- 4 files changed, 27 insertions(+), 16 deletions(-) create mode 100644 app-transaction.go diff --git a/app-transaction.go b/app-transaction.go new file mode 100644 index 0000000..2ddfaaa --- /dev/null +++ b/app-transaction.go @@ -0,0 +1,24 @@ +//go:build !go_pam_module + +package pam + +/* +#include +#include +*/ +import "C" + +import "runtime/cgo" + +// _go_pam_conv_handler is a C wrapper for the conversation callback function. +// +//export _go_pam_conv_handler +func _go_pam_conv_handler(msg *C.struct_pam_message, c C.uintptr_t, outMsg **C.char) C.int { + convHandler, ok := cgo.Handle(c).Value().(ConversationHandler) + if !ok || convHandler == nil { + return C.int(ErrConv) + } + replyMsg, r := pamConvHandler(Style(msg.msg_style), msg.msg, convHandler) + *outMsg = replyMsg + return r +} diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index b95195d..94298dc 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -119,9 +119,9 @@ func main() { tags = *buildTags } - var generateTags []string + generateTags := []string{"go_pam_module"} if len(*moduleBuildTags) > 0 { - generateTags = strings.Split(*moduleBuildTags, ",") + generateTags = append(generateTags, strings.Split(*moduleBuildTags, ",")...) } var buildFlags []string diff --git a/example-module/pam_module.go b/example-module/pam_module.go index b3bfb08..b13924e 100644 --- a/example-module/pam_module.go +++ b/example-module/pam_module.go @@ -1,6 +1,6 @@ // Code generated by "pam-moduler "; DO NOT EDIT. -//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so -tags go_pam_module // Package main is the package for the PAM module library. package main diff --git a/transaction.go b/transaction.go index 376f5ff..462136d 100644 --- a/transaction.go +++ b/transaction.go @@ -75,19 +75,6 @@ func (f ConversationFunc) RespondPAM(s Style, msg string) (string, error) { return f(s, msg) } -// _go_pam_conv_handler is a C wrapper for the conversation callback function. -// -//export _go_pam_conv_handler -func _go_pam_conv_handler(msg *C.struct_pam_message, c C.uintptr_t, outMsg **C.char) C.int { - convHandler, ok := cgo.Handle(c).Value().(ConversationHandler) - if !ok || convHandler == nil { - return C.int(ErrConv) - } - replyMsg, r := pamConvHandler(Style(msg.msg_style), msg.msg, convHandler) - *outMsg = replyMsg - return r -} - // pamConvHandler is a Go wrapper for the conversation callback function. func pamConvHandler(style Style, msg *C.char, handler ConversationHandler) (*C.char, C.int) { var r string From c0be3144217073088aea7a0a643ed82eca11f544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 20 Nov 2023 21:29:55 +0100 Subject: [PATCH 07/24] transaction: Move PAM app side function only to app-transaction In this way all these features not even compiled when creating modules, avoiding generating unused code. --- app-transaction.go | 218 ++++++++++++++++++++++++++++++++++++++++++++- transaction.go | 210 ------------------------------------------- 2 files changed, 215 insertions(+), 213 deletions(-) diff --git a/app-transaction.go b/app-transaction.go index 2ddfaaa..671b48e 100644 --- a/app-transaction.go +++ b/app-transaction.go @@ -3,12 +3,46 @@ package pam /* -#include -#include +#include "transaction.h" */ import "C" -import "runtime/cgo" +import ( + "fmt" + "runtime/cgo" + "sync/atomic" + "unsafe" +) + +// ConversationHandler is an interface for objects that can be used as +// conversation callbacks during PAM authentication. +type ConversationHandler interface { + // RespondPAM receives a message style and a message string. If the + // message Style is PromptEchoOff or PromptEchoOn then the function + // should return a response string. + RespondPAM(Style, string) (string, error) +} + +// BinaryConversationHandler is an interface for objects that can be used as +// conversation callbacks during PAM authentication if binary protocol is going +// to be supported. +type BinaryConversationHandler interface { + ConversationHandler + // RespondPAMBinary receives a pointer to the binary message. It's up to + // the receiver to parse it according to the protocol specifications. + // The function can return a byte array that will passed as pointer back + // to the module. + RespondPAMBinary(BinaryPointer) ([]byte, error) +} + +// ConversationFunc is an adapter to allow the use of ordinary functions as +// conversation callbacks. +type ConversationFunc func(Style, string) (string, error) + +// RespondPAM is a conversation callback adapter. +func (f ConversationFunc) RespondPAM(s Style, msg string) (string, error) { + return f(s, msg) +} // _go_pam_conv_handler is a C wrapper for the conversation callback function. // @@ -22,3 +56,181 @@ func _go_pam_conv_handler(msg *C.struct_pam_message, c C.uintptr_t, outMsg **C.c *outMsg = replyMsg return r } + +// pamConvHandler is a Go wrapper for the conversation callback function. +func pamConvHandler(style Style, msg *C.char, handler ConversationHandler) (*C.char, C.int) { + var r string + var err error + switch cb := handler.(type) { + case BinaryConversationHandler: + if style == BinaryPrompt { + bytes, err := cb.RespondPAMBinary(BinaryPointer(msg)) + if err != nil { + return nil, C.int(ErrConv) + } + if bytes == nil { + return nil, success + } + return (*C.char)(C.CBytes(bytes)), success + } + handler = cb + case ConversationHandler: + if style == BinaryPrompt { + return nil, C.int(ErrConv) + } + handler = cb + default: + return nil, C.int(ErrConv) + } + r, err = handler.RespondPAM(style, C.GoString(msg)) + if err != nil { + return nil, C.int(ErrConv) + } + return C.CString(r), success +} + +// Transaction is the application's handle for a PAM transaction. +type Transaction struct { + transactionBase + + conv *C.struct_pam_conv + c cgo.Handle +} + +// Start initiates a new PAM transaction. Service is treated identically to +// how pam_start treats it internally. +// +// All application calls to PAM begin with Start*. The returned +// transaction provides an interface to the remainder of the API. +// +// It's responsibility of the Transaction owner to release all the resources +// allocated underneath by PAM by calling End() once done. +// +// It's not advised to End the transaction using a runtime.SetFinalizer unless +// you're absolutely sure that your stack is multi-thread friendly (normally it +// is not!) and using a LockOSThread/UnlockOSThread pair. +func Start(service, user string, handler ConversationHandler) (*Transaction, error) { + return start(service, user, handler, "") +} + +// StartFunc registers the handler func as a conversation handler and starts +// the transaction (see Start() documentation). +func StartFunc(service, user string, handler func(Style, string) (string, error)) (*Transaction, error) { + return start(service, user, ConversationFunc(handler), "") +} + +// StartConfDir initiates a new PAM transaction. Service is treated identically to +// how pam_start treats it internally. +// confdir allows to define where all pam services are defined. This is used to provide +// custom paths for tests. +// +// All application calls to PAM begin with Start*. The returned +// transaction provides an interface to the remainder of the API. +// +// It's responsibility of the Transaction owner to release all the resources +// allocated underneath by PAM by calling End() once done. +// +// It's not advised to End the transaction using a runtime.SetFinalizer unless +// you're absolutely sure that your stack is multi-thread friendly (normally it +// is not!) and using a LockOSThread/UnlockOSThread pair. +func StartConfDir(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { + if !CheckPamHasStartConfdir() { + return nil, fmt.Errorf( + "%w: StartConfDir was used, but the pam version on the system is not recent enough", + ErrSystem) + } + + return start(service, user, handler, confDir) +} + +func start(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { + switch handler.(type) { + case BinaryConversationHandler: + if !CheckPamHasBinaryProtocol() { + return nil, fmt.Errorf("%w: BinaryConversationHandler was used, but it is not supported by this platform", + ErrSystem) + } + } + t := &Transaction{ + conv: &C.struct_pam_conv{}, + c: cgo.NewHandle(handler), + } + + C.init_pam_conv(t.conv, C.uintptr_t(t.c)) + s := C.CString(service) + defer C.free(unsafe.Pointer(s)) + var u *C.char + if len(user) != 0 { + u = C.CString(user) + defer C.free(unsafe.Pointer(u)) + } + var err error + if confDir == "" { + err = t.handlePamStatus(C.pam_start(s, u, t.conv, &t.handle)) + } else { + c := C.CString(confDir) + defer C.free(unsafe.Pointer(c)) + err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle)) + } + if err != nil { + var _ = t.End() + return nil, err + } + return t, nil +} + +// Authenticate is used to authenticate the user. +// +// Valid flags: Silent, DisallowNullAuthtok +func (t *Transaction) Authenticate(f Flags) error { + return t.handlePamStatus(C.pam_authenticate(t.handle, C.int(f))) +} + +// SetCred is used to establish, maintain and delete the credentials of a +// user. +// +// Valid flags: EstablishCred, DeleteCred, ReinitializeCred, RefreshCred +func (t *Transaction) SetCred(f Flags) error { + return t.handlePamStatus(C.pam_setcred(t.handle, C.int(f))) +} + +// AcctMgmt is used to determine if the user's account is valid. +// +// Valid flags: Silent, DisallowNullAuthtok +func (t *Transaction) AcctMgmt(f Flags) error { + return t.handlePamStatus(C.pam_acct_mgmt(t.handle, C.int(f))) +} + +// ChangeAuthTok is used to change the authentication token. +// +// Valid flags: Silent, ChangeExpiredAuthtok +func (t *Transaction) ChangeAuthTok(f Flags) error { + return t.handlePamStatus(C.pam_chauthtok(t.handle, C.int(f))) +} + +// OpenSession sets up a user session for an authenticated user. +// +// Valid flags: Slient +func (t *Transaction) OpenSession(f Flags) error { + return t.handlePamStatus(C.pam_open_session(t.handle, C.int(f))) +} + +// CloseSession closes a previously opened session. +// +// Valid flags: Silent +func (t *Transaction) CloseSession(f Flags) error { + return t.handlePamStatus(C.pam_close_session(t.handle, C.int(f))) +} + +// End cleans up the PAM handle and deletes the callback function. +// It must be called when done with the transaction. +func (t *Transaction) End() error { + handle := atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.handle)), nil) + if handle == nil { + return nil + } + + defer t.c.Delete() + return t.handlePamStatus(C.pam_end((*C.pam_handle_t)(handle), + C.int(t.lastStatus.Load()))) +} diff --git a/transaction.go b/transaction.go index 462136d..bd2876d 100644 --- a/transaction.go +++ b/transaction.go @@ -8,8 +8,6 @@ package pam import "C" import ( - "fmt" - "runtime/cgo" "strings" "sync/atomic" "unsafe" @@ -40,73 +38,11 @@ const ( BinaryPrompt Style = C.PAM_BINARY_PROMPT ) -// ConversationHandler is an interface for objects that can be used as -// conversation callbacks during PAM authentication. -type ConversationHandler interface { - // RespondPAM receives a message style and a message string. If the - // message Style is PromptEchoOff or PromptEchoOn then the function - // should return a response string. - RespondPAM(Style, string) (string, error) -} - // BinaryPointer exposes the type used for the data in a binary conversation // it represents a pointer to data that is produced by the module and that // must be parsed depending on the protocol in use type BinaryPointer unsafe.Pointer -// BinaryConversationHandler is an interface for objects that can be used as -// conversation callbacks during PAM authentication if binary protocol is going -// to be supported. -type BinaryConversationHandler interface { - ConversationHandler - // RespondPAMBinary receives a pointer to the binary message. It's up to - // the receiver to parse it according to the protocol specifications. - // The function can return a byte array that will passed as pointer back - // to the module. - RespondPAMBinary(BinaryPointer) ([]byte, error) -} - -// ConversationFunc is an adapter to allow the use of ordinary functions as -// conversation callbacks. -type ConversationFunc func(Style, string) (string, error) - -// RespondPAM is a conversation callback adapter. -func (f ConversationFunc) RespondPAM(s Style, msg string) (string, error) { - return f(s, msg) -} - -// pamConvHandler is a Go wrapper for the conversation callback function. -func pamConvHandler(style Style, msg *C.char, handler ConversationHandler) (*C.char, C.int) { - var r string - var err error - switch cb := handler.(type) { - case BinaryConversationHandler: - if style == BinaryPrompt { - bytes, err := cb.RespondPAMBinary(BinaryPointer(msg)) - if err != nil { - return nil, C.int(ErrConv) - } - if bytes == nil { - return nil, success - } - return (*C.char)(C.CBytes(bytes)), success - } - handler = cb - case ConversationHandler: - if style == BinaryPrompt { - return nil, C.int(ErrConv) - } - handler = cb - default: - return nil, C.int(ErrConv) - } - r, err = handler.RespondPAM(style, C.GoString(msg)) - if err != nil { - return nil, C.int(ErrConv) - } - return C.CString(r), success -} - // NativeHandle is the type of the native PAM handle for a transaction so that // it can be exported type NativeHandle = *C.pam_handle_t @@ -119,27 +55,6 @@ type transactionBase struct { lastStatus atomic.Int32 } -// Transaction is the application's handle for a PAM transaction. -type Transaction struct { - transactionBase - - conv *C.struct_pam_conv - c cgo.Handle -} - -// End cleans up the PAM handle and deletes the callback function. -// It must be called when done with the transaction. -func (t *Transaction) End() error { - handle := atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.handle)), nil) - if handle == nil { - return nil - } - - defer t.c.Delete() - return t.handlePamStatus(C.pam_end((*C.pam_handle_t)(handle), - C.int(t.lastStatus.Load()))) -} - // Allows to call pam functions managing return status func (t *transactionBase) handlePamStatus(cStatus C.int) error { t.lastStatus.Store(int32(cStatus)) @@ -149,88 +64,6 @@ func (t *transactionBase) handlePamStatus(cStatus C.int) error { return nil } -// Start initiates a new PAM transaction. Service is treated identically to -// how pam_start treats it internally. -// -// All application calls to PAM begin with Start*. The returned -// transaction provides an interface to the remainder of the API. -// -// It's responsibility of the Transaction owner to release all the resources -// allocated underneath by PAM by calling End() once done. -// -// It's not advised to End the transaction using a runtime.SetFinalizer unless -// you're absolutely sure that your stack is multi-thread friendly (normally it -// is not!) and using a LockOSThread/UnlockOSThread pair. -func Start(service, user string, handler ConversationHandler) (*Transaction, error) { - return start(service, user, handler, "") -} - -// StartFunc registers the handler func as a conversation handler and starts -// the transaction (see Start() documentation). -func StartFunc(service, user string, handler func(Style, string) (string, error)) (*Transaction, error) { - return start(service, user, ConversationFunc(handler), "") -} - -// StartConfDir initiates a new PAM transaction. Service is treated identically to -// how pam_start treats it internally. -// confdir allows to define where all pam services are defined. This is used to provide -// custom paths for tests. -// -// All application calls to PAM begin with Start*. The returned -// transaction provides an interface to the remainder of the API. -// -// It's responsibility of the Transaction owner to release all the resources -// allocated underneath by PAM by calling End() once done. -// -// It's not advised to End the transaction using a runtime.SetFinalizer unless -// you're absolutely sure that your stack is multi-thread friendly (normally it -// is not!) and using a LockOSThread/UnlockOSThread pair. -func StartConfDir(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { - if !CheckPamHasStartConfdir() { - return nil, fmt.Errorf( - "%w: StartConfDir was used, but the pam version on the system is not recent enough", - ErrSystem) - } - - return start(service, user, handler, confDir) -} - -func start(service, user string, handler ConversationHandler, confDir string) (*Transaction, error) { - switch handler.(type) { - case BinaryConversationHandler: - if !CheckPamHasBinaryProtocol() { - return nil, fmt.Errorf("%w: BinaryConversationHandler was used, but it is not supported by this platform", - ErrSystem) - } - } - t := &Transaction{ - conv: &C.struct_pam_conv{}, - c: cgo.NewHandle(handler), - } - - C.init_pam_conv(t.conv, C.uintptr_t(t.c)) - s := C.CString(service) - defer C.free(unsafe.Pointer(s)) - var u *C.char - if len(user) != 0 { - u = C.CString(user) - defer C.free(unsafe.Pointer(u)) - } - var err error - if confDir == "" { - err = t.handlePamStatus(C.pam_start(s, u, t.conv, &t.handle)) - } else { - c := C.CString(confDir) - defer C.free(unsafe.Pointer(c)) - err = t.handlePamStatus(C.pam_start_confdir(s, u, t.conv, c, &t.handle)) - } - if err != nil { - var _ = t.End() - return nil, err - } - return t, nil -} - // Item is a an PAM information type. type Item int @@ -307,49 +140,6 @@ const ( ChangeExpiredAuthtok Flags = C.PAM_CHANGE_EXPIRED_AUTHTOK ) -// Authenticate is used to authenticate the user. -// -// Valid flags: Silent, DisallowNullAuthtok -func (t *Transaction) Authenticate(f Flags) error { - return t.handlePamStatus(C.pam_authenticate(t.handle, C.int(f))) -} - -// SetCred is used to establish, maintain and delete the credentials of a -// user. -// -// Valid flags: EstablishCred, DeleteCred, ReinitializeCred, RefreshCred -func (t *Transaction) SetCred(f Flags) error { - return t.handlePamStatus(C.pam_setcred(t.handle, C.int(f))) -} - -// AcctMgmt is used to determine if the user's account is valid. -// -// Valid flags: Silent, DisallowNullAuthtok -func (t *Transaction) AcctMgmt(f Flags) error { - return t.handlePamStatus(C.pam_acct_mgmt(t.handle, C.int(f))) -} - -// ChangeAuthTok is used to change the authentication token. -// -// Valid flags: Silent, ChangeExpiredAuthtok -func (t *Transaction) ChangeAuthTok(f Flags) error { - return t.handlePamStatus(C.pam_chauthtok(t.handle, C.int(f))) -} - -// OpenSession sets up a user session for an authenticated user. -// -// Valid flags: Slient -func (t *Transaction) OpenSession(f Flags) error { - return t.handlePamStatus(C.pam_open_session(t.handle, C.int(f))) -} - -// CloseSession closes a previously opened session. -// -// Valid flags: Silent -func (t *Transaction) CloseSession(f Flags) error { - return t.handlePamStatus(C.pam_close_session(t.handle, C.int(f))) -} - // PutEnv adds or changes the value of PAM environment variables. // // NAME=value will set a variable to a value. From 1b0984110ed9ce945d6988e2461de52bd36279af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 25 Sep 2023 23:08:08 +0200 Subject: [PATCH 08/24] moduler: Move module transaction invoke handling to transaction itself So we can reduce the generated code and add more unit tests --- cmd/pam-moduler/moduler.go | 10 ++-- example-module/pam_module.go | 8 +-- module-transaction.go | 58 ++++++++++++++++++- module-transaction_test.go | 106 +++++++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 11 deletions(-) diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 94298dc..68f4852 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -257,14 +257,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)), - pam.Flags(flags), sliceFromArgv(argc, argv)) - + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) if err == nil { - return 0; + return 0 } - if (pam.Flags(flags) & pam.Silent) == 0 { + if (pam.Flags(flags) & pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { fmt.Fprintf(os.Stderr, "module returned error: %%v\n", err) } diff --git a/example-module/pam_module.go b/example-module/pam_module.go index b13924e..080e97c 100644 --- a/example-module/pam_module.go +++ b/example-module/pam_module.go @@ -44,14 +44,14 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - err := moduleFunc(pam.NewModuleTransaction(pam.NativeHandle(pamh)), - pam.Flags(flags), sliceFromArgv(argc, argv)) - + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) if err == nil { return 0 } - if (pam.Flags(flags) & pam.Silent) == 0 { + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) } diff --git a/module-transaction.go b/module-transaction.go index 12b3a40..0e87fe5 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -1,6 +1,13 @@ // Package pam provides a wrapper for the PAM application API. package pam +import "C" + +import ( + "errors" + "fmt" +) + // ModuleTransaction is an interface that a pam module transaction // should implement. type ModuleTransaction interface { @@ -30,8 +37,55 @@ type ModuleHandler interface { SetCred(ModuleTransaction, Flags, []string) error } -// NewModuleTransaction allows initializing a transaction invoker from +// ModuleTransactionInvoker is an interface that a pam module transaction +// should implement to redirect requests from C handlers to go, +type ModuleTransactionInvoker interface { + ModuleTransaction + InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error +} + +// NewModuleTransactionInvoker allows initializing a transaction invoker from // the module side. -func NewModuleTransaction(handle NativeHandle) ModuleTransaction { +func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker { return &moduleTransaction{transactionBase{handle: handle}} } + +func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, + flags Flags, args []string) error { + invoker := func() error { + if handler == nil { + return ErrIgnore + } + err := handler(m, flags, args) + if err != nil { + service, _ := m.GetItem(Service) + + var pamErr Error + if !errors.As(err, &pamErr) { + err = ErrSystem + } + + if pamErr == ErrIgnore || service == "" { + return err + } + + return fmt.Errorf("%s failed: %w", service, err) + } + return nil + } + err := invoker() + if errors.Is(err, Error(0)) { + err = nil + } + var status int32 + if err != nil { + status = int32(ErrSystem) + + var pamErr Error + if errors.As(err, &pamErr) { + status = int32(pamErr) + } + } + m.lastStatus.Store(status) + return err +} diff --git a/module-transaction_test.go b/module-transaction_test.go index d5c7533..8661f68 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -4,6 +4,7 @@ package pam import ( "errors" "reflect" + "strings" "testing" ) @@ -129,3 +130,108 @@ func Test_NewNullModuleTransaction(t *testing.T) { }) } } + +func Test_ModuleTransaction_InvokeHandler(t *testing.T) { + t.Parallel() + mt := &moduleTransaction{} + + err := mt.InvokeHandler(nil, 0, nil) + if !errors.Is(err, ErrIgnore) { + t.Fatalf("unexpected err: %v", err) + } + + tests := map[string]struct { + flags Flags + args []string + returnedError error + expectedError error + expectedErrorMsg string + }{ + "success": { + expectedError: nil, + }, + "success-with-flags": { + expectedError: nil, + flags: Silent | RefreshCred, + }, + "success-with-args": { + expectedError: nil, + args: []string{"foo", "bar"}, + }, + "success-with-args-and-flags": { + expectedError: nil, + flags: Silent | RefreshCred, + args: []string{"foo", "bar"}, + }, + "ignore": { + expectedError: ErrIgnore, + returnedError: ErrIgnore, + }, + "ignore-with-args-and-flags": { + expectedError: ErrIgnore, + returnedError: ErrIgnore, + args: []string{"foo", "bar"}, + }, + "generic-error": { + expectedError: ErrSystem, + returnedError: errors.New("this is a generic go error"), + expectedErrorMsg: "this is a generic go error", + }, + "transaction-error-service-error": { + expectedError: ErrService, + returnedError: errors.Join(ErrService, errors.New("ErrService")), + expectedErrorMsg: ErrService.Error(), + }, + "return-type-as-error-success": { + expectedError: nil, + returnedError: Error(0), + }, + "return-type-as-error": { + expectedError: ErrNoModuleData, + returnedError: ErrNoModuleData, + expectedErrorMsg: ErrNoModuleData.Error(), + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := mt.InvokeHandler(func(handlerMt ModuleTransaction, + handlerFlags Flags, handlerArgs []string) error { + if handlerMt != mt { + t.Fatalf("unexpected mt: %#v vs %#v", mt, handlerMt) + } + if handlerFlags != tc.flags { + t.Fatalf("unexpected mt: %#v vs %#v", tc.flags, handlerFlags) + } + if strings.Join(handlerArgs, "") != strings.Join(tc.args, "") { + t.Fatalf("unexpected mt: %#v vs %#v", tc.args, handlerArgs) + } + + return tc.returnedError + }, tc.flags, tc.args) + + status := Error(mt.lastStatus.Load()) + + if !errors.Is(err, tc.expectedError) { + t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError) + } + + var expectedStatus Error + if err != nil { + var pamErr Error + if errors.As(err, &pamErr) { + expectedStatus = pamErr + } else { + expectedStatus = ErrSystem + } + } + + if status != expectedStatus { + t.Fatalf("unexpected status: %#v vs %#v", status, expectedStatus) + } + }) + } +} From e9dffa7a8240ee3d9aa518fc058616b49c83ee03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Mon, 25 Sep 2023 23:13:34 +0200 Subject: [PATCH 09/24] pam-moduler: Add test that generates a new debug module and verify it works We mimic what pam_debug.so does by default, by implementing a similar module fully in go, generated using pam-moduler. This requires various utilities to generate the module and run the tests that are in a separate internal modules so that it can be shared between multiple implementations --- .gitignore | 2 + .../tests/debug-module/debug-module.go | 119 ++++++++++++ .../tests/debug-module/debug-module_test.go | 120 ++++++++++++ .../tests/debug-module/pam_module.go | 96 ++++++++++ .../tests/internal/utils/base-module.go | 38 ++++ .../tests/internal/utils/base-module_test.go | 35 ++++ .../tests/internal/utils/test-setup.go | 135 +++++++++++++ .../tests/internal/utils/test-setup_test.go | 180 ++++++++++++++++++ .../tests/internal/utils/test-utils.go | 99 ++++++++++ 9 files changed, 824 insertions(+) create mode 100644 cmd/pam-moduler/tests/debug-module/debug-module.go create mode 100644 cmd/pam-moduler/tests/debug-module/debug-module_test.go create mode 100644 cmd/pam-moduler/tests/debug-module/pam_module.go create mode 100644 cmd/pam-moduler/tests/internal/utils/base-module.go create mode 100644 cmd/pam-moduler/tests/internal/utils/base-module_test.go create mode 100644 cmd/pam-moduler/tests/internal/utils/test-setup.go create mode 100644 cmd/pam-moduler/tests/internal/utils/test-setup_test.go create mode 100644 cmd/pam-moduler/tests/internal/utils/test-utils.go diff --git a/.gitignore b/.gitignore index a2f238d..0700a89 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ coverage.out example-module/*.so example-module/*.h +cmd/pam-moduler/tests/*/*.so +cmd/pam-moduler/tests/*/*.h diff --git a/cmd/pam-moduler/tests/debug-module/debug-module.go b/cmd/pam-moduler/tests/debug-module/debug-module.go new file mode 100644 index 0000000..843b329 --- /dev/null +++ b/cmd/pam-moduler/tests/debug-module/debug-module.go @@ -0,0 +1,119 @@ +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -libname "pam_godebug.so" +//go:generate go generate --skip="pam_module.go" + +// This is a similar implementation of pam_debug.so + +// Package main is the package for the debug PAM module library +package main + +import ( + "fmt" + "strings" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +var pamModuleHandler pam.ModuleHandler = &DebugModule{} +var _ = pamModuleHandler + +var moduleArgsRetTypes = map[string]error{ + "success": nil, + "open_err": pam.ErrOpen, + "symbol_err": pam.ErrSymbol, + "service_err": pam.ErrService, + "system_err": pam.ErrSystem, + "buf_err": pam.ErrBuf, + "perm_denied": pam.ErrPermDenied, + "auth_err": pam.ErrAuth, + "cred_insufficient": pam.ErrCredInsufficient, + "authinfo_unavail": pam.ErrAuthinfoUnavail, + "user_unknown": pam.ErrUserUnknown, + "maxtries": pam.ErrMaxtries, + "new_authtok_reqd": pam.ErrNewAuthtokReqd, + "acct_expired": pam.ErrAcctExpired, + "session_err": pam.ErrSession, + "cred_unavail": pam.ErrCredUnavail, + "cred_expired": pam.ErrCredExpired, + "cred_err": pam.ErrCred, + "no_module_data": pam.ErrNoModuleData, + "conv_err": pam.ErrConv, + "authtok_err": pam.ErrAuthtok, + "authtok_recover_err": pam.ErrAuthtokRecovery, + "authtok_lock_busy": pam.ErrAuthtokLockBusy, + "authtok_disable_aging": pam.ErrAuthtokDisableAging, + "try_again": pam.ErrTryAgain, + "ignore": pam.ErrIgnore, + "abort": pam.ErrAbort, + "authtok_expired": pam.ErrAuthtokExpired, + "module_unknown": pam.ErrModuleUnknown, + "bad_item": pam.ErrBadItem, + "conv_again": pam.ErrConvAgain, + "incomplete": pam.ErrIncomplete, +} + +var debugModuleArgs = []string{"auth", "cred", "acct", "prechauthtok", + "chauthtok", "open_session", "close_session"} + +// DebugModule is the PAM module structure. +type DebugModule struct { + utils.BaseModule +} + +func (dm *DebugModule) getReturnType(args []string, key string) error { + var value string + for _, a := range args { + v, found := strings.CutPrefix(a, key+"=") + if found { + value = v + } + } + + if value == "" { + return fmt.Errorf("Value not found") + } + + if ret, found := moduleArgsRetTypes[value]; found { + return ret + } + return fmt.Errorf("Parameter %s not known", value) +} + +func (dm *DebugModule) handleCall(args []string, action string) error { + err := dm.getReturnType(args, action) + if err == nil { + return nil + } + + return fmt.Errorf("error %w", err) +} + +// AcctMgmt is a PAM handler. +func (dm *DebugModule) AcctMgmt(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "acct") +} + +// Authenticate is a PAM handler. +func (dm *DebugModule) Authenticate(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "auth") +} + +// ChangeAuthTok is a PAM handler. +func (dm *DebugModule) ChangeAuthTok(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "chauthtok") +} + +// OpenSession is a PAM handler. +func (dm *DebugModule) OpenSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "open_session") +} + +// CloseSession is a PAM handler. +func (dm *DebugModule) CloseSession(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "close_session") +} + +// SetCred is a PAM handler. +func (dm *DebugModule) SetCred(mt pam.ModuleTransaction, flags pam.Flags, args []string) error { + return dm.handleCall(args, "cred") +} diff --git a/cmd/pam-moduler/tests/debug-module/debug-module_test.go b/cmd/pam-moduler/tests/debug-module/debug-module_test.go new file mode 100644 index 0000000..8a5d58d --- /dev/null +++ b/cmd/pam-moduler/tests/debug-module/debug-module_test.go @@ -0,0 +1,120 @@ +package main + +import ( + "errors" + "fmt" + "testing" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func Test_DebugModule_ActionStatus(t *testing.T) { + t.Parallel() + + module := DebugModule{} + + for ret, expected := range moduleArgsRetTypes { + ret := ret + expected := expected + for actionName, action := range utils.Actions { + actionName := actionName + action := action + t.Run(fmt.Sprintf("%s %s", ret, actionName), func(t *testing.T) { + t.Parallel() + moduleArgs := make([]string, 0) + for _, a := range debugModuleArgs { + moduleArgs = append(moduleArgs, fmt.Sprintf("%s=%s", a, ret)) + } + + mt := pam.ModuleTransactionInvoker(nil) + var err error + + switch action { + case utils.Account: + err = module.AcctMgmt(mt, 0, moduleArgs) + case utils.Auth: + err = module.Authenticate(mt, 0, moduleArgs) + case utils.Password: + err = module.ChangeAuthTok(mt, 0, moduleArgs) + case utils.Session: + err = module.OpenSession(mt, 0, moduleArgs) + } + + if !errors.Is(err, expected) { + t.Fatalf("error #unexpected %#v vs %#v", expected, err) + } + }) + } + } +} + +func Test_DebugModuleTransaction_ActionStatus(t *testing.T) { + t.Parallel() + if !pam.CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + modulePath := ts.GenerateModule(".", "pam_godebug.so") + + for ret, expected := range moduleArgsRetTypes { + ret := ret + expected := expected + for actionName, action := range utils.Actions { + ret := ret + expected := expected + actionName := actionName + action := action + t.Run(fmt.Sprintf("%s %s", ret, actionName), func(t *testing.T) { + t.Parallel() + serviceName := ret + "-" + actionName + moduleArgs := make([]string, 0) + for _, a := range debugModuleArgs { + moduleArgs = append(moduleArgs, fmt.Sprintf("%s=%s", a, ret)) + } + control := utils.Requisite + fallbackModule := utils.Permit + if ret == "success" { + fallbackModule = utils.Deny + control = utils.Sufficient + } + ts.CreateService(serviceName, []utils.ServiceLine{ + {Action: action, Control: control, Module: modulePath, Args: moduleArgs}, + {Action: action, Control: control, Module: fallbackModule.String(), Args: []string{}}, + }) + + tx, err := pam.StartConfDir(serviceName, "user", nil, ts.WorkDir()) + if err != nil { + t.Fatalf("start #error: %v", err) + } + defer func() { + err := tx.End() + if err != nil { + t.Fatalf("end #error: %v", err) + } + }() + + switch action { + case utils.Account: + err = tx.AcctMgmt(pam.Silent) + case utils.Auth: + err = tx.Authenticate(pam.Silent) + case utils.Password: + err = tx.ChangeAuthTok(pam.Silent) + case utils.Session: + err = tx.OpenSession(pam.Silent) + } + + if errors.Is(expected, pam.ErrIgnore) { + // Ignore can't be returned + expected = nil + } + + if !errors.Is(err, expected) { + t.Fatalf("error #unexpected %#v vs %#v", expected, err) + } + }) + } + } +} diff --git a/cmd/pam-moduler/tests/debug-module/pam_module.go b/cmd/pam-moduler/tests/debug-module/pam_module.go new file mode 100644 index 0000000..837842e --- /dev/null +++ b/cmd/pam-moduler/tests/debug-module/pam_module.go @@ -0,0 +1,96 @@ +// Code generated by "pam-moduler -libname pam_godebug.so"; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_godebug.so" -buildmode=c-shared -o pam_godebug.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam/v2" + "os" + "unsafe" +) + +// Do a typecheck at compile time +var _ pam.ModuleHandler = pamModuleHandler + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/cmd/pam-moduler/tests/internal/utils/base-module.go b/cmd/pam-moduler/tests/internal/utils/base-module.go new file mode 100644 index 0000000..494b077 --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/base-module.go @@ -0,0 +1,38 @@ +package utils + +import "github.com/msteinert/pam/v2" + +// BaseModule is the type for a base PAM module. +type BaseModule struct{} + +// AcctMgmt is the handler function for PAM AcctMgmt. +func (h *BaseModule) AcctMgmt(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// Authenticate is the handler function for PAM Authenticate. +func (h *BaseModule) Authenticate(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// ChangeAuthTok is the handler function for PAM ChangeAuthTok. +func (h *BaseModule) ChangeAuthTok(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// OpenSession is the handler function for PAM OpenSession. +func (h *BaseModule) OpenSession(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// CloseSession is the handler function for PAM CloseSession. +func (h *BaseModule) CloseSession(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +// SetCred is the handler function for PAM SetCred. +func (h *BaseModule) SetCred(pam.ModuleTransaction, pam.Flags, []string) error { + return nil +} + +var _ pam.ModuleHandler = &BaseModule{} diff --git a/cmd/pam-moduler/tests/internal/utils/base-module_test.go b/cmd/pam-moduler/tests/internal/utils/base-module_test.go new file mode 100644 index 0000000..461d90f --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/base-module_test.go @@ -0,0 +1,35 @@ +package utils + +import ( + "testing" + + "github.com/msteinert/pam/v2" +) + +func TestMain(t *testing.T) { + bm := BaseModule{} + + if bm.AcctMgmt(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.Authenticate(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.ChangeAuthTok(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.OpenSession(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.CloseSession(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } + + if bm.SetCred(nil, pam.Flags(0), nil) != nil { + t.Fatalf("Unexpected non-nil value") + } +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-setup.go b/cmd/pam-moduler/tests/internal/utils/test-setup.go new file mode 100644 index 0000000..77fc71d --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/test-setup.go @@ -0,0 +1,135 @@ +// Package utils contains the internal test utils +package utils + +import ( + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/msteinert/pam/v2" +) + +// TestSetup is an utility type for having a playground for test PAM modules. +type TestSetup struct { + t *testing.T + workDir string +} + +type withWorkDir struct{} + +//nolint:revive +func WithWorkDir() withWorkDir { + return withWorkDir{} +} + +// NewTestSetup creates a new TestSetup. +func NewTestSetup(t *testing.T, args ...interface{}) *TestSetup { + t.Helper() + + ts := &TestSetup{t: t} + for _, arg := range args { + switch argType := arg.(type) { + case withWorkDir: + ts.ensureWorkDir() + default: + t.Fatalf("Unknown parameter of type %v", argType) + } + } + + return ts +} + +// CreateTemporaryDir creates a temporary directory with provided basename. +func (ts *TestSetup) CreateTemporaryDir(basename string) string { + tmpDir, err := os.MkdirTemp(os.TempDir(), basename) + if err != nil { + ts.t.Fatalf("can't create service path %v", err) + } + + ts.t.Cleanup(func() { os.RemoveAll(tmpDir) }) + return tmpDir +} + +func (ts *TestSetup) ensureWorkDir() string { + if ts.workDir != "" { + return ts.workDir + } + + ts.workDir = ts.CreateTemporaryDir("go-pam-*") + return ts.workDir +} + +// WorkDir returns the test setup work directory. +func (ts TestSetup) WorkDir() string { + return ts.workDir +} + +// GenerateModule generates a PAM module for the provided path and name. +func (ts *TestSetup) GenerateModule(testModulePath string, moduleName string) string { + cmd := exec.Command("go", "generate", "-C", testModulePath) + out, err := cmd.CombinedOutput() + if err != nil { + ts.t.Fatalf("can't build pam module %v: %s", err, out) + } + + builtFile := filepath.Join(cmd.Dir, testModulePath, moduleName) + modulePath := filepath.Join(ts.ensureWorkDir(), filepath.Base(builtFile)) + if err = os.Rename(builtFile, modulePath); err != nil { + ts.t.Fatalf("can't move module: %v", err) + os.Remove(builtFile) + } + + return modulePath +} + +func (ts TestSetup) currentFile(skip int) string { + _, currentFile, _, ok := runtime.Caller(skip) + if !ok { + ts.t.Fatalf("can't get current binary path") + } + return currentFile +} + +// GetCurrentFile returns the current file path. +func (ts TestSetup) GetCurrentFile() string { + // This is a library so we care about the caller location + return ts.currentFile(2) +} + +// GetCurrentFileDir returns the current file directory. +func (ts TestSetup) GetCurrentFileDir() string { + return filepath.Dir(ts.currentFile(2)) +} + +// GenerateModuleDefault generates a default module. +func (ts *TestSetup) GenerateModuleDefault(testModulePath string) string { + return ts.GenerateModule(testModulePath, "pam_go.so") +} + +// CreateService creates a service file. +func (ts *TestSetup) CreateService(serviceName string, services []ServiceLine) string { + if !pam.CheckPamHasStartConfdir() { + ts.t.Skip("PAM has no support for custom service paths") + return "" + } + + serviceName = strings.ToLower(serviceName) + serviceFile := filepath.Join(ts.ensureWorkDir(), serviceName) + var contents = []string{} + + for _, s := range services { + contents = append(contents, strings.TrimRight(strings.Join([]string{ + s.Action.String(), s.Control.String(), s.Module, strings.Join(s.Args, " "), + }, "\t"), "\t")) + } + + if err := os.WriteFile(serviceFile, + []byte(strings.Join(contents, "\n")), 0600); err != nil { + ts.t.Fatalf("can't create service file %v: %v", serviceFile, err) + } + + return serviceFile +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-setup_test.go b/cmd/pam-moduler/tests/internal/utils/test-setup_test.go new file mode 100644 index 0000000..f8a17a6 --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/test-setup_test.go @@ -0,0 +1,180 @@ +package utils + +import ( + "fmt" + "math/rand" + "os" + "path/filepath" + "strings" + "testing" +) + +func isDir(t *testing.T, path string) bool { + t.Helper() + if file, err := os.Open(path); err == nil { + if fileInfo, err := file.Stat(); err == nil { + return fileInfo.IsDir() + } + t.Fatalf("error: %v", err) + } else { + t.Fatalf("error: %v", err) + } + return false +} + +func Test_CreateTemporaryDir(t *testing.T) { + t.Parallel() + ts := NewTestSetup(t) + dir := ts.CreateTemporaryDir("") + if !isDir(t, dir) { + t.Fatalf("%s not a directory", dir) + } + + dir = ts.CreateTemporaryDir("foo-prefix-*") + if !isDir(t, dir) { + t.Fatalf("%s not a directory", dir) + } +} + +func Test_TestSetupWithWorkDir(t *testing.T) { + t.Parallel() + ts := NewTestSetup(t, WithWorkDir()) + if !isDir(t, ts.WorkDir()) { + t.Fatalf("%s not a directory", ts.WorkDir()) + } +} + +func Test_CreateService(t *testing.T) { + t.Parallel() + ts := NewTestSetup(t) + + tests := map[string]struct { + services []ServiceLine + expectedContent string + }{ + "empty": {}, + "CApital-Empty": {}, + "auth-sufficient-permit": { + services: []ServiceLine{ + {Auth, Sufficient, Permit.String(), []string{}}, + }, + expectedContent: "auth sufficient pam_permit.so", + }, + "auth-sufficient-permit-args": { + services: []ServiceLine{ + {Auth, Required, Deny.String(), []string{"a b c [d e]"}}, + }, + expectedContent: "auth required pam_deny.so a b c [d e]", + }, + "complete-custom": { + services: []ServiceLine{ + {Account, Required, "pam_account_module.so", []string{"a", "b", "c", "[d e]"}}, + {Account, Required, Deny.String(), []string{}}, + {Auth, Requisite, "pam_auth_module.so", []string{}}, + {Auth, Requisite, Deny.String(), []string{}}, + {Password, Sufficient, "pam_password_module.so", []string{"arg"}}, + {Password, Sufficient, Deny.String(), []string{}}, + {Session, Optional, "pam_session_module.so", []string{""}}, + {Session, Optional, Deny.String(), []string{}}, + }, + expectedContent: `account required pam_account_module.so a b c [d e] +account required pam_deny.so +auth requisite pam_auth_module.so +auth requisite pam_deny.so +password sufficient pam_password_module.so arg +password sufficient pam_deny.so +session optional pam_session_module.so +session optional pam_deny.so`, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + service := ts.CreateService(name, tc.services) + + if filepath.Base(service) != strings.ToLower(name) { + t.Fatalf("Invalid service name %s", service) + } + + if bytes, err := os.ReadFile(service); err != nil { + t.Fatalf("Failed reading %s: %v", service, err) + } else { + if string(bytes) != tc.expectedContent { + t.Fatalf("Unexpected file content:\n%s\n---\n%s", + tc.expectedContent, string(bytes)) + } + } + }) + } +} + +func Test_GenerateModule(t *testing.T) { + ts := NewTestSetup(t) + dir := ts.CreateTemporaryDir("") + if !isDir(t, dir) { + t.Fatalf("%s not a directory", dir) + } + + f, err := os.Create(filepath.Join(dir, "test-generate.go")) + if err != nil { + t.Fatalf("can't create file %v", err) + } + defer f.Close() + + randomName := "" + for i := 0; i < 10; i++ { + // #nosec:G404 - it's a test, we don't care. + randomName += string(byte('a' + rand.Intn('z'-'a'))) + } + + wantFile := randomName + ".so" + fmt.Fprintf(f, `//go:generate touch %s +package generate_file +`, wantFile) + + mod, err := os.Create(filepath.Join(dir, "go.mod")) + if err != nil { + t.Fatalf("can't create file %v", err) + } + defer mod.Close() + + fmt.Fprintf(mod, `module example.com/greetings + +go 1.20 +`) + + fakeModule := ts.GenerateModule(dir, wantFile) + if _, err := os.Stat(fakeModule); err != nil { + t.Fatalf("module not generated %v", err) + } + + fmt.Fprint(f, `//go:generate touch pam_go.so +package generate_file +`, wantFile) + + fakeModule = ts.GenerateModuleDefault(dir) + if _, err := os.Stat(fakeModule); err != nil { + t.Fatalf("module not generated %v", err) + } +} + +func Test_GetCurrentFileDir(t *testing.T) { + t.Parallel() + + ts := NewTestSetup(t) + if !strings.HasSuffix(ts.GetCurrentFileDir(), filepath.Join("internal", "utils")) { + t.Fatalf("unexpected file %v", ts.GetCurrentFileDir()) + } +} + +func Test_GetCurrentFile(t *testing.T) { + t.Parallel() + + ts := NewTestSetup(t) + if !strings.HasSuffix(ts.GetCurrentFile(), filepath.Join("utils", "test-setup_test.go")) { + t.Fatalf("unexpected file %v", ts.GetCurrentFile()) + } +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go new file mode 100644 index 0000000..556f160 --- /dev/null +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -0,0 +1,99 @@ +// Package utils contains the internal test utils +package utils + +// Action represents a PAM action to perform. +type Action int + +const ( + // Account is the account. + Account Action = iota + 1 + // Auth is the auth. + Auth + // Password is the password. + Password + // Session is the session. + Session +) + +func (a Action) String() string { + switch a { + case Account: + return "account" + case Auth: + return "auth" + case Password: + return "password" + case Session: + return "session" + default: + return "" + } +} + +// Actions is a map with all the available Actions by their name. +var Actions = map[string]Action{ + Account.String(): Account, + Auth.String(): Auth, + Password.String(): Password, + Session.String(): Session, +} + +// Control represents how a PAM module should controlled in PAM service file. +type Control int + +const ( + // Required implies that the module is required. + Required Control = iota + 1 + // Requisite implies that the module is requisite. + Requisite + // Sufficient implies that the module is sufficient. + Sufficient + // Optional implies that the module is optional. + Optional +) + +func (c Control) String() string { + switch c { + case Required: + return "required" + case Requisite: + return "requisite" + case Sufficient: + return "sufficient" + case Optional: + return "optional" + default: + return "" + } +} + +// ServiceLine is the representation of a PAM module service file line. +type ServiceLine struct { + Action Action + Control Control + Module string + Args []string +} + +// FallBackModule is a type to represent the module that should be used as fallback. +type FallBackModule int + +const ( + // NoFallback add no fallback module. + NoFallback FallBackModule = iota + 1 + // Permit uses a module that always permits. + Permit + // Deny uses a module that always denys. + Deny +) + +func (a FallBackModule) String() string { + switch a { + case Permit: + return "pam_permit.so" + case Deny: + return "pam_deny.so" + default: + return "" + } +} From 5fbbb88843a7050586f214033a73a0b157d708ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 29 Sep 2023 02:12:06 +0200 Subject: [PATCH 10/24] tests: Add a module implementation with dynamic control from the app In order to properly test the interaction of a module transaction from the application point of view, we need to perform operation in the module and ensure that the expected values are returned and handled In order to do this, without using the PAM apis that we want to test, use a simple trick: - Create an application that works as server using an unix socket - Create a module that connects to it - Pass the socket to the module via the module service file arguments - Add some basic protocol that allows the application to send a request and to the module to reply to that. - Use reflection and serialization to automatically call module methods and return the values to the application where we do the check --- cmd/pam-moduler/moduler.go | 2 +- .../communication.go | 230 +++++ .../communication_test.go | 107 +++ .../integration-tester-module.go | 137 +++ .../integration-tester-module_test.go | 783 ++++++++++++++++++ .../integration-tester-module/pam_module.go | 95 +++ .../serialization.go | 35 + .../tests/internal/utils/test-utils.go | 9 + 8 files changed, 1397 insertions(+), 1 deletion(-) create mode 100644 cmd/pam-moduler/tests/integration-tester-module/communication.go create mode 100644 cmd/pam-moduler/tests/integration-tester-module/communication_test.go create mode 100644 cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go create mode 100644 cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go create mode 100644 cmd/pam-moduler/tests/integration-tester-module/pam_module.go create mode 100644 cmd/pam-moduler/tests/integration-tester-module/serialization.go diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 68f4852..165d5a6 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -11,7 +11,7 @@ // // For example: // -// //go:generate go run github.com/msteinert/pam/pam-moduler +// //go:generate go run github.com/msteinert/pam/v2/pam-moduler // //go:generate go generate --skip="pam_module" // package main // diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication.go b/cmd/pam-moduler/tests/integration-tester-module/communication.go new file mode 100644 index 0000000..67bada4 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/communication.go @@ -0,0 +1,230 @@ +// Package main is the package for the integration tester module PAM shared library. +package main + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "io" + "net" + "runtime" +) + +// Request is a serializable integration module tester structure request. +type Request struct { + Action string + ActionArgs []interface{} +} + +// Result is a serializable integration module tester structure result. +type Result = Request + +// NewRequest returns a new Request. +func NewRequest(action string, actionArgs ...interface{}) Request { + return Request{action, actionArgs} +} + +// GOB serializes the request in binary format. +func (r *Request) GOB() ([]byte, error) { + b := bytes.Buffer{} + e := gob.NewEncoder(&b) + if err := e.Encode(r); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// NewRequestFromGOB gets a Request from a serialized binary. +func NewRequestFromGOB(data []byte) (*Request, error) { + b := bytes.Buffer{} + b.Write(data) + d := gob.NewDecoder(&b) + + var req Request + if err := d.Decode(&req); err != nil { + return nil, err + } + return &req, nil +} + +const bufSize = 1024 + +type connectionHandler struct { + inOutData chan []byte + outErr chan error + SocketPath string +} + +// Listener is a socket listener. +type Listener struct { + connectionHandler + listener net.Listener +} + +// NewListener creates a new Listener. +func NewListener(socketPath string) *Listener { + if len(socketPath) > 90 { + // See https://manpages.ubuntu.com/manpages/jammy/man7/sys_un.h.7posix.html#application%20usage + panic(fmt.Sprintf("Socket path %s too long", socketPath)) + } + return &Listener{connectionHandler{SocketPath: socketPath}, nil} +} + +// WaitForData waits for result data (or an error) on connection to be returned. +func (c *connectionHandler) WaitForData() (*Result, error) { + data, err := <-c.inOutData, <-c.outErr + if err != nil { + if errors.Is(err, io.EOF) { + return nil, nil + } + return nil, err + } + + req, err := NewRequestFromGOB(data) + if err != nil { + return nil, err + } + + return req, nil +} + +// SendRequest sends a request to the connection. +func (c *connectionHandler) SendRequest(req *Request) error { + bytes, err := req.GOB() + if err != nil { + return err + } + + c.inOutData <- bytes + return nil +} + +// SendResult sends the Result to the connection. +func (c *connectionHandler) SendResult(res *Result) error { + return c.SendRequest(res) +} + +// DoRequest performs a Request on the connection, waiting for data. +func (c *connectionHandler) DoRequest(req *Request) (*Result, error) { + if err := c.SendRequest(req); err != nil { + return nil, err + } + + return c.WaitForData() +} + +// Send performs a request. +func (r *Request) Send(c *connectionHandler) error { + return c.SendRequest(r) +} + +// ErrAlreadyListening is the error if a listener is already set. +var ErrAlreadyListening = errors.New("listener already set") + +// StartListening initiates the unix listener. +func (l *Listener) StartListening() error { + if l.listener != nil { + return ErrAlreadyListening + } + + listener, err := net.Listen("unix", l.SocketPath) + if err != nil { + return err + } + + l.listener = listener + l.inOutData, l.outErr = make(chan []byte), make(chan error) + + go func() { + bytes, err := func() ([]byte, error) { + for { + c, err := l.listener.Accept() + if err != nil { + return nil, err + } + + for { + buf := make([]byte, bufSize) + nr, err := c.Read(buf) + if err != nil { + return buf, err + } + + data := buf[0:nr] + l.inOutData <- data + l.outErr <- nil + + _, err = c.Write(<-l.inOutData) + if err != nil { + return nil, err + } + } + } + }() + + l.inOutData <- bytes + l.outErr <- err + }() + + return nil +} + +// Connector is a connection type. +type Connector struct { + connectionHandler + connection net.Conn +} + +// NewConnector creates a new connection. +func NewConnector(socketPath string) *Connector { + return &Connector{connectionHandler{SocketPath: socketPath}, nil} +} + +// ErrAlreadyConnected is the error if a connection is already set. +var ErrAlreadyConnected = errors.New("connection already set") + +// Connect connects to a listening unix socket. +func (c *Connector) Connect() error { + if c.connection != nil { + return ErrAlreadyConnected + } + + connection, err := net.Dial("unix", c.SocketPath) + if err != nil { + return err + } + + runtime.SetFinalizer(c, func(c *Connector) { + c.connection.Close() + }) + + c.connection = connection + c.inOutData, c.outErr = make(chan []byte), make(chan error) + + go func() { + buf := make([]byte, bufSize) + writeAndRead := func() ([]byte, error) { + data := <-c.inOutData + _, err := c.connection.Write(data) + if err != nil { + return nil, err + } + + n, err := c.connection.Read(buf[:]) + if err != nil { + return nil, err + } + + return buf[0:n], nil + } + + for { + bytes, err := writeAndRead() + c.inOutData <- bytes + c.outErr <- err + } + }() + + return nil +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication_test.go b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go new file mode 100644 index 0000000..7abc2e3 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "errors" + "path/filepath" + "reflect" + "testing" + + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func ensureNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +func ensureError(t *testing.T, err error, expected error) { + t.Helper() + if err == nil { + t.Fatalf("error was expected, got none") + } + if !errors.Is(err, expected) { + t.Fatalf("error %v was expected, got %v", err, expected) + } +} + +func ensureEqual(t *testing.T, a any, b any) { + t.Helper() + if !reflect.DeepEqual(a, b) { + t.Fatalf("values mismatch %v vs %v", a, b) + } +} + +func Test_Communication(t *testing.T) { + t.Parallel() + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + + for _, name := range []string{"test-1", "test-2"} { + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + + listener := NewListener(socketPath) + connector := NewConnector(socketPath) + + ensureNoError(t, listener.StartListening()) + ensureNoError(t, connector.Connect()) + + ensureError(t, listener.StartListening(), ErrAlreadyListening) + ensureError(t, connector.Connect(), ErrAlreadyConnected) + + resChan, errChan := make(chan *Result), make(chan error) + go func() { + res, err := listener.WaitForData() + resChan <- res + errChan <- err + }() + + req := NewRequest("A Request") + ensureNoError(t, connector.SendRequest(&req)) + + res, err := <-resChan, <-errChan + ensureNoError(t, err) + ensureEqual(t, *res, req) + + go func() { + res := NewRequest("Listener result") + ensureNoError(t, listener.SendResult(&res)) + }() + + res, err = connector.WaitForData() + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Listener result")) + + go func() { + req, err := listener.WaitForData() + res := NewRequest("Response", *req) + + defer func() { + resChan <- &res + errChan <- err + }() + ensureNoError(t, listener.SendResult(&res)) + }() + + done := make(chan bool) + req = NewRequest("Requesting...") + go func() { + defer func() { + done <- true + }() + res, err := connector.DoRequest(&req) + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Response", req)) + }() + + res, err = <-resChan, <-errChan + ensureNoError(t, err) + ensureEqual(t, *res, NewRequest("Response", req)) + <-done + }) + } +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go new file mode 100644 index 0000000..995e0c2 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -0,0 +1,137 @@ +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule +//go:generate go generate --skip="pam_module.go" + +// Package main is the package for the integration tester module PAM shared library. +package main + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +type integrationTesterModule struct { + utils.BaseModule +} + +type authRequest struct { + mt pam.ModuleTransaction + lastError error +} + +func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request) (res *Result, err error) { + switch r.Action { + case "bye": + return nil, authReq.lastError + } + + defer func() { + if p := recover(); p != nil { + if s, ok := p.(string); ok { + if strings.HasPrefix(s, "reflect:") { + res = nil + err = &utils.SerializableError{Msg: fmt.Sprintf( + "error on request %v: %v", *r, p)} + authReq.lastError = err + return + } + } + panic(p) + } + + if err != nil { + authReq.lastError = err + } + }() + + method := reflect.ValueOf(authReq.mt).MethodByName(r.Action) + if method == (reflect.Value{}) { + return nil, &utils.SerializableError{Msg: fmt.Sprintf( + "no method %s found", r.Action)} + } + + var args []reflect.Value + for _, arg := range r.ActionArgs { + args = append(args, reflect.ValueOf(arg)) + } + + res = &Result{Action: "return"} + for _, ret := range method.Call(args) { + iface := ret.Interface() + switch value := iface.(type) { + case pam.Error: + authReq.lastError = value + res.ActionArgs = append(res.ActionArgs, value) + case error: + var pamError pam.Error + if errors.As(value, &pamError) { + retErr := &SerializablePamError{Msg: value.Error(), + RetStatus: pamError} + authReq.lastError = retErr + res.ActionArgs = append(res.ActionArgs, retErr) + return res, err + } + authReq.lastError = value + res.ActionArgs = append(res.ActionArgs, + &utils.SerializableError{Msg: value.Error()}) + default: + res.ActionArgs = append(res.ActionArgs, iface) + } + } + return res, err +} + +func (m *integrationTesterModule) handleError(err error) *Result { + return &Result{ + Action: "error", + ActionArgs: []interface{}{&utils.SerializableError{Msg: err.Error()}}, + } +} + +func (m *integrationTesterModule) Authenticate(mt pam.ModuleTransaction, _ pam.Flags, args []string) error { + if len(args) != 1 { + return errors.New("Invalid arguments") + } + + authRequest := authRequest{mt, nil} + connection := NewConnector(args[0]) + if err := connection.Connect(); err != nil { + return err + } + + connectionHandler := func() error { + if err := connection.SendRequest(&Request{Action: "hello"}); err != nil { + return err + } + + for { + req, err := connection.WaitForData() + if err != nil { + return err + } + + res, err := m.handleRequest(&authRequest, req) + if err != nil { + _ = connection.SendResult(m.handleError(err)) + return err + } + if res == nil { + return nil + } + if err := connection.SendResult(res); err != nil { + _ = connection.SendResult(m.handleError(err)) + return err + } + } + } + + if err := connectionHandler(); err != nil { + return err + } + + return nil +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go new file mode 100644 index 0000000..ecde5ce --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -0,0 +1,783 @@ +package main + +import ( + "errors" + "fmt" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +func (r *Request) check(res *Result, expectedResults []interface{}) error { + switch res.Action { + case "return": + case "error": + return fmt.Errorf("module error: %v", res.ActionArgs...) + default: + return fmt.Errorf("unexpected action %v", res.Action) + } + + if !reflect.DeepEqual(res.ActionArgs, expectedResults) { + return fmt.Errorf("unexpected return values %#v vs %#v", + res.ActionArgs, expectedResults) + } + + return nil +} + +func (r *Request) checkRemote(listener *Listener, expectedResults []interface{}) error { + res, err := listener.DoRequest(r) + if err != nil { + return err + } + + return res.check(res, expectedResults) +} + +type checkedRequest struct { + r Request + exp []interface{} + compareWithTestState bool +} + +func (cr *checkedRequest) checkRemote(listener *Listener) error { + return cr.r.checkRemote(listener, cr.exp) +} + +func (cr *checkedRequest) check(res *Result) error { + return cr.r.check(res, cr.exp) +} + +func ensureItem(tx *pam.Transaction, item pam.Item, expected string) error { + if value, err := tx.GetItem(item); err != nil { + return err + } else if value != expected { + return fmt.Errorf("invalid item %v value: %s vs %v", item, value, expected) + } + return nil +} + +func ensureEnv(tx *pam.Transaction, variable string, expected string) error { + if env := tx.GetEnv(variable); env != expected { + return fmt.Errorf("unexpected env %s value: %s vs %s", variable, env, expected) + } + return nil +} + +func Test_Moduler_IntegrationTesterModule(t *testing.T) { + t.Parallel() + if !pam.CheckPamHasStartConfdir() { + t.Skip("this requires PAM with Conf dir support") + } + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + modulePath := ts.GenerateModuleDefault(ts.GetCurrentFileDir()) + + type testState = map[string]interface{} + + tests := map[string]struct { + expectedError error + user string + credentials pam.ConversationHandler + checkedRequests []checkedRequest + setup func(*pam.Transaction, *Listener, testState) error + finish func(*pam.Transaction, *Listener, testState) error + }{ + "success": { + expectedError: nil, + }, + "get-item-Service": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"get-item-service", nil}, + }}, + }, + "get-item-User-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-User-preset": { + user: "test-user", + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"test-user", nil}, + }}, + }, + "get-item-Authtok-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Authtok), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-Oldauthtok-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.Oldauthtok), + exp: []interface{}{"", nil}, + }}, + }, + "get-item-UserPrompt-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.UserPrompt), + exp: []interface{}{"", nil}, + }}, + }, + "set-item-Service": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.Service, "foo-service"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"foo-service", nil}, + }, + }, + }, + "set-item-User-empty": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.User, "an-user"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"an-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureItem(tx, pam.User, "an-user") + }, + }, + "set-item-User-preset": { + user: "test-user", + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetItem", pam.User, "an-user"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"an-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureItem(tx, pam.User, "an-user") + }, + }, + "set-get-item-User-empty": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"setup-user", nil}, + }}, + }, + "set-get-item-User-preset": { + user: "test-user", + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"setup-user", nil}, + }}, + }, + "get-env-unset": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_HOPEFULLY_NOT_SET"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_HOPEFULLY_NOT_SET", "") + }, + }, + "get-env-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"foobar"}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "foobar") + }, + }, + "get-env-preset-empty": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + if err := tx.PutEnv("_PAM_GO_ENV_SET_VAR=value"); err != nil { + return err + } + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "get-env-preset-unset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + if err := tx.PutEnv("_PAM_GO_ENV_SET_VAR=value"); err != nil { + return err + } + return tx.PutEnv("_PAM_GO_ENV_SET_VAR") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-not-preset": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "a value") + }, + }, + "put-env-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=another value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"another value"}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "another value") + }, + }, + "put-env-resets-not-preset": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-resets-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.PutEnv("_PAM_GO_ENV_SET_VAR=foobar") + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR=a value"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{"a value"}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{""}, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureEnv(tx, "_PAM_GO_ENV_SET_VAR", "") + }, + }, + "put-env-unsets-not-set": { + expectedError: pam.ErrBadItem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_NEVER_SET"), + exp: []interface{}{pam.ErrBadItem}, + }, + }, + }, + "put-env-unsets-empty-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + exp: []interface{}{ + map[string]string{"_PAM_GO_ENV_SET_VAR": ""}, nil, + }, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + exp: []interface{}{map[string]string{}, nil}, + }, + }, + }, + "put-env-invalid-syntax": { + expectedError: pam.ErrBadItem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "="), + exp: []interface{}{pam.ErrBadItem}, + }, + { + r: NewRequest("PutEnv", "=bar"), + exp: []interface{}{pam.ErrBadItem}, + }, + { + r: NewRequest("PutEnv", "with spaces"), + exp: []interface{}{pam.ErrBadItem}, + }, + }, + }, + "get-env-list-empty": { + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnvList"), + exp: []interface{}{map[string]string{}, nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return nil + }, + }, + "get-env-list-preset": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + expected := map[string]string{ + "_PAM_GO_ENV_SET_VAR1": "value1", + "_PAM_GO_ENV_SET_VAR2": "value due", + "_PAM_GO_ENV_SET_VAR3": "3", + "_PAM_GO_ENV_SET_VAR_EMPTY": "", + "_PAM_GO_ENV WITH SPACES": "yes works", + } + + for env, value := range expected { + if err := tx.PutEnv(fmt.Sprintf("%s=%s", env, value)); err != nil { + return err + } + } + ts["expected"] = expected + ts["expectedResults"] = [][]interface{}{{expected, nil}} + return nil + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetEnvList"), + compareWithTestState: true, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + if list, err := tx.GetEnvList(); err != nil { + return err + } else if !reflect.DeepEqual(list, ts["expected"]) { + return fmt.Errorf("Unexpected return values %#v vs %#v", + list, ts["expected"]) + } + return nil + }, + }, + "get-env-list-module-set": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + expected := map[string]string{ + "_PAM_GO_ENV_SET_VAR1": "value1", + "_PAM_GO_ENV_SET_VAR2": "value due", + "_PAM_GO_ENV_SET_VAR3": "3", + "_PAM_GO_ENV_SET_VAR_EMPTY": "", + "_PAM_GO_ENV WITH SPACES": "yes works", + } + + ts["expected"] = expected + ts["expectedResults"] = [][]interface{}{ + nil, nil, nil, nil, nil, nil, nil, {expected, nil}, + } + return nil + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR1=value1"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR2=value due"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR3=3"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_EMPTY="), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_TO_UNSET=unset"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV_SET_VAR_TO_UNSET"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("PutEnv", "_PAM_GO_ENV WITH SPACES=yes works"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetEnvList"), + compareWithTestState: true, + }, + }, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + if list, err := tx.GetEnvList(); err != nil { + return err + } else if !reflect.DeepEqual(list, ts["expected"]) { + return fmt.Errorf("unexpected return values %#v vs %#v", + list, ts["expected"]) + } + return nil + }, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + ts.CreateService(name, []utils.ServiceLine{ + {Action: utils.Auth, Control: utils.Requisite, Module: modulePath, + Args: []string{socketPath}}, + }) + + tx, err := pam.StartConfDir(name, tc.user, tc.credentials, ts.WorkDir()) + if err != nil { + t.Fatalf("start #error: %v", err) + } + defer func() { + err := tx.End() + if err != nil { + t.Fatalf("end #error: %v", err) + } + }() + + listener := NewListener(socketPath) + if err := listener.StartListening(); err != nil { + t.Fatalf("listening #error: %v", err) + } + + listenerHandler := func() error { + res, err := listener.WaitForData() + if err != nil { + return err + } + + if res == nil || res.Action != "hello" { + return errors.New("missing hello packet") + } + + req := NewRequest("GetItem", pam.Service) + if err := req.checkRemote(listener, + []interface{}{strings.ToLower(name), nil}); err != nil { + return err + } + + testState := testState{} + if tc.setup != nil { + if err := tc.setup(tx, listener, testState); err != nil { + return err + } + } + + for i, req := range tc.checkedRequests { + if req.compareWithTestState { + expectedResults, _ := testState["expectedResults"].([][]interface{}) + if err := req.r.checkRemote(listener, expectedResults[i]); err != nil { + return err + } + } else if err := req.checkRemote(listener); err != nil { + return err + } + } + + if tc.finish != nil { + if err := tc.finish(tx, listener, testState); err != nil { + return err + } + } + + if err := listener.SendRequest(&Request{Action: "bye"}); err != nil { + return err + } + + return nil + } + + serverError := make(chan error) + go func() { + serverError <- listenerHandler() + }() + + authResult := make(chan error) + go func() { + authResult <- tx.Authenticate(pam.Silent) + }() + + if err = <-serverError; err != nil { + t.Fatalf("communication #error: %v", err) + } + + err = <-authResult + if !errors.Is(err, tc.expectedError) { + t.Fatalf("authenticate #unexpected: %#v vs %#v", + err, tc.expectedError) + } + }) + } + + t.Cleanup(func() { + // Ensure GC will happen, so that transaction's pam_end will be called + runtime.GC() + time.Sleep(5 * time.Millisecond) + }) +} + +func Test_Moduler_IntegrationTesterModule_handleRequest(t *testing.T) { + t.Parallel() + + module := integrationTesterModule{} + mt := pam.NewModuleTransactionInvoker(nil) + + tests := []struct { + checkedRequest + name string + parallel bool + }{ + { + name: "putEnv", + checkedRequest: checkedRequest{ + r: NewRequest("PutEnv", "FOO_ENV=Bar"), + exp: []interface{}{pam.ErrAbort}, + }, + }, + { + parallel: true, + name: "get-item-Service", + checkedRequest: checkedRequest{ + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + { + parallel: true, + name: "set-item-Service", + checkedRequest: checkedRequest{ + r: NewRequest("SetItem", pam.Service, "foo"), + exp: []interface{}{pam.ErrSystem}, + }, + }, + } + + for _, cr := range tests { + cr := cr + t.Run(cr.name, func(t *testing.T) { + if cr.parallel { + t.Parallel() + } + + authRequest := authRequest{mt, nil} + res, err := module.handleRequest(&authRequest, &cr.r) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if res.Action != "return" { + t.Fatalf("unexpected result action %v", res.Action) + } + + if err := cr.check(res); err != nil { + t.Fatalf("unexpected result %v", err) + } + }) + } + + t.Run("missing-method", func(t *testing.T) { + t.Parallel() + req := NewRequest("Hopefully a missing method") + res, err := module.handleRequest(&authRequest{mt, nil}, &req) + + if err == nil { + t.Fatalf("error was expected, got %v", res) + } + if res != nil { + t.Fatalf("unexpected result %v", res) + } + }) + + t.Run("wrong-signature", func(t *testing.T) { + t.Parallel() + req := NewRequest("GetItem", "this", "and", 3, "of that") + res, err := module.handleRequest(&authRequest{mt, nil}, &req) + + if err == nil { + t.Fatalf("error was expected, got %v", res) + } + if res != nil { + t.Fatalf("unexpected result %v", res) + } + }) +} + +func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { + t.Parallel() + + ts := utils.NewTestSetup(t, utils.WithWorkDir()) + module := integrationTesterModule{} + + tests := map[string]struct { + expectedError error + credentials pam.ConversationHandler + checkedRequests []checkedRequest + }{ + "success": { + expectedError: nil, + }, + "get-item-Service": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("GetItem", pam.Service), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + }, + "get-item-User": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("GetItem", pam.User), + exp: []interface{}{"", pam.ErrSystem}, + }, + }, + }, + "putEnv": { + expectedError: pam.ErrAbort, + checkedRequests: []checkedRequest{ + { + r: NewRequest("PutEnv", "FooBar=Baz"), + exp: []interface{}{pam.ErrAbort}, + }, + }, + }, + } + + for name, tc := range tests { + tc := tc + name := name + t.Run(name, func(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(ts.WorkDir(), name+".socket") + listener := NewListener(socketPath) + if err := listener.StartListening(); err != nil { + t.Fatalf("listening #error: %v", err) + } + + listenerHandler := func() error { + res, err := listener.WaitForData() + if err != nil { + return err + } + + if res == nil || res.Action != "hello" { + return errors.New("missing hello packet") + } + + for _, req := range tc.checkedRequests { + if err := req.checkRemote(listener); err != nil { + return err + } + } + + if err := listener.SendRequest(&Request{Action: "bye"}); err != nil { + return err + } + + return nil + } + + serverError := make(chan error) + go func() { + serverError <- listenerHandler() + }() + + authResult := make(chan error) + go func() { + authResult <- module.Authenticate( + pam.NewModuleTransactionInvoker(nil), + pam.Silent, []string{socketPath}) + }() + + if err := <-serverError; err != nil { + t.Fatalf("communication #error: %v", err) + } + + err := <-authResult + if !errors.Is(err, tc.expectedError) { + t.Fatalf("authenticate #unexpected: %#v vs %#v", + err, tc.expectedError) + } + }) + } +} diff --git a/cmd/pam-moduler/tests/integration-tester-module/pam_module.go b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go new file mode 100644 index 0000000..39a22b7 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go @@ -0,0 +1,95 @@ +// Code generated by "pam-moduler -type integrationTesterModule"; DO NOT EDIT. + +//go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so -tags go_pam_module + +// Package main is the package for the PAM module library. +package main + +/* +#cgo LDFLAGS: -lpam -fPIC +#include + +typedef const char _const_char_t; +*/ +import "C" + +import ( + "errors" + "fmt" + "github.com/msteinert/pam/v2" + "os" + "unsafe" +) + +var pamModuleHandler pam.ModuleHandler = &integrationTesterModule{} + +// sliceFromArgv returns a slice of strings given to the PAM module. +func sliceFromArgv(argc C.int, argv **C._const_char_t) []string { + r := make([]string, 0, argc) + for _, s := range unsafe.Slice(argv, argc) { + r = append(r, C.GoString(s)) + } + return r +} + +// handlePamCall is the function that translates C pam requests to Go. +func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, + argv **C._const_char_t, moduleFunc pam.ModuleHandlerFunc) C.int { + if pamModuleHandler == nil { + return C.int(pam.ErrNoModuleData) + } + + if moduleFunc == nil { + return C.int(pam.ErrIgnore) + } + + mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + sliceFromArgv(argc, argv)) + if err == nil { + return 0 + } + + if (pam.Flags(flags)&pam.Silent) == 0 && !errors.Is(err, pam.ErrIgnore) { + fmt.Fprintf(os.Stderr, "module returned error: %v\n", err) + } + + var pamErr pam.Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + + return C.int(pam.ErrSystem) +} + +//export pam_sm_authenticate +func pam_sm_authenticate(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.Authenticate) +} + +//export pam_sm_setcred +func pam_sm_setcred(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.SetCred) +} + +//export pam_sm_acct_mgmt +func pam_sm_acct_mgmt(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.AcctMgmt) +} + +//export pam_sm_open_session +func pam_sm_open_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.OpenSession) +} + +//export pam_sm_close_session +func pam_sm_close_session(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.CloseSession) +} + +//export pam_sm_chauthtok +func pam_sm_chauthtok(pamh *C.pam_handle_t, flags C.int, argc C.int, argv **C._const_char_t) C.int { + return handlePamCall(pamh, flags, argc, argv, pamModuleHandler.ChangeAuthTok) +} + +func main() {} diff --git a/cmd/pam-moduler/tests/integration-tester-module/serialization.go b/cmd/pam-moduler/tests/integration-tester-module/serialization.go new file mode 100644 index 0000000..33b26a7 --- /dev/null +++ b/cmd/pam-moduler/tests/integration-tester-module/serialization.go @@ -0,0 +1,35 @@ +package main + +import ( + "encoding/gob" + + "github.com/msteinert/pam/v2" + "github.com/msteinert/pam/v2/cmd/pam-moduler/tests/internal/utils" +) + +// SerializablePamError represents a [pam.Error] in a +// serializable way that splits message and return code. +type SerializablePamError struct { + Msg string + RetStatus pam.Error +} + +// NewSerializablePamError initializes a SerializablePamError from +// the default status error message. +func NewSerializablePamError(status pam.Error) SerializablePamError { + return SerializablePamError{Msg: status.Error(), RetStatus: status} +} + +func (e *SerializablePamError) Error() string { + return e.RetStatus.Error() +} + +func init() { + gob.Register(map[string]string{}) + gob.Register(Request{}) + gob.Register(pam.Item(0)) + gob.Register(pam.Error(0)) + gob.RegisterName("main.SerializablePamError", + SerializablePamError{}) + gob.Register(utils.SerializableError{}) +} diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go index 556f160..3fc6b0c 100644 --- a/cmd/pam-moduler/tests/internal/utils/test-utils.go +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -97,3 +97,12 @@ func (a FallBackModule) String() string { return "" } } + +// SerializableError is a representation of an error in a way can be serialized. +type SerializableError struct { + Msg string +} + +func (e *SerializableError) Error() string { + return e.Msg +} From 5f3c15c157c712bfae5d7b6d6382d46ab0809c92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 29 Sep 2023 15:13:43 +0200 Subject: [PATCH 11/24] module-transaction: Add GetUser() method that prompts an user if non-set We can now finally test this properly both using a mock and through the interactive module that will do the request for us in various conditions. --- .../communication_test.go | 2 +- .../integration-tester-module_test.go | 60 +++++++++- .../tests/internal/utils/test-utils.go | 47 ++++++++ module-transaction-mock.go | 106 ++++++++++++++++++ module-transaction.go | 38 +++++++ module-transaction_test.go | 99 ++++++++++++++++ 6 files changed, 348 insertions(+), 4 deletions(-) create mode 100644 module-transaction-mock.go diff --git a/cmd/pam-moduler/tests/integration-tester-module/communication_test.go b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go index 7abc2e3..7ef01f7 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/communication_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/communication_test.go @@ -29,7 +29,7 @@ func ensureError(t *testing.T, err error, expected error) { func ensureEqual(t *testing.T, a any, b any) { t.Helper() if !reflect.DeepEqual(a, b) { - t.Fatalf("values mismatch %v vs %v", a, b) + t.Fatalf("values mismatch %#v vs %#v", a, b) } } diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index ecde5ce..38e95c3 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -54,7 +54,8 @@ func (cr *checkedRequest) check(res *Result) error { return cr.r.check(res, cr.exp) } -func ensureItem(tx *pam.Transaction, item pam.Item, expected string) error { +func ensureUser(tx *pam.Transaction, expected string) error { + item := pam.User if value, err := tx.GetItem(item); err != nil { return err } else if value != expected { @@ -152,7 +153,7 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { exp: []interface{}{"an-user", nil}, }}, finish: func(tx *pam.Transaction, l *Listener, ts testState) error { - return ensureItem(tx, pam.User, "an-user") + return ensureUser(tx, "an-user") }, }, "set-item-User-preset": { @@ -167,7 +168,7 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { exp: []interface{}{"an-user", nil}, }}, finish: func(tx *pam.Transaction, l *Listener, ts testState) error { - return ensureItem(tx, pam.User, "an-user") + return ensureUser(tx, "an-user") }, }, "set-get-item-User-empty": { @@ -488,6 +489,59 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { return nil }, }, + "get-user-empty-no-conv-set": { + expectedError: pam.ErrConv, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"", pam.ErrConv}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "") + }, + }, + "get-user-empty-with-conv": { + credentials: utils.Credentials{ + User: "replying-user", + ExpectedMessage: "who are you? ", + ExpectedStyle: pam.PromptEchoOn, + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"replying-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "replying-user") + }, + }, + "get-user-preset-without-conv": { + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"setup-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "setup-user") + }, + }, + "get-user-preset-with-conv": { + credentials: utils.Credentials{ + User: "replying-user", + ExpectedMessage: "No message should have been shown!", + ExpectedStyle: pam.PromptEchoOn, + }, + setup: func(tx *pam.Transaction, l *Listener, ts testState) error { + return tx.SetItem(pam.User, "setup-user") + }, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetUser", "who are you? "), + exp: []interface{}{"setup-user", nil}, + }}, + finish: func(tx *pam.Transaction, l *Listener, ts testState) error { + return ensureUser(tx, "setup-user") + }, + }, } for name, tc := range tests { diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go index 3fc6b0c..095994b 100644 --- a/cmd/pam-moduler/tests/internal/utils/test-utils.go +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -1,6 +1,13 @@ // Package utils contains the internal test utils package utils +import ( + "errors" + "fmt" + + "github.com/msteinert/pam/v2" +) + // Action represents a PAM action to perform. type Action int @@ -106,3 +113,43 @@ type SerializableError struct { func (e *SerializableError) Error() string { return e.Msg } + +// Credentials is a test [pam.ConversationHandler] implementation. +type Credentials struct { + User string + Password string + ExpectedMessage string + CheckEmptyMessage bool + ExpectedStyle pam.Style + CheckZeroStyle bool + Context interface{} +} + +// RespondPAM handles PAM string conversations. +func (c Credentials) RespondPAM(s pam.Style, msg string) (string, error) { + if (c.ExpectedMessage != "" || c.CheckEmptyMessage) && + msg != c.ExpectedMessage { + return "", errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("unexpected prompt: %s vs %s", msg, c.ExpectedMessage), + }) + } + + if (c.ExpectedStyle != 0 || c.CheckZeroStyle) && + s != c.ExpectedStyle { + return "", errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("unexpected style: %#v vs %#v", s, c.ExpectedStyle), + }) + } + + switch s { + case pam.PromptEchoOn: + return c.User, nil + case pam.PromptEchoOff: + return c.Password, nil + } + + return "", errors.Join(pam.ErrConv, + &SerializableError{fmt.Sprintf("unhandled style: %v", s)}) +} diff --git a/module-transaction-mock.go b/module-transaction-mock.go new file mode 100644 index 0000000..f00202e --- /dev/null +++ b/module-transaction-mock.go @@ -0,0 +1,106 @@ +//go:build !go_pam_module + +package pam + +/* +#cgo CFLAGS: -Wall -std=c99 +#include +#include +*/ +import "C" + +import ( + "errors" + "fmt" + "runtime" + "testing" + "unsafe" +) + +type mockModuleTransactionExpectations struct { + UserPrompt string +} + +type mockModuleTransactionReturnedData struct { + User string + InteractiveUser bool + Status Error +} + +type mockModuleTransaction struct { + moduleTransaction + T *testing.T + Expectations mockModuleTransactionExpectations + RetData mockModuleTransactionReturnedData + ConversationHandler ConversationHandler + allocatedData []unsafe.Pointer +} + +func newMockModuleTransaction(m *mockModuleTransaction) *mockModuleTransaction { + runtime.SetFinalizer(m, func(m *mockModuleTransaction) { + for _, ptr := range m.allocatedData { + C.free(ptr) + } + }) + return m +} + +func (m *mockModuleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { + goPrompt := C.GoString(prompt) + if goPrompt != m.Expectations.UserPrompt { + m.T.Fatalf("unexpected prompt: %s vs %s", goPrompt, m.Expectations.UserPrompt) + return C.int(ErrAbort) + } + + user := m.RetData.User + if m.RetData.InteractiveUser || (m.RetData.User == "" && m.ConversationHandler != nil) { + if m.ConversationHandler == nil { + m.T.Fatalf("no conversation handler provided") + } + u, err := m.ConversationHandler.RespondPAM(PromptEchoOn, goPrompt) + user = u + + if err != nil { + var pamErr Error + if errors.As(err, &pamErr) { + return C.int(pamErr) + } + return C.int(ErrAbort) + } + } + + cUser := C.CString(user) + m.allocatedData = append(m.allocatedData, unsafe.Pointer(cUser)) + + *outUser = cUser + return C.int(m.RetData.Status) +} + +type mockConversationHandler struct { + User string + ExpectedMessage string + CheckEmptyMessage bool + ExpectedStyle Style + CheckZeroStyle bool +} + +func (c mockConversationHandler) RespondPAM(s Style, msg string) (string, error) { + if (c.ExpectedMessage != "" || c.CheckEmptyMessage) && + msg != c.ExpectedMessage { + return "", fmt.Errorf("%w: unexpected prompt: %s vs %s", + ErrConv, msg, c.ExpectedMessage) + } + + if (c.ExpectedStyle != 0 || c.CheckZeroStyle) && + s != c.ExpectedStyle { + return "", fmt.Errorf("%w: unexpected style: %#v vs %#v", + ErrConv, s, c.ExpectedStyle) + } + + switch s { + case PromptEchoOn: + return c.User, nil + } + + return "", fmt.Errorf("%w: unhandled style: %v", ErrConv, s) +} diff --git a/module-transaction.go b/module-transaction.go index 0e87fe5..a698f42 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -1,11 +1,20 @@ // Package pam provides a wrapper for the PAM application API. package pam +/* +#cgo CFLAGS: -Wall -std=c99 +#cgo LDFLAGS: -lpam + +#include +#include +#include +*/ import "C" import ( "errors" "fmt" + "unsafe" ) // ModuleTransaction is an interface that a pam module transaction @@ -16,6 +25,7 @@ type ModuleTransaction interface { PutEnv(nameVal string) error GetEnv(name string) string GetEnvList() (map[string]string, error) + GetUser(prompt string) (string, error) } // ModuleHandlerFunc is a function type used by the ModuleHandler. @@ -89,3 +99,31 @@ func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, m.lastStatus.Store(status) return err } + +type moduleTransactionIface interface { + getUser(outUser **C.char, prompt *C.char) C.int +} + +func (m *moduleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { + return C.pam_get_user(m.handle, outUser, prompt) +} + +// getUserImpl is the default implementation for GetUser, but kept as private so +// that can be used to test the pam package +func (m *moduleTransaction) getUserImpl(iface moduleTransactionIface, + prompt string) (string, error) { + var user *C.char + var cPrompt = C.CString(prompt) + defer C.free(unsafe.Pointer(cPrompt)) + err := m.handlePamStatus(iface.getUser(&user, cPrompt)) + if err != nil { + return "", err + } + return C.GoString(user), nil +} + +// GetUser is similar to GetItem(User), but it would start a conversation if +// no user is currently set in PAM. +func (m *moduleTransaction) GetUser(prompt string) (string, error) { + return m.getUserImpl(m, prompt) +} diff --git a/module-transaction_test.go b/module-transaction_test.go index 8661f68..7a44fd3 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -62,6 +62,12 @@ func Test_NewNullModuleTransaction(t *testing.T) { return nil, err }, }, + "GetUser": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetUser("prompt") + }, + }, } for name, tc := range tests { @@ -235,3 +241,96 @@ func Test_ModuleTransaction_InvokeHandler(t *testing.T) { }) } } + +func Test_MockModuleTransaction(t *testing.T) { + t.Parallel() + + mt, _ := NewModuleTransactionInvoker(nil).(*moduleTransaction) + + tests := map[string]struct { + testFunc func(mock *mockModuleTransaction) (any, error) + mockExpectations mockModuleTransactionExpectations + mockRetData mockModuleTransactionReturnedData + conversationHandler ConversationHandler + + expectedError error + expectedValue any + ignoreError bool + }{ + "GetUser-empty": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-preset-value": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + mockRetData: mockModuleTransactionReturnedData{User: "dummy-user"}, + expectedValue: "dummy-user", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-value": { + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "who are you?", + User: "returned-dummy-user", + }, + expectedValue: "returned-dummy-user", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-error-prompt": { + expectedError: ErrConv, + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "who are you???", + }, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + "GetUser-conversation-error-style": { + expectedError: ErrConv, + mockExpectations: mockModuleTransactionExpectations{ + UserPrompt: "who are you?"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOff, + ExpectedMessage: "who are you?", + }, + expectedValue: "", + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getUserImpl(mock, "who are you?") + }, + }, + } + + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mock := newMockModuleTransaction(&mockModuleTransaction{T: t, + Expectations: tc.mockExpectations, RetData: tc.mockRetData, + ConversationHandler: tc.conversationHandler}) + data, err := tc.testFunc(mock) + + if !tc.ignoreError && !errors.Is(err, tc.expectedError) { + t.Fatalf("unexpected err: %#v vs %#v", err, tc.expectedError) + } + + if !reflect.DeepEqual(data, tc.expectedValue) { + t.Fatalf("data mismatch, %#v vs %#v", data, tc.expectedValue) + } + }) + } +} From 13989bbd5cc927a58969daf12a514471e671242a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Tue, 3 Oct 2023 14:37:28 +0200 Subject: [PATCH 12/24] module-transaction: Add support for setting/getting module data Module data is data associated with a module handle that is available for the whole module loading time so it can be used also during different operations. We use cgo handles to preserve the life of the go objects so any value can be associated with a pam transaction. --- .../integration-tester-module.go | 8 +- .../integration-tester-module_test.go | 106 ++++++++++++++++++ module-transaction-mock.go | 34 ++++++ module-transaction.go | 65 ++++++++++- module-transaction_test.go | 82 ++++++++++++++ transaction.h | 22 ++++ 6 files changed, 309 insertions(+), 8 deletions(-) diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 995e0c2..76cebe8 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -55,8 +55,12 @@ func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request } var args []reflect.Value - for _, arg := range r.ActionArgs { - args = append(args, reflect.ValueOf(arg)) + for i, arg := range r.ActionArgs { + if arg == nil { + args = append(args, reflect.Zero(method.Type().In(i))) + } else { + args = append(args, reflect.ValueOf(arg)) + } } res = &Result{Action: "return"} diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index 38e95c3..d17ba51 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -542,6 +542,94 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { return ensureUser(tx, "setup-user") }, }, + "get-data-not-available": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{{ + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }}, + }, + "set-data-empty-nil": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, + "set-data-empty-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "", []string{"hello", "world"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", ""), + exp: []interface{}{[]string{"hello", "world"}, nil}, + }, + }, + }, + "set-data-to-value": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-error-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-error-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + }, + }, + "set-data-to-value-replacing": { + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", "Hello"), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{"Hello", nil}, + }, + }, + }, + "set-data-to-value-unset": { + expectedError: pam.ErrNoModuleData, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", + utils.SerializableError{Msg: "An error"}), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{utils.SerializableError{Msg: "An error"}, nil}, + }, + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{nil}, + }, + { + r: NewRequest("GetData", "some-data"), + exp: []interface{}{nil, pam.ErrNoModuleData}, + }, + }, + }, } for name, tc := range tests { @@ -774,6 +862,24 @@ func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { }, }, }, + "SetData-nil": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", nil), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, + "SetData": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{ + { + r: NewRequest("SetData", "some-data", true), + exp: []interface{}{pam.ErrSystem}, + }, + }, + }, } for name, tc := range tests { diff --git a/module-transaction-mock.go b/module-transaction-mock.go index f00202e..968026a 100644 --- a/module-transaction-mock.go +++ b/module-transaction-mock.go @@ -6,6 +6,7 @@ package pam #cgo CFLAGS: -Wall -std=c99 #include #include +#include */ import "C" @@ -19,6 +20,7 @@ import ( type mockModuleTransactionExpectations struct { UserPrompt string + DataKey string } type mockModuleTransactionReturnedData struct { @@ -33,14 +35,19 @@ type mockModuleTransaction struct { Expectations mockModuleTransactionExpectations RetData mockModuleTransactionReturnedData ConversationHandler ConversationHandler + moduleData map[string]uintptr allocatedData []unsafe.Pointer } func newMockModuleTransaction(m *mockModuleTransaction) *mockModuleTransaction { + m.moduleData = make(map[string]uintptr) runtime.SetFinalizer(m, func(m *mockModuleTransaction) { for _, ptr := range m.allocatedData { C.free(ptr) } + for _, handle := range m.moduleData { + _go_pam_data_cleanup(nil, C.uintptr_t(handle), C.PAM_DATA_SILENT) + } }) return m } @@ -76,6 +83,33 @@ func (m *mockModuleTransaction) getUser(outUser **C.char, prompt *C.char) C.int return C.int(m.RetData.Status) } +func (m *mockModuleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if handle, ok := m.moduleData[goKey]; ok { + *outHandle = C.uintptr_t(handle) + } else { + *outHandle = 0 + } + return C.int(m.RetData.Status) +} + +func (m *mockModuleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + goKey := C.GoString(key) + if m.Expectations.DataKey != "" && goKey != m.Expectations.DataKey { + m.T.Fatalf("data key mismatch: %#v vs %#v", goKey, m.Expectations.DataKey) + } + if oldHandle, ok := m.moduleData[goKey]; ok { + _go_pam_data_cleanup(nil, C.uintptr_t(oldHandle), C.PAM_DATA_REPLACE) + } + if handle != 0 { + m.moduleData[goKey] = uintptr(handle) + } + return C.int(m.RetData.Status) +} + type mockConversationHandler struct { User string ExpectedMessage string diff --git a/module-transaction.go b/module-transaction.go index a698f42..71419e0 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -2,18 +2,14 @@ package pam /* -#cgo CFLAGS: -Wall -std=c99 -#cgo LDFLAGS: -lpam - -#include -#include -#include +#include "transaction.h" */ import "C" import ( "errors" "fmt" + "runtime/cgo" "unsafe" ) @@ -26,6 +22,8 @@ type ModuleTransaction interface { GetEnv(name string) string GetEnvList() (map[string]string, error) GetUser(prompt string) (string, error) + SetData(key string, data any) error + GetData(key string) (any, error) } // ModuleHandlerFunc is a function type used by the ModuleHandler. @@ -102,6 +100,8 @@ func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, type moduleTransactionIface interface { getUser(outUser **C.char, prompt *C.char) C.int + setData(key *C.char, handle C.uintptr_t) C.int + getData(key *C.char, outHandle *C.uintptr_t) C.int } func (m *moduleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { @@ -127,3 +127,56 @@ func (m *moduleTransaction) getUserImpl(iface moduleTransactionIface, func (m *moduleTransaction) GetUser(prompt string) (string, error) { return m.getUserImpl(m, prompt) } + +// SetData allows to save any value in the module data that is preserved +// during the whole time the module is loaded. +func (m *moduleTransaction) SetData(key string, data any) error { + return m.setDataImpl(m, key, data) +} + +func (m *moduleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { + return C.set_data(m.handle, key, handle) +} + +// setDataImpl is the implementation for SetData for testing purposes. +func (m *moduleTransaction) setDataImpl(iface moduleTransactionIface, + key string, data any) error { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle cgo.Handle + if data != nil { + handle = cgo.NewHandle(data) + } + return m.handlePamStatus(iface.setData(cKey, C.uintptr_t(handle))) +} + +//export _go_pam_data_cleanup +func _go_pam_data_cleanup(h NativeHandle, handle C.uintptr_t, status C.int) { + cgo.Handle(handle).Delete() +} + +// GetData allows to get any value from the module data saved using SetData +// that is preserved across the whole time the module is loaded. +func (m *moduleTransaction) GetData(key string) (any, error) { + return m.getDataImpl(m, key) +} + +func (m *moduleTransaction) getData(key *C.char, outHandle *C.uintptr_t) C.int { + return C.get_data(m.handle, key, outHandle) +} + +// getDataImpl is the implementation for GetData for testing purposes. +func (m *moduleTransaction) getDataImpl(iface moduleTransactionIface, + key string) (any, error) { + var cKey = C.CString(key) + defer C.free(unsafe.Pointer(cKey)) + var handle C.uintptr_t + if err := m.handlePamStatus(iface.getData(cKey, &handle)); err != nil { + return nil, err + } + if goHandle := cgo.Handle(handle); goHandle != cgo.Handle(0) { + return goHandle.Value(), nil + } + + return nil, m.handlePamStatus(C.int(ErrNoModuleData)) +} diff --git a/module-transaction_test.go b/module-transaction_test.go index 7a44fd3..fa4c1be 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -8,6 +8,13 @@ import ( "testing" ) +func ensureNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + func Test_NewNullModuleTransaction(t *testing.T) { t.Parallel() mt := moduleTransaction{} @@ -68,6 +75,24 @@ func Test_NewNullModuleTransaction(t *testing.T) { return mt.GetUser("prompt") }, }, + "GetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.GetData("some-data") + }, + }, + "SetData": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", []interface{}{}) + }, + }, + "SetData-nil": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return nil, mt.SetData("foo", nil) + }, + }, } for name, tc := range tests { @@ -313,6 +338,63 @@ func Test_MockModuleTransaction(t *testing.T) { return mt.getUserImpl(mock, "who are you?") }, }, + "GetData-not-available": { + expectedError: ErrNoModuleData, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "GetData-not-available-other-failure": { + expectedError: ErrBuf, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "not-available-data"}, + mockRetData: mockModuleTransactionReturnedData{Status: ErrBuf}, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.getDataImpl(mock, "not-available-data") + }, + }, + "SetData-empty-nil": { + expectedError: ErrNoModuleData, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", nil)) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-empty-to-value": { + expectedValue: []string{"hello", "world"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "", + []string{"hello", "world"})) + return mt.getDataImpl(mock, "") + }, + }, + "SetData-to-value": { + expectedValue: []interface{}{"a string", true, 0.55, errors.New("oh no")}, + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "some-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "some-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + return mt.getDataImpl(mock, "some-data") + }, + }, + "SetData-to-value-replacing": { + expectedValue: "just a value", + mockExpectations: mockModuleTransactionExpectations{ + DataKey: "replaced-data"}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + []interface{}{"a string", true, 0.55, errors.New("oh no")})) + ensureNoError(mock.T, mt.setDataImpl(mock, "replaced-data", + "just a value")) + return mt.getDataImpl(mock, "replaced-data") + }, + }, } for name, tc := range tests { diff --git a/transaction.h b/transaction.h index 88d2766..b19ce3e 100644 --- a/transaction.h +++ b/transaction.h @@ -1,4 +1,7 @@ +#pragma once + #include +#include #include #include #include @@ -18,6 +21,7 @@ #endif extern int _go_pam_conv_handler(struct pam_message *, uintptr_t, char **reply); +extern void _go_pam_data_cleanup(pam_handle_t *, uintptr_t, int status); static inline int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, struct pam_response **resp, void *appdata_ptr) { @@ -67,3 +71,21 @@ static inline int check_pam_start_confdir(void) return 0; } + +static inline void data_cleanup(pam_handle_t *pamh, void *data, int error_status) +{ + _go_pam_data_cleanup(pamh, (uintptr_t)data, error_status); +} + +static inline int set_data(pam_handle_t *pamh, const char *name, uintptr_t handle) +{ + if (handle) + return pam_set_data(pamh, name, (void *)handle, data_cleanup); + + return pam_set_data(pamh, name, NULL, NULL); +} + +static inline int get_data(pam_handle_t *pamh, const char *name, uintptr_t *out_handle) +{ + return pam_get_data(pamh, name, (const void **)out_handle); +} From bc34d3b63960b8c35c06c8d5531aa1ddffa444dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Wed, 4 Oct 2023 23:34:20 +0200 Subject: [PATCH 13/24] module-transaction: Add support for initiating PAM Conversations Modules have the ability to start PAM conversations, so while the transaction code can handle them we did not have a way to init them. Yet. So add some APIs allowing this, making it easier from the go side to handle the conversations. In this commit we only support text-based conversations, but code is designed with the idea of supporting binary cases too. Added the integration tests using the module that is now able to both start conversation and handle them using Go only. --- .../integration-tester-module.go | 17 +- .../integration-tester-module_test.go | 205 ++++++++++++++ .../serialization.go | 18 ++ .../tests/internal/utils/test-utils.go | 18 +- module-transaction-mock.go | 56 +++- module-transaction.go | 233 ++++++++++++++++ module-transaction_test.go | 252 ++++++++++++++++++ transaction.h | 5 + 8 files changed, 790 insertions(+), 14 deletions(-) diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 76cebe8..7437534 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -56,10 +56,16 @@ func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request var args []reflect.Value for i, arg := range r.ActionArgs { - if arg == nil { - args = append(args, reflect.Zero(method.Type().In(i))) - } else { - args = append(args, reflect.ValueOf(arg)) + switch v := arg.(type) { + case SerializableStringConvRequest: + args = append(args, reflect.ValueOf( + pam.NewStringConvRequest(v.Style, v.Request))) + default: + if arg == nil { + args = append(args, reflect.Zero(method.Type().In(i))) + } else { + args = append(args, reflect.ValueOf(arg)) + } } } @@ -67,6 +73,9 @@ func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request for _, ret := range method.Call(args) { iface := ret.Interface() switch value := iface.(type) { + case pam.StringConvResponse: + res.ActionArgs = append(res.ActionArgs, + SerializableStringConvResponse{value.Style(), value.Response()}) case pam.Error: authReq.lastError = value res.ActionArgs = append(res.ActionArgs, value) diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index d17ba51..a8b9bb8 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -630,6 +630,194 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { }, }, }, + "start-conv-no-conv-set": { + expectedError: pam.ErrConv, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "hello PAM!", + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, "hello PAM!"), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-prompt-text-info": { + credentials: utils.Credentials{ + ExpectedMessage: "hello PAM!", + ExpectedStyle: pam.TextInfo, + TextInfo: "nice to see you, Go!", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "hello PAM!", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.TextInfo, + "nice to see you, Go!", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, "hello PAM!"), + exp: []interface{}{SerializableStringConvResponse{ + pam.TextInfo, + "nice to see you, Go!", + }, nil}, + }, + { + r: NewRequest("StartStringConvf", pam.TextInfo, "hello %s!", "PAM"), + exp: []interface{}{SerializableStringConvResponse{ + pam.TextInfo, + "nice to see you, Go!", + }, nil}, + }, + }, + }, + "start-conv-prompt-error-msg": { + credentials: utils.Credentials{ + ExpectedMessage: "This is wrong, PAM!", + ExpectedStyle: pam.ErrorMsg, + ErrorMsg: "ops, sorry...", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.ErrorMsg, + "This is wrong, PAM!", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.ErrorMsg, + "ops, sorry...", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.ErrorMsg, + "This is wrong, PAM!", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.ErrorMsg, + "ops, sorry...", + }, nil}, + }, + { + r: NewRequest("StartStringConvf", pam.ErrorMsg, + "This is wrong, %s!", "PAM", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.ErrorMsg, + "ops, sorry...", + }, nil}, + }, + }, + }, + "start-conv-prompt-echo-on": { + credentials: utils.Credentials{ + ExpectedMessage: "Give me your non-private infos", + ExpectedStyle: pam.PromptEchoOn, + EchoOn: "here's my public data", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.PromptEchoOn, + "Give me your non-private infos", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOn, + "here's my public data", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.PromptEchoOn, + "Give me your non-private infos", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOn, + "here's my public data", + }, nil}, + }, + }, + }, + "start-conv-prompt-echo-off": { + credentials: utils.Credentials{ + ExpectedMessage: "Give me your super-secret data", + ExpectedStyle: pam.PromptEchoOff, + EchoOff: "here's my private token", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.PromptEchoOff, + "Give me your super-secret data", + }), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOff, + "here's my private token", + }, nil}, + }, + { + r: NewRequest("StartStringConv", pam.PromptEchoOff, + "Give me your super-secret data", + ), + exp: []interface{}{SerializableStringConvResponse{ + pam.PromptEchoOff, + "here's my private token", + }, nil}, + }, + }, + }, + "start-conv-text-info-handle-failure-message-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.Credentials{ + ExpectedMessage: "This is an info message", + ExpectedStyle: pam.TextInfo, + TextInfo: "And this is what is returned", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "This should have been an info message, but is not", + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, + "This should have been an info message, but is not", + ), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-text-info-handle-failure-style-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.Credentials{ + ExpectedMessage: "This is an info message", + ExpectedStyle: pam.PromptEchoOff, + TextInfo: "And this is what is returned", + }, + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "This is an info message", + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartStringConv", pam.TextInfo, + "This is an info message", + ), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, } for name, tc := range tests { @@ -880,6 +1068,23 @@ func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { }, }, }, + "StartConv": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{{ + r: NewRequest("StartConv", SerializableStringConvRequest{ + pam.TextInfo, + "hello PAM!", + }), + exp: []interface{}{nil, pam.ErrSystem}, + }}, + }, + "StartStringConv": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{{ + r: NewRequest("StartStringConv", pam.TextInfo, "hello PAM!"), + exp: []interface{}{nil, pam.ErrSystem}, + }}, + }, } for name, tc := range tests { diff --git a/cmd/pam-moduler/tests/integration-tester-module/serialization.go b/cmd/pam-moduler/tests/integration-tester-module/serialization.go index 33b26a7..2eae17b 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/serialization.go +++ b/cmd/pam-moduler/tests/integration-tester-module/serialization.go @@ -24,12 +24,30 @@ func (e *SerializablePamError) Error() string { return e.RetStatus.Error() } +// SerializableStringConvRequest is a serializable string request. +type SerializableStringConvRequest struct { + Style pam.Style + Request string +} + +// SerializableStringConvResponse is a serializable string response. +type SerializableStringConvResponse struct { + Style pam.Style + Response string +} + func init() { gob.Register(map[string]string{}) gob.Register(Request{}) gob.Register(pam.Item(0)) gob.Register(pam.Error(0)) + gob.Register(pam.Style(0)) + gob.Register([]pam.ConvResponse{}) gob.RegisterName("main.SerializablePamError", SerializablePamError{}) + gob.RegisterName("main.SerializableStringConvRequest", + SerializableStringConvRequest{}) + gob.RegisterName("main.SerializableStringConvResponse", + SerializableStringConvResponse{}) gob.Register(utils.SerializableError{}) } diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go index 095994b..ce2281b 100644 --- a/cmd/pam-moduler/tests/internal/utils/test-utils.go +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -118,6 +118,10 @@ func (e *SerializableError) Error() string { type Credentials struct { User string Password string + EchoOn string + EchoOff string + TextInfo string + ErrorMsg string ExpectedMessage string CheckEmptyMessage bool ExpectedStyle pam.Style @@ -145,9 +149,19 @@ func (c Credentials) RespondPAM(s pam.Style, msg string) (string, error) { switch s { case pam.PromptEchoOn: - return c.User, nil + if c.User != "" { + return c.User, nil + } + return c.EchoOn, nil case pam.PromptEchoOff: - return c.Password, nil + if c.Password != "" { + return c.Password, nil + } + return c.EchoOff, nil + case pam.TextInfo: + return c.TextInfo, nil + case pam.ErrorMsg: + return c.ErrorMsg, nil } return "", errors.Join(pam.ErrConv, diff --git a/module-transaction-mock.go b/module-transaction-mock.go index 968026a..a789b1e 100644 --- a/module-transaction-mock.go +++ b/module-transaction-mock.go @@ -7,6 +7,8 @@ package pam #include #include #include + +void init_pam_conv(struct pam_conv *conv, uintptr_t appdata); */ import "C" @@ -14,6 +16,7 @@ import ( "errors" "fmt" "runtime" + "runtime/cgo" "testing" "unsafe" ) @@ -110,17 +113,41 @@ func (m *mockModuleTransaction) setData(key *C.char, handle C.uintptr_t) C.int { return C.int(m.RetData.Status) } +func (m *mockModuleTransaction) getConv() (*C.struct_pam_conv, error) { + if m.ConversationHandler != nil { + conv := C.struct_pam_conv{} + handler := cgo.NewHandle(m.ConversationHandler) + C.init_pam_conv(&conv, C.uintptr_t(handler)) + return &conv, nil + } + if C.int(m.RetData.Status) != success { + return nil, m.RetData.Status + } + return nil, nil +} + type mockConversationHandler struct { - User string - ExpectedMessage string - CheckEmptyMessage bool - ExpectedStyle Style - CheckZeroStyle bool + User string + PromptEchoOn string + PromptEchoOff string + TextInfo string + ErrorMsg string + ExpectedMessage string + ExpectedMessagesByStyle map[Style]string + CheckEmptyMessage bool + ExpectedStyle Style + CheckZeroStyle bool + IgnoreUnknownStyle bool } func (c mockConversationHandler) RespondPAM(s Style, msg string) (string, error) { - if (c.ExpectedMessage != "" || c.CheckEmptyMessage) && - msg != c.ExpectedMessage { + var expectedMsg = c.ExpectedMessage + if msg, ok := c.ExpectedMessagesByStyle[s]; ok { + expectedMsg = msg + } + + if (expectedMsg != "" || c.CheckEmptyMessage) && + msg != expectedMsg { return "", fmt.Errorf("%w: unexpected prompt: %s vs %s", ErrConv, msg, c.ExpectedMessage) } @@ -133,7 +160,20 @@ func (c mockConversationHandler) RespondPAM(s Style, msg string) (string, error) switch s { case PromptEchoOn: - return c.User, nil + if c.User != "" { + return c.User, nil + } + return c.PromptEchoOn, nil + case PromptEchoOff: + return c.PromptEchoOff, nil + case TextInfo: + return c.TextInfo, nil + case ErrorMsg: + return c.ErrorMsg, nil + } + + if c.IgnoreUnknownStyle { + return c.ExpectedMessage, nil } return "", fmt.Errorf("%w: unhandled style: %v", ErrConv, s) diff --git a/module-transaction.go b/module-transaction.go index 71419e0..bbb1e13 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -13,6 +13,8 @@ import ( "unsafe" ) +const maxNumMsg = C.PAM_MAX_NUM_MSG + // ModuleTransaction is an interface that a pam module transaction // should implement. type ModuleTransaction interface { @@ -24,6 +26,11 @@ type ModuleTransaction interface { GetUser(prompt string) (string, error) SetData(key string, data any) error GetData(key string) (any, error) + StartStringConv(style Style, prompt string) (StringConvResponse, error) + StartStringConvf(style Style, format string, args ...interface{}) ( + StringConvResponse, error) + StartConv(ConvRequest) (ConvResponse, error) + StartConvMulti([]ConvRequest) ([]ConvResponse, error) } // ModuleHandlerFunc is a function type used by the ModuleHandler. @@ -102,6 +109,10 @@ type moduleTransactionIface interface { getUser(outUser **C.char, prompt *C.char) C.int setData(key *C.char, handle C.uintptr_t) C.int getData(key *C.char, outHandle *C.uintptr_t) C.int + getConv() (*C.struct_pam_conv, error) + startConv(conv *C.struct_pam_conv, nMsg C.int, + messages **C.struct_pam_message, + outResponses **C.struct_pam_response) C.int } func (m *moduleTransaction) getUser(outUser **C.char, prompt *C.char) C.int { @@ -180,3 +191,225 @@ func (m *moduleTransaction) getDataImpl(iface moduleTransactionIface, return nil, m.handlePamStatus(C.int(ErrNoModuleData)) } + +// getConv is a private function to get the conversation pointer to be used +// with C.do_conv() to initiate conversations. +func (m *moduleTransaction) getConv() (*C.struct_pam_conv, error) { + var convPtr unsafe.Pointer + + if err := m.handlePamStatus( + C.pam_get_item(m.handle, C.PAM_CONV, &convPtr)); err != nil { + return nil, err + } + + return (*C.struct_pam_conv)(convPtr), nil +} + +// ConvRequest is an interface that all the Conversation requests should +// implement. +type ConvRequest interface { + Style() Style +} + +// ConvResponse is an interface that all the Conversation responses should +// implement. +type ConvResponse interface { + Style() Style +} + +// StringConvRequest is a ConvRequest for performing text-based conversations. +type StringConvRequest struct { + style Style + prompt string +} + +// NewStringConvRequest creates a new StringConvRequest. +func NewStringConvRequest(style Style, prompt string) StringConvRequest { + return StringConvRequest{style, prompt} +} + +// Style returns the conversation style of the StringConvRequest. +func (s StringConvRequest) Style() Style { + return s.style +} + +// Prompt returns the conversation style of the StringConvRequest. +func (s StringConvRequest) Prompt() string { + return s.prompt +} + +// StringConvResponse is an interface that string Conversation responses implements. +type StringConvResponse interface { + ConvResponse + Response() string +} + +// stringConvResponse is a StringConvResponse implementation used for text-based +// conversation responses. +type stringConvResponse struct { + style Style + response string +} + +// Style returns the conversation style of the StringConvResponse. +func (s stringConvResponse) Style() Style { + return s.style +} + +// Response returns the string response of the conversation. +func (s stringConvResponse) Response() string { + return s.response +} + +// StartStringConv starts a text-based conversation using the provided style +// and prompt. +func (m *moduleTransaction) StartStringConv(style Style, prompt string) ( + StringConvResponse, error) { + return m.startStringConvImpl(m, style, prompt) +} + +func (m *moduleTransaction) startStringConvImpl(iface moduleTransactionIface, + style Style, prompt string) ( + StringConvResponse, error) { + switch style { + case BinaryPrompt: + return nil, fmt.Errorf("%w: binary style is not supported", ErrConv) + } + + res, err := m.startConvImpl(iface, NewStringConvRequest(style, prompt)) + if err != nil { + return nil, err + } + + stringRes, _ := res.(stringConvResponse) + return stringRes, nil +} + +// StartStringConvf allows to start string conversation with formatting support. +func (m *moduleTransaction) StartStringConvf(style Style, format string, args ...interface{}) ( + StringConvResponse, error) { + return m.StartStringConv(style, fmt.Sprintf(format, args...)) +} + +// StartConv initiates a PAM conversation using the provided ConvRequest. +func (m *moduleTransaction) StartConv(req ConvRequest) ( + ConvResponse, error) { + return m.startConvImpl(m, req) +} + +func (m *moduleTransaction) startConvImpl(iface moduleTransactionIface, req ConvRequest) ( + ConvResponse, error) { + resp, err := m.startConvMultiImpl(iface, []ConvRequest{req}) + if err != nil { + return nil, err + } + if len(resp) != 1 { + return nil, fmt.Errorf("%w: not enough values returned", ErrConv) + } + return resp[0], nil +} + +func (m *moduleTransaction) startConv(conv *C.struct_pam_conv, nMsg C.int, + messages **C.struct_pam_message, outResponses **C.struct_pam_response) C.int { + return C.start_pam_conv(conv, nMsg, messages, outResponses) +} + +// startConvMultiImpl is the implementation for GetData for testing purposes. +func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, + requests []ConvRequest) (responses []ConvResponse, err error) { + defer func() { + if err == nil { + _ = m.handlePamStatus(success) + return + } + var pamErr Error + if !errors.As(err, &pamErr) { + err = errors.Join(ErrConv, err) + pamErr = ErrConv + } + _ = m.handlePamStatus(C.int(pamErr)) + }() + + if len(requests) == 0 { + return nil, errors.New("no requests defined") + } + if len(requests) > maxNumMsg { + return nil, errors.New("too many requests") + } + + conv, err := iface.getConv() + if err != nil { + return nil, err + } + + if conv == nil || conv.conv == nil { + return nil, errors.New("impossible to find conv handler") + } + + // FIXME: Just use make([]C.struct_pam_message, 0, len(requests)) + // and append, when it's possible to use runtime.Pinner + var cMessagePtr *C.struct_pam_message + cMessages := (**C.struct_pam_message)(C.calloc(C.size_t(len(requests)), + (C.size_t)(unsafe.Sizeof(cMessagePtr)))) + defer C.free(unsafe.Pointer(cMessages)) + goMsgs := unsafe.Slice(cMessages, len(requests)) + + for i, req := range requests { + var cBytes unsafe.Pointer + switch r := req.(type) { + case StringConvRequest: + cBytes = unsafe.Pointer(C.CString(r.Prompt())) + defer C.free(cBytes) + default: + return nil, fmt.Errorf("unsupported conversation type %#v", r) + } + + goMsgs[i] = &C.struct_pam_message{ + msg_style: C.int(req.Style()), + msg: (*C.char)(cBytes), + } + } + + var cResponses *C.struct_pam_response + ret := iface.startConv(conv, C.int(len(requests)), cMessages, &cResponses) + if ret != success { + return nil, Error(ret) + } + + goResponses := unsafe.Slice(cResponses, len(requests)) + defer func() { + for _, resp := range goResponses { + C.free(unsafe.Pointer(resp.resp)) + } + C.free(unsafe.Pointer(cResponses)) + }() + + responses = make([]ConvResponse, 0, len(requests)) + for i, resp := range goResponses { + msgStyle := requests[i].Style() + switch msgStyle { + case PromptEchoOff: + fallthrough + case PromptEchoOn: + fallthrough + case ErrorMsg: + fallthrough + case TextInfo: + responses = append(responses, stringConvResponse{ + style: msgStyle, + response: C.GoString(resp.resp), + }) + default: + return nil, + fmt.Errorf("unsupported conversation type %v", msgStyle) + } + } + + return responses, nil +} + +// StartConvMulti initiates a PAM conversation with multiple ConvRequest's. +func (m *moduleTransaction) StartConvMulti(requests []ConvRequest) ( + []ConvResponse, error) { + return m.startConvMultiImpl(m, requests) +} diff --git a/module-transaction_test.go b/module-transaction_test.go index fa4c1be..9c4da20 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -8,6 +8,12 @@ import ( "testing" ) +type customConvRequest int + +func (r customConvRequest) Style() Style { + return Style(r) +} + func ensureNoError(t *testing.T, err error) { t.Helper() if err != nil { @@ -93,6 +99,33 @@ func Test_NewNullModuleTransaction(t *testing.T) { return nil, mt.SetData("foo", nil) }, }, + "StartConv-StringConv": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartConv(NewStringConvRequest(TextInfo, "a prompt")) + }, + }, + "StartStringConv": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartStringConv(TextInfo, "a prompt") + }, + }, + "StartStringConvf": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartStringConvf(TextInfo, "a prompt %s", "with info") + }, + }, + "StartConvMulti": { + testFunc: func(t *testing.T) (any, error) { + t.Helper() + return mt.StartConvMulti([]ConvRequest{ + NewStringConvRequest(TextInfo, "a prompt"), + NewStringConvRequest(ErrorMsg, "another prompt"), + }) + }, + }, } for name, tc := range tests { @@ -395,6 +428,225 @@ func Test_MockModuleTransaction(t *testing.T) { return mt.getDataImpl(mock, "replaced-data") }, }, + "StartConv-no-conv-set": { + expectedError: ErrConv, + expectedValue: nil, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + TextInfo, + "hello PAM!", + }) + }, + }, + "StartConv-text-info": { + expectedValue: stringConvResponse{TextInfo, "nice to see you, Go!"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: TextInfo, + ExpectedMessage: "hello PAM!", + TextInfo: "nice to see you, Go!", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + TextInfo, + "hello PAM!", + }) + }, + }, + "StartConv-error-msg": { + expectedValue: stringConvResponse{ErrorMsg, "ops, sorry..."}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: ErrorMsg, + ExpectedMessage: "This is wrong, PAM!", + ErrorMsg: "ops, sorry...", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + ErrorMsg, + "This is wrong, PAM!", + }) + }, + }, + "StartConv-prompt-echo-on": { + expectedValue: stringConvResponse{PromptEchoOn, "here's my public data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "Give me your non-private infos", + PromptEchoOn: "here's my public data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + PromptEchoOn, + "Give me your non-private infos", + }) + }, + }, + "StartConv-prompt-echo-off": { + expectedValue: stringConvResponse{PromptEchoOff, "here's my private data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOff, + ExpectedMessage: "Give me your private secrets", + PromptEchoOff: "here's my private data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + PromptEchoOff, + "Give me your private secrets", + }) + }, + }, + "StartConv-unknown-style": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: Style(9999), + ExpectedMessage: "hello PAM!", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + Style(9999), + "hello PAM!", + }) + }, + }, + "StartConv-unknown-style-response": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: Style(9999), + ExpectedMessage: "hello PAM!", + IgnoreUnknownStyle: true, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, StringConvRequest{ + Style(9999), + "hello PAM!", + }) + }, + }, + "StartStringConv-text-info": { + expectedValue: stringConvResponse{TextInfo, "nice to see you, Go!"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: TextInfo, + ExpectedMessage: "hello PAM!", + TextInfo: "nice to see you, Go!", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, TextInfo, + "hello PAM!") + }, + }, + "StartStringConv-error-msg": { + expectedValue: stringConvResponse{ErrorMsg, "ops, sorry..."}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: ErrorMsg, + ExpectedMessage: "This is wrong, PAM!", + ErrorMsg: "ops, sorry...", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, ErrorMsg, + "This is wrong, PAM!") + }, + }, + "StartStringConv-prompt-echo-on": { + expectedValue: stringConvResponse{PromptEchoOn, "here's my public data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOn, + ExpectedMessage: "Give me your non-private infos", + PromptEchoOn: "here's my public data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, PromptEchoOn, + "Give me your non-private infos") + }, + }, + "StartStringConv-prompt-echo-off": { + expectedValue: stringConvResponse{PromptEchoOff, "here's my private data"}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: PromptEchoOff, + ExpectedMessage: "Give me your private secrets", + PromptEchoOff: "here's my private data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, PromptEchoOff, + "Give me your private secrets") + }, + }, + "StartStringConv-binary": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedMessage: "require binary data", + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startStringConvImpl(mock, PromptEchoOff, + "require binary data") + }, + }, + "StartConvMulti-missing": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvMultiImpl(mock, nil) + }, + }, + "StartConvMulti-too-many": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + reqs := [maxNumMsg + 1]ConvRequest{} + return mt.startConvMultiImpl(mock, reqs[:]) + }, + }, + "StartConvMulti-unexpected-style": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + var req ConvRequest = customConvRequest(0xdeadbeef) + return mt.startConvMultiImpl(mock, []ConvRequest{req}) + }, + }, + "StartConvMulti-string-as-binary": { + expectedError: ErrConv, + expectedValue: ([]ConvResponse)(nil), + conversationHandler: mockConversationHandler{}, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvMultiImpl(mock, []ConvRequest{ + NewStringConvRequest(BinaryPrompt, "no binary!"), + }) + }, + }, + "StartConvMulti-all-types": { + expectedValue: []ConvResponse{ + stringConvResponse{TextInfo, "nice to see you, Go!"}, + stringConvResponse{ErrorMsg, "ops, sorry..."}, + stringConvResponse{PromptEchoOn, "here's my public data"}, + stringConvResponse{PromptEchoOff, "here's my private data"}, + }, + conversationHandler: mockConversationHandler{ + TextInfo: "nice to see you, Go!", + ErrorMsg: "ops, sorry...", + PromptEchoOn: "here's my public data", + PromptEchoOff: "here's my private data", + ExpectedMessagesByStyle: map[Style]string{ + TextInfo: "hello PAM!", + ErrorMsg: "This is wrong, PAM!", + PromptEchoOn: "Give me your non-private infos", + PromptEchoOff: "Give me your private secrets", + }, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvMultiImpl(mock, []ConvRequest{ + NewStringConvRequest(TextInfo, "hello PAM!"), + NewStringConvRequest(ErrorMsg, "This is wrong, PAM!"), + NewStringConvRequest(PromptEchoOn, "Give me your non-private infos"), + NewStringConvRequest(PromptEchoOff, "Give me your private secrets"), + }) + }, + }, } for name, tc := range tests { diff --git a/transaction.h b/transaction.h index b19ce3e..4c9f000 100644 --- a/transaction.h +++ b/transaction.h @@ -59,6 +59,11 @@ static inline void init_pam_conv(struct pam_conv *conv, uintptr_t appdata) conv->appdata_ptr = (void *)appdata; } +static inline int start_pam_conv(struct pam_conv *pc, int num_msgs, const struct pam_message **msgs, struct pam_response **out_resp) +{ + return pc->conv(num_msgs, msgs, out_resp, pc->appdata_ptr); +} + // pam_start_confdir is a recent PAM api to declare a confdir (mostly for // tests) weaken the linking dependency to detect if it’s present. int pam_start_confdir(const char *service_name, const char *user, const struct pam_conv *pam_conversation, From ce46e15c8e20073cb5608eaf148eb7d3af5366a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Thu, 5 Oct 2023 05:49:22 +0200 Subject: [PATCH 14/24] module-transaction: Add support for binary conversations A module can now initiate a binary conversation decoding the native pointer value as it wants. Added tests to verify the main cases --- .../integration-tester-module.go | 9 + .../integration-tester-module_test.go | 129 +++++++++ .../serialization.go | 14 + .../tests/internal/utils/test-utils.go | 94 +++++++ module-transaction-mock.go | 53 ++++ module-transaction.go | 204 +++++++++++++- module-transaction_test.go | 260 +++++++++++++++++- transaction.h | 5 +- 8 files changed, 753 insertions(+), 15 deletions(-) diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 7437534..7991d5b 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -60,6 +60,9 @@ func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request case SerializableStringConvRequest: args = append(args, reflect.ValueOf( pam.NewStringConvRequest(v.Style, v.Request))) + case SerializableBinaryConvRequest: + args = append(args, reflect.ValueOf( + pam.NewBinaryConvRequestFromBytes(v.Request))) default: if arg == nil { args = append(args, reflect.Zero(method.Type().In(i))) @@ -76,6 +79,12 @@ func (m *integrationTesterModule) handleRequest(authReq *authRequest, r *Request case pam.StringConvResponse: res.ActionArgs = append(res.ActionArgs, SerializableStringConvResponse{value.Style(), value.Response()}) + case pam.BinaryConvResponse: + data, err := value.Decode(utils.TestBinaryDataDecoder) + if err != nil { + return nil, err + } + res.ActionArgs = append(res.ActionArgs, SerializableBinaryConvResponse{data}) case pam.Error: authReq.lastError = value res.ActionArgs = append(res.ActionArgs, value) diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index a8b9bb8..71fd5b9 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -71,6 +71,21 @@ func ensureEnv(tx *pam.Transaction, variable string, expected string) error { return nil } +func (r *Request) toBytes(t *testing.T) []byte { + t.Helper() + bytes, err := r.GOB() + if err != nil { + t.Fatalf("error: %v", err) + return nil + } + return bytes +} + +func (r *Request) toTransactionData(t *testing.T) []byte { + t.Helper() + return utils.TestBinaryDataEncoder(r.toBytes(t)) +} + func Test_Moduler_IntegrationTesterModule(t *testing.T) { t.Parallel() if !pam.CheckPamHasStartConfdir() { @@ -818,6 +833,104 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { }, }, }, + "start-conv-binary": { + credentials: utils.NewBinaryTransactionWithData([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!"), + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffYes it is!")), + }), + exp: []interface{}{SerializableBinaryConvResponse{ + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, nil}, + }, + { + r: NewRequest("StartBinaryConv", + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"))), + exp: []interface{}{SerializableBinaryConvResponse{ + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, nil}, + }, + }, + }, + "start-conv-binary-handle-failure-passed-data-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.NewBinaryTransactionWithData([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!"), + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + (&Request{"Not the expected binary data", nil}).toTransactionData(t), + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartBinaryConv", + (&Request{"Not the expected binary data", nil}).toTransactionData(t)), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-binary-handle-failure-returned-data-mismatch": { + expectedError: pam.ErrConv, + credentials: utils.NewBinaryTransactionWithRandomData(100, + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + (&Request{"Wrong binary data", nil}).toTransactionData(t), + }), + exp: []interface{}{nil, pam.ErrConv}, + }, + { + r: NewRequest("StartBinaryConv", + (&Request{"Wrong binary data", nil}).toTransactionData(t)), + exp: []interface{}{nil, pam.ErrConv}, + }, + }, + }, + "start-conv-binary-in-nil": { + credentials: utils.NewBinaryTransactionWithData(nil, + (&Request{"Binary data", []interface{}{true, 123, 0.5, "yay!"}}).toBytes(t)), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{}), + exp: []interface{}{SerializableBinaryConvResponse{ + (&Request{"Binary data", []interface{}{true, 123, 0.5, "yay!"}}).toBytes(t), + }, nil}, + }, + { + r: NewRequest("StartBinaryConv", nil), + exp: []interface{}{SerializableBinaryConvResponse{ + (&Request{"Binary data", []interface{}{true, 123, 0.5, "yay!"}}).toBytes(t), + }, nil}, + }, + }, + }, + "start-conv-binary-out-nil": { + credentials: utils.NewBinaryTransactionWithData([]byte( + "\x00This is a binary data request\xC5\x00\xffGimme nil!"), nil), + checkedRequests: []checkedRequest{ + { + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffGimme nil!")), + }), + exp: []interface{}{SerializableBinaryConvResponse{}, nil}, + }, + { + r: NewRequest("StartBinaryConv", + utils.TestBinaryDataEncoder( + []byte("\x00This is a binary data request\xC5\x00\xffGimme nil!"))), + exp: []interface{}{SerializableBinaryConvResponse{}, nil}, + }, + }, + }, } for name, tc := range tests { @@ -831,6 +944,13 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { Args: []string{socketPath}}, }) + switch tc.credentials.(type) { + case pam.BinaryConversationHandler: + if !pam.CheckPamHasBinaryProtocol() { + t.Skip("Binary protocol is not supported") + } + } + tx, err := pam.StartConfDir(name, tc.user, tc.credentials, ts.WorkDir()) if err != nil { t.Fatalf("start #error: %v", err) @@ -1085,6 +1205,15 @@ func Test_Moduler_IntegrationTesterModule_Authenticate(t *testing.T) { exp: []interface{}{nil, pam.ErrSystem}, }}, }, + "StartConv-Binary": { + expectedError: pam.ErrSystem, + checkedRequests: []checkedRequest{{ + r: NewRequest("StartConv", SerializableBinaryConvRequest{ + []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }), + exp: []interface{}{nil, pam.ErrSystem}, + }}, + }, } for name, tc := range tests { diff --git a/cmd/pam-moduler/tests/integration-tester-module/serialization.go b/cmd/pam-moduler/tests/integration-tester-module/serialization.go index 2eae17b..7a549c2 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/serialization.go +++ b/cmd/pam-moduler/tests/integration-tester-module/serialization.go @@ -36,6 +36,16 @@ type SerializableStringConvResponse struct { Response string } +// SerializableBinaryConvRequest is a serializable binary request. +type SerializableBinaryConvRequest struct { + Request []byte +} + +// SerializableBinaryConvResponse is a serializable binary response. +type SerializableBinaryConvResponse struct { + Response []byte +} + func init() { gob.Register(map[string]string{}) gob.Register(Request{}) @@ -49,5 +59,9 @@ func init() { SerializableStringConvRequest{}) gob.RegisterName("main.SerializableStringConvResponse", SerializableStringConvResponse{}) + gob.RegisterName("main.SerializableBinaryConvRequest", + SerializableBinaryConvRequest{}) + gob.RegisterName("main.SerializableBinaryConvResponse", + SerializableBinaryConvResponse{}) gob.Register(utils.SerializableError{}) } diff --git a/cmd/pam-moduler/tests/internal/utils/test-utils.go b/cmd/pam-moduler/tests/internal/utils/test-utils.go index ce2281b..fd6f11b 100644 --- a/cmd/pam-moduler/tests/internal/utils/test-utils.go +++ b/cmd/pam-moduler/tests/internal/utils/test-utils.go @@ -1,9 +1,16 @@ // Package utils contains the internal test utils package utils +//#include +import "C" + import ( + "crypto/rand" + "encoding/binary" "errors" "fmt" + "reflect" + "unsafe" "github.com/msteinert/pam/v2" ) @@ -167,3 +174,90 @@ func (c Credentials) RespondPAM(s pam.Style, msg string) (string, error) { return "", errors.Join(pam.ErrConv, &SerializableError{fmt.Sprintf("unhandled style: %v", s)}) } + +// BinaryTransaction represents a binary PAM transaction handler struct. +type BinaryTransaction struct { + data []byte + ExpectedNull bool + ReturnedData []byte +} + +// TestBinaryDataEncoder encodes a test binary data. +func TestBinaryDataEncoder(bytes []byte) []byte { + if len(bytes) > 0xff { + panic("Binary transaction size not supported") + } + + if bytes == nil { + return bytes + } + + data := make([]byte, 0, len(bytes)+1) + data = append(data, byte(len(bytes))) + data = append(data, bytes...) + return data +} + +// TestBinaryDataDecoder decodes a test binary data. +func TestBinaryDataDecoder(ptr pam.BinaryPointer) ([]byte, error) { + if ptr == nil { + return nil, nil + } + + length := uint8(*((*C.uint8_t)(ptr))) + if length == 0 { + return []byte{}, nil + } + return C.GoBytes(unsafe.Pointer(ptr), C.int(length+1))[1:], nil +} + +// NewBinaryTransactionWithData creates a new [pam.BinaryTransaction] from bytes. +func NewBinaryTransactionWithData(data []byte, retData []byte) BinaryTransaction { + t := BinaryTransaction{ReturnedData: retData} + t.data = TestBinaryDataEncoder(data) + t.ExpectedNull = data == nil + return t +} + +// NewBinaryTransactionWithRandomData creates a new [pam.BinaryTransaction] with random data. +func NewBinaryTransactionWithRandomData(size uint8, retData []byte) BinaryTransaction { + t := BinaryTransaction{ReturnedData: retData} + randomData := make([]byte, size) + if err := binary.Read(rand.Reader, binary.LittleEndian, &randomData); err != nil { + panic(err) + } + + t.data = TestBinaryDataEncoder(randomData) + return t +} + +// Data returns the bytes of the transaction. +func (b BinaryTransaction) Data() []byte { + return b.data +} + +// RespondPAM (not) handles the PAM string conversations. +func (b BinaryTransaction) RespondPAM(s pam.Style, msg string) (string, error) { + return "", errors.Join(pam.ErrConv, + &SerializableError{"unexpected non-binary request"}) +} + +// RespondPAMBinary handles the PAM binary conversations. +func (b BinaryTransaction) RespondPAMBinary(ptr pam.BinaryPointer) ([]byte, error) { + if ptr == nil && !b.ExpectedNull { + return nil, errors.Join(pam.ErrConv, + &SerializableError{"unexpected null binary data"}) + } else if ptr == nil { + return TestBinaryDataEncoder(b.ReturnedData), nil + } + + bytes, _ := TestBinaryDataDecoder(ptr) + if !reflect.DeepEqual(bytes, b.data[1:]) { + return nil, errors.Join(pam.ErrConv, + &SerializableError{ + fmt.Sprintf("data mismatch %#v vs %#v", bytes, b.data[1:]), + }) + } + + return TestBinaryDataEncoder(b.ReturnedData), nil +} diff --git a/module-transaction-mock.go b/module-transaction-mock.go index a789b1e..c76087d 100644 --- a/module-transaction-mock.go +++ b/module-transaction-mock.go @@ -15,6 +15,7 @@ import "C" import ( "errors" "fmt" + "reflect" "runtime" "runtime/cgo" "testing" @@ -40,10 +41,12 @@ type mockModuleTransaction struct { ConversationHandler ConversationHandler moduleData map[string]uintptr allocatedData []unsafe.Pointer + binaryProtocol bool } func newMockModuleTransaction(m *mockModuleTransaction) *mockModuleTransaction { m.moduleData = make(map[string]uintptr) + m.binaryProtocol = true runtime.SetFinalizer(m, func(m *mockModuleTransaction) { for _, ptr := range m.allocatedData { C.free(ptr) @@ -126,14 +129,21 @@ func (m *mockModuleTransaction) getConv() (*C.struct_pam_conv, error) { return nil, nil } +func (m *mockModuleTransaction) hasBinaryProtocol() bool { + return m.binaryProtocol +} + type mockConversationHandler struct { User string PromptEchoOn string PromptEchoOff string TextInfo string ErrorMsg string + Binary []byte ExpectedMessage string ExpectedMessagesByStyle map[Style]string + ExpectedNil bool + ExpectedBinary []byte CheckEmptyMessage bool ExpectedStyle Style CheckZeroStyle bool @@ -178,3 +188,46 @@ func (c mockConversationHandler) RespondPAM(s Style, msg string) (string, error) return "", fmt.Errorf("%w: unhandled style: %v", ErrConv, s) } + +func testBinaryDataEncoder(bytes []byte) []byte { + if len(bytes) > 0xff { + panic("Binary transaction size not supported") + } + + if bytes == nil { + return bytes + } + + data := make([]byte, 0, len(bytes)+1) + data = append(data, byte(len(bytes))) + data = append(data, bytes...) + return data +} + +func testBinaryDataDecoder(ptr BinaryPointer) ([]byte, error) { + if ptr == nil { + return nil, nil + } + + length := uint8(*((*C.uint8_t)(ptr))) + if length == 0 { + return []byte{}, nil + } + return C.GoBytes(unsafe.Pointer(ptr), C.int(length+1))[1:], nil +} + +func (c mockConversationHandler) RespondPAMBinary(ptr BinaryPointer) ([]byte, error) { + if ptr == nil && !c.ExpectedNil { + return nil, fmt.Errorf("%w: unexpected null binary data", ErrConv) + } else if ptr == nil { + return testBinaryDataEncoder(c.Binary), nil + } + + bytes, _ := testBinaryDataDecoder(ptr) + if !reflect.DeepEqual(bytes, c.ExpectedBinary) { + return nil, fmt.Errorf("%w: data mismatch %#v vs %#v", + ErrConv, bytes, c.ExpectedBinary) + } + + return testBinaryDataEncoder(c.Binary), nil +} diff --git a/module-transaction.go b/module-transaction.go index bbb1e13..df1bfa3 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -9,7 +9,10 @@ import "C" import ( "errors" "fmt" + "runtime" "runtime/cgo" + "sync" + "sync/atomic" "unsafe" ) @@ -29,6 +32,7 @@ type ModuleTransaction interface { StartStringConv(style Style, prompt string) (StringConvResponse, error) StartStringConvf(style Style, format string, args ...interface{}) ( StringConvResponse, error) + StartBinaryConv([]byte) (BinaryConvResponse, error) StartConv(ConvRequest) (ConvResponse, error) StartConvMulti([]ConvRequest) ([]ConvResponse, error) } @@ -110,6 +114,7 @@ type moduleTransactionIface interface { setData(key *C.char, handle C.uintptr_t) C.int getData(key *C.char, outHandle *C.uintptr_t) C.int getConv() (*C.struct_pam_conv, error) + hasBinaryProtocol() bool startConv(conv *C.struct_pam_conv, nMsg C.int, messages **C.struct_pam_message, outResponses **C.struct_pam_response) C.int @@ -261,6 +266,143 @@ func (s stringConvResponse) Response() string { return s.response } +// BinaryFinalizer is a type of function that can be used to release +// the binary when it's not required anymore +type BinaryFinalizer func(BinaryPointer) + +// BinaryConvRequester is the interface that binary ConvRequests should +// implement +type BinaryConvRequester interface { + ConvRequest + Pointer() BinaryPointer + CreateResponse(BinaryPointer) BinaryConvResponse + Release() +} + +// BinaryConvRequest is a ConvRequest for performing binary conversations. +type BinaryConvRequest struct { + ptr atomic.Uintptr + finalizer BinaryFinalizer + responseFinalizer BinaryFinalizer +} + +// NewBinaryConvRequestFull creates a new BinaryConvRequest with finalizer +// for response BinaryResponse. +func NewBinaryConvRequestFull(ptr BinaryPointer, finalizer BinaryFinalizer, + responseFinalizer BinaryFinalizer) *BinaryConvRequest { + b := &BinaryConvRequest{finalizer: finalizer, responseFinalizer: responseFinalizer} + b.ptr.Store(uintptr(ptr)) + if ptr == nil || finalizer == nil { + return b + } + + // The ownership of the data here is temporary + runtime.SetFinalizer(b, func(b *BinaryConvRequest) { b.Release() }) + return b +} + +// NewBinaryConvRequest creates a new BinaryConvRequest +func NewBinaryConvRequest(ptr BinaryPointer, finalizer BinaryFinalizer) *BinaryConvRequest { + return NewBinaryConvRequestFull(ptr, finalizer, finalizer) +} + +// NewBinaryConvRequestFromBytes creates a new BinaryConvRequest from an array +// of bytes. +func NewBinaryConvRequestFromBytes(bytes []byte) *BinaryConvRequest { + if bytes == nil { + return &BinaryConvRequest{} + } + return NewBinaryConvRequest(BinaryPointer(C.CBytes(bytes)), + func(ptr BinaryPointer) { C.free(unsafe.Pointer(ptr)) }) +} + +// Style returns the response style for the request, so always BinaryPrompt. +func (b *BinaryConvRequest) Style() Style { + return BinaryPrompt +} + +// Pointer returns the conversation style of the StringConvRequest. +func (b *BinaryConvRequest) Pointer() BinaryPointer { + ptr := b.ptr.Load() + return *(*BinaryPointer)(unsafe.Pointer(&ptr)) +} + +// CreateResponse creates a new BinaryConvResponse from the request +func (b *BinaryConvRequest) CreateResponse(ptr BinaryPointer) BinaryConvResponse { + bcr := &binaryConvResponse{ptr, b.responseFinalizer, &sync.Mutex{}} + runtime.SetFinalizer(bcr, func(bcr *binaryConvResponse) { + bcr.Release() + }) + return bcr +} + +// Release releases the resources allocated by the request +func (b *BinaryConvRequest) Release() { + ptr := b.ptr.Swap(0) + if b.finalizer != nil { + b.finalizer(*(*BinaryPointer)(unsafe.Pointer(&ptr))) + runtime.SetFinalizer(b, nil) + } +} + +// BinaryDecoder is a function type for decode the a binary pointer data into +// bytes +type BinaryDecoder func(BinaryPointer) ([]byte, error) + +// BinaryConvResponse is a subtype of ConvResponse used for binary +// conversation responses. +type BinaryConvResponse interface { + ConvResponse + Data() BinaryPointer + Decode(BinaryDecoder) ([]byte, error) + Release() +} + +type binaryConvResponse struct { + ptr BinaryPointer + finalizer BinaryFinalizer + mutex *sync.Mutex +} + +// Style returns the response style for the response, so always BinaryPrompt. +func (b binaryConvResponse) Style() Style { + return BinaryPrompt +} + +// Data returns the response native pointer, it's up to the protocol to parse +// it accordingly. +func (b *binaryConvResponse) Data() BinaryPointer { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.ptr +} + +// Decode decodes the binary data using the provided decoder function. +func (b *binaryConvResponse) Decode(decoder BinaryDecoder) ( + []byte, error) { + if decoder == nil { + return nil, errors.New("nil decoder provided") + } + b.mutex.Lock() + defer b.mutex.Unlock() + return decoder(b.ptr) +} + +// Release releases the binary conversation response data. +// This is also automatically via a finalizer, but applications may control +// this explicitly deferring execution of this. +func (b *binaryConvResponse) Release() { + b.mutex.Lock() + defer b.mutex.Unlock() + ptr := b.ptr + b.ptr = nil + if b.finalizer != nil { + b.finalizer(ptr) + } else { + C.free(unsafe.Pointer(ptr)) + } +} + // StartStringConv starts a text-based conversation using the provided style // and prompt. func (m *moduleTransaction) StartStringConv(style Style, prompt string) ( @@ -291,6 +433,29 @@ func (m *moduleTransaction) StartStringConvf(style Style, format string, args .. return m.StartStringConv(style, fmt.Sprintf(format, args...)) } +// HasBinaryProtocol checks if binary protocol is supported. +func (m *moduleTransaction) hasBinaryProtocol() bool { + return CheckPamHasBinaryProtocol() +} + +// StartBinaryConv starts a binary conversation using the provided bytes. +func (m *moduleTransaction) StartBinaryConv(bytes []byte) ( + BinaryConvResponse, error) { + return m.startBinaryConvImpl(m, bytes) +} + +func (m *moduleTransaction) startBinaryConvImpl(iface moduleTransactionIface, + bytes []byte) ( + BinaryConvResponse, error) { + res, err := m.startConvImpl(iface, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return nil, err + } + + binaryRes, _ := res.(BinaryConvResponse) + return binaryRes, nil +} + // StartConv initiates a PAM conversation using the provided ConvRequest. func (m *moduleTransaction) StartConv(req ConvRequest) ( ConvResponse, error) { @@ -360,14 +525,21 @@ func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, case StringConvRequest: cBytes = unsafe.Pointer(C.CString(r.Prompt())) defer C.free(cBytes) + case BinaryConvRequester: + if !iface.hasBinaryProtocol() { + return nil, errors.New("%w: binary protocol is not supported") + } + cBytes = unsafe.Pointer(r.Pointer()) default: return nil, fmt.Errorf("unsupported conversation type %#v", r) } - goMsgs[i] = &C.struct_pam_message{ - msg_style: C.int(req.Style()), - msg: (*C.char)(cBytes), - } + cMessage := (*C.struct_pam_message)(C.calloc(1, + (C.size_t)(unsafe.Sizeof(*goMsgs[i])))) + defer C.free(unsafe.Pointer(cMessage)) + cMessage.msg_style = C.int(req.Style()) + cMessage.msg = (*C.char)(cBytes) + goMsgs[i] = cMessage } var cResponses *C.struct_pam_response @@ -378,15 +550,26 @@ func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, goResponses := unsafe.Slice(cResponses, len(requests)) defer func() { - for _, resp := range goResponses { - C.free(unsafe.Pointer(resp.resp)) + for i, resp := range goResponses { + if resp.resp == nil { + continue + } + switch req := requests[i].(type) { + case BinaryConvRequester: + // In the binary prompt case, we need to rely on the provided + // finalizer to release the response, so let's create a new one. + req.CreateResponse(BinaryPointer(resp.resp)).Release() + default: + C.free(unsafe.Pointer(resp.resp)) + } } C.free(unsafe.Pointer(cResponses)) }() responses = make([]ConvResponse, 0, len(requests)) for i, resp := range goResponses { - msgStyle := requests[i].Style() + request := requests[i] + msgStyle := request.Style() switch msgStyle { case PromptEchoOff: fallthrough @@ -399,6 +582,13 @@ func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, style: msgStyle, response: C.GoString(resp.resp), }) + case BinaryPrompt: + // Let's steal the resp ownership here, so that the request + // finalizer won't act on it. + bcr, _ := request.(BinaryConvRequester) + resp := bcr.CreateResponse(BinaryPointer(resp.resp)) + goResponses[i].resp = nil + responses = append(responses, resp) default: return nil, fmt.Errorf("unsupported conversation type %v", msgStyle) diff --git a/module-transaction_test.go b/module-transaction_test.go index 9c4da20..85233d3 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -123,6 +123,11 @@ func Test_NewNullModuleTransaction(t *testing.T) { return mt.StartConvMulti([]ConvRequest{ NewStringConvRequest(TextInfo, "a prompt"), NewStringConvRequest(ErrorMsg, "another prompt"), + NewBinaryConvRequest(BinaryPointer(&mt), nil), + NewBinaryConvRequestFromBytes([]byte("These are bytes!")), + NewBinaryConvRequestFromBytes([]byte{}), + NewBinaryConvRequestFromBytes(nil), + NewBinaryConvRequest(nil, nil), }) }, }, @@ -620,31 +625,272 @@ func Test_MockModuleTransaction(t *testing.T) { }, }, "StartConvMulti-all-types": { - expectedValue: []ConvResponse{ - stringConvResponse{TextInfo, "nice to see you, Go!"}, - stringConvResponse{ErrorMsg, "ops, sorry..."}, - stringConvResponse{PromptEchoOn, "here's my public data"}, - stringConvResponse{PromptEchoOff, "here's my private data"}, + expectedValue: []any{ + []ConvResponse{ + stringConvResponse{TextInfo, "nice to see you, Go!"}, + stringConvResponse{ErrorMsg, "ops, sorry..."}, + stringConvResponse{PromptEchoOn, "here's my public data"}, + stringConvResponse{PromptEchoOff, "here's my private data"}, + }, + [][]byte{ + {0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, }, conversationHandler: mockConversationHandler{ TextInfo: "nice to see you, Go!", ErrorMsg: "ops, sorry...", PromptEchoOn: "here's my public data", PromptEchoOff: "here's my private data", + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, ExpectedMessagesByStyle: map[Style]string{ TextInfo: "hello PAM!", ErrorMsg: "This is wrong, PAM!", PromptEchoOn: "Give me your non-private infos", PromptEchoOff: "Give me your private secrets", }, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), }, testFunc: func(mock *mockModuleTransaction) (any, error) { - return mt.startConvMultiImpl(mock, []ConvRequest{ + requests := []ConvRequest{ NewStringConvRequest(TextInfo, "hello PAM!"), NewStringConvRequest(ErrorMsg, "This is wrong, PAM!"), NewStringConvRequest(PromptEchoOn, "Give me your non-private infos"), NewStringConvRequest(PromptEchoOff, "Give me your private secrets"), - }) + NewBinaryConvRequestFromBytes( + testBinaryDataEncoder([]byte("\x00This is a binary data request\xC5\x00\xffYes it is!"))), + } + + data, err := mt.startConvMultiImpl(mock, requests) + if err != nil { + return data, err + } + + stringResponses := []ConvResponse{} + binaryResponses := [][]byte{} + for i, r := range data { + if r.Style() != requests[i].Style() { + mock.T.Fatalf("unexpected style %#v vs %#v", + r.Style(), requests[i].Style()) + } + + switch rt := r.(type) { + case BinaryConvResponse: + decoded, err := rt.Decode(testBinaryDataDecoder) + if err != nil { + return data, err + } + binaryResponses = append(binaryResponses, decoded) + case StringConvResponse: + stringResponses = append(stringResponses, r) + default: + mock.T.Fatalf("unexpected value %v", rt) + } + } + return []any{ + stringResponses, + binaryResponses, + }, err + }, + }, + "StartConvMulti-all-types-some-failing": { + expectedError: ErrConv, + expectedValue: []ConvResponse(nil), + conversationHandler: mockConversationHandler{ + TextInfo: "nice to see you, Go!", + ErrorMsg: "ops, sorry...", + PromptEchoOn: "here's my public data", + PromptEchoOff: "here's my private data", + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + ExpectedMessagesByStyle: map[Style]string{ + TextInfo: "hello PAM!", + ErrorMsg: "This is wrong, PAM!", + PromptEchoOn: "Give me your non-private infos", + PromptEchoOff: "Give me your private secrets", + Style(0xfaaf): "This will fail", + }, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + IgnoreUnknownStyle: true, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + requests := []ConvRequest{ + NewStringConvRequest(TextInfo, "hello PAM!"), + NewStringConvRequest(ErrorMsg, "This is wrong, PAM!"), + NewStringConvRequest(PromptEchoOn, "Give me your non-private infos"), + NewStringConvRequest(PromptEchoOff, "Give me your private secrets"), + NewStringConvRequest(Style(0xfaaf), "This will fail"), + NewBinaryConvRequestFromBytes( + testBinaryDataEncoder([]byte("\x00This is a binary data request\xC5\x00\xffYes it is!"))), + } + + return mt.startConvMultiImpl(mock, requests) + }, + }, + "StartConv-Binary-unsupported": { + expectedValue: nil, + expectedError: ErrConv, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + mock.binaryProtocol = false + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + }, + }, + "StartConv-Binary": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-expected-data-mismatch": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is not the expected data!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + }, + }, + "StartConv-Binary-unexpected-nil": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(nil)) + }, + }, + "StartConv-Binary-expected-nil": { + expectedValue: []byte("\x1ASome binary Dat\xaa"), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedNil: true, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(nil)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-returns-nil": { + expectedValue: BinaryPointer(nil), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x1ASome binary Dat\xaa"), + Binary: nil, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte("\x1ASome binary Dat\xaa")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Data(), err + }, + }, + "StartBinaryConv": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is a binary data request\xC5\x00\xffYes it is!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + bcr, _ := data.(BinaryConvResponse) + return bcr.Decode(testBinaryDataDecoder) + }, + }, + "StartBinaryConv-expected-data-mismatch": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This is not the expected data!"), + Binary: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + return mt.startBinaryConvImpl(mock, bytes) + }, + }, + "StartBinaryConv-unexpected-nil": { + expectedError: ErrConv, + expectedValue: nil, + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startBinaryConvImpl(mock, nil) + }, + }, + "StartBinaryConv-expected-nil": { + expectedValue: []byte("\x1ASome binary Dat\xaa"), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedNil: true, + ExpectedBinary: []byte("\x00This should not be nil"), + Binary: []byte("\x1ASome binary Dat\xaa"), + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + data, err := mt.startBinaryConvImpl(mock, nil) + if err != nil { + return data, err + } + return data.Decode(testBinaryDataDecoder) + }, + }, + "StartBinaryConv-returns-nil": { + expectedValue: BinaryPointer(nil), + conversationHandler: mockConversationHandler{ + ExpectedStyle: BinaryPrompt, + ExpectedBinary: []byte("\x1ASome binary Dat\xaa"), + Binary: nil, + }, + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte("\x1ASome binary Dat\xaa")) + data, err := mt.startBinaryConvImpl(mock, bytes) + if err != nil { + return data, err + } + return data.Data(), err }, }, } diff --git a/transaction.h b/transaction.h index 4c9f000..292aa96 100644 --- a/transaction.h +++ b/transaction.h @@ -42,7 +42,10 @@ static inline int cb_pam_conv(int num_msg, PAM_CONST struct pam_message **msg, s error: for (size_t i = 0; i < num_msg; ++i) { if ((*resp)[i].resp) { - memset((*resp)[i].resp, 0, strlen((*resp)[i].resp)); +#ifdef PAM_BINARY_PROMPT + if (msg[i]->msg_style != PAM_BINARY_PROMPT) +#endif + memset((*resp)[i].resp, 0, strlen((*resp)[i].resp)); free((*resp)[i].resp); } } From a047550bedb5a28a35b117774f370fb535bb87b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 13 Oct 2023 19:24:49 +0200 Subject: [PATCH 15/24] module-transaction: Do not allow parallel conversations by default Pam conversations per se may also run in parallel, but this implies that the application supports this. Since this normally not the case, do not create modules that may invoke the pam conversations in parallel by default, adding a mutex to protect such calls. --- cmd/pam-moduler/moduler.go | 14 ++++++++-- .../integration-tester-module.go | 2 +- .../integration-tester-module/pam_module.go | 4 +-- module-transaction.go | 28 +++++++++++++++++-- module-transaction_test.go | 15 ++++++++-- 5 files changed, 51 insertions(+), 12 deletions(-) diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 165d5a6..0a74125 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -69,6 +69,7 @@ var ( moduleBuildFlags = flag.String("build-flags", "", "comma-separated list of go build flags to use when generating the module") moduleBuildTags = flag.String("build-tags", "", "comma-separated list of build tags to use when generating the module") noMain = flag.Bool("no-main", false, "whether to add an empty main to generated file") + parallelConv = flag.Bool("parallel-conv", false, "whether to support performing PAM conversations in parallel") ) // Usage is a replacement usage function for the flags package. @@ -137,6 +138,7 @@ func main() { generateTags: generateTags, noMain: *noMain, typeName: *typeName, + parallelConv: *parallelConv, } // Print the header and package clause. @@ -169,6 +171,7 @@ type Generator struct { generateTags []string buildFlags []string noMain bool + parallelConv bool } func (g *Generator) printf(format string, args ...interface{}) { @@ -186,6 +189,11 @@ func (g *Generator) generate() { buildTagsArg = fmt.Sprintf("-tags %s", strings.Join(g.generateTags, ",")) } + var transactionCreator = "NewModuleTransactionInvoker" + if g.parallelConv { + transactionCreator = "NewModuleTransactionInvokerParallelConv" + } + // We use a slice since we want to keep order, for reproducible builds. vFuncs := []struct { cName string @@ -257,8 +265,8 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) - err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), + mt := pam.%s(pam.NativeHandle(pamh)) + err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), sliceFromArgv(argc, argv)) if err == nil { return 0 @@ -275,7 +283,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrSystem) } -`) +`, transactionCreator) for _, f := range vFuncs { g.printf(` diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 7991d5b..fcdeaa9 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule +//go:generate go run github.com/msteinert/pam/v2/cmd/pam-moduler -type integrationTesterModule -parallel-conv //go:generate go generate --skip="pam_module.go" // Package main is the package for the integration tester module PAM shared library. diff --git a/cmd/pam-moduler/tests/integration-tester-module/pam_module.go b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go index 39a22b7..e64a4f9 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/pam_module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/pam_module.go @@ -1,4 +1,4 @@ -// Code generated by "pam-moduler -type integrationTesterModule"; DO NOT EDIT. +// Code generated by "pam-moduler -type integrationTesterModule -parallel-conv"; DO NOT EDIT. //go:generate go build "-ldflags=-extldflags -Wl,-soname,pam_go.so" -buildmode=c-shared -o pam_go.so -tags go_pam_module @@ -43,7 +43,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return C.int(pam.ErrIgnore) } - mt := pam.NewModuleTransactionInvoker(pam.NativeHandle(pamh)) + mt := pam.NewModuleTransactionInvokerParallelConv(pam.NativeHandle(pamh)) err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), sliceFromArgv(argc, argv)) if err == nil { diff --git a/module-transaction.go b/module-transaction.go index df1bfa3..fc754a1 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -43,6 +43,7 @@ type ModuleHandlerFunc func(ModuleTransaction, Flags, []string) error // ModuleTransaction is the module-side handle for a PAM transaction. type moduleTransaction struct { transactionBase + convMutex *sync.Mutex } // ModuleHandler is an interface for objects that can be used to create @@ -63,10 +64,27 @@ type ModuleTransactionInvoker interface { InvokeHandler(handler ModuleHandlerFunc, flags Flags, args []string) error } -// NewModuleTransactionInvoker allows initializing a transaction invoker from -// the module side. +// NewModuleTransactionParallelConv allows initializing a transaction from the +// module side. Conversations using this transaction can be multi-thread, but +// this requires the application loading the module to support this, otherwise +// we may just break their assumptions. +func NewModuleTransactionParallelConv(handle NativeHandle) ModuleTransaction { + return &moduleTransaction{transactionBase{handle: handle}, nil} +} + +// NewModuleTransactionInvoker allows initializing a transaction invoker from the +// module side. func NewModuleTransactionInvoker(handle NativeHandle) ModuleTransactionInvoker { - return &moduleTransaction{transactionBase{handle: handle}} + return &moduleTransaction{transactionBase{handle: handle}, &sync.Mutex{}} +} + +// NewModuleTransactionInvokerParallelConv allows initializing a transaction invoker +// from the module side. +// Conversations using this transaction can be multi-thread, but this requires +// the application loading the module to support this, otherwise we may just +// break their assumptions. +func NewModuleTransactionInvokerParallelConv(handle NativeHandle) ModuleTransactionInvoker { + return &moduleTransaction{transactionBase{handle: handle}, nil} } func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, @@ -542,6 +560,10 @@ func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, goMsgs[i] = cMessage } + if m.convMutex != nil { + m.convMutex.Lock() + defer m.convMutex.Unlock() + } var cResponses *C.struct_pam_response ret := iface.startConv(conv, C.int(len(requests)), cMessages, &cResponses) if ret != success { diff --git a/module-transaction_test.go b/module-transaction_test.go index 85233d3..2e678e0 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -305,11 +305,10 @@ func Test_ModuleTransaction_InvokeHandler(t *testing.T) { } } -func Test_MockModuleTransaction(t *testing.T) { +func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { + t.Helper() t.Parallel() - mt, _ := NewModuleTransactionInvoker(nil).(*moduleTransaction) - tests := map[string]struct { testFunc func(mock *mockModuleTransaction) (any, error) mockExpectations mockModuleTransactionExpectations @@ -914,3 +913,13 @@ func Test_MockModuleTransaction(t *testing.T) { }) } } + +func Test_MockModuleTransaction(t *testing.T) { + mt, _ := NewModuleTransactionInvoker(nil).(*moduleTransaction) + testMockModuleTransaction(t, mt) +} + +func Test_MockModuleTransactionParallelConv(t *testing.T) { + mt, _ := NewModuleTransactionInvokerParallelConv(nil).(*moduleTransaction) + testMockModuleTransaction(t, mt) +} From 3253288342dfb753968858de146be123bd82d688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Thu, 19 Oct 2023 03:02:46 +0200 Subject: [PATCH 16/24] github/test: Run tests with address sanitizer We have lots of cgo interaction here so better to check things fully. This also requires manually checking for leaks, so add support for this. --- .github/workflows/test.yaml | 15 +++++++++++++-- .gitignore | 2 +- module-transaction_test.go | 5 +++++ transaction_test.go | 30 ++++++++++++++++++++++++++++++ utils.go | 31 +++++++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 utils.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a4007dc..69c49ef 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,14 +29,25 @@ jobs: run: sudo useradd -d /tmp/test -p '$1$Qd8H95T5$RYSZQeoFbEB.gS19zS99A0' -s /bin/false test - name: Checkout code uses: actions/checkout@v4 + - name: Test + run: sudo go test -v -cover -coverprofile=coverage.out ./... + - name: Test with Address Sanitizer + env: + GO_PAM_TEST_WITH_ASAN: true + CGO_CFLAGS: "-O0 -g3 -fno-omit-frame-pointer" + run: | + # Do not run sudo-requiring go tests because as PAM has some leaks in 22.04 + go test -v -asan -cover -coverprofile=coverage-asan-tx.out -gcflags=all="-N -l" + + # Run the rest of tests normally + sudo go test -v -cover -coverprofile=coverage-asan-module.out -asan -gcflags=all="-N -l" -run Module + sudo go test -C cmd -coverprofile=coverage-asan.out -v -asan -gcflags=all="-N -l" ./... - name: Generate example module run: | rm -f example-module/pam_go.so go generate -C example-module -v test -e example-module/pam_go.so git diff --exit-code example-module - - name: Test - run: sudo go test -v -cover -coverprofile=coverage.out ./... - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 env: diff --git a/.gitignore b/.gitignore index 0700a89..8206d2f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -coverage.out +coverage*.out example-module/*.so example-module/*.h cmd/pam-moduler/tests/*/*.so diff --git a/module-transaction_test.go b/module-transaction_test.go index 2e678e0..cf64085 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -23,6 +23,7 @@ func ensureNoError(t *testing.T, err error) { func Test_NewNullModuleTransaction(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) mt := moduleTransaction{} if mt.handle != nil { @@ -137,6 +138,7 @@ func Test_NewNullModuleTransaction(t *testing.T) { tc := tc t.Run(name+"-error-check", func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) data, err := tc.testFunc(t) switch d := data.(type) { @@ -202,6 +204,7 @@ func Test_NewNullModuleTransaction(t *testing.T) { func Test_ModuleTransaction_InvokeHandler(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) mt := &moduleTransaction{} err := mt.InvokeHandler(nil, 0, nil) @@ -308,6 +311,7 @@ func Test_ModuleTransaction_InvokeHandler(t *testing.T) { func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { t.Helper() t.Parallel() + t.Cleanup(maybeDoLeakCheck) tests := map[string]struct { testFunc func(mock *mockModuleTransaction) (any, error) @@ -898,6 +902,7 @@ func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { tc := tc t.Run(name, func(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) mock := newMockModuleTransaction(&mockModuleTransaction{T: t, Expectations: tc.mockExpectations, RetData: tc.mockRetData, ConversationHandler: tc.conversationHandler}) diff --git a/transaction_test.go b/transaction_test.go index d358b0e..3166159 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -39,6 +39,7 @@ func ensureTransactionEnds(t *testing.T, tx *Transaction) { } func TestPAM_001(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -67,6 +68,7 @@ func TestPAM_001(t *testing.T) { } func TestPAM_002(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -107,6 +109,7 @@ func (c Credentials) RespondPAM(s Style, msg string) (string, error) { } func TestPAM_003(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -128,6 +131,7 @@ func TestPAM_003(t *testing.T) { } func TestPAM_004(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -148,10 +152,14 @@ func TestPAM_004(t *testing.T) { } func TestPAM_005(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") } + if _, found := os.LookupEnv("GO_PAM_TEST_WITH_ASAN"); found { + t.Skip("test fails under ASAN") + } tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { return "secret", nil }) @@ -174,6 +182,7 @@ func TestPAM_005(t *testing.T) { } func TestPAM_006(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -197,6 +206,7 @@ func TestPAM_006(t *testing.T) { } func TestPAM_007(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() if u.Uid != "0" { t.Skip("run this test as root") @@ -223,6 +233,7 @@ func TestPAM_007(t *testing.T) { } func TestPAM_ConfDir(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() c := Credentials{ // the custom service always permits even with wrong password. @@ -258,6 +269,7 @@ func TestPAM_ConfDir(t *testing.T) { } func TestPAM_ConfDir_FailNoServiceOrUnsupported(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) if !CheckPamHasStartConfdir() { t.Skip("this requires PAM with Conf dir support") } @@ -286,6 +298,7 @@ func TestPAM_ConfDir_FailNoServiceOrUnsupported(t *testing.T) { } func TestPAM_ConfDir_InfoMessage(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) u, _ := user.Current() var infoText string tx, err := StartConfDir("echo-service", u.Username, @@ -319,6 +332,7 @@ func TestPAM_ConfDir_InfoMessage(t *testing.T) { } func TestPAM_ConfDir_Deny(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) if !CheckPamHasStartConfdir() { t.Skip("this requires PAM with Conf dir support") } @@ -350,6 +364,7 @@ func TestPAM_ConfDir_Deny(t *testing.T) { } func TestPAM_ConfDir_PromptForUserName(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) c := Credentials{ User: "testuser", // the custom service only cares about correct user name. @@ -375,6 +390,7 @@ func TestPAM_ConfDir_PromptForUserName(t *testing.T) { } func TestPAM_ConfDir_WrongUserName(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) c := Credentials{ User: "wronguser", Password: "wrongsecret", @@ -403,6 +419,7 @@ func TestPAM_ConfDir_WrongUserName(t *testing.T) { } func TestItem(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx, err := StartFunc("passwd", "test", func(s Style, msg string) (string, error) { return "", nil }) @@ -442,6 +459,7 @@ func TestItem(t *testing.T) { } func TestEnv(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx, err := StartFunc("", "", func(s Style, msg string) (string, error) { return "", nil }) @@ -511,6 +529,7 @@ func TestEnv(t *testing.T) { func Test_Error(t *testing.T) { t.Parallel() + t.Cleanup(maybeDoLeakCheck) if !CheckPamHasStartConfdir() { t.Skip("this requires PAM with Conf dir support") } @@ -624,6 +643,7 @@ func Test_Error(t *testing.T) { } func Test_Finalizer(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) if !CheckPamHasStartConfdir() { t.Skip("this requires PAM with Conf dir support") } @@ -643,6 +663,7 @@ func Test_Finalizer(t *testing.T) { } func TestFailure_001(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} _, err := tx.GetEnvList() if err == nil { @@ -651,6 +672,7 @@ func TestFailure_001(t *testing.T) { } func TestFailure_002(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.PutEnv("") if err == nil { @@ -659,6 +681,7 @@ func TestFailure_002(t *testing.T) { } func TestFailure_003(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.CloseSession(0) if err == nil { @@ -667,6 +690,7 @@ func TestFailure_003(t *testing.T) { } func TestFailure_004(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.OpenSession(0) if err == nil { @@ -675,6 +699,7 @@ func TestFailure_004(t *testing.T) { } func TestFailure_005(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.ChangeAuthTok(0) if err == nil { @@ -683,6 +708,7 @@ func TestFailure_005(t *testing.T) { } func TestFailure_006(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.AcctMgmt(0) if err == nil { @@ -691,6 +717,7 @@ func TestFailure_006(t *testing.T) { } func TestFailure_007(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.SetCred(0) if err == nil { @@ -699,6 +726,7 @@ func TestFailure_007(t *testing.T) { } func TestFailure_008(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.SetItem(User, "test") if err == nil { @@ -707,6 +735,7 @@ func TestFailure_008(t *testing.T) { } func TestFailure_009(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} _, err := tx.GetItem(User) if err == nil { @@ -715,6 +744,7 @@ func TestFailure_009(t *testing.T) { } func TestFailure_010(t *testing.T) { + t.Cleanup(maybeDoLeakCheck) tx := Transaction{} err := tx.End() if err != nil { diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..d094c30 --- /dev/null +++ b/utils.go @@ -0,0 +1,31 @@ +// Package pam provides a wrapper for the PAM application API. +package pam + +/* +#ifdef __SANITIZE_ADDRESS__ +#include +#endif + +static inline void +maybe_do_leak_check (void) +{ +#ifdef __SANITIZE_ADDRESS__ + __lsan_do_leak_check(); +#endif +} +*/ +import "C" + +import ( + "os" + "runtime" + "time" +) + +func maybeDoLeakCheck() { + runtime.GC() + time.Sleep(time.Millisecond * 20) + if os.Getenv("GO_PAM_SKIP_LEAK_CHECK") == "" { + C.maybe_do_leak_check() + } +} From 62b69f7d44dff56a546ed9341fef898dbdd1f61b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Thu, 26 Oct 2023 00:56:49 +0200 Subject: [PATCH 17/24] transaction: Add BinaryConversationFunc adapter --- app-transaction.go | 14 +++++++++++++ module-transaction_test.go | 42 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/app-transaction.go b/app-transaction.go index 671b48e..159a5cd 100644 --- a/app-transaction.go +++ b/app-transaction.go @@ -44,6 +44,20 @@ func (f ConversationFunc) RespondPAM(s Style, msg string) (string, error) { return f(s, msg) } +// BinaryConversationFunc is an adapter to allow the use of ordinary functions +// as binary (only) conversation callbacks. +type BinaryConversationFunc func(BinaryPointer) ([]byte, error) + +// RespondPAMBinary is a conversation callback adapter. +func (f BinaryConversationFunc) RespondPAMBinary(ptr BinaryPointer) ([]byte, error) { + return f(ptr) +} + +// RespondPAM is a dummy conversation callback adapter. +func (f BinaryConversationFunc) RespondPAM(Style, string) (string, error) { + return "", ErrConv +} + // _go_pam_conv_handler is a C wrapper for the conversation callback function. // //export _go_pam_conv_handler diff --git a/module-transaction_test.go b/module-transaction_test.go index cf64085..ed08922 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -3,6 +3,7 @@ package pam import ( "errors" + "fmt" "reflect" "strings" "testing" @@ -896,6 +897,47 @@ func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { return data.Data(), err }, }, + "StartConv-Binary-with-ConvFunc": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}, + conversationHandler: BinaryConversationFunc(func(ptr BinaryPointer) ([]byte, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, fmt.Errorf("%w, data mismatch %#v vs %#v", + ErrConv, bytes, expectedBinary) + } + return testBinaryDataEncoder([]byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x99}), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-ConvFunc-error": { + expectedError: ErrConv, + conversationHandler: BinaryConversationFunc(func(ptr BinaryPointer) ([]byte, error) { + return nil, errors.New("got an error") + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes([]byte{})) + }, + }, + "StartConv-String-with-ConvBinaryFunc": { + expectedError: ErrConv, + conversationHandler: BinaryConversationFunc(func(ptr BinaryPointer) ([]byte, error) { + return nil, nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewStringConvRequest(TextInfo, "prompt")) + }, + }, } for name, tc := range tests { From 5689405c0d3754aa44ee7b06e2f34908f36ba0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Tue, 7 Nov 2023 14:22:43 +0200 Subject: [PATCH 18/24] transaction: Add support for using raw binary pointers conversation handler This requires the allocating function to provide a binary pointer that will be free'd by the conversation handlers finalizers. This is for a more advanced usage scenario where the binary conversion may be handled manually. --- app-transaction.go | 42 ++++++ .../integration-tester-module_test.go | 4 + module-transaction_test.go | 139 ++++++++++++++++++ utils.go | 11 ++ 4 files changed, 196 insertions(+) diff --git a/app-transaction.go b/app-transaction.go index 159a5cd..39a3cf4 100644 --- a/app-transaction.go +++ b/app-transaction.go @@ -35,6 +35,18 @@ type BinaryConversationHandler interface { RespondPAMBinary(BinaryPointer) ([]byte, error) } +// BinaryPointerConversationHandler is an interface for objects that can be used as +// conversation callbacks during PAM authentication if binary protocol is going +// to be supported. +type BinaryPointerConversationHandler interface { + ConversationHandler + // RespondPAMBinary receives a pointer to the binary message. It's up to + // the receiver to parse it according to the protocol specifications. + // The function must return a pointer that is allocated via malloc or + // similar, as it's expected to be free'd by the conversation handler. + RespondPAMBinary(BinaryPointer) (BinaryPointer, error) +} + // ConversationFunc is an adapter to allow the use of ordinary functions as // conversation callbacks. type ConversationFunc func(Style, string) (string, error) @@ -58,6 +70,20 @@ func (f BinaryConversationFunc) RespondPAM(Style, string) (string, error) { return "", ErrConv } +// BinaryPointerConversationFunc is an adapter to allow the use of ordinary +// functions as binary pointer (only) conversation callbacks. +type BinaryPointerConversationFunc func(BinaryPointer) (BinaryPointer, error) + +// RespondPAMBinary is a conversation callback adapter. +func (f BinaryPointerConversationFunc) RespondPAMBinary(ptr BinaryPointer) (BinaryPointer, error) { + return f(ptr) +} + +// RespondPAM is a dummy conversation callback adapter. +func (f BinaryPointerConversationFunc) RespondPAM(Style, string) (string, error) { + return "", ErrConv +} + // _go_pam_conv_handler is a C wrapper for the conversation callback function. // //export _go_pam_conv_handler @@ -88,6 +114,16 @@ func pamConvHandler(style Style, msg *C.char, handler ConversationHandler) (*C.c return (*C.char)(C.CBytes(bytes)), success } handler = cb + case BinaryPointerConversationHandler: + if style == BinaryPrompt { + ptr, err := cb.RespondPAMBinary(BinaryPointer(msg)) + if err != nil { + defer C.free(unsafe.Pointer(ptr)) + return nil, C.int(ErrConv) + } + return (*C.char)(ptr), success + } + handler = cb case ConversationHandler: if style == BinaryPrompt { return nil, C.int(ErrConv) @@ -164,6 +200,12 @@ func start(service, user string, handler ConversationHandler, confDir string) (* return nil, fmt.Errorf("%w: BinaryConversationHandler was used, but it is not supported by this platform", ErrSystem) } + case BinaryPointerConversationHandler: + if !CheckPamHasBinaryProtocol() { + return nil, fmt.Errorf( + "%w: BinaryPointerConversationHandler was used, but it is not supported by this platform", + ErrSystem) + } } t := &Transaction{ conv: &C.struct_pam_conv{}, diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go index 71fd5b9..45acc70 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module_test.go @@ -949,6 +949,10 @@ func Test_Moduler_IntegrationTesterModule(t *testing.T) { if !pam.CheckPamHasBinaryProtocol() { t.Skip("Binary protocol is not supported") } + case pam.BinaryPointerConversationHandler: + if !pam.CheckPamHasBinaryProtocol() { + t.Skip("Binary protocol is not supported") + } } tx, err := pam.StartConfDir(name, tc.user, tc.credentials, ts.WorkDir()) diff --git a/module-transaction_test.go b/module-transaction_test.go index ed08922..0514694 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -938,6 +938,145 @@ func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { return mt.startConvImpl(mock, NewStringConvRequest(TextInfo, "prompt")) }, }, + "StartConv-Binary-with-PointerConvFunc": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x95}, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From bytes pointer.") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{ + 0x01, 0x02, 0x03, 0x05, 0x00, 0x95})), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From bytes pointer.")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-and-allocated-data": { + expectedValue: []byte{0x01, 0x02, 0x03, 0x05, 0x00, 0x95}, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{ + 0x01, 0x02, 0x03, 0x05, 0x00, 0x95})), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...")) + data, err := mt.startConvImpl(mock, + NewBinaryConvRequest(allocateCBytes(bytes), binaryPointerCBytesFinalizer)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-and-allocated-data-erroring": { + expectedValue: nil, + expectedError: ErrConv, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{ + 0x01, 0x02, 0x03, 0x05, 0x00, 0x95})), ErrConv + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a binary data request\xC5\x00\xffYes it is! From pointer...")) + data, err := mt.startConvImpl(mock, + NewBinaryConvRequest(allocateCBytes(bytes), binaryPointerCBytesFinalizer)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-empty": { + expectedValue: []byte{}, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is an empty binary data request\xC5\x00\xffYes it is!") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return allocateCBytes(testBinaryDataEncoder([]byte{})), nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is an empty binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-nil": { + expectedValue: []byte(nil), + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + bytes, _ := testBinaryDataDecoder(ptr) + expectedBinary := []byte( + "\x00This is a nil binary data request\xC5\x00\xffYes it is!") + if !reflect.DeepEqual(bytes, expectedBinary) { + return nil, + fmt.Errorf("%w: data mismatch %#v vs %#v", ErrConv, bytes, expectedBinary) + } + return nil, nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + bytes := testBinaryDataEncoder([]byte( + "\x00This is a nil binary data request\xC5\x00\xffYes it is!")) + data, err := mt.startConvImpl(mock, NewBinaryConvRequestFromBytes(bytes)) + if err != nil { + return data, err + } + resp, _ := data.(BinaryConvResponse) + return resp.Decode(testBinaryDataDecoder) + }, + }, + "StartConv-Binary-with-PointerConvFunc-error": { + expectedError: ErrConv, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + return nil, errors.New("got an error") + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewBinaryConvRequestFromBytes([]byte{})) + }, + }, + "StartConv-String-with-ConvPointerBinaryFunc": { + expectedError: ErrConv, + conversationHandler: BinaryPointerConversationFunc(func(ptr BinaryPointer) (BinaryPointer, error) { + return nil, nil + }), + testFunc: func(mock *mockModuleTransaction) (any, error) { + return mt.startConvImpl(mock, NewStringConvRequest(TextInfo, "prompt")) + }, + }, } for name, tc := range tests { diff --git a/utils.go b/utils.go index d094c30..ad61daa 100644 --- a/utils.go +++ b/utils.go @@ -2,6 +2,8 @@ package pam /* +#include + #ifdef __SANITIZE_ADDRESS__ #include #endif @@ -20,6 +22,7 @@ import ( "os" "runtime" "time" + "unsafe" ) func maybeDoLeakCheck() { @@ -29,3 +32,11 @@ func maybeDoLeakCheck() { C.maybe_do_leak_check() } } + +func allocateCBytes(bytes []byte) BinaryPointer { + return BinaryPointer(C.CBytes(bytes)) +} + +func binaryPointerCBytesFinalizer(ptr BinaryPointer) { + C.free(unsafe.Pointer(ptr)) +} From 5c4d79646798aef3b3e1a20ee5ac1b7483c6fac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 20 Oct 2023 00:17:39 +0200 Subject: [PATCH 19/24] README: Update how to run tests --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 8738e00..539deac 100644 --- a/README.md +++ b/README.md @@ -143,5 +143,8 @@ Then execute the tests: $ sudo GOPATH=$GOPATH $(which go) test -v ``` +Other tests can instead run as user without any setup with +normal `go test ./...` + [1]: http://godoc.org/github.com/msteinert/pam/v2 [2]: http://www.linux-pam.org/Linux-PAM-html/Linux-PAM_ADG.html From 61621ce9c3ae449de653df5f3f1d36051cd677d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Fri, 1 Dec 2023 23:41:06 +0100 Subject: [PATCH 20/24] ci: Show coverage for all packages We have test utils in other packages that are not shown as tested, while they definitely are. --- .github/workflows/test.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 69c49ef..43fde5d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -30,18 +30,18 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Test - run: sudo go test -v -cover -coverprofile=coverage.out ./... + run: sudo go test -v -cover -coverprofile=coverage.out -coverpkg=./... ./... - name: Test with Address Sanitizer env: GO_PAM_TEST_WITH_ASAN: true CGO_CFLAGS: "-O0 -g3 -fno-omit-frame-pointer" run: | # Do not run sudo-requiring go tests because as PAM has some leaks in 22.04 - go test -v -asan -cover -coverprofile=coverage-asan-tx.out -gcflags=all="-N -l" + go test -v -asan -cover -coverprofile=coverage-asan-tx.out -coverpkg=./... -gcflags=all="-N -l" # Run the rest of tests normally - sudo go test -v -cover -coverprofile=coverage-asan-module.out -asan -gcflags=all="-N -l" -run Module - sudo go test -C cmd -coverprofile=coverage-asan.out -v -asan -gcflags=all="-N -l" ./... + sudo go test -v -cover -coverprofile=coverage-asan-module.out -coverpkg=./... -asan -gcflags=all="-N -l" -run Module + sudo go test -C cmd -coverprofile=coverage-asan.out -v -coverpkg=./... -asan -gcflags=all="-N -l" ./... - name: Generate example module run: | rm -f example-module/pam_go.so From e182844e4e11401b515cf1c0964d66bae3753602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Sun, 18 Feb 2024 18:31:07 +0100 Subject: [PATCH 21/24] transaction: Fix typo in conversation doc --- transaction.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transaction.go b/transaction.go index bd2876d..b918a24 100644 --- a/transaction.go +++ b/transaction.go @@ -19,7 +19,7 @@ const success = C.PAM_SUCCESS // Style is the type of message that the conversation handler should display. type Style int -// Coversation handler style types. +// Conversation handler style types. const ( // PromptEchoOff indicates the conversation handler should obtain a // string without echoing any text. From f19903865176f469cf13504841965b401584831e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Thu, 21 Mar 2024 06:44:21 +0100 Subject: [PATCH 22/24] transaction: Add some missing flags for modules --- transaction.go | 7 +++++++ transaction.h | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/transaction.go b/transaction.go index b918a24..8ff6e86 100644 --- a/transaction.go +++ b/transaction.go @@ -138,6 +138,13 @@ const ( // ChangeExpiredAuthtok indicates that the authentication token // should be changed if it has expired. ChangeExpiredAuthtok Flags = C.PAM_CHANGE_EXPIRED_AUTHTOK + // PrelimCheck indicates that the password service should only + // perform preliminary checks. No passwords should be updated. + PrelimCheck Flags = C.PAM_PRELIM_CHECK + // UpdateAuthtok indicates that password service should update + // passwords Note: [pam.PrelimCheck] and [pam.UpdateAuthtok] cannot + // both be set simultaneously! + UpdateAuthtok Flags = C.PAM_UPDATE_AUTHTOK ) // PutEnv adds or changes the value of PAM environment variables. diff --git a/transaction.h b/transaction.h index 292aa96..f224d80 100644 --- a/transaction.h +++ b/transaction.h @@ -14,6 +14,14 @@ #define BINARY_PROMPT_IS_SUPPORTED 0 #endif +#ifndef PAM_PRELIM_CHECK +#define PAM_PRELIM_CHECK 0 +#endif + +#ifndef PAM_UPDATE_AUTHTOK +#define PAM_UPDATE_AUTHTOK 0 +#endif + #ifdef __sun #define PAM_CONST #else From 979a1109aa1d126fc8b17a6e2975e1377f892d53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Thu, 11 Apr 2024 19:42:22 +0200 Subject: [PATCH 23/24] fixup! pam-moduler: Add first implementation of a Go PAM Module generator --- cmd/pam-moduler/moduler.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 0a74125..6f5955a 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -65,7 +65,7 @@ var ( libName = flag.String("libname", "", "output library name; default pam_go.so") typeName = flag.String("type", "", "type name to be used as pam.ModuleHandler") buildTags = flag.String("tags", "", "build tags expression to append to use in the go:build directive") - skipGenerator = flag.Bool("no-generator", false, "whether to add go:generator directives to the generated source") + skipGenerator = flag.Bool("no-generator", false, "whether to add go:generate directives to the generated source") moduleBuildFlags = flag.String("build-flags", "", "comma-separated list of go build flags to use when generating the module") moduleBuildTags = flag.String("build-tags", "", "comma-separated list of build tags to use when generating the module") noMain = flag.Bool("no-main", false, "whether to add an empty main to generated file") @@ -89,14 +89,17 @@ func main() { if *libName != "" { fmt.Fprintf(os.Stderr, "Generator directives disabled, libname will have no effect\n") + *libName = "" } if *moduleBuildTags != "" { fmt.Fprintf(os.Stderr, "Generator directives disabled, build-tags will have no effect\n") + *moduleBuildTags = "" } if *moduleBuildFlags != "" { fmt.Fprintf(os.Stderr, "Generator directives disabled, build-flags will have no effect\n") + *moduleBuildFlags = "" } } From e0753f6dab82f0be934a5c792b252338987b3fa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Trevisan=20=28Trevi=C3=B1o=29?= Date: Thu, 11 Apr 2024 19:44:29 +0200 Subject: [PATCH 24/24] fixup! pam-moduler: Add first implementation of a Go PAM Module generator --- cmd/pam-moduler/moduler.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index 6f5955a..42d7acf 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -210,10 +210,10 @@ func (g *Generator) generate() { {"chauthtok", "ChangeAuthTok"}, } - g.printf(`//go:generate go build "-ldflags=-extldflags -Wl,-soname,%[2]s.so" `+ - `-buildmode=c-shared -o %[2]s.so %[3]s %[4]s + g.printf(`//go:generate go build "-ldflags=-extldflags -Wl,-soname,%[1]s.so" `+ + `-buildmode=c-shared -o %[1]s.so %[2]s %[3]s `, - g.outputName, g.libName, buildTagsArg, strings.Join(g.buildFlags, " ")) + g.libName, buildTagsArg, strings.Join(g.buildFlags, " ")) g.printf(` // Package main is the package for the PAM module library.