001package com.basilv.examples.hibernate;
002
003import java.io.Serializable;
004
005import org.hibernate.EmptyInterceptor;
006import org.hibernate.type.Type;
007
008/**
009 * AuditInterceptor is used to populate the audit fields on onUpdate and onSave events.
010 *
011 */
012public class AuditInterceptor extends EmptyInterceptor
013{
014
015  private static ThreadLocal<String> userPerThread = new ThreadLocal<String>();
016
017  /**
018   * Store the user for the current thread.
019   * @param user Cannot be null or empty.
020   */
021  public static void setUserForCurrentThread(String user) {
022    userPerThread.set(user);
023  }
024
025  /**
026   * Get the user for the current thread.
027   * (Used primarily for testing).
028   * @return the current user.
029   */
030  public static String getUserForCurrentThread() {
031    return userPerThread.get();
032  }
033
034  @Override public boolean onFlushDirty(Object entity,
035    Serializable id, Object[] currentState,
036    Object[] previousState, String[] propertyNames,
037    Type[] types) {
038
039    boolean changed = false;
040
041    if (entity instanceof Auditable) {
042      changed = updateAuditable(currentState, propertyNames);
043    }
044    return changed;
045  }
046
047  @Override public boolean onSave(Object entity,
048    Serializable id, Object[] currentState,
049    String[] propertyNames, Type[] types) {
050
051    boolean changed = false;
052
053    if (entity instanceof Auditable) {
054      changed = updateAuditable(currentState, propertyNames);
055    }
056    return changed;
057
058  }
059
060  private boolean updateAuditable(Object[] currentState,
061    String[] propertyNames) {
062    boolean changed = false;
063    for (int i = 0; i < propertyNames.length; i++) {
064      if ("createUserId".equals(propertyNames[i])) {
065        if (currentState[i] == null) {
066          currentState[i] = userPerThread.get();
067          changed = true;
068        }
069      }
070      if ("updateUserId".equals(propertyNames[i])) {
071        currentState[i] = userPerThread.get();
072        changed = true;
073      }
074    }
075    return changed;
076  }
077
078}