package org.jee6unit; import org.jee6unit.annotations.RollbackAfterTest; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.rules.TestWatchman; import org.junit.runners.model.FrameworkMethod; import javax.ejb.TransactionAttribute; import javax.ejb.TransactionAttributeType; import javax.ejb.embeddable.EJBContainer; import javax.naming.NamingException; import javax.persistence.Entity; import javax.persistence.Table; import javax.sql.DataSource; import javax.transaction.UserTransaction; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.Properties; import java.util.logging.Logger; /** * Abstract base class for test classes. Sets up a container and manages transactions around tests. * Also provides utility methods to simplify testing. */ public abstract class AbstractJPATest { private final Logger log = Logger.getLogger(AbstractJPATest.class.getName()); protected DataSource ds; protected static EJBContainer ec; protected UserTransaction t; @BeforeClass /** * Sets up the EJBContainer. */ public static void setupContainer() { System.out.println("Setting up container"); Properties properties = new Properties(); properties.put(EJBContainer.APP_NAME, "myapp"); ec = EJBContainer.createEJBContainer(properties); } @Rule /** * jUnit rule that starts and ends transactions before and after tests. Transactions are rolled back by default, * unless the @RollbackAfterTest(false) annotation is on the test method. */ public TestWatchman transactional = new TestWatchman() { @Override public void starting(FrameworkMethod method) { TransactionAttribute ta = method.getAnnotation(TransactionAttribute.class); if (ta == null || ta.value() != TransactionAttributeType.NOT_SUPPORTED) { log.info("Beginning transaction"); try { t = (UserTransaction) ec.getContext().lookup("java:comp/UserTransaction"); t.begin(); } catch (Exception ex) { throw new RuntimeException(ex); } } else { log.info("TransactionAttributeType.NOT_SUPPORTED found on test, skip starting transaction."); } } @Override public void finished(FrameworkMethod method) { TransactionAttribute ta = method.getAnnotation(TransactionAttribute.class); if (ta == null || ta.value() != TransactionAttributeType.NOT_SUPPORTED) { try { RollbackAfterTest rollback = method.getAnnotation(RollbackAfterTest.class); if (rollback == null || rollback.value()) { log.info("Rolling back transaction"); t.rollback(); } else { log.info("Committing transaction"); t.commit(); } } catch (Exception ex) { throw new RuntimeException(ex); } } } }; /** * Lookup the DataSource specified by the getDatasourceName method and make it available in the ds instance field. * * @throws NamingException When the datasource can't be found. */ @Before public final void setupDataSource() throws NamingException { ds = (DataSource) ec.getContext().lookup(getDatasourceName()); } /** * Lookup a session bean in the global JNDI namespace using the bean's class name. This only works for no-interface * session beans. * * @param clazz The implementation class of a no-interface session bean. * @return The bean instance. * @throws NamingException If the bean cannot be found. */ @SuppressWarnings("unchecked") public final T lookup(Class clazz) throws NamingException { Object bean = ec.getContext().lookup("java:global/myapp/classes/" + clazz.getSimpleName() + "!" + clazz.getName()); return (T) bean; } /** * Lookup a session bean in the global JNDI namespace using the bean's interface name and bean name. * This method should be used for session beans with a business interface. * session beans. * * @param clazz The implementation class of a no-interface session bean. * @param beanName The Session Bean's name. * @return The bean instance. * @throws NamingException If the bean cannot be found. */ @SuppressWarnings("unchecked") public final T lookup(String beanName, Class clazz) throws NamingException { Object bean = ec.getContext().lookup("java:global/myapp/classes/" + clazz.getSimpleName() + "!" + beanName); return (T) bean; } /** * Count the rows in a table using JDBC on the specified DataSource. * * @param table Name of the table * @return The number of rows in the table. */ public final int countRowsInTable(String table) { Connection con = null; int nrOfContacts = 0; ResultSet rs = null; try { con = ds.getConnection(); String stmt = "select count(*) nr from " + table; log.fine("Executing count query: '" + stmt + "'"); PreparedStatement ps = con.prepareStatement(stmt); rs = ps.executeQuery(); rs.next(); nrOfContacts = rs.getInt("nr"); } catch (SQLException ex) { throw new RuntimeException(ex); } finally { try { if (rs != null) rs.close(); if (con != null) con.close(); } catch (SQLException ex) { log.severe("Exception while closing connection: " + ex.getMessage()); } } return nrOfContacts; } /** * Count the rows in the table for a specified JPA entity. The table name is resolved in the following order:
* 1- @Table(name = ) if specified.
* 2- @Entity(name =) if specified.
* 3- The simple class name.
* * @param clazz The JPA entity. * @return The number of rows in the entity table. * @throws IllegalArgumentException If the specified class is not an entity. */ public final int countEntities(Class clazz) { Entity entityAnnotation = clazz.getAnnotation(Entity.class); if (entityAnnotation == null) { throw new IllegalArgumentException(clazz.getName() + " is not a JPA entity (@Entity is missing)"); } String tableName = resolveTableNameFromEntity(clazz, entityAnnotation); return countRowsInTable(tableName); } private String resolveTableNameFromEntity(Class clazz, Entity entityAnnotation) { String tableName = null; if (entityAnnotation.name() != null && entityAnnotation.name().length() > 0) { tableName = entityAnnotation.name(); } Table tableAnnotation = clazz.getAnnotation(Table.class); if (tableAnnotation != null && tableAnnotation.name() != null && tableAnnotation.name().length() > 0) { tableName = tableAnnotation.name(); } if (tableName == null) { tableName = clazz.getSimpleName(); } return tableName; } /** * Subclasses must override this method. * * @return The JNDI name of the DataSource to use for JDBC code. */ protected abstract String getDatasourceName(); }