/*
  Copyright (C) 2005-2012 Steven L. Scott

  This library is free software; you can redistribute it and/or
  modify it under the terms of the GNU Lesser General Public
  License as published by the Free Software Foundation; either
  version 2.1 of the License, or (at your option) any later version.

  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public
  License along with this library; if not, write to the Free Software
  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
*/

#include <Models/PointProcess/PoissonClusterProcess.hpp>
#include <Models/PointProcess/HomogeneousPoissonProcess.hpp>
#include <cpputil/report_error.hpp>
#include <distributions.hpp>
#include <cpputil/math_utils.hpp>
#include <cpputil/lse.hpp>

namespace BOOM{

  namespace{
    template <class T>
        bool contains(const std::vector<T*> &vec, const T* target){
      return std::find(vec.begin(), vec.end(), target) != vec.end();
    }

    // Determine the state of the process given that the previous
    // state was 'previous_state' and the event was generated by the
    // named process.
    // Args:
    //   previous_state: An integer 0, 1, 2, or 3 indicating the previous
    //     state of the model.
    //   label: a string giving the name of the process that produced
    //     the current event.
    // Returns:
    //   An integer 0, 1, or 2 indicating the state after the current event.
    int determine_state(int previous_state, const string &label){
      if(previous_state == 0){
        if(label == "background") return 0;
        else if(label == "primary_birth") return 1;
      }else if(previous_state == 1){
        if(label == "background"
           || label == "primary_traffic"
           || label == "secondary_traffic") return 1;
        else if(label == "primary_death") return 2;
        else if(label == "secondary_death") return 3;
      }else if(previous_state == 2){
        if(label == "secondary_death") return 0;
        else if(label == "background" || label == "secondary_traffic") return 2;
        else if(label == "primary_birth" ) return 1;
      }else if(previous_state == 3){
        if(label == "primary_traffic") return 1;
        else if(label == "primary_death") return 2;
        else if(label == "background") return 3;
      }
      ostringstream err;
      err << "could not determine the next state, with initial state = "
          << previous_state << " and next event produced by " << label << endl;
      report_error(err.str());
      return 0;
    }

    double normalize_filter(Matrix &P){
      double max_log = max(P);
      P -= max_log;
      P.exp();
      double total = sum(P);
      P /= total;
      return max_log + log(total);
    }

    // For looking up responsible processes in the responsible process
    // map.
    inline std::pair<int, int> index(int i1, int i2) {
      return std::pair<int, int>(i1, i2);
    }

  }  // unnamed namespace

  PoissonClusterProcess::PoissonClusterProcess(
      const PoissonClusterComponentProcesses &components)
      : background_(components.background),
        primary_birth_(components.primary_birth),
        primary_death_(components.primary_death),
        primary_traffic_(components.primary_traffic),
        secondary_traffic_(components.secondary_traffic),
        secondary_death_(components.secondary_death),
        primary_mark_model_(0),
        secondary_mark_model_(0)
  {
    initialize();
  }

  PoissonClusterProcess::PoissonClusterProcess(
      const PoissonClusterComponentProcesses &components,
      Ptr<MixtureComponent> primary_mark_model,
      Ptr<MixtureComponent> secondary_mark_model)
      : background_(components.background),
        primary_birth_(components.primary_birth),
        primary_death_(components.primary_death),
        primary_traffic_(components.primary_traffic),
        secondary_traffic_(components.secondary_traffic),
        secondary_death_(components.secondary_death),
        primary_mark_model_(primary_mark_model),
        secondary_mark_model_(secondary_mark_model)
  {
    initialize();
  }

  //----------------------------------------------------------------------
  PoissonClusterProcess::PoissonClusterProcess(const PoissonClusterProcess &rhs)
      : Model(rhs),
        ParamPolicy(rhs),
        DataPolicy(rhs),
        PriorPolicy(rhs),
        background_(rhs.background_->clone()),
        primary_birth_(rhs.primary_birth_->clone()),
        primary_death_(rhs.primary_death_->clone()),
        primary_traffic_(rhs.primary_traffic_->clone()),
        secondary_traffic_(rhs.secondary_traffic_->clone()),
        secondary_death_(rhs.secondary_death_->clone()),
        primary_mark_model_(0),
        secondary_mark_model_(0)
  {
    if(!!rhs.primary_mark_model_) {
      primary_mark_model_.reset(rhs.primary_mark_model_->clone());
      secondary_mark_model_.reset(rhs.secondary_mark_model_->clone());
    }
    initialize();
  }

  //----------------------------------------------------------------------
  PoissonClusterProcess * PoissonClusterProcess::clone()const{
    return new PoissonClusterProcess(*this);}

  //----------------------------------------------------------------------
  void PoissonClusterProcess::set_mark_models(
      Ptr<MixtureComponent> primary_mark_model,
      Ptr<MixtureComponent> secondary_mark_model){
    primary_mark_model_ = primary_mark_model;
    secondary_mark_model_ = secondary_mark_model;
    fill_state_maps();
    register_models_with_param_policy();
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::clear_client_data(){
    background_->clear_data();
    primary_birth_->clear_data();
    primary_death_->clear_data();
    primary_traffic_->clear_data();
    secondary_death_->clear_data();
    secondary_traffic_->clear_data();
    if(!!primary_mark_model_) primary_mark_model_->clear_data();
    if(!!secondary_mark_model_) secondary_mark_model_->clear_data();

    for(int i = 0; i < probability_of_activity_.size(); ++i) {
      probability_of_activity_[i] = 0;
      probability_of_responsibility_[i] = 0;
    }
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::impute_latent_data(RNG &rng){
    const std::vector<Ptr<PointProcess> > &data(dat());
    last_loglike_ = 0;
    clear_client_data();
    std::vector<int> empty_source;

    for(int i = 0; i < data.size(); ++i){
      const PointProcess &process(*data[i]);
      SourceMap::iterator it = known_source_store_.find(data[i]);
      const std::vector<int> &source(
          it == known_source_store_.end() ? empty_source : it->second);
      last_loglike_ += filter(process, source);
      backward_sampling(rng,
                        process,
                        source,
                        probability_of_activity_[i],
                        probability_of_responsibility_[i]);
    }
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::sample_client_posterior(){
    background_->sample_posterior();
    primary_birth_->sample_posterior();
    primary_death_->sample_posterior();
    primary_traffic_->sample_posterior();
    secondary_traffic_->sample_posterior();
    secondary_death_->sample_posterior();
    if(!!primary_mark_model_) primary_mark_model_->sample_posterior();
    if(!!secondary_mark_model_) secondary_mark_model_->sample_posterior();
  }

  //----------------------------------------------------------------------
  double PoissonClusterProcess::logpri()const{
    double ans = background_->logpri()
        + primary_birth_->logpri()
        + primary_death_->logpri()
        + primary_traffic_->logpri()
        + secondary_traffic_->logpri()
        + secondary_death_->logpri();
    if(!!primary_mark_model_){
      ans += primary_mark_model_->logpri() + secondary_mark_model_->logpri();
    }
    return ans;
  }

  //----------------------------------------------------------------------
  double PoissonClusterProcess::conditional_event_loglikelihood(
      int r, int s, const PointProcessEvent &event,
      double logp_primary, double logp_secondary, int source)const{
    std::vector<const PoissonProcess *> responsible_processes =
        get_responsible_processes(r, s, source);
    int n = responsible_processes.size();
    if(n==0) return negative_infinity();
    const DateTime &t(event.timestamp());

    if(n==1){
      const PoissonProcess *process = responsible_processes[0];
      return log(process->event_rate(t))
          + (primary(process) ? logp_primary : logp_secondary);
    }else{
      Vector wsp(n);
      for(int i = 0; i < n; ++i){
        const PoissonProcess *process = responsible_processes[i];
        wsp[i] = log(process->event_rate(t)) +
            (primary(process) ? logp_primary : logp_secondary);
      }
      return lse(wsp);
    }

  }

  //----------------------------------------------------------------------
  double PoissonClusterProcess::conditional_cumulative_hazard(
      const DateTime &t0, const DateTime &t1, int r)const{
    double ans = 0;
    const std::vector<PoissonProcess *> &active(active_processes_[r]);
    for(int i = 0; i < active.size(); ++i){
      ans += active[i]->expected_number_of_events(t0, t1);
    }
    return ans;
  }

  //----------------------------------------------------------------------
  int PoissonClusterProcess::number_of_hmm_states()const{
    return activity_state_.size();
  }

  //----------------------------------------------------------------------
  double PoissonClusterProcess::filter(
      const PointProcess &data, const std::vector<int> &source){
    // The filter is initialized at the beginning of the observation
    // window, which is <= the time of the first event.
    double loglike = initialize_filter(data);
    for(int t = 0; t < data.number_of_events(); ++t){
      loglike += fwd_1(data, t, source.empty() ? -1 : source[t]);
    }
    return loglike;
  }

  //----------------------------------------------------------------------
  // Determine the a prior state of the filter at the beginning of the
  // observation window.  Make sure everything is sized correctly.
  double PoissonClusterProcess::initialize_filter(const PointProcess &data){
    int S = number_of_hmm_states();
    int n = data.number_of_events();
    if(n==0) return 0;
    double loglike = 0;
    if(initialization_strategy_ == UniformInitialState){
      pi0_ = 1.0 / S;
    }else if(initialization_strategy_ == StationaryDistribution){
      pi0_ = 1.0 / S;
    }else{
      report_error("unknown initialization_strategy");
    }

    while(filter_.size() < data.number_of_events()){
      Matrix P(S, S);
      filter_.push_back(P);
    }

    if(nrow(filter_[0]) < S){
      for(int i = 0; i < filter_.size(); ++i){
        filter_[i].resize(S, S);
      }
    }
    return loglike;
  }

  //----------------------------------------------------------------------
  // return log(p(events[t] | events[0..t-1]).
  double PoissonClusterProcess::fwd_1(const PointProcess &data,
                                      int t,
                                      int source){
    Matrix &P(filter_[t]);
    P = negative_infinity();
    int S = number_of_hmm_states();
    const DateTime & t0(t==0 ?
                        data.window_begin() :
                        data.event(t-1).timestamp());
    const PointProcessEvent & event(data.event(t));
    const DateTime & t1(event.timestamp());
    double logp_primary = 0;
    double logp_secondary = 0;
    if(!!primary_mark_model_ && event.has_mark()){
      logp_primary = primary_mark_model_->pdf(event.mark(), true);
      logp_secondary = secondary_mark_model_->pdf(event.mark(), true);
    }
    // TODO(stevescott):  remove comments
    // if(source == 1){
    //   logp_secondary == negative_infinity();
    // }
    // if(source == 0){
    //   logp_primary == negative_infinity();
    // }
    for(int r = 0; r < S; ++r){
      const std::vector<int> &target(legal_target_transitions_[r]);
      double log_prior_hazard = log(pi0_[r]) -
          conditional_cumulative_hazard(t0, t1, r);
      for(int ss = 0; ss < target.size(); ++ss){
        int s = target[ss];
        P(r, s) = log_prior_hazard +
            conditional_event_loglikelihood(
                r, s, event,logp_primary, logp_secondary, source);
      }
    }

    double loglike = normalize_filter(P);
    pi0_ = one_ * P;  // pi0_ is now the marginal of time t;
    return loglike;
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::backward_sampling(
      RNG &rng,
      const PointProcess &data,
      const std::vector<int> &source,
      Matrix & probability_of_activity,
      Matrix & probability_of_responsibility){

    int n = data.number_of_events();
    if(n == 0){
      probability_of_responsibility = 0;
      probability_of_activity = 0;
      return;
    }
    // Sample the final hmm state
    int current_state = rmulti_mt(rng, pi0_);
    record_activity(probability_of_activity.col(n-1), current_state);

    for(int t = data.number_of_events()-1; t >= 0; --t){
      // Draw the state of the hidden Markov chain at time t, given
      // state at time t+1.
      int previous_state = -1;
      try {
        previous_state = draw_previous_state(rng, t, current_state);
      } catch(std::exception &e) {
        ostringstream err;
        err << e.what() << endl
            << "Error occurred in PoissonClusterProcess::backward_sampling"
            << " at time " << t << " (counting from 0)."
            << endl
            << "Current state at time t is " << current_state
            << "." << endl;

        if(t > 0) {
          err << "source[t-1] = " << source[t-1]
              << " filter[t-1] = " << endl
              << filter_[t-1] << endl;
        }
        err << "source[t] = " << source[t]
            << " filter[t] = " << endl
            << filter_[t] << endl;
        if (t < data.number_of_events()-1) {
          err << "source[t+1] = " << source[t+1]
              << " filter[t+1] = " << endl
              << filter_[t+1] << endl;
        }
        report_error(err.str());
      }

      update_exposure_time(data, t, previous_state, current_state);
      int src = source.empty() ? -1 : source[t];
      PoissonProcess * responsible_process = assign_responsibility(
          rng, data, t, previous_state, current_state, src);
      attribute_event(data.event(t), responsible_process);

      if(t > 0){
        record_activity(probability_of_activity.col(t), previous_state);
      }
      record_responsibility(probability_of_responsibility.col(t),
                            responsible_process);
      current_state = previous_state;
    }
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::backward_smoothing(
      const PointProcess &data,
      const std::vector<int> &source,
      Matrix &probability_of_activity,
      Matrix &probability_of_responsibility){
    int n = data.number_of_events();
    if(n==0){
      probability_of_responsibility = 0;
      probability_of_activity = 0;
      return;
    }
    if(ncol(probability_of_activity) != n
       || ncol(probability_of_responsibility) != n) {
      report_error("wrong size probability matrices in "
                   "PoissonClusterProcess::backward_smoothing");
    }
    bool have_source = !source.empty();
    for(int t = n-1; t>=0; --t){
      Matrix &transition_density(filter_[t]);
      record_activity_distribution(
          probability_of_activity.col(t),
          transition_density);
      int src = have_source ? source[t] : -1;
      record_responsibility_distribution(
          probability_of_responsibility.col(t),
          transition_density,
          data.event(t),
          src);
      backward_smoothing_step(transition_density, pi0_);
    }
  }

  void PoissonClusterProcess::backward_smoothing_step(
      Matrix &transition_density, Vector &marginal){
    wsp_ = one_ * transition_density;
    wsp_ *= marginal;
    for(int r = 0; r < nrow(transition_density); ++r){
      transition_density.row(r) *= wsp_;
    }
    marginal = transition_density * one_;
  }

  //----------------------------------------------------------------------
  // Draws the t-1 -> t transition given the state at t.
  int PoissonClusterProcess::draw_previous_state(
      RNG &rng, int t, int current_state){
    return rmulti_mt(rng, filter_[t].col(current_state));
  }

  //----------------------------------------------------------------------
  bool PoissonClusterProcess::primary(
      const PoissonProcess *process)const{
    return process == primary_traffic_.get()
        || process == primary_birth_.get()
        || process == primary_death_.get();
  }

  bool PoissonClusterProcess::secondary(
      const PoissonProcess *process)const{
    return process == background_.get()
        || process == secondary_traffic_.get()
        || process == secondary_death_.get();
  }

  //----------------------------------------------------------------------
  std::vector<PoissonProcess *>
  PoissonClusterProcess::get_responsible_processes(int r, int s, int source){
    ResponsibleProcessMap::iterator it =
        responsible_process_map_.find(index(r, s));
    if(it == responsible_process_map_.end()){
      return std::vector<PoissonProcess *>(0);
    }
    return subset_matching_source(it->second, source);
  }

  //----------------------------------------------------------------------
  std::vector<const PoissonProcess *>
  PoissonClusterProcess::get_responsible_processes(int r, int s, int source)const{
    ResponsibleProcessMap::const_iterator it =
        responsible_process_map_.find(index(r, s));
    if(it == responsible_process_map_.end()){
      return std::vector<const PoissonProcess *>(0);
    }
    return subset_matching_source(it->second, source);
  }

  //----------------------------------------------------------------------
  std::vector<PoissonProcess *>
  PoissonClusterProcess::subset_matching_source(
      std::vector<PoissonProcess *> &candidates, int source){
    if(source < 0) return candidates;
    std::vector<PoissonProcess *> ans;
    ans.reserve(candidates.size());
    for (int i = 0; i < candidates.size(); ++i) {
      PoissonProcess *process = candidates[i];
      if(matches_source(process, source)){
        ans.push_back(process);
      }
    }
    return ans;
  }

  //----------------------------------------------------------------------
  std::vector<const PoissonProcess *>
  PoissonClusterProcess::subset_matching_source(
      const std::vector<PoissonProcess *> & candidates, int source)const{
    if(source < 0){
      return std::vector<const PoissonProcess *>(candidates.begin(),
                                                 candidates.end());
    }
    std::vector<const PoissonProcess *> ans;
    ans.reserve(candidates.size());
    for (int i = 0; i < candidates.size(); ++i) {
      const PoissonProcess *process = candidates[i];
      if(matches_source(process, source)){
        ans.push_back(process);
      }
    }
    return ans;
  }

  //----------------------------------------------------------------------
  bool PoissonClusterProcess::matches_source(
      const PoissonProcess *process, int source)const{
    if(source < 0) return true;
    if(source == 1){
      return primary(process);
    }else if(source == 0){
      return secondary(process);
    }
    report_error("unknown process, source combination in "
                 "PoissonClusterProcess::matches_source");
    return false;
  }

  //----------------------------------------------------------------------
  // Returns the process responsible for the transition at time t,
  // given the value of the transition.  If source < 0 (the expected
  // state in many cases, the source for this observation is missing.
  PoissonProcess * PoissonClusterProcess::assign_responsibility(
      RNG &rng, const PointProcess &data, int t,
      int previous_state, int current_state, int source){

    std::vector<PoissonProcess *> candidates(
        get_responsible_processes(previous_state, current_state, source));
    int n = candidates.size();
    if(n == 0){
      std::ostringstream err;
      err << "trouble in PoissonClusterProcess::assign_responsibility: "
          << "no potential candidates in transition from state "
          << previous_state << " to " << current_state
          << " with source = " << source << "." << endl;
      report_error(err.str());
    }else if(n == 1){
      return candidates[0];
    }
    // If n != 0 and n != 1 then there are several processes that
    // could have produced the event.  Sample one of them from the
    // full conditional distribution.
    Vector wsp(n);
    const PointProcessEvent &event(data.event(t));
    const DateTime &time(event.timestamp());
    double logp_primary = 0;
    double logp_secondary = 0;
    if(event.has_mark() && !!primary_mark_model_){
      logp_primary = primary_mark_model_->pdf(event.mark(), true);
      logp_secondary = secondary_mark_model_->pdf(event.mark(), true);
    }
    for(int i = 0; i < n; ++i){
      PoissonProcess *process = candidates[i];
      wsp[i] = log(process->event_rate(time)) +
          (primary(process) ? logp_primary : logp_secondary);
    }
    wsp.normalize_logprob();
    int index = rmulti_mt(rng, wsp);
    return candidates[index];
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::attribute_event(
      const PointProcessEvent &event,
      PoissonProcess * responsible_process){
    responsible_process->add_event(event.timestamp());
    if(event.has_mark() && !! primary_mark_model_){
      mark_model(responsible_process)->add_data(event.mark_ptr());
    }
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::update_exposure_time(
      const PointProcess &data,
      int current_time,
      int previous_state,
      int current_state){
    std::vector<PoissonProcess *> &running(active_processes_[previous_state]);
    const DateTime &then(current_time > 0 ?
                         data.event(current_time - 1).timestamp() :
                         data.window_begin());
    const DateTime &now(data.event(current_time).timestamp());
    for(int process = 0; process < running.size(); ++process){
      running[process]->add_exposure_window(then, now);
    }
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::clear_data(){
    DataPolicy::clear_data();
    clear_client_data();
    probability_of_responsibility_.clear();
    probability_of_activity_.clear();
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::add_data(Ptr<Data> dp){
    Ptr<PointProcess> d = DAT(dp);
    add_data(d);
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::add_data(Ptr<PointProcess> dp){
    int n = dp->number_of_events();
    int nproc = 3;
    Matrix activity(nproc, n, 0.0);
    Matrix responsibility(nproc, n, 0.0);
    probability_of_activity_.push_back(activity);
    probability_of_responsibility_.push_back(responsibility);
    DataPolicy::add_data(dp);
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::add_supervised_data(
      Ptr<PointProcess> dp, const std::vector<int> &source) {
    add_data(dp);
    if(dp->number_of_events() != source.size()){
      ostringstream err;
      err << "Error in PoissonClusterProcess::add_supervised_data." << endl
          << "The size of source (" << source.size() << ") does not match the"
          << " number of events in the corresponding point process ("
          << dp->number_of_events() << ")";
      report_error(err.str());
    }
    for(int i = 0; i < source.size(); ++i){
      if(source[i] > 1){
        ostringstream err;
        err << "Error in PoissonClusterProcess::add_supervised_data." << endl
            << "source[" << i << "] = " << source[i] << endl
            << "legal values are " << endl
            << "  0 (background or secondary process)" << endl
            << "  1 (primary process)" << endl
            << " < 0 (source unknown)" << endl;
        report_error(err.str());
      }
    }

    known_source_store_[dp] = source;
  }

  //----------------------------------------------------------------------
  PointProcess PoissonClusterProcess::simulate(
      const DateTime &t0, const DateTime &t1,
      std::function<Data*()> primary_event_simulator,
      std::function<Data*()> secondary_event_simulator)const{
    std::vector<PoissonProcess *> all_processes;
    all_processes.push_back(background_.get());
    all_processes.push_back(primary_birth_.get());
    all_processes.push_back(primary_traffic_.get());
    all_processes.push_back(primary_death_.get());
    all_processes.push_back(secondary_death_.get());
    all_processes.push_back(secondary_traffic_.get());

    std::map<PoissonProcess *, string> labels;
    labels[background_.get()] = "background";
    labels[primary_birth_.get()] = "primary_birth";
    labels[primary_traffic_.get()] = "primary_traffic";
    labels[primary_death_.get()] = "primary_death";
    labels[secondary_death_.get()] = "secondary_death";
    labels[secondary_traffic_.get()] = "secondary_traffic";

    typedef std::map<PointProcessEvent, PoissonProcess *> EventMap;
    EventMap event_map;
    // The event_map keeps track of which process generated each event.

    for(int p = 0; p < all_processes.size(); ++p){
      PoissonProcess *process = all_processes[p];
      std::function<Data*()>* mark_generator = &secondary_event_simulator;
      if(primary(process)) mark_generator = &primary_event_simulator;
      PointProcess data = process->simulate(t0, t1, *mark_generator);

      for(int i = 0; i < data.number_of_events(); ++i){
        event_map[data.event(i)] = process;
      }
    }

    int state = 0;
    PointProcess ans(t0, t1);
    for(EventMap::iterator it = event_map.begin();
        it != event_map.end(); ++it){
      // The key to the event_map is events, which the map keeps
      // sorted in time order.
      const std::vector<PoissonProcess *> &active(active_processes_[state]);
      if(contains(active, it->second)){
        // If the process that generated the event is active at the
        // time of the event, then we count the event.  Otherwise we
        // discard it and move on to the next one.
        ans.add_event(it->first);
        // Now that we've seen the state
        state = determine_state(state, labels[it->second]);
      }
    }

    return ans;
  }
  //----------------------------------------------------------------------
  const std::vector<Mat> &
  PoissonClusterProcess::probability_of_activity()const{
    return probability_of_activity_;}
  //----------------------------------------------------------------------
  const std::vector<Mat> &
  PoissonClusterProcess::probability_of_responsibility()const{
    return probability_of_responsibility_;}

  //----------------------------------------------------------------------
  void PoissonClusterProcess::record_activity(VectorView probs, int state) {
    const Selector &active(activity_state_[state]);
    for(int i = 0; i < active.nvars(); ++i){
      int I = active.indx(i);
      ++probs[I];
    }
  }

  void PoissonClusterProcess::record_activity_distribution(
      VectorView probs,
      const Matrix &transition_distribution){
    // Compute the probability distribution over which states were
    // active at time t-1.
    wsp_ = transition_distribution * one_;
    for(int r = 0; r < wsp_.size(); ++r){
      const Selector &active(activity_state_[r]);
      for(int i = 0; i < active.nvars(); ++i) {
        int I = active.indx(i);
        probs[I] += wsp_[r];
      }
    }
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::record_responsibility(
      VectorView probs, PoissonProcess * process){
    if(process == background_.get()){
      ++probs[0];
    }else if(primary(process)){
      ++probs[1];
    }else{
      ++probs[2];
    }
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::record_responsibility_distribution(
      VectorView probs,
      const Matrix &transition_distribution,
      const PointProcessEvent &event,
      int source){
    int S = nrow(transition_distribution);
    double logp_primary = 0;
    double logp_secondary = 0;
    if(event.has_mark() && !!primary_mark_model_){
      if (source == 0) {
        logp_primary = negative_infinity();
      } else {
        logp_primary = primary_mark_model_->pdf(event.mark(), true);
      }

      if (source == 1) {
        logp_secondary = negative_infinity();
      } else {
        logp_secondary = secondary_mark_model_->pdf(event.mark(), true);
      }
    }
    for(int r = 0; r < S; ++r){
      for(int s = 0; s < S; ++s){
        allocate_probability(r,
                             s,
                             probs,
                             transition_distribution(r, s),
                             logp_primary,
                             logp_secondary,
                             event.timestamp(),
                             source);
      }
    }
  }

  //----------------------------------------------------------------------
  bool PoissonClusterProcess::legal_transition(int r, int s)const{
    if (s < 0 || s > 3) {
      ostringstream err;
      err << "Illegal value of s (" << s
          <<  ") in PoissonClusterProcess::legal_transition." <<endl
          << "Legal values are 0, 1, 2, 3." << endl;
      report_error(err.str());
    }
    if (r < 0 || r > 3) {
      ostringstream err;
      err << "Illegal value of r (" << r
          <<  ") in PoissonClusterProcess::legal_transition." <<endl
          << "Legal values are 0, 1, 2, 3." << endl;
      report_error(err.str());
    }

    if (r == 0) return s <= 1;
    else if (r == 1 || r == 3) return s > 0;
    else if (r == 2) return s < 3;
    return false;
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::allocate_probability(
      int previous_state,
      int current_state,
      VectorView process_probs,
      double transition_probability,
      double logp_primary,
      double logp_secondary,
      const DateTime &timestamp,
      int source){
    if(!legal_transition(previous_state, current_state)) return;
    bool primary = true;
    bool secondary_or_background = false;
    if(previous_state == 0) {
      if(current_state == 0) {
        // 100 -> 100: The background process must have produced the
        // event.
        check_source(transition_probability, source, secondary_or_background);
        process_probs[0] += transition_probability;
      } else if(current_state == 1) {
        // 100 -> 111: A primary session birth.  A primary event.
        check_source(transition_probability, source, primary);
        process_probs[1] += transition_probability;
      }
    } else if(previous_state == 1) {
      if(current_state == 1){
        // 111 -> 111:  All processes are active.
        wsp_.resize(3);
        if (source < 0) {
          wsp_[0] = log(background_->event_rate(timestamp)) + logp_secondary;
          wsp_[1] = log(primary_traffic_->event_rate(timestamp)) + logp_primary;
          wsp_[2] = log(secondary_traffic_->event_rate(timestamp)) + logp_secondary;
          wsp_.normalize_logprob();
        } else if (source == 1) {
          wsp_ = 0;
          wsp_[1] = 1.0;
        } else if (source == 0) {
          wsp_ = 0;
          wsp_[0] = background_->event_rate(timestamp);
          wsp_[1] = secondary_traffic_->event_rate(timestamp);
          wsp_.normalize_prob();
        }
        process_probs.axpy(wsp_, transition_probability);
      } else if(current_state == 2) {
        // 111 -> 101: Death of a primary session, which is a primary
        // event.
        check_source(transition_probability, source, primary);
        process_probs[1] += transition_probability;
      } else if(current_state == 3) {
        // 111 -> 110: End of a secondary session, which is a
        // secondary event.
        check_source(transition_probability, source, secondary_or_background);
        process_probs[2] += transition_probability;
      }
    } else if (previous_state == 2) {
      if (current_state == 0) {
        // 101 -> 100: End of a secondary session, which is a
        // secondary event.
        check_source(transition_probability, source, secondary_or_background);
        process_probs[2] += transition_probability;
      } else if (current_state == 1) {
        // 101 -> 111:  New primary session, a primary event
        check_source(transition_probability, source, primary);
        process_probs[1] += transition_probability;
      } else if (current_state == 2) {
        // 101 -> 101: Either a background event or a secondary event
        // took place.  Both processes share the same mark model, so
        // their full conditional distribution is determined by the
        // relative event rates.
        check_source(transition_probability, source, secondary_or_background);
        wsp_.resize(3);
        wsp_[0] = background_->event_rate(timestamp);
        wsp_[1] = 0;
        wsp_[2] = secondary_traffic_->event_rate(timestamp);
        wsp_.normalize_prob();
        process_probs.axpy(wsp_, transition_probability);
      }
    } else if (previous_state == 3) {
      if (current_state == 1) {
        // 110 -> 111 Start of a new secondary session, caused by a
        // primary event.
        check_source(transition_probability, source, primary);
        process_probs[1] += transition_probability;
      } else if (current_state == 2) {
        // 110 -> 101: End of a primary session, a primary event
        check_source(transition_probability, source, primary);
        process_probs[1] += transition_probability;
      } else if (current_state == 3) {
        // 110 -> 110: The secondary process is off in both instances,
        // so the event must have been produced by the background
        // process.
        check_source(transition_probability, source, secondary_or_background);
        process_probs[0] += transition_probability;
      }
    }
  }

  //----------------------------------------------------------------------
  MixtureComponent * PoissonClusterProcess::mark_model(
      const PoissonProcess * process){
    if(primary(process)) return primary_mark_model_.get();
    if(secondary(process)) return secondary_mark_model_.get();
    report_error("Unknown process passed to PoissonClusterProcess::mark_model().");
    return 0;
  }

  //----------------------------------------------------------------------
  const MixtureComponent * PoissonClusterProcess::mark_model(
      const PoissonProcess * process)const{
    if(!primary_mark_model_) return 0;

    if(process == background_.get()
       || process == primary_death_.get()
       || process == secondary_traffic_.get()
       || process == secondary_death_.get()) {
      return secondary_mark_model_.get();
    }

    if(process == primary_traffic_.get()
       || process == primary_birth_.get()) {
      return primary_mark_model_.get();
    }

    report_error(
        "Unknown process passed to PoissonClusterProcess::mark_model.");
    return 0;
  }

  //======================================================================
  // Begin private functions

  //----------------------------------------------------------------------
  void PoissonClusterProcess::register_models_with_param_policy(){
    ParamPolicy::clear();
    ParamPolicy::add_model(background_);
    ParamPolicy::add_model(primary_birth_);
    ParamPolicy::add_model(primary_death_);
    ParamPolicy::add_model(primary_traffic_);
    ParamPolicy::add_model(secondary_traffic_);
    ParamPolicy::add_model(secondary_death_);
    if(!!primary_mark_model_) ParamPolicy::add_model(primary_mark_model_);
    if(!!secondary_mark_model_) ParamPolicy::add_model(secondary_mark_model_);
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::initialize(){
    fill_state_maps();
    register_models_with_param_policy();
    setup_filter();
    initialization_strategy_ = UniformInitialState;
    wsp_.resize(number_of_hmm_states());
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::setup_filter(){
    int S = number_of_hmm_states();
    pi0_.resize(S); pi0_ = 1.0/S;
    one_.resize(S); one_ = 1.0;
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::check_source(
      double probability, int source, bool primary){
    if(source < 0
       || (primary && source == 1)
       || (!primary && source == 0)
       || probability < .0001) return;
    ostringstream err;
    err << "Positive probability was assigned to an impossible event, "
        << "based on known source. " << endl
        << "Source = " << source << " but probability " << probability
        << " was assigned to a " <<  (primary ? "primary" : "non-primary")
        << " process." << endl;
    report_error(err.str());
  }

  //----------------------------------------------------------------------
  void PoissonClusterProcess::fill_state_maps(){
    // The three states in each selector are the background proces,
    // the primary process, and the secondary process

    // This is the mapping from hmm state index to which processes are
    // active.  Its size determines the number of hmm states.
    activity_state_.resize(4);
    activity_state_[0] = Selector("100");  // only bg process is active
    activity_state_[1] = Selector("111");  // primary, secondary both active
    activity_state_[2] = Selector("101");  // secondary active, primary not
    activity_state_[3] = Selector("110");  // primary active, secondary not

    //----------------------------------------------------------------------
    // Legal transitions are:
    // 0: 100 -> [ 100, 111,    ,     ]
    // 1: 111 -> [    , 111, 101, 110 ]
    // 2: 101 -> [ 100, 111, 101,     ]
    // 3: 110 -> [    , 111, 101, 110 ]
    // Note that you can always do a self transition if the background
    // process produces an event.
    std::vector<int> legal_transitions_from_state_0;
    legal_transitions_from_state_0.push_back(0);
    legal_transitions_from_state_0.push_back(1);

    std::vector<int> legal_transitions_from_state_1;
    legal_transitions_from_state_1.push_back(0);
    legal_transitions_from_state_1.push_back(1);
    legal_transitions_from_state_1.push_back(2);
    legal_transitions_from_state_1.push_back(3);

    std::vector<int> legal_transitions_from_state_2;
    legal_transitions_from_state_2.push_back(0);
    legal_transitions_from_state_2.push_back(1);
    legal_transitions_from_state_2.push_back(2);

    std::vector<int> legal_transitions_from_state_3;
    legal_transitions_from_state_3.push_back(1);
    legal_transitions_from_state_3.push_back(2);
    legal_transitions_from_state_3.push_back(3);

    legal_target_transitions_.resize(4);
    legal_target_transitions_[0] = legal_transitions_from_state_0;
    legal_target_transitions_[1] = legal_transitions_from_state_1;
    legal_target_transitions_[2] = legal_transitions_from_state_2;
    legal_target_transitions_[3] = legal_transitions_from_state_3;

    //----------------------------------------------------------------------
    std::vector<PoissonProcess *> active_processes_in_state_0;
    active_processes_in_state_0.push_back(background_.get());
    active_processes_in_state_0.push_back(primary_birth_.get());

    std::vector<PoissonProcess *> active_processes_in_state_1;
    active_processes_in_state_1.push_back(background_.get());
    active_processes_in_state_1.push_back(primary_traffic_.get());
    active_processes_in_state_1.push_back(secondary_traffic_.get());
    active_processes_in_state_1.push_back(secondary_death_.get());
    active_processes_in_state_1.push_back(primary_death_.get());

    std::vector<PoissonProcess *> active_processes_in_state_2;
    active_processes_in_state_2.push_back(background_.get());
    active_processes_in_state_2.push_back(secondary_traffic_.get());
    active_processes_in_state_2.push_back(secondary_death_.get());
    active_processes_in_state_2.push_back(primary_birth_.get());

    std::vector<PoissonProcess *> active_processes_in_state_3;
    active_processes_in_state_3.push_back(background_.get());
    active_processes_in_state_3.push_back(primary_traffic_.get());
    active_processes_in_state_3.push_back(primary_death_.get());

    active_processes_.resize(4);
    active_processes_[0] = active_processes_in_state_0;
    active_processes_[1] = active_processes_in_state_1;
    active_processes_[2] = active_processes_in_state_2;
    active_processes_[3] = active_processes_in_state_3;

    //----------------------------------------------------------------------
    // The map of responsible processes keeps track of which set of
    // processes is potentially responsible for the r, s transition
    std::vector<PoissonProcess *> responsible_process_0_0;
    responsible_process_map_[index(0, 0)] =
        std::vector<PoissonProcess *>(1, background_.get());

    responsible_process_map_[index(0, 1)] =
        std::vector<PoissonProcess *>(1, primary_birth_.get());

    std::vector<PoissonProcess *> tmp;
    tmp.push_back(background_.get());
    tmp.push_back(primary_traffic_.get());
    tmp.push_back(secondary_traffic_.get());
    responsible_process_map_[index(1, 1)] = tmp;
    tmp.clear();

    responsible_process_map_[index(1, 2)] =
        std::vector<PoissonProcess *>(1, primary_death_.get());
    responsible_process_map_[index(1, 3)] =
        std::vector<PoissonProcess *>(1, secondary_death_.get());

    responsible_process_map_[index(2, 0)] =
        std::vector<PoissonProcess *>(1, secondary_death_.get());
    responsible_process_map_[index(2, 1)] =
        std::vector<PoissonProcess *>(1, primary_birth_.get());

    tmp.push_back(background_.get());
    tmp.push_back(secondary_traffic_.get());
    responsible_process_map_[index(2, 2)] = tmp;
    tmp.clear();

    responsible_process_map_[index(3, 1)] =
        std::vector<PoissonProcess *>(1, primary_traffic_.get());
    responsible_process_map_[index(3, 2)] =
        std::vector<PoissonProcess *>(1, primary_death_.get());
    responsible_process_map_[index(3, 3)] =
        std::vector<PoissonProcess *>(1, background_.get());
  }


} // namespace BOOM
