1
0
Fork 0
forked from External/ergo
ergo/src/irc/persistence.go
2013-05-26 13:28:22 -07:00

260 lines
5.2 KiB
Go

package irc
import (
"database/sql"
//"fmt"
"bufio"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
"path/filepath"
"strings"
)
type Database struct {
*sql.DB
}
type Transaction struct {
*sql.Tx
}
type RowId uint64
type Queryable interface {
Exec(string, ...interface{}) (sql.Result, error)
Query(string, ...interface{}) (*sql.Rows, error)
QueryRow(string, ...interface{}) *sql.Row
}
type Savable interface {
Save(q Queryable) bool
}
//
// general
//
func NewDatabase() *Database {
db, err := sql.Open("sqlite3", "ergonomadic.db")
if err != nil {
log.Fatalln("cannot open database")
}
return &Database{db}
}
func NewTransaction(tx *sql.Tx) *Transaction {
return &Transaction{tx}
}
func readLines(filename string) <-chan string {
file, err := os.Open(filename)
if err != nil {
log.Fatalln(err)
}
reader := bufio.NewReader(file)
lines := make(chan string)
go func(lines chan<- string) {
defer file.Close()
defer close(lines)
for {
line, err := reader.ReadString(';')
if err != nil {
break
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
lines <- line
}
}(lines)
return lines
}
func (db *Database) ExecSqlFile(filename string) *Database {
db.Transact(func(q Queryable) bool {
for line := range readLines(filepath.Join("sql", filename)) {
log.Println(line)
_, err := q.Exec(line)
if err != nil {
log.Fatalln(err)
}
}
return true
})
return db
}
func (db *Database) Transact(txf func(Queryable) bool) {
tx, err := db.Begin()
if err != nil {
log.Panicln(err)
}
if txf(tx) {
tx.Commit()
} else {
tx.Rollback()
}
}
func (db *Database) Save(s Savable) {
db.Transact(func(tx Queryable) bool {
return s.Save(tx)
})
}
//
// general purpose sql
//
func FindId(q Queryable, sql string, args ...interface{}) (rowId RowId, err error) {
row := q.QueryRow(sql, args...)
err = row.Scan(&rowId)
return
}
func Count(q Queryable, sql string, args ...interface{}) (count uint, err error) {
row := q.QueryRow(sql, args...)
err = row.Scan(&count)
return
}
//
// data
//
type UserRow struct {
id RowId
nick string
hash []byte
}
type ChannelRow struct {
id RowId
name string
}
// user
func FindAllUsers(q Queryable) (urs []UserRow, err error) {
var rows *sql.Rows
rows, err = q.Query("SELECT id, nick, hash FROM user")
if err != nil {
return
}
urs = make([]UserRow, 0)
for rows.Next() {
ur := UserRow{}
err = rows.Scan(&(ur.id), &(ur.nick), &(ur.hash))
if err != nil {
return
}
urs = append(urs, ur)
}
return
}
func FindUserByNick(q Queryable, nick string) (ur *UserRow, err error) {
ur = &UserRow{}
row := q.QueryRow("SELECT id, nick, hash FROM user LIMIT 1 WHERE nick = ?",
nick)
err = row.Scan(&(ur.id), &(ur.nick), &(ur.hash))
return
}
func FindUserIdByNick(q Queryable, nick string) (RowId, error) {
return FindId(q, "SELECT id FROM user WHERE nick = ?", nick)
}
func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
cr = new(ChannelRow)
row := q.QueryRow("SELECT id, name FROM channel LIMIT 1 WHERE name = ?", name)
err := row.Scan(&(cr.id), &(cr.name))
if err != nil {
cr = nil
}
return
}
func InsertUser(q Queryable, user *User) (err error) {
_, err = q.Exec("INSERT INTO user (nick, hash) VALUES (?, ?)",
user.nick, user.hash)
return
}
func UpdateUser(q Queryable, user *User) (err error) {
_, err = q.Exec("UPDATE user SET nick = ?, hash = ? WHERE id = ?",
user.nick, user.hash, *(user.id))
return
}
// user-channel
func DeleteAllUserChannels(q Queryable, rowId RowId) (err error) {
_, err = q.Exec("DELETE FROM user_channel WHERE user_id = ?", rowId)
return
}
func DeleteOtherUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) {
_, err = q.Exec(`DELETE FROM user_channel WHERE
user_id = ? AND channel_id NOT IN ?`, userId, channelIds)
return
}
func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err error) {
ins := "INSERT OR IGNORE INTO user_channel (user_id, channel_id) VALUES "
vals := strings.Repeat("(?, ?), ", len(channelIds))
vals = vals[0 : len(vals)-2]
args := make([]interface{}, 2*len(channelIds))
var i = 0
for channelId := range channelIds {
args[i] = userId
args[i+1] = channelId
i += 2
}
_, err = q.Exec(ins+vals, args)
return
}
// channel
func FindChannelIdByName(q Queryable, name string) (RowId, error) {
return FindId(q, "SELECT id FROM channel WHERE name = ?", name)
}
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow, err error) {
query := ` FROM channel WHERE id IN
(SELECT channel_id from user_channel WHERE user_id = ?)`
count, err := Count(q, "SELECT COUNT(id)"+query, userId)
if err != nil {
return
}
rows, err := q.Query("SELECT id, name"+query, userId)
if err != nil {
return
}
crs = make([]ChannelRow, count)
var i = 0
for rows.Next() {
cr := ChannelRow{}
err = rows.Scan(&(cr.id), &(cr.name))
if err != nil {
return
}
crs[i] = cr
i++
}
return
}
func InsertChannel(q Queryable, channel *Channel) (err error) {
_, err = q.Exec("INSERT INTO channel (name) VALUES (?)", channel.name)
return
}
func UpdateChannel(q Queryable, channel *Channel) (err error) {
_, err = q.Exec("UPDATE channel SET name = ? WHERE id = ?",
channel.name, *(channel.id))
return
}