diff --git a/luhn/luhn.go b/luhn/luhn.go index ab87e42bf..bc4e15b36 100644 --- a/luhn/luhn.go +++ b/luhn/luhn.go @@ -1,7 +1,10 @@ // Package luhn generates and validates Luhn mod N check digits. package luhn -import "strings" +import ( + "fmt" + "strings" +) // An alphabet is a string of N characters, representing the digits of a given // base N. @@ -13,13 +16,20 @@ var ( // Generate returns a check digit for the string s, which should be composed // of characters from the Alphabet a. -func (a Alphabet) Generate(s string) rune { +func (a Alphabet) Generate(s string) (rune, error) { + if err:=a.check();err!=nil{ + return 0,err + } + factor := 1 sum := 0 n := len(a) for i := range s { codepoint := strings.IndexByte(string(a), s[i]) + if codepoint == -1 { + return 0, fmt.Errorf("Digit %q not valid in alphabet %q", s[i], a) + } addend := factor * codepoint if factor == 2 { factor = 1 @@ -31,13 +41,28 @@ func (a Alphabet) Generate(s string) rune { } remainder := sum % n checkCodepoint := (n - remainder) % n - return rune(a[checkCodepoint]) + return rune(a[checkCodepoint]), nil } // Validate returns true if the last character of the string s is correct, for // a string s composed of characters in the alphabet a. func (a Alphabet) Validate(s string) bool { t := s[:len(s)-1] - c := a.Generate(t) + c, err := a.Generate(t) + if err != nil { + return false + } return rune(s[len(s)-1]) == c } + +// check returns an error if the given alphabet does not consist of unique characters +func (a Alphabet) check() error { + cm := make(map[byte]bool, len(a)) + for i := range a { + if cm[a[i]] { + return fmt.Errorf("Digit %q non-unique in alphabet %q", a[i], a) + } + cm[a[i]] = true + } + return nil +} diff --git a/luhn/luhn_test.go b/luhn/luhn_test.go index dcf8513bf..7a439c222 100644 --- a/luhn/luhn_test.go +++ b/luhn/luhn_test.go @@ -9,19 +9,43 @@ import ( func TestGenerate(t *testing.T) { // Base 6 Luhn a := luhn.Alphabet("abcdef") - c := a.Generate("abcdef") + c, err := a.Generate("abcdef") + if err != nil { + t.Fatal(err) + } if c != 'e' { t.Errorf("Incorrect check digit %c != e", c) } // Base 10 Luhn a = luhn.Alphabet("0123456789") - c = a.Generate("7992739871") + c, err = a.Generate("7992739871") + if err != nil { + t.Fatal(err) + } if c != '3' { t.Errorf("Incorrect check digit %c != 3", c) } } +func TestInvalidString(t *testing.T) { + a := luhn.Alphabet("ABC") + _, err := a.Generate("7992739871") + t.Log(err) + if err == nil { + t.Error("Unexpected nil error") + } +} + +func TestBadAlphabet(t *testing.T) { + a := luhn.Alphabet("01234566789") + _, err := a.Generate("7992739871") + t.Log(err) + if err == nil { + t.Error("Unexpected nil error") + } +} + func TestValidate(t *testing.T) { a := luhn.Alphabet("abcdef") if !a.Validate("abcdefe") { diff --git a/protocol/nodeid.go b/protocol/nodeid.go index af62309b7..415d265d3 100644 --- a/protocol/nodeid.go +++ b/protocol/nodeid.go @@ -34,7 +34,11 @@ func NodeIDFromString(s string) (NodeID, error) { func (n NodeID) String() string { id := base32.StdEncoding.EncodeToString(n[:]) id = strings.Trim(id, "=") - id = luhnify(id) + id, err := luhnify(id) + if err != nil { + // Should never happen + panic(err) + } id = chunkify(id) return id } @@ -84,7 +88,7 @@ func (n *NodeID) UnmarshalText(bs []byte) error { } } -func luhnify(s string) string { +func luhnify(s string) (string, error) { if len(s) != 52 { panic("unsupported string length") } @@ -92,10 +96,13 @@ func luhnify(s string) string { res := make([]string, 0, 4) for i := 0; i < 4; i++ { p := s[i*13 : (i+1)*13] - l := luhn.Base32.Generate(p) + l, err := luhn.Base32.Generate(p) + if err != nil { + return "", err + } res = append(res, fmt.Sprintf("%s%c", p, l)) } - return res[0] + res[1] + res[2] + res[3] + return res[0] + res[1] + res[2] + res[3], nil } func unluhnify(s string) (string, error) { @@ -106,7 +113,10 @@ func unluhnify(s string) (string, error) { res := make([]string, 0, 4) for i := 0; i < 4; i++ { p := s[i*14 : (i+1)*14-1] - l := luhn.Base32.Generate(p) + l, err := luhn.Base32.Generate(p) + if err != nil { + return "", err + } if g := fmt.Sprintf("%s%c", p, l); g != s[i*14:(i+1)*14] { log.Printf("%q; %q", g, s[i*14:(i+1)*14]) return "", errors.New("check digit incorrect")