Skip to content

Commit

Permalink
limit auth schemes removing warnings, and tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
ahgittin committed Aug 30, 2021
1 parent 12c2e16 commit c0f3f30
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ShellCommand implements AutoCloseable {
* If no output is available before the wsman:OperationTimeout expires, the server MUST return a WSManFault with the Code attribute equal to "2150858793"
* https://msdn.microsoft.com/en-us/library/cc251676.aspx
*/
private static final String WSMAN_FAULT_CODE_OPERATION_TIMEOUT_EXPIRED = "2150858793";
static final String WSMAN_FAULT_CODE_OPERATION_TIMEOUT_EXPIRED = "2150858793";

/**
* Example response:
Expand Down
49 changes: 36 additions & 13 deletions client/src/main/java/io/cloudsoft/winrm4j/client/WinRmClient.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.cloudsoft.winrm4j.client;

import io.cloudsoft.winrm4j.client.encryption.AsyncHttpEncryptionAwareConduitFactory;
import io.cloudsoft.winrm4j.client.ntlm.NTCredentialsWithEncryption;
import io.cloudsoft.winrm4j.client.spnego.WsmanViaSpnegoSchemeFactory;
import java.io.Writer;
Expand All @@ -12,10 +13,13 @@
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Predicate;

import java.util.function.Supplier;
Expand Down Expand Up @@ -56,6 +60,7 @@
import org.apache.http.config.RegistryBuilder;
import org.apache.http.impl.auth.KerberosSchemeFactory;
import org.apache.http.impl.auth.NTLMSchemeFactory;
import org.apache.http.impl.client.TargetAuthenticationStrategy;
import org.apache.neethi.Policy;
import org.apache.neethi.builders.PrimitiveAssertion;
import org.ietf.jgss.GSSContext;
Expand All @@ -76,6 +81,7 @@
import io.cloudsoft.winrm4j.client.wsman.Locale;
import io.cloudsoft.winrm4j.client.wsman.OptionSetType;
import io.cloudsoft.winrm4j.client.wsman.OptionType;
import sun.awt.image.ImageWatched.Link;

/**
* TODO confirm if commands can be called in parallel in one shell (probably not)!
Expand Down Expand Up @@ -299,21 +305,24 @@ private static void initializeClientAndService(WinRm winrm, WinRmClientBuilder b

Supplier<Credentials> creds = () -> new NTCredentialsWithEncryption(username, password, null, domain);

Map<String,AuthSchemeProvider> authSchemeRegistry = null;
Set<String> authSchemes = null;

switch (authenticationScheme) {
case AuthSchemes.BASIC:
if (builder.payloadEncryptionMode().isRequired()) {
throw new IllegalStateException("Encryption is required, which is not compatible with auth");
}
bp.getRequestContext().put(BindingProvider.USERNAME_PROPERTY, username);
bp.getRequestContext().put(BindingProvider.PASSWORD_PROPERTY, password);
authSchemes = Collections.singleton(AuthSchemes.BASIC);

break;

case AuthSchemes.NTLM:
Registry<AuthSchemeProvider> authSchemeRegistry = RegistryBuilder.<AuthSchemeProvider>create()
.register(AuthSchemes.NTLM, new NTLMSchemeFactory())
.register(AuthSchemes.SPNEGO, new NtlmMasqAsSpnegoSchemeFactory(builder.payloadEncryptionMode()))
.build();
bp.getRequestContext().put(AuthSchemeProvider.class.getName(), authSchemeRegistry);
authSchemeRegistry = new LinkedHashMap<>();
authSchemeRegistry.put(AuthSchemes.NTLM, new NTLMSchemeFactory());
authSchemeRegistry.put(AuthSchemes.SPNEGO, new NtlmMasqAsSpnegoSchemeFactory(builder.payloadEncryptionMode()));

advancedHttpConfigNeeded = true;
break;
Expand All @@ -337,25 +346,39 @@ private static void initializeClientAndService(WinRm winrm, WinRmClientBuilder b
creds = () -> newCreds;
}

authSchemeRegistry = RegistryBuilder.<AuthSchemeProvider>create()
.register(AuthSchemes.KERBEROS, new KerberosSchemeFactory())
.build();
bp.getRequestContext().put(AuthSchemeProvider.class.getName(), authSchemeRegistry);
authSchemeRegistry = new LinkedHashMap<>();
authSchemeRegistry.put(AuthSchemes.KERBEROS, new KerberosSchemeFactory());

advancedHttpConfigNeeded = true;
break;

case AuthSchemes.SPNEGO:
authSchemeRegistry = RegistryBuilder.<AuthSchemeProvider>create()
.register(AuthSchemes.SPNEGO, new WsmanViaSpnegoSchemeFactory())
.build();
bp.getRequestContext().put(AuthSchemeProvider.class.getName(), authSchemeRegistry);
authSchemeRegistry = new LinkedHashMap<>();
authSchemeRegistry.put(AuthSchemes.SPNEGO, new WsmanViaSpnegoSchemeFactory());

advancedHttpConfigNeeded = true;
break;
default:
throw new UnsupportedOperationException("No such authentication scheme " + authenticationScheme+"; " +
"options are "+Arrays.asList(AuthSchemes.BASIC, AuthSchemes.NTLM, AuthSchemes.SPNEGO, AuthSchemes.KERBEROS));
}

if (authSchemeRegistry!=null) {
if (authSchemes==null) authSchemes = authSchemeRegistry.keySet();
RegistryBuilder<AuthSchemeProvider> rb = RegistryBuilder.<AuthSchemeProvider>create();
authSchemeRegistry.forEach(rb::register);
bp.getRequestContext().put(AuthSchemeProvider.class.getName(), rb.build());
}

if (authSchemes!=null) {
if (builder.endpointConduitFactory==null) {
// set this again mainly so we can set the target auth schemes; but also so we can fail if interceptors did not apply
builder.endpointConduitFactory = new AsyncHttpEncryptionAwareConduitFactory(builder.payloadEncryptionMode(), builder.targetAuthSchemes(), null);
} else {
builder.endpointConduitFactory.targetAuthSchemes(authSchemes);
}
}

if (advancedHttpConfigNeeded) {
bp.getRequestContext().put(Credentials.class.getName(), creds.get());
bp.getRequestContext().put("http.autoredirect", true);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.cloudsoft.winrm4j.client;

import io.cloudsoft.winrm4j.client.encryption.AsyncHttpEncryptionAwareConduitFactory;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Collection;
import java.util.Map;
import java.util.function.Predicate;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -65,7 +67,8 @@ public class WinRmClientBuilder {
protected SSLContext sslContext;
protected boolean requestNewKerberosTicket;
protected PayloadEncryptionMode payloadEncryptionMode;
protected HTTPConduitFactory endpointConduitFactory;
protected Collection<String> targetAuthSchemes;
protected AsyncHttpEncryptionAwareConduitFactory endpointConduitFactory;

WinRmClientBuilder(String endpoint) {
this(toUrlUnchecked(WinRmClient.checkNotNull(endpoint, "endpoint")));
Expand Down Expand Up @@ -164,7 +167,7 @@ public WinRmClientBuilder receiveTimeout(Long receiveTimeout) {
/**
* @param retryReceiveAfterOperationTimeout define if a new Receive request will be send when the server returns
* a fault with the code {@link ShellCommand#WSMAN_FAULT_CODE_OPERATION_TIMEOUT_EXPIRED}.
* Default value {@link #ALWAYS_RETRY_AFTER_OPERATION_TIMEOUT_EXPIRED}.
* Default value {@link #alwaysRetryReceiveAfterOperationTimeout()}.
*/
public WinRmClientBuilder retryReceiveAfterOperationTimeout(Predicate<String> retryReceiveAfterOperationTimeout) {
this.retryReceiveAfterOperationTimeout = retryReceiveAfterOperationTimeout;
Expand Down Expand Up @@ -299,4 +302,13 @@ public PayloadEncryptionMode payloadEncryptionMode() {
return payloadEncryptionMode!=null ? payloadEncryptionMode : PayloadEncryptionMode.OPTIONAL;
}

public WinRmClientBuilder targetAuthSchemes(Collection<String> targetAuthSchemes) {
this.targetAuthSchemes = targetAuthSchemes;
return this;
}

public Collection<String> targetAuthSchemes() {
return targetAuthSchemes;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private static WinRm doCreateServiceWithBean(Bus bus, WinRmClientBuilder builder
// it doesn't work to override the conduit factory, ie this doesn't work:
// properties.put(HTTPConduitFactory.class.getName(), new HttpEncryptingConduitFactory(builder.payloadEncryptionMode(), (Map) null));
// we need to set it explicitly on the input, which this special property does:
builder.endpointConduitFactory = new AsyncHttpEncryptionAwareConduitFactory(builder.payloadEncryptionMode(), (Map) null);
builder.endpointConduitFactory = new AsyncHttpEncryptionAwareConduitFactory(builder.payloadEncryptionMode(), builder.targetAuthSchemes(), (Map) null);
}

factory.setInInterceptors(inInterceptors);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.util.Collection;
import org.apache.cxf.Bus;
import org.apache.cxf.io.CacheAndWriteOutputStream;
import org.apache.cxf.message.Message;
Expand All @@ -20,6 +21,7 @@
import org.apache.cxf.ws.addressing.EndpointReferenceType;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.auth.Credentials;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.protocol.HTTP;
import org.slf4j.Logger;
Expand Down Expand Up @@ -47,10 +49,12 @@ static ContentWithType getAppropriate(Message msg) {
}

private final PayloadEncryptionMode payloadEncryptionMode;
private final Collection<String> targetAuthSchemes;

public AsyncHttpEncryptionAwareConduit(PayloadEncryptionMode payloadEncryptionMode, Bus b, EndpointInfo ei, EndpointReferenceType t, AsyncHTTPConduitFactory factory) throws IOException {
public AsyncHttpEncryptionAwareConduit(PayloadEncryptionMode payloadEncryptionMode, Bus b, EndpointInfo ei, EndpointReferenceType t, AsyncHttpEncryptionAwareConduitFactory factory) throws IOException {
super(b, ei, t, factory);
this.payloadEncryptionMode = payloadEncryptionMode;
this.payloadEncryptionMode = factory.payloadEncryptionMode;
this.targetAuthSchemes = factory.targetAuthSchemes;
}

protected OutputStream createOutputStream(Message message,
Expand Down Expand Up @@ -109,6 +113,10 @@ protected ContentWithType getAppropriate() {
entity.setChunked(true);
entity.setContentType((String)message.get(Message.CONTENT_TYPE));
requestEntity.setEntity(entity);

requestEntity.setConfig(RequestConfig.copy( requestEntity.getConfig() )
.setTargetPreferredAuthSchemes(targetAuthSchemes)
.build());
}

public abstract static class EncryptionAwareHttpEntity extends BasicHttpEntity {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import io.cloudsoft.winrm4j.client.PayloadEncryptionMode;
import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
import org.apache.cxf.Bus;
import org.apache.cxf.service.model.EndpointInfo;
import org.apache.cxf.transport.http.HTTPConduit;
Expand All @@ -11,11 +13,13 @@

public class AsyncHttpEncryptionAwareConduitFactory extends AsyncHTTPConduitFactory {

private final PayloadEncryptionMode payloadEncryptionMode;
final PayloadEncryptionMode payloadEncryptionMode;
Collection<String> targetAuthSchemes;

public AsyncHttpEncryptionAwareConduitFactory(PayloadEncryptionMode payloadEncryptionMode, Map<String, Object> conf) {
public AsyncHttpEncryptionAwareConduitFactory(PayloadEncryptionMode payloadEncryptionMode, Collection<String> targetAuthSchemes, Map<String, Object> conf) {
super(conf);
this.payloadEncryptionMode = payloadEncryptionMode;
this.targetAuthSchemes = targetAuthSchemes;
}

@Override
Expand All @@ -28,4 +32,7 @@ public HTTPConduit createConduit(Bus bus,
return new AsyncHttpEncryptionAwareConduit(payloadEncryptionMode, bus, localInfo, target, this);
}

public void targetAuthSchemes(Set<String> authSchemes) {
this.targetAuthSchemes = authSchemes;
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
package io.cloudsoft.winrm4j.client.encryption;

import io.cloudsoft.winrm4j.client.PayloadEncryptionMode;
import io.cloudsoft.winrm4j.client.encryption.NtlmEncryptionUtils.Decryptor;
import io.cloudsoft.winrm4j.client.ntlm.NTCredentialsWithEncryption;
import io.cloudsoft.winrm4j.client.ntlm.NtlmKeys.NegotiateFlags;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import org.apache.cxf.helpers.IOUtils;
import org.apache.cxf.interceptor.StaxInInterceptor;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.AbstractPhaseInterceptor;
import org.apache.cxf.phase.Phase;
import org.apache.http.auth.Credentials;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -24,10 +14,6 @@ public class DecryptAndVerifyInInterceptor extends AbstractPhaseInterceptor<Mess

public static final String APPLIED = DecryptAndVerifyInInterceptor.class.getSimpleName()+".APPLIED";

public static final String ENCRYPTED_BOUNDARY_PREFIX = "--Encrypted Boundary";
public static final String ENCRYPTED_BOUNDARY_CR = ENCRYPTED_BOUNDARY_PREFIX+"\r\n";
public static final String ENCRYPTED_BOUNDARY_END = ENCRYPTED_BOUNDARY_PREFIX+"--\r\n";

private final PayloadEncryptionMode payloadEncryptionMode;

public DecryptAndVerifyInInterceptor(PayloadEncryptionMode payloadEncryptionMode) {
Expand All @@ -37,7 +23,8 @@ public DecryptAndVerifyInInterceptor(PayloadEncryptionMode payloadEncryptionMode
}

public void handleMessage(Message message) {
NtlmEncryptionUtils.of(message, payloadEncryptionMode).decrypt(message);
NtlmEncryptionUtils utils = NtlmEncryptionUtils.of(message, payloadEncryptionMode);
if (utils!=null) utils.decrypt(message);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,12 @@ public static byte[] md5digest(byte[] bytes) {
}
}

public static CryptoHandler encryptorArc4(byte[] key) {
return cryptorArc4(true, key);
}

public static CryptoHandler decryptorArc4(byte[] key) {
return cryptorArc4(false, key);
}

public static CryptoHandler cryptorArc4(boolean forEncryption, byte[] key) {
public static Cipher arc4(byte[] key) {
// engine needs to be stateful
try {
final Cipher rc4 = Cipher.getInstance("RC4");
rc4.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(key, "RC4"));

return new CryptoHandler() {
@Override
public byte[] update(byte[] input) {
try {
return rc4.update(input);

} catch (Exception e) {
throw new IllegalStateException(e);
}
}
};
return rc4;

} catch (Exception e) {
throw new IllegalStateException(e);
Expand Down Expand Up @@ -81,8 +62,4 @@ public static byte[] hmacMd5(byte[] key, byte[] body) {
}
}

public interface CryptoHandler {
byte[] update(byte[] input);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import io.cloudsoft.winrm4j.client.encryption.AsyncHttpEncryptionAwareConduit.EncryptionAwareHttpEntity;
import io.cloudsoft.winrm4j.client.encryption.WinrmEncryptionUtils;
import io.cloudsoft.winrm4j.client.encryption.WinrmEncryptionUtils.CryptoHandler;
import io.cloudsoft.winrm4j.client.ntlm.forks.httpclient.NTLMEngineImpl.Type3Message;
import java.util.concurrent.atomic.AtomicLong;
import javax.crypto.Cipher;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.HttpRequest;
import org.apache.http.auth.NTCredentials;
Expand Down Expand Up @@ -84,15 +84,15 @@ public AtomicLong getSequenceNumberOutgoing() {
return sequenceNumberOutgoing;
}

CryptoHandler encryptor = null;
public CryptoHandler getStatefulEncryptor() {
if (encryptor==null) encryptor = WinrmEncryptionUtils.encryptorArc4(getClientSealingKey());
Cipher encryptor = null;
public Cipher getStatefulEncryptor() {
if (encryptor==null) encryptor = WinrmEncryptionUtils.arc4(getClientSealingKey());
return encryptor;
}

CryptoHandler decryptor = null;
public CryptoHandler getStatefulDecryptor() {
if (decryptor==null) decryptor = WinrmEncryptionUtils.decryptorArc4(getServerSealingKey());
Cipher decryptor = null;
public Cipher getStatefulDecryptor() {
if (decryptor==null) decryptor = WinrmEncryptionUtils.arc4(getServerSealingKey());
return decryptor;
}

Expand Down
Loading

0 comments on commit c0f3f30

Please sign in to comment.