Merge pull request #61 from asdf-vm/tb/shim-generation-2

feat(golang-rewrite): shim generation 2
This commit is contained in:
Trevor Brown 2024-09-03 08:31:52 -04:00 committed by Trevor Brown
commit 43cd285727
5 changed files with 273 additions and 26 deletions

View File

@ -10,6 +10,7 @@ import (
"asdf/config"
"asdf/internal/info"
"asdf/internal/shims"
"asdf/internal/versions"
"asdf/plugins"
@ -135,6 +136,12 @@ func Execute(version string) {
},
},
},
{
Name: "reshim",
Action: func(_ *cli.Context) error {
return reshimCommand(logger)
},
},
},
Action: func(_ *cli.Context) error {
// TODO: flesh this out
@ -287,12 +294,15 @@ func installCommand(logger *log.Logger, toolName, version string) error {
errs := versions.InstallAll(conf, dir, os.Stdout, os.Stderr)
if len(errs) > 0 {
for _, err := range errs {
// write error stderr
os.Stderr.Write([]byte(err.Error()))
os.Stderr.Write([]byte("\n"))
}
return errs[0]
filtered := filterInstallErrors(errs)
if len(filtered) > 0 {
return filtered[0]
}
return nil
}
} else {
// Install specific version
@ -321,6 +331,16 @@ func installCommand(logger *log.Logger, toolName, version string) error {
return err
}
func filterInstallErrors(errs []error) []error {
var filtered []error
for _, err := range errs {
if _, ok := err.(versions.NoVersionSetError); !ok {
filtered = append(filtered, err)
}
}
return filtered
}
func parseInstallVersion(version string) (string, string) {
segments := strings.Split(version, ":")
if len(segments) > 1 && segments[0] == "latest" {
@ -369,6 +389,26 @@ func latestCommand(logger *log.Logger, all bool, toolName, pattern string) (err
return nil
}
func reshimCommand(logger *log.Logger) (err error) {
conf, err := config.LoadConfig()
if err != nil {
logger.Printf("error loading config: %s", err)
return err
}
err = shims.RemoveAll(conf)
if err != nil {
return err
}
err = shims.GenerateAll(conf, os.Stdout, os.Stderr)
if err != nil {
return err
}
return err
}
func latestForPlugin(conf config.Config, toolName, pattern string, showStatus bool) error {
// show single plugin
plugin := plugins.New(conf, toolName)
@ -385,7 +425,7 @@ func latestForPlugin(conf config.Config, toolName, pattern string, showStatus bo
}
if showStatus {
installed := versions.Installed(conf, plugin, latest)
installed := versions.IsInstalled(conf, plugin, latest)
fmt.Printf("%s\t%s\t%s\n", plugin.Name, latest, installedStatus(installed))
} else {
fmt.Printf("%s\n", latest)

View File

@ -3,11 +3,14 @@ package shims
import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"asdf/config"
"asdf/hook"
"asdf/internal/toolversions"
"asdf/internal/versions"
"asdf/plugins"
@ -15,6 +18,65 @@ import (
"golang.org/x/sys/unix"
)
const shimDirName = "shims"
// RemoveAll removes all shim scripts
func RemoveAll(conf config.Config) error {
shimDir := filepath.Join(conf.DataDir, shimDirName)
entries, err := os.ReadDir(shimDir)
if err != nil {
return err
}
for _, entry := range entries {
os.RemoveAll(path.Join(shimDir, entry.Name()))
}
return nil
}
// GenerateAll generates shims for all executables of every version of every
// plugin.
func GenerateAll(conf config.Config, stdOut io.Writer, stdErr io.Writer) error {
plugins, err := plugins.List(conf, false, false)
if err != nil {
return err
}
for _, plugin := range plugins {
err := GenerateForPluginVersions(conf, plugin, stdOut, stdErr)
if err != nil {
return err
}
}
return nil
}
// GenerateForPluginVersions generates all shims for all installed versions of
// a tool.
func GenerateForPluginVersions(conf config.Config, plugin plugins.Plugin, stdOut io.Writer, stdErr io.Writer) error {
installedVersions, err := versions.Installed(conf, plugin)
if err != nil {
return err
}
for _, version := range installedVersions {
err = hook.RunWithOutput(conf, fmt.Sprintf("pre_asdf_reshim_%s", plugin.Name), []string{version}, stdOut, stdErr)
if err != nil {
return err
}
GenerateForVersion(conf, plugin, version)
err = hook.RunWithOutput(conf, fmt.Sprintf("post_asdf_reshim_%s", plugin.Name), []string{version}, stdOut, stdErr)
if err != nil {
return err
}
}
return nil
}
// GenerateForVersion loops over all the executable files found for a tool and
// generates a shim for each one
func GenerateForVersion(conf config.Config, plugin plugins.Plugin, version string) error {
@ -59,11 +121,11 @@ func Write(conf config.Config, plugin plugins.Plugin, version, executablePath st
// Path returns the path for a shim script
func Path(conf config.Config, shimName string) string {
return filepath.Join(conf.DataDir, "shims", shimName)
return filepath.Join(conf.DataDir, shimDirName, shimName)
}
func ensureShimDirExists(conf config.Config) error {
return os.MkdirAll(filepath.Join(conf.DataDir, "shims"), 0o777)
return os.MkdirAll(filepath.Join(conf.DataDir, shimDirName), 0o777)
}
// ToolExecutables returns a slice of executables for a given tool version
@ -77,20 +139,20 @@ func ToolExecutables(conf config.Config, plugin plugins.Plugin, version string)
paths := dirsToPaths(dirs, installPath)
for _, path := range paths {
// Walk the directory and any sub directories
err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
entries, err := os.ReadDir(path)
if err != nil {
return executables, err
}
for _, entry := range entries {
// If entry is dir or cannot be executed by the current user ignore it
if info.IsDir() || unix.Access(path, unix.X_OK) != nil {
return nil
filePath := filepath.Join(path, entry.Name())
if entry.IsDir() || unix.Access(filePath, unix.X_OK) != nil {
return executables, nil
}
executables = append(executables, path)
return nil
})
executables = append(executables, filePath)
return executables, nil
}
if err != nil {
return executables, err
}

View File

@ -1,6 +1,8 @@
package shims
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
@ -17,6 +19,93 @@ import (
const testPluginName = "lua"
func TestRemoveAll(t *testing.T) {
version := "1.1.0"
conf, plugin := generateConfig(t)
installVersion(t, conf, plugin, version)
executables, err := ToolExecutables(conf, plugin, version)
assert.Nil(t, err)
stdout, stderr := buildOutputs()
t.Run("removes all files in shim directory", func(t *testing.T) {
assert.Nil(t, GenerateAll(conf, &stdout, &stderr))
assert.Nil(t, RemoveAll(conf))
// check for generated shims
for _, executable := range executables {
_, err := os.Stat(Path(conf, filepath.Base(executable)))
assert.True(t, errors.Is(err, os.ErrNotExist))
}
})
}
func TestGenerateAll(t *testing.T) {
version := "1.1.0"
version2 := "2.0.0"
conf, plugin := generateConfig(t)
installVersion(t, conf, plugin, version)
installPlugin(t, conf, "dummy_plugin", "ruby")
installVersion(t, conf, plugin, version2)
executables, err := ToolExecutables(conf, plugin, version)
assert.Nil(t, err)
stdout, stderr := buildOutputs()
t.Run("generates shim script for every executable in every version of every tool", func(t *testing.T) {
assert.Nil(t, GenerateAll(conf, &stdout, &stderr))
// check for generated shims
for _, executable := range executables {
shimName := filepath.Base(executable)
shimPath := Path(conf, shimName)
assert.Nil(t, unix.Access(shimPath, unix.X_OK))
// shim exists and has expected contents
content, err := os.ReadFile(shimPath)
assert.Nil(t, err)
want := fmt.Sprintf("#!/usr/bin/env bash\n# asdf-plugin: lua 2.0.0\n# asdf-plugin: lua 1.1.0\nexec asdf exec \"%s\" \"$@\"", shimName)
assert.Equal(t, want, string(content))
}
})
}
func TestGenerateForPluginVersions(t *testing.T) {
t.Setenv("ASDF_CONFIG_FILE", "testdata/asdfrc")
version := "1.1.0"
version2 := "2.0.0"
conf, plugin := generateConfig(t)
installVersion(t, conf, plugin, version)
installVersion(t, conf, plugin, version2)
executables, err := ToolExecutables(conf, plugin, version)
assert.Nil(t, err)
stdout, stderr := buildOutputs()
t.Run("generates shim script for every executable in every version the tool", func(t *testing.T) {
assert.Nil(t, GenerateForPluginVersions(conf, plugin, &stdout, &stderr))
// check for generated shims
for _, executable := range executables {
shimName := filepath.Base(executable)
shimPath := Path(conf, shimName)
assert.Nil(t, unix.Access(shimPath, unix.X_OK))
// shim exists and has expected contents
content, err := os.ReadFile(shimPath)
assert.Nil(t, err)
want := fmt.Sprintf("#!/usr/bin/env bash\n# asdf-plugin: lua 2.0.0\n# asdf-plugin: lua 1.1.0\nexec asdf exec \"%s\" \"$@\"", shimName)
assert.Equal(t, want, string(content))
}
})
t.Run("runs pre and post reshim hooks", func(t *testing.T) {
stdout, stderr := buildOutputs()
assert.Nil(t, GenerateForPluginVersions(conf, plugin, &stdout, &stderr))
want := "pre_reshim 1.1.0\npost_reshim 1.1.0\npre_reshim 2.0.0\npost_reshim 2.0.0\n"
assert.Equal(t, want, stdout.String())
})
}
func TestGenerateForVersion(t *testing.T) {
version := "1.1.0"
version2 := "2.0.0"
@ -125,7 +214,7 @@ func TestToolExecutables(t *testing.T) {
filenames = append(filenames, filepath.Base(executablePath))
}
assert.Equal(t, filenames, []string{"dummy", "other_bin"})
assert.Equal(t, filenames, []string{"dummy"})
})
}
@ -165,10 +254,14 @@ func generateConfig(t *testing.T) (config.Config, plugins.Plugin) {
assert.Nil(t, err)
conf.DataDir = testDataDir
_, err = repotest.InstallPlugin("dummy_plugin", testDataDir, testPluginName)
return conf, installPlugin(t, conf, "dummy_plugin", testPluginName)
}
func installPlugin(t *testing.T, conf config.Config, fixture, pluginName string) plugins.Plugin {
_, err := repotest.InstallPlugin(fixture, conf.DataDir, pluginName)
assert.Nil(t, err)
return conf, plugins.New(conf, testPluginName)
return plugins.New(conf, testPluginName)
}
func installVersion(t *testing.T, conf config.Config, plugin plugins.Plugin, version string) {

View File

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"regexp"
@ -48,6 +49,29 @@ func (e NoVersionSetError) Error() string {
return "no version set"
}
// Installed returns a slice of all installed versions for a given plugin
func Installed(conf config.Config, plugin plugins.Plugin) (versions []string, err error) {
installDirectory := pluginInstallPath(conf, plugin)
files, err := os.ReadDir(installDirectory)
if err != nil {
if _, ok := err.(*fs.PathError); ok {
return versions, nil
}
return versions, err
}
for _, file := range files {
if !file.IsDir() {
continue
}
versions = append(versions, file.Name())
}
return versions, err
}
// InstallAll installs all specified versions of every tool for the current
// directory. Typically this will just be a single version, if not already
// installed, but it may be multiple versions if multiple versions for the tool
@ -134,7 +158,7 @@ func InstallOneVersion(conf config.Config, plugin plugins.Plugin, version string
installDir := InstallPath(conf, plugin, version)
versionType, version := ParseString(version)
if Installed(conf, plugin, version) {
if IsInstalled(conf, plugin, version) {
return fmt.Errorf("version %s of %s is already installed", version, plugin.Name)
}
@ -198,8 +222,8 @@ func asdfConcurrency(conf config.Config) string {
return val
}
// Installed checks if a specific version of a tool is installed
func Installed(conf config.Config, plugin plugins.Plugin, version string) bool {
// IsInstalled checks if a specific version of a tool is installed
func IsInstalled(conf config.Config, plugin plugins.Plugin, version string) bool {
installDir := InstallPath(conf, plugin, version)
// Check if version already installed
@ -321,5 +345,9 @@ func downloadPath(conf config.Config, plugin plugins.Plugin, version string) str
// InstallPath returns the path to a tool installation
func InstallPath(conf config.Config, plugin plugins.Plugin, version string) string {
return filepath.Join(conf.DataDir, dataDirInstalls, plugin.Name, version)
return filepath.Join(pluginInstallPath(conf, plugin), version)
}
func pluginInstallPath(conf config.Config, plugin plugins.Plugin) string {
return filepath.Join(conf.DataDir, dataDirInstalls, plugin.Name)
}

View File

@ -16,6 +16,30 @@ import (
const testPluginName = "lua"
func TestInstalled(t *testing.T) {
conf, plugin := generateConfig(t)
//stdout, stderr := buildOutputs()
//currentDir := t.TempDir()
//secondPlugin := installPlugin(t, conf, "dummy_plugin", "another")
//version := "1.0.0"
t.Run("returns empty slice for newly installed plugin", func(t *testing.T) {
installedVersions, err := Installed(conf, plugin)
assert.Nil(t, err)
assert.Empty(t, installedVersions)
})
t.Run("returns slice of all installed versions for a tool", func(t *testing.T) {
stdout, stderr := buildOutputs()
err := InstallOneVersion(conf, plugin, "1.0.0", &stdout, &stderr)
assert.Nil(t, err)
installedVersions, err := Installed(conf, plugin)
assert.Nil(t, err)
assert.Equal(t, installedVersions, []string{"1.0.0"})
})
}
func TestInstallAll(t *testing.T) {
t.Run("installs multiple tools when multiple tool versions are specified", func(t *testing.T) {
conf, plugin := generateConfig(t)
@ -263,17 +287,17 @@ func TestInstallOneVersion(t *testing.T) {
})
}
func TestInstalled(t *testing.T) {
func TestIsInstalled(t *testing.T) {
conf, plugin := generateConfig(t)
stdout, stderr := buildOutputs()
err := InstallOneVersion(conf, plugin, "1.0.0", &stdout, &stderr)
assert.Nil(t, err)
t.Run("returns false when not installed", func(t *testing.T) {
assert.False(t, Installed(conf, plugin, "4.0.0"))
assert.False(t, IsInstalled(conf, plugin, "4.0.0"))
})
t.Run("returns true when installed", func(t *testing.T) {
assert.True(t, Installed(conf, plugin, "1.0.0"))
assert.True(t, IsInstalled(conf, plugin, "1.0.0"))
})
}