From 0f058d3b76973d11d8cc23e892252ebced6877b7 Mon Sep 17 00:00:00 2001 From: Mikkel Krautz Date: Sat, 1 Dec 2012 11:42:02 +0100 Subject: [PATCH] pkg/cryptstate, pkg/cryptstate/ocb2: split OCB2 implementation out into its own package. --- pkg/cryptstate/cryptstate.go | 135 +----------------------------- pkg/cryptstate/cryptstate_test.go | 75 ++--------------- pkg/cryptstate/ocb2/ocb2.go | 129 ++++++++++++++++++++++++++++ pkg/cryptstate/ocb2/ocb2_test.go | 58 +++++++++++++ 4 files changed, 196 insertions(+), 201 deletions(-) create mode 100644 pkg/cryptstate/ocb2/ocb2.go create mode 100644 pkg/cryptstate/ocb2/ocb2_test.go diff --git a/pkg/cryptstate/cryptstate.go b/pkg/cryptstate/cryptstate.go index 8d6d998..00d2821 100644 --- a/pkg/cryptstate/cryptstate.go +++ b/pkg/cryptstate/cryptstate.go @@ -10,6 +10,7 @@ import ( "crypto/cipher" "crypto/rand" "errors" + "mumbleapp.com/grumble/pkg/cryptstate/ocb2" "time" ) @@ -175,7 +176,7 @@ func (cs *CryptState) Decrypt(dst, src []byte) (err error) { } } - cs.OCBDecrypt(dst[0:], src[4:], cs.DecryptIV[0:], tag[0:]) + ocb2.Decrypt(cs.cipher, dst[0:], src[4:], cs.DecryptIV[0:], tag[0:]) for i := 0; i < 3; i++ { if tag[i] != src[i+1] { @@ -225,7 +226,7 @@ func (cs *CryptState) Encrypt(dst, src []byte) { } } - cs.OCBEncrypt(dst[4:], src, cs.EncryptIV[0:], tag[0:]) + ocb2.Encrypt(cs.cipher, dst[4:], src, cs.EncryptIV[0:], tag[0:]) dst[0] = cs.EncryptIV[0] dst[1] = tag[0] @@ -233,132 +234,4 @@ func (cs *CryptState) Encrypt(dst, src []byte) { dst[3] = tag[2] return -} - -func zeros(block []byte) { - for i := range block { - block[i] = 0 - } -} - -func xor(dst []byte, a []byte, b []byte) { - for i := 0; i < aes.BlockSize; i++ { - dst[i] = a[i] ^ b[i] - } -} - -func times2(block []byte) { - carry := (block[0] >> 7) & 0x1 - for i := 0; i < aes.BlockSize-1; i++ { - block[i] = (block[i] << 1) | ((block[i+1] >> 7) & 0x1) - } - block[aes.BlockSize-1] = (block[aes.BlockSize-1] << 1) ^ (carry * 135) -} - -func times3(block []byte) { - carry := (block[0] >> 7) & 0x1 - for i := 0; i < aes.BlockSize-1; i++ { - block[i] ^= (block[i] << 1) | ((block[i+1] >> 7) & 0x1) - } - block[aes.BlockSize-1] ^= ((block[aes.BlockSize-1] << 1) ^ (carry * 135)) -} - -func (cs *CryptState) OCBEncrypt(dst []byte, src []byte, nonce []byte, tag []byte) (err error) { - var delta [aes.BlockSize]byte - var checksum [aes.BlockSize]byte - var tmp [aes.BlockSize]byte - var pad [aes.BlockSize]byte - off := 0 - - cs.cipher.Encrypt(delta[0:], nonce[0:]) - zeros(checksum[0:]) - - remain := len(src) - for remain > aes.BlockSize { - times2(delta[0:]) - xor(tmp[0:], delta[0:], src[off:off+aes.BlockSize]) - cs.cipher.Encrypt(tmp[0:], tmp[0:]) - xor(dst[off:off+aes.BlockSize], delta[0:], tmp[0:]) - xor(checksum[0:], checksum[0:], src[off:off+aes.BlockSize]) - remain -= aes.BlockSize - off += aes.BlockSize - } - - times2(delta[0:]) - zeros(tmp[0:]) - num := remain * 8 - tmp[aes.BlockSize-2] = uint8((uint32(num) >> 8) & 0xff) - tmp[aes.BlockSize-1] = uint8(num & 0xff) - xor(tmp[0:], tmp[0:], delta[0:]) - cs.cipher.Encrypt(pad[0:], tmp[0:]) - copied := copy(tmp[0:], src[off:]) - if copied != remain { - err = errors.New("Copy failed") - return - } - if copy(tmp[copied:], pad[copied:]) != (aes.BlockSize - remain) { - err = errors.New("Copy failed") - return - } - xor(checksum[0:], checksum[0:], tmp[0:]) - xor(tmp[0:], pad[0:], tmp[0:]) - if copy(dst[off:], tmp[0:]) != remain { - err = errors.New("Copy failed") - return - } - - times3(delta[0:]) - xor(tmp[0:], delta[0:], checksum[0:]) - cs.cipher.Encrypt(tag[0:], tmp[0:]) - - return -} - -func (cs *CryptState) OCBDecrypt(plain []byte, encrypted []byte, nonce []byte, tag []byte) (err error) { - var checksum [aes.BlockSize]byte - var delta [aes.BlockSize]byte - var tmp [aes.BlockSize]byte - var pad [aes.BlockSize]byte - off := 0 - - cs.cipher.Encrypt(delta[0:], nonce[0:]) - zeros(checksum[0:]) - - remain := len(encrypted) - for remain > aes.BlockSize { - times2(delta[0:]) - xor(tmp[0:], delta[0:], encrypted[off:off+aes.BlockSize]) - cs.cipher.Decrypt(tmp[0:], tmp[0:]) - xor(plain[off:off+aes.BlockSize], delta[0:], tmp[0:]) - xor(checksum[0:], checksum[0:], plain[off:off+aes.BlockSize]) - off += aes.BlockSize - remain -= aes.BlockSize - } - - times2(delta[0:]) - zeros(tmp[0:]) - num := remain * 8 - tmp[aes.BlockSize-2] = uint8((uint32(num) >> 8) & 0xff) - tmp[aes.BlockSize-1] = uint8(num & 0xff) - xor(tmp[0:], tmp[0:], delta[0:]) - cs.cipher.Encrypt(pad[0:], tmp[0:]) - zeros(tmp[0:]) - copied := copy(tmp[0:remain], encrypted[off:off+remain]) - if copied != remain { - err = errors.New("Copy failed") - return - } - xor(tmp[0:], tmp[0:], pad[0:]) - xor(checksum[0:], checksum[0:], tmp[0:]) - copied = copy(plain[off:off+remain], tmp[0:remain]) - if copied != remain { - err = errors.New("Copy failed") - return - } - - times3(delta[0:]) - xor(tmp[0:], delta[0:], checksum[0:]) - cs.cipher.Encrypt(tag[0:], tmp[0:]) - - return -} +} \ No newline at end of file diff --git a/pkg/cryptstate/cryptstate_test.go b/pkg/cryptstate/cryptstate_test.go index 455794a..28248b2 100644 --- a/pkg/cryptstate/cryptstate_test.go +++ b/pkg/cryptstate/cryptstate_test.go @@ -1,76 +1,11 @@ package cryptstate import ( + "bytes" "crypto/aes" "testing" ) -func BlockCompare(a []byte, b []byte) (match bool) { - if len(a) != len(b) { - return - } - - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return - } - } - - match = true - return -} - -func TestTimes2(t *testing.T) { - msg := [aes.BlockSize]byte{ - 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, - } - expected := [aes.BlockSize]byte{ - 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7b, - } - - times2(msg[0:]) - if BlockCompare(msg[0:], expected[0:]) == false { - t.Errorf("times2 produces invalid output: %v, expected: %v", msg, expected) - } -} - -func TestTimes3(t *testing.T) { - msg := [aes.BlockSize]byte{ - 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, - } - expected := [aes.BlockSize]byte{ - 0x81, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x85, - } - - times3(msg[0:]) - if BlockCompare(msg[0:], expected[0:]) == false { - t.Errorf("times3 produces invalid output: %v, expected: %v", msg, expected) - } -} - -func TestZeros(t *testing.T) { - var msg [aes.BlockSize]byte - zeros(msg[0:]) - for i := 0; i < len(msg); i++ { - if msg[i] != 0 { - t.Errorf("zeros does not zero slice.") - } - } -} - -func TestXor(t *testing.T) { - msg := [aes.BlockSize]byte{ - 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, - } - var out [aes.BlockSize]byte - xor(out[0:], msg[0:], msg[0:]) - for i := 0; i < len(out); i++ { - if out[i] != 0 { - t.Errorf("XOR broken") - } - } -} - func TestEncrypt(t *testing.T) { msg := [15]byte{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, @@ -100,11 +35,11 @@ func TestEncrypt(t *testing.T) { cs.SetKey(key[0:], eiv[0:], div[0:]) cs.Encrypt(out[0:], msg[0:]) - if BlockCompare(out[0:], expected[0:]) == false { + if !bytes.Equal(out[0:], expected[0:]) { t.Errorf("Mismatch in output") } - if BlockCompare(cs.EncryptIV[0:], expected_eiv[0:]) == false { + if !bytes.Equal(cs.EncryptIV[0:], expected_eiv[0:]) { t.Errorf("EIV mismatch") } } @@ -138,11 +73,11 @@ func TestDecrypt(t *testing.T) { cs.SetKey(key[0:], div[0:], eiv[0:]) cs.Decrypt(out[0:], crypted[0:]) - if BlockCompare(out[0:], expected[0:]) == false { + if !bytes.Equal(out[0:], expected[0:]) { t.Errorf("Mismatch in output") } - if BlockCompare(cs.DecryptIV[0:], post_div[0:]) == false { + if !bytes.Equal(cs.DecryptIV[0:], post_div[0:]) { t.Errorf("Mismatch in DIV") } } diff --git a/pkg/cryptstate/ocb2/ocb2.go b/pkg/cryptstate/ocb2/ocb2.go new file mode 100644 index 0000000..b47bbb5 --- /dev/null +++ b/pkg/cryptstate/ocb2/ocb2.go @@ -0,0 +1,129 @@ +package ocb2 + +import ( + "crypto/aes" + "crypto/cipher" +) + +func zeros(block []byte) { + for i := range block { + block[i] = 0 + } +} + +func xor(dst []byte, a []byte, b []byte) { + for i := 0; i < aes.BlockSize; i++ { + dst[i] = a[i] ^ b[i] + } +} + +func times2(block []byte) { + carry := (block[0] >> 7) & 0x1 + for i := 0; i < aes.BlockSize-1; i++ { + block[i] = (block[i] << 1) | ((block[i+1] >> 7) & 0x1) + } + block[aes.BlockSize-1] = (block[aes.BlockSize-1] << 1) ^ (carry * 135) +} + +func times3(block []byte) { + carry := (block[0] >> 7) & 0x1 + for i := 0; i < aes.BlockSize-1; i++ { + block[i] ^= (block[i] << 1) | ((block[i+1] >> 7) & 0x1) + } + block[aes.BlockSize-1] ^= ((block[aes.BlockSize-1] << 1) ^ (carry * 135)) +} + +func Encrypt(cipher cipher.Block, dst []byte, src []byte, nonce []byte, tag []byte) (err error) { + var delta [aes.BlockSize]byte + var checksum [aes.BlockSize]byte + var tmp [aes.BlockSize]byte + var pad [aes.BlockSize]byte + off := 0 + + cipher.Encrypt(delta[0:], nonce[0:]) + zeros(checksum[0:]) + + remain := len(src) + for remain > aes.BlockSize { + times2(delta[0:]) + xor(tmp[0:], delta[0:], src[off:off+aes.BlockSize]) + cipher.Encrypt(tmp[0:], tmp[0:]) + xor(dst[off:off+aes.BlockSize], delta[0:], tmp[0:]) + xor(checksum[0:], checksum[0:], src[off:off+aes.BlockSize]) + remain -= aes.BlockSize + off += aes.BlockSize + } + + times2(delta[0:]) + zeros(tmp[0:]) + num := remain * 8 + tmp[aes.BlockSize-2] = uint8((uint32(num) >> 8) & 0xff) + tmp[aes.BlockSize-1] = uint8(num & 0xff) + xor(tmp[0:], tmp[0:], delta[0:]) + cipher.Encrypt(pad[0:], tmp[0:]) + copied := copy(tmp[0:], src[off:]) + if copied != remain { + panic("ocb2: copy failed") + } + if copy(tmp[copied:], pad[copied:]) != (aes.BlockSize - remain) { + panic("ocb2: copy failed") + } + xor(checksum[0:], checksum[0:], tmp[0:]) + xor(tmp[0:], pad[0:], tmp[0:]) + if copy(dst[off:], tmp[0:]) != remain { + panic("ocb2: copy failed") + } + + times3(delta[0:]) + xor(tmp[0:], delta[0:], checksum[0:]) + cipher.Encrypt(tag[0:], tmp[0:]) + + return +} + +func Decrypt(cipher cipher.Block, plain []byte, encrypted []byte, nonce []byte, tag []byte) (err error) { + var checksum [aes.BlockSize]byte + var delta [aes.BlockSize]byte + var tmp [aes.BlockSize]byte + var pad [aes.BlockSize]byte + off := 0 + + cipher.Encrypt(delta[0:], nonce[0:]) + zeros(checksum[0:]) + + remain := len(encrypted) + for remain > aes.BlockSize { + times2(delta[0:]) + xor(tmp[0:], delta[0:], encrypted[off:off+aes.BlockSize]) + cipher.Decrypt(tmp[0:], tmp[0:]) + xor(plain[off:off+aes.BlockSize], delta[0:], tmp[0:]) + xor(checksum[0:], checksum[0:], plain[off:off+aes.BlockSize]) + off += aes.BlockSize + remain -= aes.BlockSize + } + + times2(delta[0:]) + zeros(tmp[0:]) + num := remain * 8 + tmp[aes.BlockSize-2] = uint8((uint32(num) >> 8) & 0xff) + tmp[aes.BlockSize-1] = uint8(num & 0xff) + xor(tmp[0:], tmp[0:], delta[0:]) + cipher.Encrypt(pad[0:], tmp[0:]) + zeros(tmp[0:]) + copied := copy(tmp[0:remain], encrypted[off:off+remain]) + if copied != remain { + panic("ocb2: copy failed") + } + xor(tmp[0:], tmp[0:], pad[0:]) + xor(checksum[0:], checksum[0:], tmp[0:]) + copied = copy(plain[off:off+remain], tmp[0:remain]) + if copied != remain { + panic("ocb2: copy failed") + } + + times3(delta[0:]) + xor(tmp[0:], delta[0:], checksum[0:]) + cipher.Encrypt(tag[0:], tmp[0:]) + + return +} \ No newline at end of file diff --git a/pkg/cryptstate/ocb2/ocb2_test.go b/pkg/cryptstate/ocb2/ocb2_test.go new file mode 100644 index 0000000..954fb2f --- /dev/null +++ b/pkg/cryptstate/ocb2/ocb2_test.go @@ -0,0 +1,58 @@ +package ocb2 + +import ( + "bytes" + "crypto/aes" + "testing" +) + +func TestTimes2(t *testing.T) { + msg := [aes.BlockSize]byte{ + 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, + } + expected := [aes.BlockSize]byte{ + 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7b, + } + + times2(msg[0:]) + if !bytes.Equal(msg[0:], expected[0:]) { + t.Fatalf("times2 produces invalid output: %v, expected: %v", msg, expected) + } +} + +func TestTimes3(t *testing.T) { + msg := [aes.BlockSize]byte{ + 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, + } + expected := [aes.BlockSize]byte{ + 0x81, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x85, + } + + times3(msg[0:]) + if !bytes.Equal(msg[0:], expected[0:]) { + t.Errorf("times3 produces invalid output: %v, expected: %v", msg, expected) + } +} + +func TestZeros(t *testing.T) { + var msg [aes.BlockSize]byte + zeros(msg[0:]) + for i := 0; i < len(msg); i++ { + if msg[i] != 0 { + t.Fatalf("zeros does not zero slice.") + } + } +} + +func TestXor(t *testing.T) { + msg := [aes.BlockSize]byte{ + 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, + } + var out [aes.BlockSize]byte + xor(out[0:], msg[0:], msg[0:]) + for i := 0; i < len(out); i++ { + if out[i] != 0 { + t.Fatalf("XOR broken") + } + } +} \ No newline at end of file