diff --git a/cobra/cmd/add_test.go b/cobra/cmd/add_test.go index 0de1d22..de92fce 100644 --- a/cobra/cmd/add_test.go +++ b/cobra/cmd/add_test.go @@ -7,31 +7,14 @@ import ( ) func TestGoldenAddCmd(t *testing.T) { - - wd, _ := os.Getwd() command := &Command{ CmdName: "test", CmdParent: parentName, - Project: &Project{ - AbsolutePath: fmt.Sprintf("%s/testproject", wd), - Legal: getLicense(), - Copyright: copyrightLine(), - - // required to init - AppName: "testproject", - PkgName: "github.com/spf13/testproject", - Viper: true, - }, + Project: getProject(), } + defer os.RemoveAll(command.AbsolutePath) - // init project first command.Project.Create() - defer func() { - if _, err := os.Stat(command.AbsolutePath); err == nil { - os.RemoveAll(command.AbsolutePath) - } - }() - if err := command.Create(); err != nil { t.Fatal(err) } diff --git a/cobra/cmd/init_test.go b/cobra/cmd/init_test.go index 9540b2d..8ee3910 100644 --- a/cobra/cmd/init_test.go +++ b/cobra/cmd/init_test.go @@ -7,29 +7,26 @@ import ( "testing" ) -func TestGoldenInitCmd(t *testing.T) { - +func getProject() *Project { wd, _ := os.Getwd() - project := &Project{ + return &Project{ AbsolutePath: fmt.Sprintf("%s/testproject", wd), - PkgName: "github.com/spf13/testproject", Legal: getLicense(), Copyright: copyrightLine(), - Viper: true, AppName: "testproject", + PkgName: "github.com/spf13/testproject", + Viper: true, } +} - err := project.Create() - if err != nil { +func TestGoldenInitCmd(t *testing.T) { + project := getProject() + defer os.RemoveAll(project.AbsolutePath) + + if err := project.Create(); err != nil { t.Fatal(err) } - defer func() { - if _, err := os.Stat(project.AbsolutePath); err == nil { - os.RemoveAll(project.AbsolutePath) - } - }() - expectedFiles := []string{"LICENSE", "main.go", "cmd/root.go"} for _, f := range expectedFiles { generatedFile := fmt.Sprintf("%s/%s", project.AbsolutePath, f) diff --git a/cobra/cmd/project.go b/cobra/cmd/project.go index a53893c..ecd783d 100644 --- a/cobra/cmd/project.go +++ b/cobra/cmd/project.go @@ -75,6 +75,7 @@ func (p *Project) createLicenseFile() error { if err != nil { return err } + defer licenseFile.Close() licenseTemplate := template.Must(template.New("license").Parse(p.Legal.Text)) return licenseTemplate.Execute(licenseFile, data)