ThreadSafeExtension.java

package io.github.jonasrutishauser.thread.context.impl;

import static jakarta.interceptor.Interceptor.Priority.LIBRARY_AFTER;

import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;

import io.github.jonasrutishauser.thread.context.ThreadSafeScoped;
import jakarta.enterprise.context.spi.CreationalContext;
import jakarta.enterprise.event.Observes;
import jakarta.enterprise.inject.UnproxyableResolutionException;
import jakarta.enterprise.inject.spi.AfterBeanDiscovery;
import jakarta.enterprise.inject.spi.BeanManager;
import jakarta.enterprise.inject.spi.BeforeShutdown;
import jakarta.enterprise.inject.spi.Extension;
import jakarta.enterprise.inject.spi.InterceptionFactory;
import jakarta.enterprise.inject.spi.ProcessAnnotatedType;
import jakarta.enterprise.inject.spi.ProcessProducer;
import jakarta.enterprise.inject.spi.Producer;
import jakarta.enterprise.inject.spi.WithAnnotations;

public class ThreadSafeExtension implements Extension {

    private final ThreadSafeContext context = new ThreadSafeContext();

    void addInterceptor(@Observes @WithAnnotations(ThreadSafeScoped.class) ProcessAnnotatedType<?> event) {
        if (event.getAnnotatedType().isAnnotationPresent(ThreadSafeScoped.class)) {
            event.configureAnnotatedType().add(ThreadSafeScopedInterceptor.Literal.INSTANCE);
        }
    }

    <T> void addInterceptor(@Observes ProcessProducer<?, T> event, BeanManager beanManager) {
        if (event.getAnnotatedMember().isAnnotationPresent(ThreadSafeScoped.class)) {
            Type baseType = event.getAnnotatedMember().getBaseType();
            Type producedClass = baseType instanceof ParameterizedType type ? type.getRawType() : baseType;
            Producer<T> producer = event.getProducer();
            event.configureProducer().produceWith(ctx -> {
                T bean = producer.produce(ctx);
                try {
                    @SuppressWarnings("unchecked")
                    InterceptionFactory<T> factory = createThreadSafeScopedInterceptionFactory(beanManager, ctx,
                            (Class<T>) producedClass);
                    return factory.createInterceptedInstance(bean);
                } catch (UnproxyableResolutionException | ClassCastException e) {
                    @SuppressWarnings("unchecked")
                    InterceptionFactory<T> factory = createThreadSafeScopedInterceptionFactory(beanManager, ctx,
                            (Class<T>) bean.getClass());
                    return factory.createInterceptedInstance(bean);
                }
            });
        }
    }

    private <T> InterceptionFactory<T> createThreadSafeScopedInterceptionFactory(BeanManager beanManager,
            CreationalContext<T> ctx, Class<T> type) {
        InterceptionFactory<T> factory = beanManager.createInterceptionFactory(ctx, type);
        factory.configure().add(ThreadSafeScopedInterceptor.Literal.INSTANCE);
        return factory;
    }

    void registerContext(@Observes AfterBeanDiscovery event) {
        event.addContext(context);
        try {
            Class<?> shutdownEvent = Class.forName("jakarta.enterprise.event.Shutdown", false,
                    getClass().getClassLoader());
            event.addObserverMethod() //
                    .observedType(shutdownEvent) //
                    .priority(LIBRARY_AFTER + 500) //
                    .notifyWith(e -> context.shutdown());
        } catch (ClassNotFoundException e) {
            // ignore
        }
    }

    void shutdown(@Observes BeforeShutdown event) {
        context.shutdown();
    }

}