mirror of
https://github.com/ergochat/ergo.git
synced 2025-12-20 18:20:13 -08:00
implement draft/webpush (#2205)
This commit is contained in:
parent
efd3764337
commit
36e5451aa5
44 changed files with 2091 additions and 100 deletions
323
vendor/github.com/ergochat/webpush-go/v2/webpush.go
generated
vendored
Normal file
323
vendor/github.com/ergochat/webpush-go/v2/webpush.go
generated
vendored
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
package webpush
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
const MaxRecordSize uint32 = 4096
|
||||
|
||||
var (
|
||||
ErrRecordSizeTooSmall = errors.New("record size too small for message")
|
||||
|
||||
invalidAuthKeyLength = errors.New("invalid auth key length (must be 16)")
|
||||
|
||||
defaultHTTPClient = &http.Client{}
|
||||
)
|
||||
|
||||
// HTTPClient is an interface for sending the notification HTTP request / testing
|
||||
type HTTPClient interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// Options are config and extra params needed to send a notification
|
||||
type Options struct {
|
||||
HTTPClient HTTPClient // Will replace with *http.Client by default if not included
|
||||
RecordSize uint32 // Limit the record size
|
||||
Subscriber string // Sub in VAPID JWT token
|
||||
Topic string // Set the Topic header to collapse a pending messages (Optional)
|
||||
TTL int // Set the TTL on the endpoint POST request, in seconds
|
||||
Urgency Urgency // Set the Urgency header to change a message priority (Optional)
|
||||
VAPIDKeys *VAPIDKeys // VAPID public-private keypair to generate the VAPID Authorization header
|
||||
VapidExpiration time.Time // optional expiration for VAPID JWT token (defaults to now + 12 hours)
|
||||
}
|
||||
|
||||
// Keys represents a subscription's keys (its ECDH public key on the P-256 curve
|
||||
// and its 16-byte authentication secret).
|
||||
type Keys struct {
|
||||
Auth [16]byte
|
||||
P256dh *ecdh.PublicKey
|
||||
}
|
||||
|
||||
// Equal compares two Keys for equality.
|
||||
func (k *Keys) Equal(o Keys) bool {
|
||||
return k.Auth == o.Auth && k.P256dh.Equal(o.P256dh)
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*Keys)(nil)
|
||||
var _ json.Unmarshaler = (*Keys)(nil)
|
||||
|
||||
type marshaledKeys struct {
|
||||
Auth string `json:"auth"`
|
||||
P256dh string `json:"p256dh"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler, allowing serialization to JSON.
|
||||
func (k *Keys) MarshalJSON() ([]byte, error) {
|
||||
m := marshaledKeys{
|
||||
Auth: base64.RawStdEncoding.EncodeToString(k.Auth[:]),
|
||||
P256dh: base64.RawStdEncoding.EncodeToString(k.P256dh.Bytes()),
|
||||
}
|
||||
return json.Marshal(&m)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Unmarshaler, allowing deserialization from JSON.
|
||||
func (k *Keys) UnmarshalJSON(b []byte) (err error) {
|
||||
var m marshaledKeys
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
authBytes, err := decodeSubscriptionKey(m.Auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(authBytes) != 16 {
|
||||
return fmt.Errorf("invalid auth bytes length %d (must be 16)", len(authBytes))
|
||||
}
|
||||
copy(k.Auth[:], authBytes)
|
||||
rawDHKey, err := decodeSubscriptionKey(m.P256dh)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.P256dh, err = ecdh.P256().NewPublicKey(rawDHKey)
|
||||
return err
|
||||
}
|
||||
|
||||
// DecodeSubscriptionKeys decodes and validates a base64-encoded pair of subscription keys
|
||||
// (the authentication secret and ECDH public key).
|
||||
func DecodeSubscriptionKeys(auth, p256dh string) (keys Keys, err error) {
|
||||
authBytes, err := decodeSubscriptionKey(auth)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(authBytes) != 16 {
|
||||
err = invalidAuthKeyLength
|
||||
return
|
||||
}
|
||||
copy(keys.Auth[:], authBytes)
|
||||
dhBytes, err := decodeSubscriptionKey(p256dh)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keys.P256dh, err = ecdh.P256().NewPublicKey(dhBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Subscription represents a PushSubscription object from the Push API
|
||||
type Subscription struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Keys Keys `json:"keys"`
|
||||
}
|
||||
|
||||
// SendNotification sends a push notification to a subscription's endpoint,
|
||||
// applying encryption (RFC 8291) and adding a VAPID header (RFC 8292).
|
||||
func SendNotification(ctx context.Context, message []byte, s *Subscription, options *Options) (*http.Response, error) {
|
||||
// Compose message body (RFC8291 encryption of the message)
|
||||
body, err := EncryptNotification(message, s.Keys, options.RecordSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get VAPID Authorization header
|
||||
vapidAuthHeader, err := getVAPIDAuthorizationHeader(
|
||||
s.Endpoint,
|
||||
options.Subscriber,
|
||||
options.VAPIDKeys,
|
||||
options.VapidExpiration,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Compose and send the HTTP request
|
||||
return sendNotification(ctx, s.Endpoint, options, vapidAuthHeader, body)
|
||||
}
|
||||
|
||||
// EncryptNotification implements the encryption algorithm specified by RFC 8291 for web push
|
||||
// (RFC 8188's aes128gcm content-encoding, with the key material derived from
|
||||
// elliptic curve Diffie-Hellman over the P-256 curve).
|
||||
func EncryptNotification(message []byte, keys Keys, recordSize uint32) ([]byte, error) {
|
||||
// Get the record size
|
||||
if recordSize == 0 {
|
||||
recordSize = MaxRecordSize
|
||||
} else if recordSize < 128 {
|
||||
return nil, ErrRecordSizeTooSmall
|
||||
}
|
||||
|
||||
// Allocate buffer to hold the eventual message
|
||||
// [ header block ] [ ciphertext ] [ 16 byte AEAD tag ], totaling RecordSize bytes
|
||||
// the ciphertext is the encryption of: [ message ] [ \x02 ] [ 0 or more \x00 as needed ]
|
||||
recordBuf := make([]byte, recordSize)
|
||||
// remainingBuf tracks our current writing position in recordBuf:
|
||||
remainingBuf := recordBuf
|
||||
|
||||
// Application server key pairs (single use)
|
||||
localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localPublicKey := localPrivateKey.PublicKey()
|
||||
|
||||
// Encryption Content-Coding Header
|
||||
// +-----------+--------+-----------+---------------+
|
||||
// | salt (16) | rs (4) | idlen (1) | keyid (idlen) |
|
||||
// +-----------+--------+-----------+---------------+
|
||||
// in our case the keyid is localPublicKey.Bytes(), so 65 bytes
|
||||
// First, generate the salt
|
||||
_, err = rand.Read(remainingBuf[:16])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
salt := remainingBuf[:16]
|
||||
remainingBuf = remainingBuf[16:]
|
||||
binary.BigEndian.PutUint32(remainingBuf[:], recordSize)
|
||||
remainingBuf = remainingBuf[4:]
|
||||
localPublicKeyBytes := localPublicKey.Bytes()
|
||||
remainingBuf[0] = byte(len(localPublicKeyBytes))
|
||||
remainingBuf = remainingBuf[1:]
|
||||
copy(remainingBuf[:], localPublicKeyBytes)
|
||||
remainingBuf = remainingBuf[len(localPublicKeyBytes):]
|
||||
|
||||
// Combine application keys with receiver's EC public key to derive ECDH shared secret
|
||||
sharedECDHSecret, err := localPrivateKey.ECDH(keys.P256dh)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deriving shared secret: %w", err)
|
||||
}
|
||||
|
||||
// ikm
|
||||
prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00"))
|
||||
prkInfoBuf.Write(keys.P256dh.Bytes())
|
||||
prkInfoBuf.Write(localPublicKey.Bytes())
|
||||
|
||||
prkHKDF := hkdf.New(sha256.New, sharedECDHSecret, keys.Auth[:], prkInfoBuf.Bytes())
|
||||
ikm, err := getHKDFKey(prkHKDF, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Derive Content Encryption Key
|
||||
contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00")
|
||||
contentHKDF := hkdf.New(sha256.New, ikm, salt, contentEncryptionKeyInfo)
|
||||
contentEncryptionKey, err := getHKDFKey(contentHKDF, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Derive the Nonce
|
||||
nonceInfo := []byte("Content-Encoding: nonce\x00")
|
||||
nonceHKDF := hkdf.New(sha256.New, ikm, salt, nonceInfo)
|
||||
nonce, err := getHKDFKey(nonceHKDF, 12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cipher
|
||||
c, err := aes.NewCipher(contentEncryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// need 1 byte for the 0x02 delimiter, 16 bytes for the AEAD tag
|
||||
if len(remainingBuf) < len(message)+17 {
|
||||
return nil, ErrRecordSizeTooSmall
|
||||
}
|
||||
// Copy the message plaintext into the buffer
|
||||
copy(remainingBuf[:], message[:])
|
||||
// The plaintext to be encrypted will include the padding delimiter and the padding;
|
||||
// cut off the final 16 bytes that are reserved for the AEAD tag
|
||||
plaintext := remainingBuf[:len(remainingBuf)-16]
|
||||
remainingBuf = remainingBuf[len(message):]
|
||||
// Add padding delimiter
|
||||
remainingBuf[0] = '\x02'
|
||||
remainingBuf = remainingBuf[1:]
|
||||
// The rest of the buffer is already zero-padded
|
||||
|
||||
// Encipher the plaintext in place, then add the AEAD tag at the end.
|
||||
// "To reuse plaintext's storage for the encrypted output, use plaintext[:0]
|
||||
// as dst. Otherwise, the remaining capacity of dst must not overlap plaintext."
|
||||
gcm.Seal(plaintext[:0], nonce, plaintext, nil)
|
||||
|
||||
return recordBuf, nil
|
||||
}
|
||||
|
||||
func sendNotification(ctx context.Context, endpoint string, options *Options, vapidAuthHeader string, body []byte) (*http.Response, error) {
|
||||
// POST request
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
req = req.WithContext(ctx)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Encoding", "aes128gcm")
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
req.Header.Set("TTL", strconv.Itoa(options.TTL))
|
||||
|
||||
// Сheck the optional headers
|
||||
if len(options.Topic) > 0 {
|
||||
req.Header.Set("Topic", options.Topic)
|
||||
}
|
||||
|
||||
if isValidUrgency(options.Urgency) {
|
||||
req.Header.Set("Urgency", string(options.Urgency))
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", vapidAuthHeader)
|
||||
|
||||
// Send the request
|
||||
var client HTTPClient
|
||||
if options.HTTPClient != nil {
|
||||
client = options.HTTPClient
|
||||
} else {
|
||||
client = defaultHTTPClient
|
||||
}
|
||||
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
// decodeSubscriptionKey decodes a base64 subscription key.
|
||||
func decodeSubscriptionKey(key string) ([]byte, error) {
|
||||
key = strings.TrimRight(key, "=")
|
||||
|
||||
if strings.IndexByte(key, '+') != -1 || strings.IndexByte(key, '/') != -1 {
|
||||
return base64.RawStdEncoding.DecodeString(key)
|
||||
}
|
||||
return base64.RawURLEncoding.DecodeString(key)
|
||||
}
|
||||
|
||||
// Returns a key of length "length" given an hkdf function
|
||||
func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) {
|
||||
key := make([]byte, length)
|
||||
n, err := io.ReadFull(hkdf, key)
|
||||
if n != len(key) || err != nil {
|
||||
return key, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue