Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Worker] Fix will not kill the subprocess in remote when stop a remote-shell task #15570 #15629

Merged
merged 12 commits into from
Apr 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.SystemUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -117,33 +118,45 @@ public static boolean kill(@NonNull TaskExecutionContext request) {
* @throws Exception exception
*/
public static String getPidsStr(int processId) throws Exception {
StringBuilder sb = new StringBuilder();
Matcher mat = null;

String rawPidStr;

// pstree pid get sub pids
if (SystemUtils.IS_OS_MAC) {
String pids = OSUtils.exeCmd(String.format("%s -sp %d", TaskConstants.PSTREE, processId));
if (StringUtils.isNotEmpty(pids)) {
mat = MACPATTERN.matcher(pids);
rawPidStr = OSUtils.exeCmd(String.format("%s -sp %d", TaskConstants.PSTREE, processId));
} else if (SystemUtils.IS_OS_LINUX) {
rawPidStr = OSUtils.exeCmd(String.format("%s -p %d", TaskConstants.PSTREE, processId));
} else {
rawPidStr = OSUtils.exeCmd(String.format("%s -p %d", TaskConstants.PSTREE, processId));
}

return parsePidStr(rawPidStr);
}

public static String parsePidStr(String rawPidStr) {
alei1206 marked this conversation as resolved.
Show resolved Hide resolved

log.info("prepare to parse pid, raw pid string: {}", rawPidStr);
ArrayList<String> allPidList = new ArrayList<>();
Matcher mat = null;
if (SystemUtils.IS_OS_MAC) {
if (StringUtils.isNotEmpty(rawPidStr)) {
mat = MACPATTERN.matcher(rawPidStr);
}
} else if (SystemUtils.IS_OS_LINUX) {
String pids = OSUtils.exeCmd(String.format("%s -p %d", TaskConstants.PSTREE, processId));
if (StringUtils.isNotEmpty(pids)) {
mat = LINUXPATTERN.matcher(pids);
if (StringUtils.isNotEmpty(rawPidStr)) {
mat = LINUXPATTERN.matcher(rawPidStr);
}
} else {
String pids = OSUtils.exeCmd(String.format("%s -p %d", TaskConstants.PSTREE, processId));
if (StringUtils.isNotEmpty(pids)) {
mat = WINDOWSPATTERN.matcher(pids);
if (StringUtils.isNotEmpty(rawPidStr)) {
mat = WINDOWSPATTERN.matcher(rawPidStr);
}
}

if (null != mat) {
while (mat.find()) {
sb.append(mat.group(1)).append(" ");
allPidList.add(mat.group(1));
}
}

return sb.toString().trim();
return String.join(" ", allPidList).trim();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

package org.apache.dolphinscheduler.plugin.task.remoteshell;

import static org.apache.dolphinscheduler.plugin.task.remoteshell.RemoteExecutor.COMMAND.PSTREE_COMMAND;

import org.apache.dolphinscheduler.plugin.datasource.ssh.SSHUtils;
import org.apache.dolphinscheduler.plugin.datasource.ssh.param.SSHConnectionParam;
import org.apache.dolphinscheduler.plugin.task.api.TaskException;
import org.apache.dolphinscheduler.plugin.task.api.parser.TaskOutputParameterParser;
import org.apache.dolphinscheduler.plugin.task.api.utils.ProcessUtils;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.channel.ChannelExec;
import org.apache.sshd.client.channel.ClientChannelEvent;
Expand Down Expand Up @@ -50,7 +54,6 @@ public class RemoteExecutor implements AutoCloseable {
static final int TRACK_INTERVAL = 5000;

protected Map<String, String> taskOutputParams = new HashMap<>();

private SshClient sshClient;
private ClientSession session;
private SSHConnectionParam sshConnectionParam;
Expand Down Expand Up @@ -154,11 +157,45 @@ public void cleanData(String taskId) {

public void kill(String taskId) throws IOException {
String pid = getTaskPid(taskId);
String killCommand = String.format(COMMAND.KILL_COMMAND, pid);

if (StringUtils.isEmpty(pid)) {
log.warn("query remote-shell task remote process id with empty");
return;
}
if (!NumberUtils.isParsable(pid)) {
log.error("query remote-shell task remote process id error, pid {} can not parse to number", pid);
return;
}

// query all pid
String remotePidStr = getAllRemotePidStr(pid);
String killCommand = String.format(COMMAND.KILL_COMMAND, remotePidStr);
log.info("prepare to execute kill command in host: {}, kill cmd: {}", sshConnectionParam.getHost(),
killCommand);
runRemote(killCommand);
cleanData(taskId);
}

protected String getAllRemotePidStr(String pid) {

String remoteProcessIdStr = "";
String cmd = String.format(PSTREE_COMMAND, pid);
log.info("query all process id cmd: {}", cmd);

try {
String rawPidStr = runRemote(cmd);
remoteProcessIdStr = ProcessUtils.parsePidStr(rawPidStr);
if (!remoteProcessIdStr.startsWith(pid)) {
log.error("query remote process id error, [{}] first pid not equal [{}]", remoteProcessIdStr, pid);
remoteProcessIdStr = pid;
}
} catch (Exception e) {
log.error("query remote all process id error", e);
remoteProcessIdStr = pid;
}
return remoteProcessIdStr;
}

public String getTaskPid(String taskId) throws IOException {
String pidCommand = String.format(COMMAND.GET_PID_COMMAND, taskId);
return runRemote(pidCommand).trim();
Expand Down Expand Up @@ -238,6 +275,9 @@ private COMMAND() {
static final String ADD_STATUS_COMMAND = "\necho %s$?";

static final String CAT_FINAL_SCRIPT = "cat %s%s.sh";

static final String PSTREE_COMMAND = "pstree -p %s";

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.dolphinscheduler.plugin.task.remoteshell;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
Expand Down Expand Up @@ -135,4 +136,22 @@ void testGetTaskExitCode() throws IOException {
doReturn("DOLPHINSCHEDULER-REMOTE-SHELL-TASK-STATUS-1").when(remoteExecutor).runRemote(trackCommand);
Assertions.assertEquals(1, remoteExecutor.getTaskExitCode(taskId));
}

@Test
void getAllRemotePidStr() throws IOException {

RemoteExecutor remoteExecutor = spy(new RemoteExecutor(sshConnectionParam));
doReturn("bash(9527)───sleep(9528)").when(remoteExecutor).runRemote(anyString());
String allPidStr = remoteExecutor.getAllRemotePidStr("9527");
Assertions.assertEquals("9527 9528", allPidStr);

doReturn("systemd(1)───sleep(9528)").when(remoteExecutor).runRemote(anyString());
allPidStr = remoteExecutor.getAllRemotePidStr("9527");
Assertions.assertEquals("9527", allPidStr);

doThrow(new TaskException()).when(remoteExecutor).runRemote(anyString());
allPidStr = remoteExecutor.getAllRemotePidStr("9527");
Assertions.assertEquals("9527", allPidStr);

}
}
Loading