1
0
Fork 0
forked from External/grumble

pkg/cryptstate: cleanups for multiple crypto modes.

This commit is contained in:
Mikkel Krautz 2012-12-08 22:25:58 +01:00
parent cabe380244
commit 33a7e1304d
2 changed files with 34 additions and 41 deletions

View file

@ -17,10 +17,9 @@ import (
const DecryptHistorySize = 0x100
type CryptState struct {
RawKey [aes.BlockSize]byte
EncryptIV [ocb2.NonceSize]byte
DecryptIV [ocb2.NonceSize]byte
decryptHistory [DecryptHistorySize]byte
Key []byte
EncryptIV []byte
DecryptIV []byte
LastGoodTime int64
@ -33,6 +32,7 @@ type CryptState struct {
RemoteLost uint32
RemoteResync uint32
decryptHistory [DecryptHistorySize]byte
cipher cipher.Block
}
@ -42,22 +42,25 @@ func SupportedModes() []string {
}
func (cs *CryptState) GenerateKey() error {
_, err := io.ReadFull(rand.Reader, cs.RawKey[0:])
cs.Key = make([]byte, aes.BlockSize)
_, err := io.ReadFull(rand.Reader, cs.Key)
if err != nil {
return err
}
_, err = io.ReadFull(rand.Reader, cs.EncryptIV[0:])
cs.EncryptIV = make([]byte, ocb2.NonceSize)
_, err = io.ReadFull(rand.Reader, cs.EncryptIV)
if err != nil {
return err
}
_, err = io.ReadFull(rand.Reader, cs.DecryptIV[0:])
cs.DecryptIV = make([]byte, ocb2.NonceSize)
_, err = io.ReadFull(rand.Reader, cs.DecryptIV)
if err != nil {
return err
}
cs.cipher, err = aes.NewCipher(cs.RawKey[0:])
cs.cipher, err = aes.NewCipher(cs.Key)
if err != nil {
return err
}
@ -65,28 +68,18 @@ func (cs *CryptState) GenerateKey() error {
return nil
}
func (cs *CryptState) SetKey(key []byte, eiv []byte, div []byte) (err error) {
if copy(cs.RawKey[0:], key[0:]) != aes.BlockSize {
err = errors.New("Unable to copy key")
return
}
func (cs *CryptState) SetKey(key []byte, eiv []byte, div []byte) error {
cs.Key = key
cs.EncryptIV = eiv
cs.DecryptIV = div
if copy(cs.EncryptIV[0:], eiv[0:]) != ocb2.NonceSize {
err = errors.New("Unable to copy EIV")
return
}
if copy(cs.DecryptIV[0:], div[0:]) != ocb2.NonceSize {
err = errors.New("Unable to copy DIV")
return
}
cs.cipher, err = aes.NewCipher(cs.RawKey[0:])
cipher, err := aes.NewCipher(cs.Key)
if err != nil {
return
return err
}
cs.cipher = cipher
return
return nil
}
func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
@ -111,7 +104,7 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
ivbyte = src[0]
restore = false
if copy(saveiv[0:], cs.DecryptIV[0:]) != ocb2.NonceSize {
if copy(saveiv[:], cs.DecryptIV) != ocb2.NonceSize {
err = errors.New("Copy failed")
return
}
@ -179,18 +172,18 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
}
if cs.decryptHistory[cs.DecryptIV[0]] == cs.DecryptIV[0] {
if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize {
if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize {
err = errors.New("Failed to copy ocb2.NonceSize bytes")
return
}
}
}
ocb2.Decrypt(cs.cipher, dst[0:], src[4:], cs.DecryptIV[0:], tag[0:])
ocb2.Decrypt(cs.cipher, dst, src[4:], cs.DecryptIV, tag[:])
for i := 0; i < 3; i++ {
if tag[i] != src[i+1] {
if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize {
if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize {
err = errors.New("Error while trying to recover from error")
return
}
@ -202,7 +195,7 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) {
cs.decryptHistory[cs.DecryptIV[0]] = cs.DecryptIV[0]
if restore {
if copy(cs.DecryptIV[0:], saveiv[0:]) != ocb2.NonceSize {
if copy(cs.DecryptIV, saveiv[:]) != ocb2.NonceSize {
err = errors.New("Error while trying to recover IV")
return
}
@ -236,7 +229,7 @@ func (cs *CryptState) Encrypt(dst, src []byte) {
}
}
ocb2.Encrypt(cs.cipher, dst[4:], src, cs.EncryptIV[0:], tag[0:])
ocb2.Encrypt(cs.cipher, dst[4:], src, cs.EncryptIV, tag[:])
dst[0] = cs.EncryptIV[0]
dst[1] = tag[0]

View file

@ -32,14 +32,14 @@ func TestEncrypt(t *testing.T) {
cs := CryptState{}
out := make([]byte, 19)
cs.SetKey(key[0:], eiv[0:], div[0:])
cs.Encrypt(out[0:], msg[0:])
cs.SetKey(key[:], eiv[:], div[:])
cs.Encrypt(out, msg[:])
if !bytes.Equal(out[0:], expected[0:]) {
if !bytes.Equal(out[:], expected[:]) {
t.Errorf("Mismatch in output")
}
if !bytes.Equal(cs.EncryptIV[0:], expected_eiv[0:]) {
if !bytes.Equal(cs.EncryptIV[:], expected_eiv[:]) {
t.Errorf("EIV mismatch")
}
}
@ -66,14 +66,14 @@ func TestDecrypt(t *testing.T) {
cs := CryptState{}
out := make([]byte, 15)
cs.SetKey(key[0:], div[0:], eiv[0:])
cs.Decrypt(out[0:], crypted[0:])
cs.SetKey(key[:], div[:], eiv[:])
cs.Decrypt(out, crypted[:])
if !bytes.Equal(out[0:], expected[0:]) {
if !bytes.Equal(out, expected[:]) {
t.Errorf("Mismatch in output")
}
if !bytes.Equal(cs.DecryptIV[0:], post_div[0:]) {
if !bytes.Equal(cs.DecryptIV, post_div[:]) {
t.Errorf("Mismatch in DIV")
}
}