Merge pull request #65 from asdf-vm/tb/shim-exec-1

feat(golang-rewrite): create `shims.FindExecutable` function for shim execution
This commit is contained in:
Trevor Brown 2024-09-12 08:50:05 -04:00 committed by Trevor Brown
commit 05b9c37232
6 changed files with 266 additions and 4 deletions

21
internal/paths/paths.go Normal file
View File

@ -0,0 +1,21 @@
// Package paths contains a variety of helper functions responsible for
// computing paths to various things. This package should not depend on any
// other asdf packages.
package paths
import (
"strings"
)
// RemoveFromPath returns the PATH without asdf shims path
func RemoveFromPath(currentPath, pathToRemove string) string {
var newPaths []string
for _, fspath := range strings.Split(currentPath, ":") {
if fspath != pathToRemove {
newPaths = append(newPaths, fspath)
}
}
return strings.Join(newPaths, ":")
}

View File

@ -0,0 +1,19 @@
package paths
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestRemoveFromPath(t *testing.T) {
t.Run("returns PATH string with matching path removed", func(t *testing.T) {
got := RemoveFromPath("/foo/bar:/baz/bim:/home/user/bin", "/baz/bim")
assert.Equal(t, got, "/foo/bar:/home/user/bin")
})
t.Run("returns PATH string unchanged when no matching path found", func(t *testing.T) {
got := RemoveFromPath("/foo/bar:/baz/bim:/home/user/bin", "/path-not-present/")
assert.Equal(t, got, "/foo/bar:/baz/bim:/home/user/bin")
})
}

View File

@ -5,14 +5,18 @@ import (
"fmt"
"io"
"os"
"os/exec"
"path"
"path/filepath"
"slices"
"strings"
"asdf/internal/config"
"asdf/internal/hook"
"asdf/internal/installs"
"asdf/internal/paths"
"asdf/internal/plugins"
"asdf/internal/resolve"
"asdf/internal/toolversions"
"golang.org/x/sys/unix"
@ -20,6 +24,134 @@ import (
const shimDirName = "shims"
// UnknownCommandError is an error returned when a shim is not found
type UnknownCommandError struct {
shim string
}
func (e UnknownCommandError) Error() string {
return fmt.Sprintf("unknown command: %s", e.shim)
}
// NoVersionSetError is returned when shim is found but no version matches
type NoVersionSetError struct {
shim string
}
func (e NoVersionSetError) Error() string {
return fmt.Sprintf("no versions set for %s", e.shim)
}
// NoExecutableForPluginError is returned when a compatible version is found
// but no executable matching the name is located.
type NoExecutableForPluginError struct {
shim string
}
func (e NoExecutableForPluginError) Error() string {
return fmt.Sprintf("no %s executable for plugin %s", e.shim, "")
}
// FindExecutable takes a shim name and a current directory and returns the path
// to the executable that the shim resolves to.
func FindExecutable(conf config.Config, shimName, currentDirectory string) (string, bool, error) {
shimPath := Path(conf, shimName)
if _, err := os.Stat(shimPath); err != nil {
return "", false, UnknownCommandError{shim: shimName}
}
toolVersions, err := getToolsAndVersionsFromShimFile(shimPath)
if err != nil {
return "", false, err
}
existingPluginToolVersions := make(map[plugins.Plugin]resolve.ToolVersions)
// loop over tools and check if the plugin for them still exists
for _, shimToolVersion := range toolVersions {
plugin := plugins.New(conf, shimToolVersion.Name)
if plugin.Exists() == nil {
versions, found, err := resolve.Version(conf, plugin, currentDirectory)
if err != nil {
return "", false, nil
}
if found {
tempVersions := toolversions.Intersect(shimToolVersion.Versions, versions.Versions)
if slices.Contains(versions.Versions, "system") {
tempVersions = append(tempVersions, "system")
}
versions.Versions = tempVersions
existingPluginToolVersions[plugin] = versions
}
}
}
if len(existingPluginToolVersions) == 0 {
return "", false, NoVersionSetError{shim: shimName}
}
for plugin, toolVersions := range existingPluginToolVersions {
for _, version := range toolVersions.Versions {
if version == "system" {
if executablePath, found := FindSystemExecutable(conf, shimName); found {
return executablePath, true, nil
}
break
}
executablePath, err := GetExecutablePath(conf, plugin, shimName, version)
if err == nil {
return executablePath, true, nil
}
}
}
return "", false, NoExecutableForPluginError{shim: shimName}
}
// FindSystemExecutable returns the path to the system
// executable if found
func FindSystemExecutable(conf config.Config, executableName string) (string, bool) {
currentPath := os.Getenv("PATH")
defer os.Setenv("PATH", currentPath)
os.Setenv("PATH", paths.RemoveFromPath(currentPath, shimsDirectory(conf)))
executablePath, err := exec.LookPath(executableName)
return executablePath, err == nil
}
// GetExecutablePath returns the path of the executable
func GetExecutablePath(conf config.Config, plugin plugins.Plugin, shimName, version string) (string, error) {
executables, err := ToolExecutables(conf, plugin, "version", version)
if err != nil {
return "", err
}
for _, executablePath := range executables {
executableName := filepath.Base(executablePath)
if executableName == shimName {
return executablePath, nil
}
}
return "", fmt.Errorf("executable not found")
}
func getToolsAndVersionsFromShimFile(shimPath string) (versions []toolversions.ToolVersions, err error) {
contents, err := os.ReadFile(shimPath)
if err != nil {
return versions, err
}
versions = parse(string(contents))
versions = toolversions.Unique(versions)
return versions, err
}
// RemoveAll removes all shim scripts
func RemoveAll(conf config.Config) error {
shimDir := filepath.Join(conf.DataDir, shimDirName)
@ -109,12 +241,10 @@ func Write(conf config.Config, plugin plugins.Plugin, version, executablePath st
versions := []toolversions.ToolVersions{{Name: plugin.Name, Versions: []string{version}}}
if _, err := os.Stat(shimPath); err == nil {
contents, err := os.ReadFile(shimPath)
oldVersions, err := getToolsAndVersionsFromShimFile(shimPath)
if err != nil {
return err
}
oldVersions := parse(string(contents))
versions = toolversions.Unique(append(versions, oldVersions...))
}
@ -126,6 +256,10 @@ func Path(conf config.Config, shimName string) string {
return filepath.Join(conf.DataDir, shimDirName, shimName)
}
func shimsDirectory(conf config.Config) string {
return filepath.Join(conf.DataDir, shimDirName)
}
func ensureShimDirExists(conf config.Config) error {
return os.MkdirAll(filepath.Join(conf.DataDir, shimDirName), 0o777)
}
@ -153,7 +287,6 @@ func ToolExecutables(conf config.Config, plugin plugins.Plugin, versionType, ver
}
executables = append(executables, filePath)
return executables, nil
}
if err != nil {
return executables, err

View File

@ -20,6 +20,60 @@ import (
const testPluginName = "lua"
func TestFindExecutable(t *testing.T) {
version := "1.1.0"
conf, plugin := generateConfig(t)
installVersion(t, conf, plugin, version)
stdout, stderr := buildOutputs()
assert.Nil(t, GenerateAll(conf, &stdout, &stderr))
currentDir := t.TempDir()
t.Run("returns error when shim with name does not exist", func(t *testing.T) {
executable, found, err := FindExecutable(conf, "foo", currentDir)
assert.Empty(t, executable)
assert.False(t, found)
assert.Equal(t, err.(UnknownCommandError).Error(), "unknown command: foo")
})
t.Run("returns error when shim is present but no version is set", func(t *testing.T) {
executable, found, err := FindExecutable(conf, "dummy", currentDir)
assert.Empty(t, executable)
assert.False(t, found)
assert.Equal(t, err.(NoVersionSetError).Error(), "no versions set for dummy")
})
t.Run("returns string containing path to executable when found", func(t *testing.T) {
// write a version file
data := []byte("lua 1.1.0")
assert.Nil(t, os.WriteFile(filepath.Join(currentDir, ".tool-versions"), data, 0o666))
executable, found, err := FindExecutable(conf, "dummy", currentDir)
assert.Equal(t, filepath.Base(filepath.Dir(filepath.Dir(executable))), "1.1.0")
assert.Equal(t, filepath.Base(executable), "dummy")
assert.True(t, found)
assert.Nil(t, err)
})
t.Run("returns string containing path to system executable when system version set", func(t *testing.T) {
// Create dummy `ls` executable
path := filepath.Join(installs.InstallPath(conf, plugin, "version", version), "bin", "ls")
assert.Nil(t, os.WriteFile(path, []byte("echo 'I'm ls'"), 0o777))
// write system version to version file
toolpath := filepath.Join(currentDir, ".tool-versions")
assert.Nil(t, os.WriteFile(toolpath, []byte("lua system\n"), 0o666))
assert.Nil(t, GenerateAll(conf, &stdout, &stderr))
executable, found, err := FindExecutable(conf, "ls", currentDir)
assert.True(t, found)
assert.Nil(t, err)
// see that it actually returns path to system ls
assert.Equal(t, filepath.Base(executable), "ls")
assert.NotEqual(t, executable, path)
})
}
func TestRemoveAll(t *testing.T) {
version := "1.1.0"
conf, plugin := generateConfig(t)

View File

@ -38,6 +38,19 @@ func GetAllToolsAndVersions(filepath string) (toolVersions []ToolVersions, err e
return toolVersions, nil
}
// Intersect takes two slices of versions and returns a new slice containing
// only the versions found in both.
func Intersect(versions1 []string, versions2 []string) (versions []string) {
for _, version1 := range versions1 {
for _, version2 := range versions2 {
if version2 == version1 {
versions = append(versions, version1)
}
}
}
return versions
}
// Unique takes a slice of ToolVersions and returns a slice of unique tools and
// versions.
func Unique(versions []ToolVersions) (uniques []ToolVersions) {

View File

@ -51,6 +51,28 @@ func TestFindToolVersions(t *testing.T) {
})
}
func TestIntersect(t *testing.T) {
t.Run("when provided two empty ToolVersions returns empty ToolVersions", func(t *testing.T) {
got := Intersect([]string{}, []string{})
want := []string(nil)
assert.Equal(t, got, want)
})
t.Run("when provided ToolVersions with no matching versions return empty ToolVersions", func(t *testing.T) {
got := Intersect([]string{"1", "2"}, []string{"3", "4"})
assert.Equal(t, got, []string(nil))
})
t.Run("when provided ToolVersions with different versions return new ToolVersions only containing versions in both", func(t *testing.T) {
got := Intersect([]string{"1", "2"}, []string{"2", "3"})
want := []string{"2"}
assert.Equal(t, got, want)
})
}
func TestUnique(t *testing.T) {
t.Run("returns unique slice of tool versions when tool appears multiple times in slice", func(t *testing.T) {
got := Unique([]ToolVersions{