#include "osl/rating/featureSet.h"
#include "osl/rating/group.h"
#include "osl/rating/feature.h"
#include "osl/rating/ratingEnv.h"
#include "osl/record/csaRecord.h"
#include "osl/record/csaString.h"
#include "osl/apply_move/applyMove.h"
#include "osl/move_generator/legalMoves.h"
#include "osl/container/moveVector.h"
#include "osl/stat/histogram.h"
#include "osl/stat/average.h"
#include "osl/oslConfig.h"

#include <cppunit/TestCase.h>
#include <cppunit/extensions/HelperMacros.h>
#include <string>
#include <fstream>
#include <iostream>

class StandardFeatureSetTest : public CppUnit::TestFixture 
{
  CPPUNIT_TEST_SUITE(StandardFeatureSetTest);
  CPPUNIT_TEST(testMatch);
  CPPUNIT_TEST(testCover);
  CPPUNIT_TEST(testPass);
  CPPUNIT_TEST_SUITE_END();
public:
  void testMatch();
  void testCover();
  void testPass();
};

CPPUNIT_TEST_SUITE_REGISTRATION(StandardFeatureSetTest);

using namespace osl;
using namespace osl::rating;
extern bool isShortTest;

static void testFileMatch(std::string const& filename)
{
  static StandardFeatureSet feature_set;
  static int kifu_count = 0;
  if ((++kifu_count % 8) == 0)
    std::cerr << '.';
  const Record rec=CsaFile(filename).getRecord();
  NumEffectState state(rec.getInitialState());
  const vector<osl::Move> moves=rec.getMoves();
  MoveStack history;

  for (size_t i=0; i<moves.size(); ++i) {
    if (i > 150)
      break;
    const Move move = moves[i];
    RatingEnv env;
    env.history = history;
    env.make(state);
    
    {
      MoveVector moves;
      LegalMoves::generate(state, moves);
      for (size_t k=0; k<moves.size(); ++k) {
	for (size_t j=0; j<feature_set.groupSize(); ++j) {
	  int match = feature_set.group(j).findMatch(state, moves[k], env);
	  if (match < 0)
	    continue;
	  match += feature_set.range(j).first;
	  if (! feature_set.feature(match).match(state, moves[k], env)) {
	    std::cerr << feature_set.group(j).group_name << " " 
		      << feature_set.feature(match).name() << " " << match - feature_set.range(j).first 
		      << "\n" << state << moves[k];
	  }
	  CPPUNIT_ASSERT(feature_set.feature(match).match(state, moves[k], env));
	}
      }
    }

    ApplyMoveOfTurn::doMove(state, move);
    history.push(move);
  }
}

void StandardFeatureSetTest::testMatch()
{
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  CPPUNIT_ASSERT(ifs);
  std::string filename;
  for(int i=0;i<100 && (ifs >> filename) ; i++){
    testFileMatch(OslConfig::testCsaFile(filename));
  }
}

static stat::Average all_average;
static void testFile(std::string const& filename)
{
  static StandardFeatureSet feature_set;
  static int kifu_count = 0;
  if ((++kifu_count % 8) == 0)
    std::cerr << '.';
  const Record rec=CsaFile(filename).getRecord();
  NumEffectState state(rec.getInitialState());
  const vector<osl::Move> moves=rec.getMoves();
  MoveStack history;
  const size_t limit = 1000;

  stat::Histogram stat(200,10);
  stat::Average order_average;
  for (size_t i=0; i<moves.size(); ++i) {
    if (i > 150)
      break;
    const Move move = moves[i];
    RatingEnv env;
    env.history = history;
    env.make(state);
    // 合法手生成のテスト
    MoveLogProbVector all_moves;
    feature_set.generateLogProb(state, env, limit, all_moves);
    for (MoveLogProbVector::const_iterator p=all_moves.begin(); p!=all_moves.end(); ++p) {
      if (p->getMove().isPass())
	continue;
      CPPUNIT_ASSERT(p->getMove().isValid());
      CPPUNIT_ASSERT(state.isValidMove(p->getMove()));
    }

    // 確率のテスト
    const MoveLogProb *found = all_moves.find(move);
    if (! found) {
      std::cerr << state << move << "\n";
    }
    CPPUNIT_ASSERT(found);
    stat.add(found->getLogProb());
    order_average.add(found - &*all_moves.begin());

    ApplyMoveOfTurn::doMove(state, move);
    history.push(move);
  }
  if (! isShortTest) {
    std::cout << filename 
	      << " average order " << order_average.getAverage() << "\n";
    if (order_average.getAverage() >= 10)
      stat.show(std::cout);
  }
  CPPUNIT_ASSERT(order_average.getAverage() < 10);
  all_average.merge(order_average);
}

void StandardFeatureSetTest::testCover()
{
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  CPPUNIT_ASSERT(ifs);
  std::string filename;
  for(int i=0;i<100 && (ifs >> filename) ; i++){
    testFile(OslConfig::testCsaFile(filename));
  }
  CPPUNIT_ASSERT(all_average.getAverage() < 6);
}

void StandardFeatureSetTest::testPass()
{
  static StandardFeatureSet feature_set;
  NumEffectState state((SimpleState(HIRATE)));
  
  RatingEnv env;
  env.history.push(Move::PASS(WHITE));
  env.make(state);

  MoveLogProbVector all_moves;
  feature_set.generateLogProb(state, env, 1200, all_moves);
}

// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
