From 33a7e1304da4bfd194318628a277d64d13d43c89 Mon Sep 17 00:00:00 2001 From: Mikkel Krautz Date: Sat, 8 Dec 2012 22:25:58 +0100 Subject: [PATCH] pkg/cryptstate: cleanups for multiple crypto modes. --- pkg/cryptstate/cryptstate.go | 59 ++++++++++++++----------------- pkg/cryptstate/cryptstate_test.go | 16 ++++----- 2 files changed, 34 insertions(+), 41 deletions(-) diff --git a/pkg/cryptstate/cryptstate.go b/pkg/cryptstate/cryptstate.go index 4d12f04..7f46154 100644 --- a/pkg/cryptstate/cryptstate.go +++ b/pkg/cryptstate/cryptstate.go @@ -17,10 +17,9 @@ import ( const DecryptHistorySize = 0x100 type CryptState struct { - RawKey [aes.BlockSize]byte - EncryptIV [ocb2.NonceSize]byte - DecryptIV [ocb2.NonceSize]byte - decryptHistory [DecryptHistorySize]byte + Key []byte + EncryptIV []byte + DecryptIV []byte LastGoodTime int64 @@ -33,7 +32,8 @@ type CryptState struct { RemoteLost uint32 RemoteResync uint32 - cipher cipher.Block + decryptHistory [DecryptHistorySize]byte + cipher cipher.Block } // SupportedModes returns the list of supported CryptoModes. @@ -42,22 +42,25 @@ func SupportedModes() []string { } func (cs *CryptState) GenerateKey() error { - _, err := io.ReadFull(rand.Reader, cs.RawKey[0:]) + cs.Key = make([]byte, aes.BlockSize) + _, err := io.ReadFull(rand.Reader, cs.Key) if err != nil { return err } - _, err = io.ReadFull(rand.Reader, cs.EncryptIV[0:]) + cs.EncryptIV = make([]byte, ocb2.NonceSize) + _, err = io.ReadFull(rand.Reader, cs.EncryptIV) if err != nil { return err } - _, err = io.ReadFull(rand.Reader, cs.DecryptIV[0:]) + cs.DecryptIV = make([]byte, ocb2.NonceSize) + _, err = io.ReadFull(rand.Reader, cs.DecryptIV) if err != nil { return err } - cs.cipher, err = aes.NewCipher(cs.RawKey[0:]) + cs.cipher, err = aes.NewCipher(cs.Key) if err != nil { return err } @@ -65,28 +68,18 @@ func (cs *CryptState) GenerateKey() error { return nil } -func (cs *CryptState) SetKey(key []byte, eiv []byte, div []byte) (err error) { - if copy(cs.RawKey[0:], key[0:]) != aes.BlockSize { - err = errors.New("Unable to copy key") - return - } +func (cs *CryptState) SetKey(key []byte, eiv []byte, div []byte) error { + cs.Key = key + cs.EncryptIV = eiv + cs.DecryptIV = div - if copy(cs.EncryptIV[0:], eiv[0:]) != ocb2.NonceSize { - err = errors.New("Unable to copy EIV") - return - } - - if copy(cs.DecryptIV[0:], div[0:]) != ocb2.NonceSize { - err = errors.New("Unable to copy DIV") - return - } - - cs.cipher, err = aes.NewCipher(cs.RawKey[0:]) + cipher, err := aes.NewCipher(cs.Key) if err != nil { - return + return err } + cs.cipher = cipher - return + return nil } func (cs *CryptState) Decrypt(dst, src []byte) (err error) { @@ -111,7 +104,7 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) { ivbyte = src[0] restore = false - if copy(saveiv[0:], cs.DecryptIV[0:]) != ocb2.NonceSize { + if copy(saveiv[:], cs.DecryptIV) != ocb2.NonceSize { err = errors.New("Copy failed") return } @@ -179,18 +172,18 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) { } if cs.decryptHistory[cs.DecryptIV[0]] == cs.DecryptIV[0] { - if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize { + if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize { err = errors.New("Failed to copy ocb2.NonceSize bytes") return } } } - ocb2.Decrypt(cs.cipher, dst[0:], src[4:], cs.DecryptIV[0:], tag[0:]) + ocb2.Decrypt(cs.cipher, dst, src[4:], cs.DecryptIV, tag[:]) for i := 0; i < 3; i++ { if tag[i] != src[i+1] { - if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize { + if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize { err = errors.New("Error while trying to recover from error") return } @@ -202,7 +195,7 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) { cs.decryptHistory[cs.DecryptIV[0]] = cs.DecryptIV[0] if restore { - if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize { + if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize { err = errors.New("Error while trying to recover IV") return } @@ -236,7 +229,7 @@ func (cs *CryptState) Encrypt(dst, src []byte) { } } - ocb2.Encrypt(cs.cipher, dst[4:], src, cs.EncryptIV[0:], tag[0:]) + ocb2.Encrypt(cs.cipher, dst[4:], src, cs.EncryptIV, tag[:]) dst[0] = cs.EncryptIV[0] dst[1] = tag[0] diff --git a/pkg/cryptstate/cryptstate_test.go b/pkg/cryptstate/cryptstate_test.go index 643df3b..8be8eae 100644 --- a/pkg/cryptstate/cryptstate_test.go +++ b/pkg/cryptstate/cryptstate_test.go @@ -32,14 +32,14 @@ func TestEncrypt(t *testing.T) { cs := CryptState{} out := make([]byte, 19) - cs.SetKey(key[0:], eiv[0:], div[0:]) - cs.Encrypt(out[0:], msg[0:]) + cs.SetKey(key[:], eiv[:], div[:]) + cs.Encrypt(out, msg[:]) - if !bytes.Equal(out[0:], expected[0:]) { + if !bytes.Equal(out[:], expected[:]) { t.Errorf("Mismatch in output") } - if !bytes.Equal(cs.EncryptIV[0:], expected_eiv[0:]) { + if !bytes.Equal(cs.EncryptIV[:], expected_eiv[:]) { t.Errorf("EIV mismatch") } } @@ -66,14 +66,14 @@ func TestDecrypt(t *testing.T) { cs := CryptState{} out := make([]byte, 15) - cs.SetKey(key[0:], div[0:], eiv[0:]) - cs.Decrypt(out[0:], crypted[0:]) + cs.SetKey(key[:], div[:], eiv[:]) + cs.Decrypt(out, crypted[:]) - if !bytes.Equal(out[0:], expected[0:]) { + if !bytes.Equal(out, expected[:]) { t.Errorf("Mismatch in output") } - if !bytes.Equal(cs.DecryptIV[0:], post_div[0:]) { + if !bytes.Equal(cs.DecryptIV, post_div[:]) { t.Errorf("Mismatch in DIV") } }