Skip to content

Commit

Permalink
feat(spring): support thread-safe Spring Boot WebApplicationContext i…
Browse files Browse the repository at this point in the history
…nitialization and lookup
  • Loading branch information
lincolnthree committed Nov 1, 2023
1 parent e6d94da commit 1614e5d
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
*/
package org.ocpsoft.rewrite.el;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.ocpsoft.common.services.ServiceLoader;
import org.ocpsoft.logging.Logger;
Expand Down Expand Up @@ -62,43 +64,55 @@ public String getExpression()
@SuppressWarnings("unchecked")
private String lookupBeanName()
{

// load the available SPI implementations
Iterator<BeanNameResolver> iterator = ServiceLoader.load(BeanNameResolver.class).iterator();

List<Exception> deferred = new ArrayList<>();
while (iterator.hasNext()) {
BeanNameResolver resolver = iterator.next();

// check if this implementation is able to tell the name
String beanName = resolver.getBeanName(clazz);

if (log.isTraceEnabled()) {
log.trace("Service provider [{}] returned [{}] for class [{}]", new Object[] {
resolver.getClass().getSimpleName(), beanName, clazz.getName()
});
}

// the first result is accepted
if (beanName != null) {

// create the complete EL expression including the component
String el = new StringBuilder()
.append(beanName).append('.').append(component)
.toString();
try {
// check if this implementation is able to tell the name
String beanName = resolver.getBeanName(clazz);

if (log.isTraceEnabled()) {
log.debug("Creation of EL expression for component [{}] of class [{}] successful: {}", new Object[] {
component, clazz.getName(), el
log.trace("Service provider [{}] returned [{}] for class [{}]", new Object[] {
resolver.getClass().getSimpleName(), beanName, clazz.getName()
});
}

return el;
// the first result is accepted
if (beanName != null) {

// create the complete EL expression including the component
String el = new StringBuilder()
.append(beanName).append('.').append(component)
.toString();

if (log.isTraceEnabled()) {
log.debug("Creation of EL expression for component [{}] of class [{}] successful: {}", new Object[] {
component, clazz.getName(), el
});
}

return el;
}
}
catch (Exception e) {
log.debug("Failed to resolve bean names using [" + resolver.getClass().getName() + "]", e);
deferred.add(e);
}

}

if (deferred.size() > 1) {
for (Exception e : deferred) {
log.error("Failed to resolve bean names.", e);
}
}
throw new IllegalStateException("Unable to obtain EL name for bean of type [" + clazz.getName()
+ "] from any of the SPI implementations. You should conside placing a @"
+ ELBeanName.class.getSimpleName() + " on the class.");
+ ELBeanName.class.getSimpleName() + " on the class.", (deferred.size() == 1 ? deferred.get(0) : null));

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,41 @@
package org.ocpsoft.rewrite.spring;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.ocpsoft.logging.Logger;
import org.ocpsoft.rewrite.el.spi.BeanNameResolver;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.context.ContextLoader;
import org.springframework.web.context.WebApplicationContext;

/**
* {@link BeanNameResolver} implementation for Spring.
*
* @author Christian Kaltepoth
* @author <a href="mailto:[email protected]">Lincoln Baxter, III</a>
*/
public class SpringBeanNameResolver implements BeanNameResolver
{

private final Logger log = Logger.getLogger(SpringBeanNameResolver.class);

@Autowired
private WebApplicationContext applicationContext;

@Override
public String getBeanName(Class<?> clazz)
{

// try to obtain the WebApplicationContext using ContextLoader
WebApplicationContext context = ContextLoader.getCurrentWebApplicationContext();
if (context == null) {
throw new IllegalStateException("Unable to get current WebApplicationContext");
if (applicationContext == null) {
applicationContext = ContextLoader.getCurrentWebApplicationContext();
if (applicationContext == null) {
throw new IllegalStateException("Unable to get current WebApplicationContext");
}
}

// obtain a map of bean names
Set<String> beanNames = resolveBeanNames(context, clazz);
Set<String> beanNames = resolveBeanNames(applicationContext, clazz);

// no beans of that type, nothing we can do
if (beanNames == null || beanNames.size() == 0) {
Expand Down Expand Up @@ -76,15 +80,14 @@ private Set<String> resolveBeanNames(ListableBeanFactory beanFactory, Class<?> c

final Set<String> result = new HashSet<String>();

Map<String, ?> beanMap = beanFactory.getBeansOfType(clazz);
if (beanMap != null) {
for (String name : beanMap.keySet()) {
String[] names = beanFactory.getBeanNamesForType(clazz);
if (names != null) {
for (String name : names) {
if (name != null && !name.startsWith("scopedTarget.")) {
result.add(name);
}
}
}

return result;

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.expression.spel.support.StandardTypeConverter;
import org.springframework.expression.spel.support.StandardTypeLocator;
import org.springframework.web.context.ContextLoader;
import org.springframework.web.context.WebApplicationContext;

/**
Expand Down Expand Up @@ -136,6 +137,13 @@ public EvaluationContext getEvaluationContext()
// we need a ConfigurableBeanFactory to build the BeanExpressionContext
ConfigurableBeanFactory beanFactory = null;

if (applicationContext == null) {
applicationContext = ContextLoader.getCurrentWebApplicationContext();
if (applicationContext == null) {
throw new IllegalStateException("Unable to get current WebApplicationContext");
}
}

// the WebApplicationContext MAY implement ConfigurableBeanFactory
if (applicationContext instanceof ConfigurableBeanFactory) {
beanFactory = (ConfigurableBeanFactory) applicationContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.stream.Collectors;

import javax.servlet.ServletContext;

import org.ocpsoft.common.spi.ServiceEnricher;
import org.ocpsoft.logging.Logger;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.SpringBeanAutowiringSupport;

/**
Expand All @@ -35,7 +39,13 @@ public class SpringServiceEnricher implements ServiceEnricher
@Override
public <T> void enrich(final T service)
{
SpringBeanAutowiringSupport.processInjectionBasedOnCurrentContext(service);
ServletContext context = SpringServletContextLoader.findCurrentServletContext();
if (context != null) {
SpringBeanAutowiringSupport.processInjectionBasedOnServletContext(service, context);
}
else {
SpringBeanAutowiringSupport.processInjectionBasedOnCurrentContext(service);
}
if (log.isDebugEnabled())
log.debug("Enriched instance of service [" + service.getClass().getName() + "]");

Expand All @@ -44,8 +54,15 @@ public <T> void enrich(final T service)
@Override
public <T> Collection<T> produce(final Class<T> type)
{
// TODO implement
return new ArrayList<T>();
WebApplicationContext webApplicationContext = SpringServletContextLoader.findCurrentApplicationContext();

if(webApplicationContext == null) {
return new ArrayList<>();
}

return webApplicationContext.getBeanProvider(type)
.stream()
.collect(Collectors.toList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,29 @@
import java.util.Set;

import org.ocpsoft.common.spi.ServiceLocator;
import org.springframework.web.context.ContextLoader;
import org.springframework.web.context.WebApplicationContext;

/**
* {@link ServiceLocator} implementation for Spring.
*
* @author Christian Kaltepoth
* @author <a href="mailto:[email protected]">Lincoln Baxter, III</a>
*/
public class SpringServiceLocator implements ServiceLocator
{

@Override
@SuppressWarnings("unchecked")
public <T> Collection<Class<T>> locate(Class<T> clazz)
{
Set<Class<T>> result = new LinkedHashSet<Class<T>>();

// use the Spring API to obtain the WebApplicationContext
WebApplicationContext context = ContextLoader.getCurrentWebApplicationContext();
WebApplicationContext applicationContext = SpringServletContextLoader.findCurrentApplicationContext();

// may be null if Spring hasn't started yet
if (context != null) {
if (applicationContext != null) {

// ask spring about SPI implementations
Map<String, T> beans = context.getBeansOfType(clazz);
Map<String, T> beans = applicationContext.getBeansOfType(clazz);

// add the implementations Class objects to the result set
for (T type : beans.values()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2011 <a href="mailto:[email protected]">Lincoln Baxter, III</a>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.ocpsoft.rewrite.spring;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import javax.servlet.ServletContext;
import javax.servlet.ServletContextEvent;

import org.ocpsoft.rewrite.servlet.spi.ContextListener;
import org.springframework.web.context.ContextLoader;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.context.support.WebApplicationContextUtils;

/**
* Thread-safe {@link ServletContext} loader implementation for Spring.
*
* @author <a href="mailto:[email protected]">Lincoln Baxter, III</a>
*/
public class SpringServletContextLoader implements ContextListener {
private static final Map<ClassLoader, ServletContext> contextMap = new ConcurrentHashMap<>(1);

@Override
public void contextInitialized(ServletContextEvent event)
{
ServletContext servletContext = event.getServletContext();
contextMap.put(Thread.currentThread().getContextClassLoader(), servletContext);
contextMap.put(servletContext.getClassLoader(), servletContext);
}

@Override
public void contextDestroyed(ServletContextEvent event)
{
ServletContext context = event.getServletContext();
contextMap.entrySet().removeIf(entry -> entry.getValue() == context);
}

public static ServletContext findCurrentServletContext()
{
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();

if (requestAttributes instanceof ServletRequestAttributes) {
return ((ServletRequestAttributes) requestAttributes).getRequest().getServletContext();
}

return contextMap.get(Thread.currentThread().getContextClassLoader());
}

public static WebApplicationContext findCurrentApplicationContext()
{
ServletContext currentServletContext = findCurrentServletContext();

if (currentServletContext != null) {
WebApplicationContext webApplicationContext = WebApplicationContextUtils.findWebApplicationContext(currentServletContext);
if (webApplicationContext != null) {
return webApplicationContext;
}
}

return ContextLoader.getCurrentWebApplicationContext();
}

@Override
public int priority()
{
return 0;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
org.ocpsoft.rewrite.spring.SpringServletContextLoader
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
org.ocpsoft.rewrite.spring.SpringServletContextLoader

0 comments on commit 1614e5d

Please sign in to comment.