diff --git a/pom.xml b/pom.xml index c540daf..d4b8507 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ 4.0.0 com.example.sshd echo-sshd-server - 1.2.2 + 1.3.0 ECHO SSH SERVER Learning Apache Mina SSHD library diff --git a/src/main/java/com/example/sshd/core/EchoShell.java b/src/main/java/com/example/sshd/core/EchoShell.java new file mode 100644 index 0000000..d0771a3 --- /dev/null +++ b/src/main/java/com/example/sshd/core/EchoShell.java @@ -0,0 +1,145 @@ +package com.example.sshd.core; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Properties; + +import org.apache.sshd.server.Command; +import org.apache.sshd.server.Environment; +import org.apache.sshd.server.ExitCallback; +import org.apache.sshd.server.SessionAware; +import org.apache.sshd.server.session.ServerSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; +import org.springframework.context.annotation.Scope; +import org.springframework.stereotype.Component; + +import com.example.sshd.util.ReplyUtil; + +@Component +@Scope(ConfigurableBeanFactory.SCOPE_PROTOTYPE) +public class EchoShell implements Command, Runnable, SessionAware { + + private static final Logger logger = LoggerFactory.getLogger(EchoShell.class); + + @Autowired + ReplyUtil replyUtil; + + @Autowired + Properties hashReplies; + + protected InputStream in; + protected OutputStream out; + protected OutputStream err; + protected ExitCallback callback; + protected Environment environment; + protected Thread thread; + protected ServerSession session; + + @Override + public void setInputStream(InputStream in) { + this.in = in; + } + + @Override + public void setOutputStream(OutputStream out) { + this.out = out; + } + + @Override + public void setErrorStream(OutputStream err) { + this.err = err; + } + + @Override + public void setExitCallback(ExitCallback callback) { + this.callback = callback; + } + + @Override + public void start(Environment env) throws IOException { + environment = env; + thread = new Thread(this, remoteIpAddress()); + logger.info("environment: {}, thread-name: {}", environment.getEnv(), thread.getName()); + thread.start(); + } + + protected String remoteIpAddress() { + String remoteIpAddress = ""; + + if (session.getIoSession().getRemoteAddress() instanceof InetSocketAddress) { + InetSocketAddress remoteAddress = (InetSocketAddress) session.getIoSession().getRemoteAddress(); + remoteIpAddress = remoteAddress.getAddress().getHostAddress(); + } else { + remoteIpAddress = session.getIoSession().getRemoteAddress().toString(); + } + return remoteIpAddress; + } + + @Override + public void destroy() { + thread.interrupt(); + } + + @Override + public void run() { + String prompt = hashReplies.getProperty("prompt", "$ "); + try { + out.write(prompt.getBytes()); + out.flush(); + + BufferedReader r = new BufferedReader(new InputStreamReader(in)); + String command = ""; + + while (!Thread.currentThread().isInterrupted()) { + int s = r.read(); + if (s == 13 || s == 10) { + + boolean containsExit = Arrays.asList(command.split(";")).stream().map(cmd -> { + boolean wantsExit = false; + try { + wantsExit = replyUtil.replyToCommand(cmd.trim(), out, prompt, session); + out.flush(); + } catch (Exception e) { + logger.error("run error!", e); + } + return wantsExit; + }).reduce((a, b) -> a || b).get(); + + if (containsExit) { + break; + } + command = ""; + } else { + logger.trace("input character: {}", s); + if (s == 127) { + if (command.length() > 0) { + command = command.substring(0, command.length() - 1); + out.write(s); + } + } else if (s >= 32 && s < 127) { + command += (char) s; + out.write(s); + } + } + out.flush(); + } + } catch (Exception e) { + logger.error("run error!", e); + } finally { + callback.onExit(0); + } + } + + @Override + public void setSession(ServerSession session) { + this.session = session; + } +} \ No newline at end of file diff --git a/src/main/java/com/example/sshd/core/EchoShellFactory.java b/src/main/java/com/example/sshd/core/EchoShellFactory.java index 92359bf..8684bc2 100644 --- a/src/main/java/com/example/sshd/core/EchoShellFactory.java +++ b/src/main/java/com/example/sshd/core/EchoShellFactory.java @@ -1,139 +1,21 @@ package com.example.sshd.core; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.OutputStream; -import java.net.InetSocketAddress; -import java.util.Map; -import java.util.Properties; - import org.apache.sshd.common.Factory; import org.apache.sshd.server.Command; -import org.apache.sshd.server.Environment; -import org.apache.sshd.server.ExitCallback; -import org.apache.sshd.server.SessionAware; -import org.apache.sshd.server.session.ServerSession; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; import org.springframework.stereotype.Component; -import com.example.sshd.util.ReplyUtil; - @Component public class EchoShellFactory implements Factory { - private static final Logger logger = LoggerFactory.getLogger(EchoShellFactory.class); - - @Autowired - ReplyUtil replyUtil; - @Autowired - Properties hashReplies; - - @Autowired - Map ipInfoMapping; + ApplicationContext applicationContext; @Override public Command create() { - return new EchoShell(); + return (Command) applicationContext.getBean("echoShell"); } - public class EchoShell implements Command, Runnable, SessionAware { - - protected InputStream in; - protected OutputStream out; - protected OutputStream err; - protected ExitCallback callback; - protected Environment environment; - protected Thread thread; - protected ServerSession session; - - @Override - public void setInputStream(InputStream in) { - this.in = in; - } - @Override - public void setOutputStream(OutputStream out) { - this.out = out; - } - - @Override - public void setErrorStream(OutputStream err) { - this.err = err; - } - - @Override - public void setExitCallback(ExitCallback callback) { - this.callback = callback; - } - - @Override - public void start(Environment env) throws IOException { - environment = env; - - if (session.getIoSession().getRemoteAddress() instanceof InetSocketAddress) { - InetSocketAddress remoteAddress = (InetSocketAddress) session.getIoSession().getRemoteAddress(); - String remoteIpAddress = remoteAddress.getAddress().getHostAddress(); - thread = new Thread(this, remoteIpAddress); - } else { - thread = new Thread(this, session.getIoSession().getRemoteAddress().toString()); - } - - logger.info("environment: {}, thread-name: {}", environment.getEnv(), thread.getName()); - thread.start(); - } - - @Override - public void destroy() { - thread.interrupt(); - } - - @Override - public void run() { - String prompt = hashReplies.getProperty("prompt", "$ "); - try { - out.write(prompt.getBytes()); - out.flush(); - - BufferedReader r = new BufferedReader(new InputStreamReader(in)); - String command = ""; - - while (!Thread.currentThread().isInterrupted()) { - int s = r.read(); - if (s == 13 || s == 10) { - if (!replyUtil.replyToCommand(command, out, prompt, session)) { - out.flush(); - return; - } - command = ""; - } else { - logger.trace("input character: {}", s); - if (s == 127) { - if (command.length() > 0) { - command = command.substring(0, command.length() - 1); - out.write(s); - } - } else if (s >= 32 && s < 127) { - command += (char) s; - out.write(s); - } - } - out.flush(); - } - } catch (Exception e) { - logger.error("run error!", e); - } finally { - callback.onExit(0); - } - } - - @Override - public void setSession(ServerSession session) { - this.session = session; - } - } } \ No newline at end of file diff --git a/src/main/java/com/example/sshd/core/OnetimeCommand.java b/src/main/java/com/example/sshd/core/OnetimeCommand.java index 7422d71..ae73641 100644 --- a/src/main/java/com/example/sshd/core/OnetimeCommand.java +++ b/src/main/java/com/example/sshd/core/OnetimeCommand.java @@ -1,94 +1,40 @@ package com.example.sshd.core; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; +import java.util.Arrays; -import org.apache.sshd.server.Command; -import org.apache.sshd.server.Environment; -import org.apache.sshd.server.ExitCallback; -import org.apache.sshd.server.SessionAware; -import org.apache.sshd.server.session.ServerSession; -import org.springframework.beans.factory.annotation.Autowired; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.context.annotation.Scope; import org.springframework.stereotype.Component; -import com.example.sshd.util.ReplyUtil; - @Component @Scope(ConfigurableBeanFactory.SCOPE_PROTOTYPE) -public class OnetimeCommand implements Command, SessionAware { - - @Autowired - ReplyUtil replyUtil; +public class OnetimeCommand extends EchoShell { + + private static final Logger logger = LoggerFactory.getLogger(OnetimeCommand.class); - private InputStream in; - private OutputStream out; - private OutputStream err; - private ExitCallback callback; - private Environment environment; private String command; - private ServerSession session; public OnetimeCommand(String cmd) { command = cmd; } - public InputStream getIn() { - return in; - } - - public OutputStream getOut() { - return out; - } - - public OutputStream getErr() { - return err; - } - - public Environment getEnvironment() { - return environment; - } - - @Override - public void setInputStream(InputStream in) { - this.in = in; - } - - @Override - public void setOutputStream(OutputStream out) { - this.out = out; - } - - @Override - public void setErrorStream(OutputStream err) { - this.err = err; - } - - @Override - public void setExitCallback(ExitCallback callback) { - this.callback = callback; - } - - @Override - public void start(Environment env) throws IOException { - environment = env; - replyUtil.replyToCommand(command, out, "", session); - out.flush(); - callback.onExit(0); - } - - @Override - public void destroy() { - } - - public ExitCallback getCallback() { - return callback; - } - @Override - public void setSession(ServerSession session) { - this.session = session; + public void run() { + try { + Arrays.asList(command.split(";")).stream().forEach(cmd -> { + try { + replyUtil.replyToCommand(cmd.trim(), out, "", session); + out.flush(); + } catch (Exception e) { + logger.error("run error!", e); + } + }); + } catch (Exception e) { + logger.error("run error!", e); + } finally { + callback.onExit(0); + } } } diff --git a/src/main/java/com/example/sshd/util/ReplyUtil.java b/src/main/java/com/example/sshd/util/ReplyUtil.java index 474bc0d..1b8020f 100644 --- a/src/main/java/com/example/sshd/util/ReplyUtil.java +++ b/src/main/java/com/example/sshd/util/ReplyUtil.java @@ -2,7 +2,6 @@ package com.example.sshd.util; import java.io.IOException; import java.io.OutputStream; -import java.net.InetSocketAddress; import java.util.Map; import java.util.Optional; import java.util.Properties; @@ -34,34 +33,28 @@ public class ReplyUtil { public boolean replyToCommand(String command, OutputStream out, String prompt, ServerSession session) throws IOException { - String remoteIpAddress = ""; - String cmdHash = DigestUtils.md5Hex(command.trim()).toUpperCase(); - if (session.getIoSession().getRemoteAddress() instanceof InetSocketAddress) { - InetSocketAddress remoteAddress = (InetSocketAddress) session.getIoSession().getRemoteAddress(); - remoteIpAddress = remoteAddress.getAddress().getHostAddress(); - } else { - remoteIpAddress = session.getIoSession().getRemoteAddress().toString(); - } + String cmdHash = DigestUtils.md5Hex(command.trim()).toUpperCase(); if (StringUtils.equals(command.trim(), "about")) { - logger.info("[{}] {} About command detected: {}", remoteIpAddress, cmdHash, command.trim()); - out.write(String.format("\r\n%s\r\n%s", ipInfoMapping.get(remoteIpAddress), prompt).getBytes()); + logger.info("[{}] About command detected: {}", cmdHash, command.trim()); + out.write(String.format("\r\n%s\r\n%s", ipInfoMapping.get(Thread.currentThread().getName()), prompt) + .getBytes()); } else if (StringUtils.equals(command.trim(), "exit")) { - logger.info("[{}] {} Exiting command detected: {}", remoteIpAddress, cmdHash, command.trim()); + logger.info("[{}] Exiting command detected: {}", cmdHash, command.trim()); out.write(String.format("\r\nExiting...\r\n%s", prompt).getBytes()); - return false; + return true; } else if (hashReplies.containsKey(command.trim())) { - logger.info("[{}] {} Known command detected: {}", remoteIpAddress, cmdHash, command.trim()); + logger.info("[{}] Known command detected: {}", cmdHash, command.trim()); String reply = hashReplies.getProperty(command.trim()).replace("\\r", "\r").replace("\\n", "\n") .replace("\\t", "\t"); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } else if (hashReplies.containsKey(cmdHash)) { - logger.info("[{}] {} Known command-hash detected: {}", remoteIpAddress, cmdHash, command.trim()); + logger.info("[{}] Known command-hash detected: {}", cmdHash, command.trim()); String reply = hashReplies.getProperty(cmdHash).replace("\\r", "\r").replace("\\n", "\n").replace("\\t", "\t"); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } else if (hashReplies.containsKey(String.format("base64(%s)", cmdHash))) { - logger.info("[{}] {} Known base64-hash detected: {}", remoteIpAddress, cmdHash, command.trim()); + logger.info("[{}] Known base64-hash detected: {}", cmdHash, command.trim()); String reply = hashReplies.getProperty(String.format("base64(%s)", cmdHash)); reply = new String(Base64.decode(reply)); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); @@ -70,18 +63,17 @@ public class ReplyUtil { .filter(e -> command.trim().matches(((String) e.getKey()))) .map(e -> Pair.of((String) e.getKey(), (String) e.getValue())).findAny(); if (o.isPresent()) { - logger.info("[{}] {} Known pattern detected: {} ({})", remoteIpAddress, cmdHash, command.trim(), - o.get()); + logger.info("[{}] Known pattern detected: {} ({})", cmdHash, command.trim(), o.get()); String reply = hashReplies.getProperty(o.get().getRight(), "").replace("\\r", "\r").replace("\\n", "\n") .replace("\\t", "\t"); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } else { - logger.info("[{}] {} Command not found: {}", remoteIpAddress, cmdHash, command.trim()); - notFoundLogger.info("[{}] {} Command not found: {}", remoteIpAddress, cmdHash, command.trim()); + logger.info("[{}] Command not found: {}", cmdHash, command.trim()); + notFoundLogger.info("[{}] Command not found: {}", cmdHash, command.trim()); out.write(String.format("\r\nCommand '%s' not found. Try 'exit'.\r\n%s", command.trim(), prompt) .getBytes()); } } - return true; + return false; } }