1
0
Fork 0
forked from External/grumble

pkg/cryptstate: cleanups for multiple crypto modes.

This commit is contained in:
Mikkel Krautz 2012-12-08 22:25:58 +01:00
parent cabe380244
commit 33a7e1304d
2 changed files with 34 additions and 41 deletions

View file

@ -17,10 +17,9 @@ import (
const DecryptHistorySize = 0x100 const DecryptHistorySize = 0x100
type CryptState struct { type CryptState struct {
RawKey [aes.BlockSize]byte Key []byte
EncryptIV [ocb2.NonceSize]byte EncryptIV []byte
DecryptIV [ocb2.NonceSize]byte DecryptIV []byte
decryptHistory [DecryptHistorySize]byte
LastGoodTime int64 LastGoodTime int64
@ -33,7 +32,8 @@ type CryptState struct {
RemoteLost uint32 RemoteLost uint32
RemoteResync uint32 RemoteResync uint32
cipher cipher.Block decryptHistory [DecryptHistorySize]byte
cipher cipher.Block
} }
// SupportedModes returns the list of supported CryptoModes. // SupportedModes returns the list of supported CryptoModes.
@ -42,22 +42,25 @@ func SupportedModes() []string {
} }
func (cs *CryptState) GenerateKey() error { func (cs *CryptState) GenerateKey() error {
_, err := io.ReadFull(rand.Reader, cs.RawKey[0:]) cs.Key = make([]byte, aes.BlockSize)
_, err := io.ReadFull(rand.Reader, cs.Key)
if err != nil { if err != nil {
return err return err
} }
_, err = io.ReadFull(rand.Reader, cs.EncryptIV[0:]) cs.EncryptIV = make([]byte, ocb2.NonceSize)
_, err = io.ReadFull(rand.Reader, cs.EncryptIV)
if err != nil { if err != nil {
return err return err
} }
_, err = io.ReadFull(rand.Reader, cs.DecryptIV[0:]) cs.DecryptIV = make([]byte, ocb2.NonceSize)
_, err = io.ReadFull(rand.Reader, cs.DecryptIV)
if err != nil { if err != nil {
return err return err
} }
cs.cipher, err = aes.NewCipher(cs.RawKey[0:]) cs.cipher, err = aes.NewCipher(cs.Key)
if err != nil { if err != nil {
return err return err
} }
@ -65,28 +68,18 @@ func (cs *CryptState) GenerateKey() error {
return nil return nil
} }
func (cs *CryptState) SetKey(key []byte, eiv []byte, div []byte) (err error) { func (cs *CryptState) SetKey(key []byte, eiv []byte, div []byte) error {
if copy(cs.RawKey[0:], key[0:]) != aes.BlockSize { cs.Key = key
err = errors.New("Unable to copy key") cs.EncryptIV = eiv
return cs.DecryptIV = div
}
if copy(cs.EncryptIV[0:], eiv[0:]) != ocb2.NonceSize { cipher, err := aes.NewCipher(cs.Key)
err = errors.New("Unable to copy EIV")
return
}
if copy(cs.DecryptIV[0:], div[0:]) != ocb2.NonceSize {
err = errors.New("Unable to copy DIV")
return
}
cs.cipher, err = aes.NewCipher(cs.RawKey[0:])
if err != nil { if err != nil {
return return err
} }
cs.cipher = cipher
return return nil
} }
func (cs *CryptState) Decrypt(dst, src []byte) (err error) { func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
@ -111,7 +104,7 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
ivbyte = src[0] ivbyte = src[0]
restore = false restore = false
if copy(saveiv[0:], cs.DecryptIV[0:]) != ocb2.NonceSize { if copy(saveiv[:], cs.DecryptIV) != ocb2.NonceSize {
err = errors.New("Copy failed") err = errors.New("Copy failed")
return return
} }
@ -179,18 +172,18 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
} }
if cs.decryptHistory[cs.DecryptIV[0]] == cs.DecryptIV[0] { if cs.decryptHistory[cs.DecryptIV[0]] == cs.DecryptIV[0] {
if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize { if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize {
err = errors.New("Failed to copy ocb2.NonceSize bytes") err = errors.New("Failed to copy ocb2.NonceSize bytes")
return return
} }
} }
} }
ocb2.Decrypt(cs.cipher, dst[0:], src[4:], cs.DecryptIV[0:], tag[0:]) ocb2.Decrypt(cs.cipher, dst, src[4:], cs.DecryptIV, tag[:])
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
if tag[i] != src[i+1] { if tag[i] != src[i+1] {
if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize { if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize {
err = errors.New("Error while trying to recover from error") err = errors.New("Error while trying to recover from error")
return return
} }
@ -202,7 +195,7 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
cs.decryptHistory[cs.DecryptIV[0]] = cs.DecryptIV[0] cs.decryptHistory[cs.DecryptIV[0]] = cs.DecryptIV[0]
if restore { if restore {
if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize { if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize {
err = errors.New("Error while trying to recover IV") err = errors.New("Error while trying to recover IV")
return return
} }
@ -236,7 +229,7 @@ func (cs *CryptState) Encrypt(dst, src []byte) {
} }
} }
ocb2.Encrypt(cs.cipher, dst[4:], src, cs.EncryptIV[0:], tag[0:]) ocb2.Encrypt(cs.cipher, dst[4:], src, cs.EncryptIV, tag[:])
dst[0] = cs.EncryptIV[0] dst[0] = cs.EncryptIV[0]
dst[1] = tag[0] dst[1] = tag[0]

View file

@ -32,14 +32,14 @@ func TestEncrypt(t *testing.T) {
cs := CryptState{} cs := CryptState{}
out := make([]byte, 19) out := make([]byte, 19)
cs.SetKey(key[0:], eiv[0:], div[0:]) cs.SetKey(key[:], eiv[:], div[:])
cs.Encrypt(out[0:], msg[0:]) cs.Encrypt(out, msg[:])
if !bytes.Equal(out[0:], expected[0:]) { if !bytes.Equal(out[:], expected[:]) {
t.Errorf("Mismatch in output") t.Errorf("Mismatch in output")
} }
if !bytes.Equal(cs.EncryptIV[0:], expected_eiv[0:]) { if !bytes.Equal(cs.EncryptIV[:], expected_eiv[:]) {
t.Errorf("EIV mismatch") t.Errorf("EIV mismatch")
} }
} }
@ -66,14 +66,14 @@ func TestDecrypt(t *testing.T) {
cs := CryptState{} cs := CryptState{}
out := make([]byte, 15) out := make([]byte, 15)
cs.SetKey(key[0:], div[0:], eiv[0:]) cs.SetKey(key[:], div[:], eiv[:])
cs.Decrypt(out[0:], crypted[0:]) cs.Decrypt(out, crypted[:])
if !bytes.Equal(out[0:], expected[0:]) { if !bytes.Equal(out, expected[:]) {
t.Errorf("Mismatch in output") t.Errorf("Mismatch in output")
} }
if !bytes.Equal(cs.DecryptIV[0:], post_div[0:]) { if !bytes.Equal(cs.DecryptIV, post_div[:]) {
t.Errorf("Mismatch in DIV") t.Errorf("Mismatch in DIV")
} }
} }