From 985c181118b39bb555a2bd2a1cdbf111a25bd512 Mon Sep 17 00:00:00 2001 From: Trevor Brown Date: Mon, 2 Sep 2024 14:42:54 -0400 Subject: [PATCH] feat(golang-rewrite): shim generation 2 * Rename `versions.Installed` function to `IsInstalled` * Create `versions.Installed` function * Create `shims.GenerateForPluginVersions` and `shims.GenerateAll` functions * Address linter warnings * Create asdf reshim command * Run asdf hook from new `shims` functions --- cmd/cmd.go | 46 +++++++++++++- internal/shims/shims.go | 88 ++++++++++++++++++++++---- internal/shims/shims_test.go | 99 +++++++++++++++++++++++++++++- internal/versions/versions.go | 36 +++++++++-- internal/versions/versions_test.go | 30 ++++++++- 5 files changed, 273 insertions(+), 26 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 6306dc90..de948997 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -10,6 +10,7 @@ import ( "asdf/config" "asdf/internal/info" + "asdf/internal/shims" "asdf/internal/versions" "asdf/plugins" @@ -135,6 +136,12 @@ func Execute(version string) { }, }, }, + { + Name: "reshim", + Action: func(_ *cli.Context) error { + return reshimCommand(logger) + }, + }, }, Action: func(_ *cli.Context) error { // TODO: flesh this out @@ -287,12 +294,15 @@ func installCommand(logger *log.Logger, toolName, version string) error { errs := versions.InstallAll(conf, dir, os.Stdout, os.Stderr) if len(errs) > 0 { for _, err := range errs { - // write error stderr os.Stderr.Write([]byte(err.Error())) os.Stderr.Write([]byte("\n")) } - return errs[0] + filtered := filterInstallErrors(errs) + if len(filtered) > 0 { + return filtered[0] + } + return nil } } else { // Install specific version @@ -321,6 +331,16 @@ func installCommand(logger *log.Logger, toolName, version string) error { return err } +func filterInstallErrors(errs []error) []error { + var filtered []error + for _, err := range errs { + if _, ok := err.(versions.NoVersionSetError); !ok { + filtered = append(filtered, err) + } + } + return filtered +} + func parseInstallVersion(version string) (string, string) { segments := strings.Split(version, ":") if len(segments) > 1 && segments[0] == "latest" { @@ -369,6 +389,26 @@ func latestCommand(logger *log.Logger, all bool, toolName, pattern string) (err return nil } +func reshimCommand(logger *log.Logger) (err error) { + conf, err := config.LoadConfig() + if err != nil { + logger.Printf("error loading config: %s", err) + return err + } + + err = shims.RemoveAll(conf) + if err != nil { + return err + } + + err = shims.GenerateAll(conf, os.Stdout, os.Stderr) + if err != nil { + return err + } + + return err +} + func latestForPlugin(conf config.Config, toolName, pattern string, showStatus bool) error { // show single plugin plugin := plugins.New(conf, toolName) @@ -385,7 +425,7 @@ func latestForPlugin(conf config.Config, toolName, pattern string, showStatus bo } if showStatus { - installed := versions.Installed(conf, plugin, latest) + installed := versions.IsInstalled(conf, plugin, latest) fmt.Printf("%s\t%s\t%s\n", plugin.Name, latest, installedStatus(installed)) } else { fmt.Printf("%s\n", latest) diff --git a/internal/shims/shims.go b/internal/shims/shims.go index 48f8c09d..a44afd1d 100644 --- a/internal/shims/shims.go +++ b/internal/shims/shims.go @@ -3,11 +3,14 @@ package shims import ( "fmt" + "io" "os" + "path" "path/filepath" "strings" "asdf/config" + "asdf/hook" "asdf/internal/toolversions" "asdf/internal/versions" "asdf/plugins" @@ -15,6 +18,65 @@ import ( "golang.org/x/sys/unix" ) +const shimDirName = "shims" + +// RemoveAll removes all shim scripts +func RemoveAll(conf config.Config) error { + shimDir := filepath.Join(conf.DataDir, shimDirName) + entries, err := os.ReadDir(shimDir) + if err != nil { + return err + } + + for _, entry := range entries { + os.RemoveAll(path.Join(shimDir, entry.Name())) + } + + return nil +} + +// GenerateAll generates shims for all executables of every version of every +// plugin. +func GenerateAll(conf config.Config, stdOut io.Writer, stdErr io.Writer) error { + plugins, err := plugins.List(conf, false, false) + if err != nil { + return err + } + + for _, plugin := range plugins { + err := GenerateForPluginVersions(conf, plugin, stdOut, stdErr) + if err != nil { + return err + } + } + + return nil +} + +// GenerateForPluginVersions generates all shims for all installed versions of +// a tool. +func GenerateForPluginVersions(conf config.Config, plugin plugins.Plugin, stdOut io.Writer, stdErr io.Writer) error { + installedVersions, err := versions.Installed(conf, plugin) + if err != nil { + return err + } + + for _, version := range installedVersions { + err = hook.RunWithOutput(conf, fmt.Sprintf("pre_asdf_reshim_%s", plugin.Name), []string{version}, stdOut, stdErr) + if err != nil { + return err + } + + GenerateForVersion(conf, plugin, version) + + err = hook.RunWithOutput(conf, fmt.Sprintf("post_asdf_reshim_%s", plugin.Name), []string{version}, stdOut, stdErr) + if err != nil { + return err + } + } + return nil +} + // GenerateForVersion loops over all the executable files found for a tool and // generates a shim for each one func GenerateForVersion(conf config.Config, plugin plugins.Plugin, version string) error { @@ -59,11 +121,11 @@ func Write(conf config.Config, plugin plugins.Plugin, version, executablePath st // Path returns the path for a shim script func Path(conf config.Config, shimName string) string { - return filepath.Join(conf.DataDir, "shims", shimName) + return filepath.Join(conf.DataDir, shimDirName, shimName) } func ensureShimDirExists(conf config.Config) error { - return os.MkdirAll(filepath.Join(conf.DataDir, "shims"), 0o777) + return os.MkdirAll(filepath.Join(conf.DataDir, shimDirName), 0o777) } // ToolExecutables returns a slice of executables for a given tool version @@ -77,20 +139,20 @@ func ToolExecutables(conf config.Config, plugin plugins.Plugin, version string) paths := dirsToPaths(dirs, installPath) for _, path := range paths { - // Walk the directory and any sub directories - err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - + entries, err := os.ReadDir(path) + if err != nil { + return executables, err + } + for _, entry := range entries { // If entry is dir or cannot be executed by the current user ignore it - if info.IsDir() || unix.Access(path, unix.X_OK) != nil { - return nil + filePath := filepath.Join(path, entry.Name()) + if entry.IsDir() || unix.Access(filePath, unix.X_OK) != nil { + return executables, nil } - executables = append(executables, path) - return nil - }) + executables = append(executables, filePath) + return executables, nil + } if err != nil { return executables, err } diff --git a/internal/shims/shims_test.go b/internal/shims/shims_test.go index 07a5bb32..549d885e 100644 --- a/internal/shims/shims_test.go +++ b/internal/shims/shims_test.go @@ -1,6 +1,8 @@ package shims import ( + "errors" + "fmt" "os" "path/filepath" "strings" @@ -17,6 +19,93 @@ import ( const testPluginName = "lua" +func TestRemoveAll(t *testing.T) { + version := "1.1.0" + conf, plugin := generateConfig(t) + installVersion(t, conf, plugin, version) + executables, err := ToolExecutables(conf, plugin, version) + assert.Nil(t, err) + stdout, stderr := buildOutputs() + + t.Run("removes all files in shim directory", func(t *testing.T) { + assert.Nil(t, GenerateAll(conf, &stdout, &stderr)) + assert.Nil(t, RemoveAll(conf)) + + // check for generated shims + for _, executable := range executables { + _, err := os.Stat(Path(conf, filepath.Base(executable))) + assert.True(t, errors.Is(err, os.ErrNotExist)) + } + }) +} + +func TestGenerateAll(t *testing.T) { + version := "1.1.0" + version2 := "2.0.0" + conf, plugin := generateConfig(t) + installVersion(t, conf, plugin, version) + installPlugin(t, conf, "dummy_plugin", "ruby") + installVersion(t, conf, plugin, version2) + executables, err := ToolExecutables(conf, plugin, version) + assert.Nil(t, err) + stdout, stderr := buildOutputs() + + t.Run("generates shim script for every executable in every version of every tool", func(t *testing.T) { + assert.Nil(t, GenerateAll(conf, &stdout, &stderr)) + + // check for generated shims + for _, executable := range executables { + shimName := filepath.Base(executable) + shimPath := Path(conf, shimName) + assert.Nil(t, unix.Access(shimPath, unix.X_OK)) + + // shim exists and has expected contents + content, err := os.ReadFile(shimPath) + assert.Nil(t, err) + want := fmt.Sprintf("#!/usr/bin/env bash\n# asdf-plugin: lua 2.0.0\n# asdf-plugin: lua 1.1.0\nexec asdf exec \"%s\" \"$@\"", shimName) + assert.Equal(t, want, string(content)) + } + }) +} + +func TestGenerateForPluginVersions(t *testing.T) { + t.Setenv("ASDF_CONFIG_FILE", "testdata/asdfrc") + version := "1.1.0" + version2 := "2.0.0" + conf, plugin := generateConfig(t) + installVersion(t, conf, plugin, version) + installVersion(t, conf, plugin, version2) + executables, err := ToolExecutables(conf, plugin, version) + assert.Nil(t, err) + stdout, stderr := buildOutputs() + + t.Run("generates shim script for every executable in every version the tool", func(t *testing.T) { + assert.Nil(t, GenerateForPluginVersions(conf, plugin, &stdout, &stderr)) + + // check for generated shims + for _, executable := range executables { + shimName := filepath.Base(executable) + shimPath := Path(conf, shimName) + assert.Nil(t, unix.Access(shimPath, unix.X_OK)) + + // shim exists and has expected contents + content, err := os.ReadFile(shimPath) + assert.Nil(t, err) + + want := fmt.Sprintf("#!/usr/bin/env bash\n# asdf-plugin: lua 2.0.0\n# asdf-plugin: lua 1.1.0\nexec asdf exec \"%s\" \"$@\"", shimName) + assert.Equal(t, want, string(content)) + } + }) + + t.Run("runs pre and post reshim hooks", func(t *testing.T) { + stdout, stderr := buildOutputs() + assert.Nil(t, GenerateForPluginVersions(conf, plugin, &stdout, &stderr)) + + want := "pre_reshim 1.1.0\npost_reshim 1.1.0\npre_reshim 2.0.0\npost_reshim 2.0.0\n" + assert.Equal(t, want, stdout.String()) + }) +} + func TestGenerateForVersion(t *testing.T) { version := "1.1.0" version2 := "2.0.0" @@ -125,7 +214,7 @@ func TestToolExecutables(t *testing.T) { filenames = append(filenames, filepath.Base(executablePath)) } - assert.Equal(t, filenames, []string{"dummy", "other_bin"}) + assert.Equal(t, filenames, []string{"dummy"}) }) } @@ -165,10 +254,14 @@ func generateConfig(t *testing.T) (config.Config, plugins.Plugin) { assert.Nil(t, err) conf.DataDir = testDataDir - _, err = repotest.InstallPlugin("dummy_plugin", testDataDir, testPluginName) + return conf, installPlugin(t, conf, "dummy_plugin", testPluginName) +} + +func installPlugin(t *testing.T, conf config.Config, fixture, pluginName string) plugins.Plugin { + _, err := repotest.InstallPlugin(fixture, conf.DataDir, pluginName) assert.Nil(t, err) - return conf, plugins.New(conf, testPluginName) + return plugins.New(conf, testPluginName) } func installVersion(t *testing.T, conf config.Config, plugin plugins.Plugin, version string) { diff --git a/internal/versions/versions.go b/internal/versions/versions.go index f1e6622c..0ffd8d6d 100644 --- a/internal/versions/versions.go +++ b/internal/versions/versions.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "io/fs" "os" "path/filepath" "regexp" @@ -48,6 +49,29 @@ func (e NoVersionSetError) Error() string { return "no version set" } +// Installed returns a slice of all installed versions for a given plugin +func Installed(conf config.Config, plugin plugins.Plugin) (versions []string, err error) { + installDirectory := pluginInstallPath(conf, plugin) + files, err := os.ReadDir(installDirectory) + if err != nil { + if _, ok := err.(*fs.PathError); ok { + return versions, nil + } + + return versions, err + } + + for _, file := range files { + if !file.IsDir() { + continue + } + + versions = append(versions, file.Name()) + } + + return versions, err +} + // InstallAll installs all specified versions of every tool for the current // directory. Typically this will just be a single version, if not already // installed, but it may be multiple versions if multiple versions for the tool @@ -134,7 +158,7 @@ func InstallOneVersion(conf config.Config, plugin plugins.Plugin, version string installDir := InstallPath(conf, plugin, version) versionType, version := ParseString(version) - if Installed(conf, plugin, version) { + if IsInstalled(conf, plugin, version) { return fmt.Errorf("version %s of %s is already installed", version, plugin.Name) } @@ -198,8 +222,8 @@ func asdfConcurrency(conf config.Config) string { return val } -// Installed checks if a specific version of a tool is installed -func Installed(conf config.Config, plugin plugins.Plugin, version string) bool { +// IsInstalled checks if a specific version of a tool is installed +func IsInstalled(conf config.Config, plugin plugins.Plugin, version string) bool { installDir := InstallPath(conf, plugin, version) // Check if version already installed @@ -321,5 +345,9 @@ func downloadPath(conf config.Config, plugin plugins.Plugin, version string) str // InstallPath returns the path to a tool installation func InstallPath(conf config.Config, plugin plugins.Plugin, version string) string { - return filepath.Join(conf.DataDir, dataDirInstalls, plugin.Name, version) + return filepath.Join(pluginInstallPath(conf, plugin), version) +} + +func pluginInstallPath(conf config.Config, plugin plugins.Plugin) string { + return filepath.Join(conf.DataDir, dataDirInstalls, plugin.Name) } diff --git a/internal/versions/versions_test.go b/internal/versions/versions_test.go index ce2eeb76..dc314cf4 100644 --- a/internal/versions/versions_test.go +++ b/internal/versions/versions_test.go @@ -16,6 +16,30 @@ import ( const testPluginName = "lua" +func TestInstalled(t *testing.T) { + conf, plugin := generateConfig(t) + //stdout, stderr := buildOutputs() + //currentDir := t.TempDir() + //secondPlugin := installPlugin(t, conf, "dummy_plugin", "another") + //version := "1.0.0" + + t.Run("returns empty slice for newly installed plugin", func(t *testing.T) { + installedVersions, err := Installed(conf, plugin) + assert.Nil(t, err) + assert.Empty(t, installedVersions) + }) + + t.Run("returns slice of all installed versions for a tool", func(t *testing.T) { + stdout, stderr := buildOutputs() + err := InstallOneVersion(conf, plugin, "1.0.0", &stdout, &stderr) + assert.Nil(t, err) + + installedVersions, err := Installed(conf, plugin) + assert.Nil(t, err) + assert.Equal(t, installedVersions, []string{"1.0.0"}) + }) +} + func TestInstallAll(t *testing.T) { t.Run("installs multiple tools when multiple tool versions are specified", func(t *testing.T) { conf, plugin := generateConfig(t) @@ -263,17 +287,17 @@ func TestInstallOneVersion(t *testing.T) { }) } -func TestInstalled(t *testing.T) { +func TestIsInstalled(t *testing.T) { conf, plugin := generateConfig(t) stdout, stderr := buildOutputs() err := InstallOneVersion(conf, plugin, "1.0.0", &stdout, &stderr) assert.Nil(t, err) t.Run("returns false when not installed", func(t *testing.T) { - assert.False(t, Installed(conf, plugin, "4.0.0")) + assert.False(t, IsInstalled(conf, plugin, "4.0.0")) }) t.Run("returns true when installed", func(t *testing.T) { - assert.True(t, Installed(conf, plugin, "1.0.0")) + assert.True(t, IsInstalled(conf, plugin, "1.0.0")) }) }