#!/usr/bin/perl -w
#
# Generate & display convergence charts from vw progress outputs
# Requires R to generate the charts
#
use Getopt::Std;
use vars qw ($opt_d $opt_x $opt_y $opt_t
	     $opt_w $opt_h $opt_q $opt_Q
	     $opt_o
);

my $TmpImgFile = '/tmp/vw-convergence.png';
my $LossStr;

sub usage(@) {
    print STDERR @_, "\n" if (@_);

    die "Usage: $0 [options] [input_files...]

    If input_file is missing, input is read from stdin.
    Assuming inputs are one or more 'vw' progress reports
    Requires R to generate the chart

    Options:
	-q	Convert from squared-loss X to abs-loss (apply sqrt(X))
	-Q	Convert from squared-loss X to (exp(sqrt(X)) - 1.0) * 100
	-o<IMG>	Output chart to <IMG> file
	-x<XL>	Use <XL> as X-axis label in chart
	-y<YL>	Use <YL> as Y-axis label in chart
	-t<T>	Use <T> as title of chart
	-w<W>	set image width in pixels (default 800)
	-h<W>	set image height in pixels (default 600)

    -o <image_file> is optional, if not given, will create a temp-file:
	$TmpImgFile
    and display it.
";
}

#
# App-specific transformation from squared-error of the log to percent.
#
sub sqlosslog2pct($) {
    my $arg = shift;
    # return 0 unless (defined $arg);
    (exp(sqrt($arg)) - 1.0) * 100;
}

sub transform_loss($) {
    my $loss = shift;

    return sqlosslog2pct($loss) if ($opt_Q);
    return sqrt($loss) if ($opt_q);

    $loss;
}

#
# input: vw progress (can concatenate multiple) from STDIN or file
# output: a list of vectors of avg loss values
#
sub average_loss_arrays() {
    my @avg_losses = ();
    my @avg_loss = ();

    while (<>) {
	# progress lines
	if (/^([0-9.]+)/) {
	    push(@avg_loss, transform_loss($1));
	    next;
	}
	# summary line
	if (/^average loss\s*=\s*([0-9.]+)/) {
	    push(@avg_loss, transform_loss($1));
	    push(@avg_losses, [ @avg_loss ]);
	    @avg_loss = ();
	    next;
	}
    }
    # if summary line wasn't included, and we have avg_loss data,
    # use whatever we have
    if (@avg_loss) {
	push(@avg_losses, [ @avg_loss ]);
	@avg_loss = ();
    }

    usage("Couldn't identify 'vw' progress report(s) in input")
	unless (@avg_losses);

    @avg_losses;
}

sub do_plot($;$) {
    my ($loss_vec_ref, $imgfile) = @_;

    my $set_r_device_line =
        ($imgfile =~ /\.e?ps$/) ?
            "postscript(file='$imgfile', width=$opt_w, height=$opt_h)"
        : ($imgfile =~ /\.jpg$/) ?
            "jpeg(file='$imgfile', width=$opt_w, height=$opt_h)"
        : # default: if you don't have 'Cairo', just use 'png' instead
            "library(graphics); autoload('CairoPNG', 'Cairo');\n" .
            "if (exists('Cairo', mode='function')) {\n" .
                "CairoPNG(file='$imgfile', width=$opt_w, height=$opt_h)\n" .
            "} else {\n" .
                "png(file='$imgfile', width=$opt_w, height=$opt_h)\n" .
            "}"
    ;

    my $R_input = "$set_r_device_line;\n";
    my $lineno = 0;
    my @colors = (4, 2, 1, 3, 5, 6, 7);
    my $ncols = @colors;
    my @colorlist = ();
    my @pchs = ();
    foreach my $lossref (@$loss_vec_ref) {
	my $R_losses_array = sprintf('c(%s)', join(',', @$lossref));
	my $color = $colors[$lineno % $ncols];
	$R_input .= "loss = $R_losses_array;\n";
	$R_input .=
	    ($lineno == 0) ?
		"plot(loss, pch=20, t='o', col=$color, lwd=2, cex=1.25,\n" .
		"\tlab=c(10,20,7), xlab='$opt_x', ylab='$opt_y',\n" .
		"\tmain='$opt_t', panel.first=grid(col='gray63'));\n"
	    :
		"lines(loss, pch=20, t='o', col=$color, lwd=2, cex=1.25);\n";

	$lineno++;
	push(@colorlist, $color);
	push(@namelist, sprintf("%.4f", $lossref->[-1]));
	push(@pchs, 20);
    }
    $R_input .= sprintf(
	"legend('topright', legend=c(%s), col=c(%s), pch=c(%s), %s);\n",
	    join(',', @namelist),
	    join(',', @colorlist),
	    join(',', @pchs),
	    "inset=0.01, title='final mean $LossStr', pt.cex=2, lwd=2.5"
    );

    open(RIN, "|R --no-save --no-init-file");
    print RIN $R_input;
    close RIN;
}

sub get_args() {
    $0 =~ s{.*/}{};
    getopts('dx:y:t:w:h:qQo:');

    $LossStr = ($opt_Q) ? '%loss' : 'loss';
    $opt_x = 'vw progress iteration' unless (defined $opt_x);
    $opt_y = "mean $LossStr"
	unless (defined $opt_y);
    $opt_t = "online training $opt_y convergence" unless (defined $opt_t);
    $opt_w = 800 unless (defined $opt_w);
    $opt_h = 600 unless (defined $opt_h);
    $opt_o = $TmpImgFile
	unless (defined $opt_o);

    unlink($TmpImgFile) if (-e $TmpImgFile);

    if (-e $opt_o && $opt_o ne $TmpImgFile) {
	die "$0: image file '$opt_o' already exists: avoiding overwrite\n";
    }
}

sub display($) {
    my $imgfile = $_[0];

    die "$0: display($imgfile): $! - must be a bug\n"
        unless (-e $imgfile);

    if ($opt_d) {
        # User requested no display,
        # Be friendly, give a hint that everything is OK
        printf STDERR "image file is: %s\n", $imgfile;
        return;
    }
    # Postscript files are generated in portrait: need 90-degrees
    # rotation
    my $rotate = ($imgfile =~ /ps$/) ? '-rotate 90' : '';
    $opt_i = '' unless (defined $opt_i);
    system("display $opt_i $rotate $imgfile");
}


# -- main
get_args();
my @losses = average_loss_arrays();
do_plot(\@losses, $opt_o);
display($opt_o);

