package main import ( "bytes" "crypto/rand" "errors" "exp/ssh" "fmt" "io" "io/ioutil" "log" "path/filepath" "strconv" "strings" ) func passwordAuth(username, password string) bool { if username == "admin" && password == "admin" { return true } return false } type SshCmdReply interface { WriteString(s string) (int, error) } type SshCmdFunc func(reply SshCmdReply, args []string) error type SshCmd struct { Name string CmdFunc SshCmdFunc Args string Description string } func (c SshCmd) Call(reply SshCmdReply, args []string) error { return c.CmdFunc(reply, args) } var commands = []SshCmd{} func RegisterSSHCmd(name string, cmdFunc SshCmdFunc, args string, desc string) { commands = append(commands, SshCmd{ Name: name, CmdFunc: cmdFunc, Args: args, Description: desc, }) } func RunSSH() { RegisterSSHCmd("help", HelpCmd, "[cmd]", "Shows this help (or help for a given command)") RegisterSSHCmd("start", StartServerCmd, "", "Starts the server with the given id") RegisterSSHCmd("stop", StopServerCmd, "", "Stops the server with the given id") RegisterSSHCmd("supw", SetSuperUserPasswordCmd, " ", "Set the SuperUser password for server with the given id") RegisterSSHCmd("setconf", SetConfCmd, " ", "Set a config value for the server with the given id") RegisterSSHCmd("getconf", GetConfCmd, " ", "Get a config value for the server with the given id") pemBytes, err := ioutil.ReadFile(filepath.Join(Args.DataDir, "key.pem")) if err != nil { log.Fatal(err) } cfg := new(ssh.ServerConfig) cfg.Rand = rand.Reader cfg.PasswordCallback = passwordAuth cfg.SetRSAPrivateKey(pemBytes) listener, err := ssh.Listen("tcp", Args.SshAddr, cfg) if err != nil { log.Fatal(err) } log.Printf("Listening for SSH connections on '%v'", Args.SshAddr) go func() { for { conn, err := listener.Accept() if err != nil { log.Fatalf("ssh: unable to accept incoming connection: %v", err) } err = conn.Handshake() if err == io.EOF { continue } else if err != nil { log.Fatalf("ssh: unable to perform handshake: %v", err) } go func() { for { channel, err := conn.Accept() if err == io.EOF { return } else if err != nil { log.Fatalf("ssh: unable to accept channel: %v", err) } go handleChannel(channel) } }() } }() } func handleChannel(channel ssh.Channel) { if channel.ChannelType() == "session" { channel.Accept() shell := ssh.NewServerShell(channel, "G> ") go func() { defer channel.Close() for { line, err := shell.ReadLine() if err == io.EOF { break } else if err != nil { log.Printf("ssh: error in reading from channel: %v", err) break } line = strings.TrimSpace(line) args := strings.Split(line, " ") if len(args) < 1 { continue } if args[0] == "exit" || args[0] == "quit" { return } var cmd *SshCmd for i := range commands { if commands[i].Name == args[0] { cmd = &commands[i] break } } if cmd != nil { buf := new(bytes.Buffer) err = cmd.Call(buf, args) if err != nil { _, err = shell.Write([]byte(fmt.Sprintf("error: %v\r\n", err.Error()))) if err != nil { return } continue } bufBytes := buf.Bytes() chunkSize := int(64) for len(bufBytes) > 0 { if len(bufBytes) < chunkSize { chunkSize = len(bufBytes) } nwritten, err := shell.Write(bufBytes[0:chunkSize]) if err != nil { return } bufBytes = bufBytes[nwritten:] } } else { _, err = shell.Write([]byte("error: unknown command\r\n")) } } }() return } channel.Reject(ssh.UnknownChannelType, "unknown channel type") } func HelpCmd(reply SshCmdReply, args []string) error { onlyShow := "" didShow := false if len(args) > 1 { onlyShow = args[1] } for _, cmd := range commands { if cmd.Name == onlyShow || onlyShow == "" { reply.WriteString("\r\n") reply.WriteString(" " + cmd.Name + " " + cmd.Args + "\r\n") reply.WriteString(" " + cmd.Description + "\r\n") didShow = true } } if onlyShow != "" && !didShow { return errors.New("no such command") } reply.WriteString("\r\n") return nil } func StartServerCmd(reply SshCmdReply, args []string) error { if len(args) != 2 { return errors.New("argument count mismatch") } serverId, err := strconv.Atoi64(args[1]) if err != nil { return errors.New("bad server id") } server, exists := servers[serverId] if !exists { return errors.New("no such server") } err = server.Start() if err != nil { return fmt.Errorf("unable to start: %v", err.Error()) } reply.WriteString(fmt.Sprintf("[%v] Started\r\n", serverId)) return nil } func StopServerCmd(reply SshCmdReply, args []string) error { if len(args) != 2 { return errors.New("argument count mismatch") } serverId, err := strconv.Atoi64(args[1]) if err != nil { return errors.New("bad server id") } server, exists := servers[serverId] if !exists { return errors.New("no such server") } err = server.Stop() if err != nil { return fmt.Errorf("unable to stop: %v", err.Error()) } reply.WriteString(fmt.Sprintf("[%v] Stopped\r\n", serverId)) return nil } func SetSuperUserPasswordCmd(reply SshCmdReply, args []string) error { if len(args) != 3 { return errors.New("argument count mismatch") } serverId, err := strconv.Atoi64(args[1]) if err != nil { return errors.New("bad server id") } server, exists := servers[serverId] if !exists { return errors.New("no such server") } server.SetSuperUserPassword(args[2]) reply.WriteString(fmt.Sprintf("[%v] SuperUser password updated.\r\n", serverId)) return nil } func SetConfCmd(reply SshCmdReply, args []string) error { if len(args) != 4 { return errors.New("argument count mismatch") } serverId, err := strconv.Atoi64(args[1]) if err != nil { return errors.New("bad server id") } server, exists := servers[serverId] if !exists { return errors.New("no such server") } key := args[2] value := args[3] server.cfg.Set(key, value) server.cfgUpdate <- &KeyValuePair{Key: key, Value: value} reply.WriteString(fmt.Sprintf("[%v] %v = %v\r\n", serverId, key, value)) return nil } func GetConfCmd(reply SshCmdReply, args []string) error { if len(args) != 3 { return errors.New("argument count mismatch") } serverId, err := strconv.Atoi64(args[1]) if err != nil { return errors.New("bad server id") } server, exists := servers[serverId] if !exists { return errors.New("no such server") } key := args[2] value := server.cfg.StringValue(key) reply.WriteString(fmt.Sprintf("[%v] %v = %v\r\n", serverId, key, value)) return nil }