ThreadSafeContext.java
package io.github.jonasrutishauser.thread.context.impl;
import java.lang.annotation.Annotation;
import java.util.HashMap;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import io.github.jonasrutishauser.thread.context.ThreadSafeScoped;
import jakarta.enterprise.context.ContextNotActiveException;
import jakarta.enterprise.context.spi.AlterableContext;
import jakarta.enterprise.context.spi.Contextual;
import jakarta.enterprise.context.spi.CreationalContext;
public class ThreadSafeContext implements AlterableContext {
private final ConcurrentMap<Contextual<?>, Queue<ContextualInstance<?>>> beans = new ConcurrentHashMap<>();
private final ThreadLocal<Map<Contextual<?>, ContextualInstance<?>>> boundBeans = ThreadLocal.withInitial(HashMap::new);
private final ConcurrentMap<ContextualInstance<?>, Integer> beanUsage = new ConcurrentHashMap<>();
private final AtomicBoolean active = new AtomicBoolean(true);
void incrementUsage(Contextual<?> contextual) {
beanUsage.compute(boundBeans.get().get(contextual), (key, value) -> Integer.valueOf(value.intValue() + 1));
}
void decrementUsage(Contextual<?> contextual) {
Map<Contextual<?>, ContextualInstance<?>> instances = boundBeans.get();
Integer usage = beanUsage.compute(instances.get(contextual), (key, value) -> Integer.valueOf(value.intValue() - 1));
if (usage.intValue() <= 0) {
ContextualInstance<?> contextualInstance = instances.remove(contextual);
beanUsage.remove(contextualInstance);
beans.computeIfAbsent(contextual, key -> new ConcurrentLinkedQueue<>()).add(contextualInstance);
if (instances.isEmpty()) {
boundBeans.remove();
}
}
}
@Override
public Class<? extends Annotation> getScope() {
return ThreadSafeScoped.class;
}
@Override
public <T> T get(Contextual<T> contextual, CreationalContext<T> creationalContext) {
if (!isActive()) {
throw new ContextNotActiveException();
}
Map<Contextual<?>, ContextualInstance<?>> instances = boundBeans.get();
@SuppressWarnings("unchecked")
ContextualInstance<T> contextualInstance = (ContextualInstance<T>) instances.computeIfAbsent(contextual, key -> {
Queue<ContextualInstance<?>> queue = beans.get(key);
ContextualInstance<?> instance = queue == null ? null : queue.poll();
return instance == null ? new ContextualInstance<>(contextual, creationalContext) : instance;
});
beanUsage.putIfAbsent(contextualInstance, Integer.valueOf(0));
return contextualInstance.getInstance();
}
@Override
public <T> T get(Contextual<T> contextual) {
if (!isActive()) {
throw new ContextNotActiveException();
}
Map<Contextual<?>, ContextualInstance<?>> instances = boundBeans.get();
if (instances.isEmpty()) {
boundBeans.remove();
return null;
}
@SuppressWarnings("unchecked")
ContextualInstance<T> contextualInstance = (ContextualInstance<T>) instances.get(contextual);
return contextualInstance == null ? null : contextualInstance.getInstance();
}
@Override
public boolean isActive() {
return active.get();
}
@Override
public void destroy(Contextual<?> contextual) {
ContextualInstance<?> instance;
while ((instance = beans.get(contextual).poll()) != null) {
instance.destroy();
}
beans.computeIfPresent(contextual, (key, queue) -> queue.isEmpty() ? null : queue);
}
void shutdown() {
if (active.compareAndSet(true, false)) {
beans.forEach((key, queue) -> queue.forEach(ContextualInstance::destroy));
beans.clear();
// there shouldn't be any boundBeans therefore the next shouldn't do anything
beanUsage.keySet().forEach(ContextualInstance::destroy);
beanUsage.clear();
}
}
}