SFTP连接的重用

背景:目前的业务系统每天都需要生成大量的报表,生成的报表都需要上传到特定的SFTP服务器上,所以项目上用到SFTP连接的地方比较多。而每次上传文件都要经历登录FTP、上传文件、登出FTP这些重复的步骤,而每次登录都需要耗时2秒左右,当文件数量过多,其耗时也是相当巨大的。所以想通过重用SFTP连接来达到优化的效果。本文主要讲述基于Apache 的commons-pool2的池化技术来实现SFTP连接的重用。这样就不用每次上传文件都要先登录ftp、登出ftp。而是相同地址第一次登录一下,保持这个连接放进池里,后面要上传的话直接从池里拿一个连接,上传文件,省去了重复登录登出的时间。

一、Apache commons-pool2

先简单介绍一下Apache commons-pool2包。Redis 的常用客户端 Jedis、数据库连接池DBCP都是用的 Commons Pool实现。

Apache Common-pool2由三大模块组成:ObjectPool、PooledObject和PooledObjectFactory。

  • ObjectPool:提供所有对象的存取管理。
  • PooledObject:池化的对象,是对对象的一个包装,加上了对象的一些其他信息,包括对象的状态(已用、空闲),对象的创建时间等。
  • PooledObjectFactory:工厂类,负责池化对象的创建,对象的初始化,对象状态的销毁和对象状态的验证。

ObjectPool会持有PooledObjectFactory,将具体的对象的创建、初始化、销毁等任务交给它处理,其操作对象是PooledObject,即具体的Object的包装类。

二、具体实现

  1. 加入依赖
<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-pool2</artifactId>
</dependency>
<dependency>
    <groupId>com.jcraft</groupId>
    <artifactId>jsch</artifactId>
    <version>0.1.55</version>
</dependency>

我们使用使用 JSch 实现的 SFTP 功能。JSch 是 Java Secure Channel 的缩写。 JSch 是一个 SSH2 的纯 Java 实现。它允许你连接到一个 SSH 服务器,并且可以使用端口转发, X11 转发,文件传输等。

  1. SFTP 连接配置

配置 IP ,用户名,密码,连接池相关配置。

@Data
public class SftpContactProperties {
    /** SFTP 登录用户名*/
    private String username;
    /** SFTP 登录密码*/
    private String password;
    /** 私钥 */
    private String privateKey;
    /** SFTP 服务器地址IP地址*/
    private String host;
    /** SFTP 端口*/
    private int port;

    private SftpContactPoolProperties poolProperties;

    public SftpContactProperties(String username, String password, String host, int port, String privateKey) {
        this.username = username;
        this.password = password;
        this.privateKey = privateKey;
        this.host = host;
        this.port = port;
    }

    @Data
    public static class SftpContactPoolProperties {
        private String poolPrefix;
        /**
         * 最大空闲
         */
        private int maxIdle = 3;
        /**
         * 最大总数
         */
        private int maxTotal = 6;
        /**
         * 最小空闲
         */
        private int minIdle = 0;
        /**
         * 初始化连接数
         */
        private int initialSize = 1;
    }
}
  1. SFTP 连接对象创建工厂

继承 commons-pool2 的 BasePooledObjectFactory 类,实现 create , wrap 等方法就可以实现一个对象创建工厂。

public class SftpContactFactory extends BasePooledObjectFactory<SftpService> {
    private SftpContactProperties properties;

    public SftpContactFactory(SftpContactProperties properties){
        this.properties = properties;
    }

    @Override
    public SftpService create() throws Exception {
        // 创建新对象
        SftpService sftpService = new SftpService(properties);
        boolean success = sftpService.login();
        if (success) {
            return sftpService;
        } else {
            throw new Exception("login fail");
        }
    }

    @Override
    public PooledObject<SftpService> wrap(SftpService sftpService) {
        // 池化对象
        return new DefaultPooledObject<>(sftpService);
    }

    @Override
    public void passivateObject(PooledObject<SftpService> p) throws Exception {
        // 将对象返回池时进行的操作,此处将工作目录设置为根目录
        p.getObject().changeToHomeDir();
    }

    @Override
    public void destroyObject(PooledObject<SftpService> p) throws Exception {
        p.getObject().logout();
    }

    @Override
    public boolean validateObject(PooledObject<SftpService> p) {
        return p.getObject().isConnected();
    }
}
  1. SFTP 连接池

继承 GenericObjectPool,连接池将工厂对象与配置对象结合。 GenericObjectPool中有borrowObject, returnObject等方法供我们获取,回归对象。

public class SftpContactPool extends GenericObjectPool<SftpService> {
    public SftpContactPool(PooledObjectFactory<SftpService> factory) {
        super(factory);
    }

    public SftpContactPool(PooledObjectFactory<SftpService> factory, GenericObjectPoolConfig config) {
        super(factory, config);
    }

    public SftpContactPool(PooledObjectFactory<SftpService> factory, GenericObjectPoolConfig config, AbandonedConfig abandonedConfig) {
        super(factory, config, abandonedConfig);
    }
}
  1. SFTP 连接对象

具体操作SFTP的连接ChannelSftp、Session。

@Data
public class SftpService {
    private transient Logger log = LoggerFactory.getLogger(this.getClass());

    @Getter
    private ChannelSftp sftp;

    private Session session;
    /** SFTP 登录用户名*/
    private String username;
    /** SFTP 登录密码*/
    private String password;
    /** 私钥 */
    private String privateKey;
    /** SFTP 服务器地址IP地址*/
    private String host;
    /** SFTP 端口*/
    private int port;

    private ScheduledExecutorService timer = Executors.newScheduledThreadPool(1);;

    private static DistributedLockService distributedLockService = SpringContextUtil.getBean(SimpleDistributedLockServiceImpl.class);

    public SftpService(SftpContactProperties properties) {
        this.username = properties.getUsername();
        this.password = properties.getPassword();
        this.host = properties.getHost();
        this.port = properties.getPort();
        this.privateKey = properties.getPrivateKey();
        this.timer = Executors.newScheduledThreadPool(1);
    }

    public SftpService(){}


    /**
     * 连接sftp服务器
     */
    public boolean login(Integer timeout) {
        boolean success = false;
        try {
            JSch jsch = new JSch();
            if (privateKey != null) {
                if (StringUtils.isNotBlank(password)) {
                    jsch.addIdentity(privateKey, password);// 设置私钥+解析密码
                } else {
                    jsch.addIdentity(privateKey);// 设置私钥
                }
            }

            session = jsch.getSession(username, host, port);
            log.info("get sftp session success. {}", host);

            if (password != null && privateKey == null) {
                session.setPassword(password);
            }
            Properties config = new Properties();
            config.put("StrictHostKeyChecking", "no");

            session.setConfig(config);
            session.connect(timeout == null? 30000 : timeout);
            log.info("sftp session connect success. {}", host);

            Channel channel = session.openChannel("sftp");
            channel.connect();
            log.info("channel connect success. {}", host);

            sftp = (ChannelSftp) channel;
            if (timer != null) {
                timer.scheduleAtFixedRate(() -> {
                    try {
                        session.sendKeepAliveMsg();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }, 2, 5, TimeUnit.SECONDS);
            }
            success = true;
        } catch (JSchException e) {
            log.error("login to {} fail.", host, e);
        }
        return success;
    }/**
     * 连接sftp服务器
     */
    public boolean login() {
        return login(null);
    }

    /**
     * 关闭连接 server
     */
    public void logout() {
        if (timer != null) {
            timer.shutdown();
        }
        if (sftp != null) {
            if (sftp.isConnected()) {
                sftp.disconnect();
            }
        }
        if (session != null) {
            if (session.isConnected()) {
                session.disconnect();
            }
        }
    }


    /**
     * 将输入流的数据上传到sftp作为文件。文件完整路径=basePath+directory
     * @param basePath  服务器的基础路径
     * @param directory  上传到该目录
     * @param sftpFileName  sftp端文件名
     */
    public void upload(String basePath,String directory, String sftpFileName, ByteArrayInputStream input) throws SftpException {
        try {
            sftp.cd(basePath);
            sftp.cd(directory);
        } catch (SftpException e) {
            //目录不存在,则创建文件夹
            String [] dirs=directory.split("/");
            String tempPath=basePath;
            for(String dir:dirs){
                if(null== dir || "".equals(dir)) {
                    continue;
                }
                tempPath+="/"+dir;
                try{
                    sftp.cd(tempPath);
                }catch(SftpException ex){
                    try {
                        boolean lock = lock(tempPath);
                        log.info("--lock--{}--{}", tempPath, lock);
                        while (!lock) {
                            Thread.sleep(100);
                            lock = lock(tempPath);
                        }
                        if(!exist(tempPath)) {
                            log.info("--dir[{}] not exist--", tempPath);
                            sftp.mkdir(tempPath);
                        } else {
                            log.info("--dir[{}] exist--", tempPath);
                        }
                        sftp.cd(tempPath);
                    } catch (SftpException ex1) {
                        log.error("mkdir in {} error1.", host);
                        throw ex1;
                    } catch (Exception ex2) {
                        log.error("mkdir in {} error2.", host);
                        try {
                            throw ex2;
                        } catch (InterruptedException e1) {
                            log.error("mkdir in {} error3.", host, e1);
                            Thread.currentThread().interrupt();
                        }
                    } finally {
                        log.info("--unlock--{}", dir);
                        unlock(tempPath);
                    }
                }
            }
        }
        sftp.put(input, sftpFileName);  //上传文件
    }

    public boolean exist(String path) {
        try {
            sftp.stat(path);
        } catch (SftpException e) {
            log.info("---path[{}] not exist", path);
            return false;
        }
        return true;
    }

    /**
     * 将输入流的数据上传到sftp作为文件。文件完整路径=basePath+directory
     * @param directory  上传到该目录
     * @param sftpFileName  sftp端文件名
     * @return
     */
    public boolean upload(String directory, String sftpFileName, byte[] bytes) throws SftpException {
        ByteArrayInputStream input = new ByteArrayInputStream(bytes);
        try {
            sftp.cd(directory);
        } catch (Exception e) {
            //目录不存在,则创建文件夹
            String [] dirs=directory.split("/");
            String tempPath="";
            for(String dir:dirs){
                if(null== dir || "".equals(dir)) {
                    continue;
                }
                tempPath+="/"+dir;
                try{
                    sftp.cd(tempPath);
                }catch(SftpException ex){
                    try {
                        boolean lock = lock(tempPath);
                        log.info("--lock--{}--{}", tempPath, lock);
                        while (!lock) {
                            Thread.sleep(100);
                            lock = lock(tempPath);
                        }
                        if(!exist(tempPath)) {
                            log.info("--dir[{}] not exist--", tempPath);
                            sftp.mkdir(tempPath);
                        } else {
                            log.info("--dir[{}] exist--", tempPath);
                        }
                        sftp.cd(tempPath);
                    } catch (SftpException ex1) {
                        log.error("mkdir in {} error1.", host);
                        throw ex1;
                    } catch (Exception ex2) {
                        log.error("mkdir in {} error2.", host);
                        try {
                            throw ex2;
                        } catch (InterruptedException e1) {
                            log.error("mkdir in {} error3.", host, e1);
                            Thread.currentThread().interrupt();
                        }
                    } finally {
                        log.info("--unlock--{}", dir);
                        unlock(tempPath);
                    }
                }
            }
        }
        sftp.put(input, sftpFileName);  //上传文件
        return true;
    }

    /**
     * 下载文件。
     * @param directory 下载目录
     * @param downloadFile 下载的文件
     * @param saveFile 存在本地的路径
     */
    public void download(String directory, String downloadFile, String saveFile) throws SftpException, FileNotFoundException{
        if (directory != null && !"".equals(directory)) {
            sftp.cd(directory);
        }
        File file = new File(saveFile);
        sftp.get(downloadFile, new FileOutputStream(file));
    }

    /**
     * 下载文件
     * @param directory 下载目录
     * @param downloadFile 下载的文件名
     * @return 字节数组
     */
    public byte[] download(String directory, String downloadFile) throws SftpException, IOException {
        if (directory != null && !"".equals(directory)) {
            sftp.cd(directory);
        }
        InputStream is = sftp.get(downloadFile);

        byte[] fileData = IOUtils.toByteArray(is);

        return fileData;
    }


    /**
     * 删除文件
     * @param directory 要删除文件所在目录
     * @param deleteFile 要删除的文件
     */
    public void delete(String directory, String deleteFile) throws SftpException{
        sftp.cd(directory);
        sftp.rm(deleteFile);
    }


    /**
     * 列出目录下的文件
     * @param directory 要列出的目录
     */
    public Vector<?> listFiles(String directory) throws SftpException {
        return sftp.ls(directory);
    }

    public void cd(String dir) throws SftpException {
        sftp.cd(dir);
    }

    /**
     * 切换到根目录
     * @return boolean
     */
    public void changeToHomeDir() throws SftpException{
        String homeDir = sftp.getHome();
        cd(homeDir.replaceAll("\\\\", "/"));
    }

    public boolean isConnected() {
        return null != sftp && sftp.isConnected();
    }

    private boolean lock(String dir) {
        return distributedLockService.lock(dir, 10);
    }

    private void unlock(String dir) {
        distributedLockService.unlock(dir);
    }
}
  1. SFTP 工具类

Sftp工具类包装及测试实例:

@Slf4j
public class SftpUtils {
    private static final String SEP = "`";

    private static final Lock lock = new ReentrantLock();

    public static SftpContactPool getSftpContactPool(SftpContactProperties properties) {
        //设置对象池的相关参数
        SftpContactProperties.SftpContactPoolProperties poolProperties = new SftpContactProperties.SftpContactPoolProperties();
        String key = String.join(SEP, properties.getUsername(), properties.getHost(), String.valueOf(properties.getPort()), properties.getPassword(), properties.getPrivateKey(), poolProperties.getPoolPrefix());
        SftpContactPoolStorage poolStorage = ApplicationContextProvider.getBean(SftpContactPoolStorage.class);
        Map<String, SftpContactPool> sftpContactPoolMap = poolStorage.getSftpContactPoolMap();
        SftpContactPool sftpContactPool = sftpContactPoolMap.get(key);
        if (sftpContactPool == null) {
            lock.lock();
            try {
                sftpContactPool = sftpContactPoolMap.get(key);
                if (sftpContactPool == null) {
                    SftpContactFactory sftpContactFactory = new SftpContactFactory(properties);
                    GenericObjectPoolConfig poolConfig = new GenericObjectPoolConfig();
                    poolConfig.setMaxIdle(poolProperties.getMaxIdle());
                    poolConfig.setMaxTotal(poolProperties.getMaxTotal());
                    poolConfig.setMinIdle(poolProperties.getMinIdle());
                    poolConfig.setBlockWhenExhausted(true);
                    poolConfig.setTestOnBorrow(true);
                    poolConfig.setTestOnReturn(true);
                    poolConfig.setTestWhileIdle(true);
                    poolConfig.setTimeBetweenEvictionRunsMillis(1000 * 60 * 30);
                    //一定要关闭jmx,不然springboot启动会报已经注册了某个jmx的错误
                    poolConfig.setJmxEnabled(false);
                    //新建一个对象池,传入对象工厂和配置
                    sftpContactPool = new SftpContactPool(sftpContactFactory, poolConfig);
                    sftpContactPoolMap.put(key, sftpContactPool);
                }
            } catch (Exception e){
                log.error("getSftpService", e);
            } finally {
                lock.unlock();
            }
        }
        return sftpContactPool;
    }

    /**
     * 上传文件测试
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        SftpContactProperties properties = new SftpContactProperties("username", "password", "127.0.0.1", 6002, null);
        SftpContactPool sftpContactPool = getSftpContactPool(properties);
        SftpService sftpService = sftpContactPool.borrowObject();
        sftpService.upload("/test", "test1.txt", "test".getBytes());
        sftpContactPool.returnObject(sftpService);

    }
}

注意使用完一定要归还对象!!

原文链接:,转发请注明来源!