CodeIQ:激ムズ!頭脳を【上限解放】~電脳体キュラゲタムを討伐せよ

私自身が表題の問題を解いた時のプログラムについて解説します。
問題の詳細は「激ムズ!頭脳を【上限解放】~電脳体キュラゲタムを討伐せよ」(CodeIQ)を参照してください。

問題の概要

問題を引用します。

【問題詳細】
マップ上を移動し、キュラゲタムの生息エリア(マップ上で最も高度の大きい場所)に向かいます。
マップ上のキュラゲタム生息エリアをすべて通りキャンプ(スタート地点)に戻るまでの最短時間を求めてください。

◎マップ
H×Wの格子状になっている文字列が与えられます。各地点には、0~9のレベルでその地点の高度情報が与えられ、移動方法・移動にかかる時間は次のように決まっています。

◎移動ルール
●隣り合う上下左右のマスのみ移動できます。
●高低差が無い場合: 1マス移動するのに3分かかります。
●高低差が「1」の上りの場合: 1マス移動するのに11分かかります。
●高低差が「1」の下りの場合: 1マス移動するのに2分かかります。
●高低差が「2」以上の場合: 上り・下り問わず移動することはできません。

例えば、2×3のマップで、左上から右上まで移動する場合、単純に右に進むと11+2で13分かかります。
1 2 1
1 1 1
※色のついた部分を移動する

しかし以下のように高低差を避けて回り込めば、3+3+3+3で12分になります。
1 2 1
1 1 1
※色のついた部分を移動する

◎注意点
●キャンプ(スタート地点)はマップ上の左上とします。
●キャンプ(スタート地点)からキュラゲタムの生息エリアまでのルートは必ず存在する入力データとなっています。
●マップ上の同じエリアを何回通ってもかまいません。
●キュラゲタムの生息地点数は最大15です。

【入力】
標準入力の1行目は、マップの縦サイズH・横サイズWが半角スペース区切り、2行目以降のH行分にマップデータとなる文字列が与えられます。

【出力】
標準出力に、最短時間を表す整数値を出力してください。

私のプログラム

Javaで解答しています。

import java.io.*;
import java.util.*;

public class Main {
    /**
     * 座標クラス
     */
    class Point {
        int x;
        int y;

        public Point(int x,int y){
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object o){
            if(this == o){  return true; }
            if(o == null){  return false; }
            if(getClass() != o.getClass()){ return false; }
            Point p = (Point)o;
            return (this.x == p.x) && (this.y == p.y);
        }

        @Override
        public int hashCode() {
            int hash = 5;
            hash = 79 * hash + this.x;
            hash = 79 * hash + this.y;
            return hash;
        }

        @Override
        public String toString(){
            return "(" + String.valueOf(x) + "," + String.valueOf(y) + ")";
        }
    }

    /**
     * 2地点間のコストを管理するクラス
     */
    class P2PCost {
        Point p1;
        Point p2;
        int cost;

        public P2PCost(Point p1, Point p2, int c){
            this.p1 = p1;
            this.p2 = p2;
            this.cost = c;
        }

        @Override
        public String toString(){
            return p1.toString() + p2.toString() + ":" + String.valueOf(cost);
        }
    }

    public class Tsp{
        /**
         *	地点間距離テーブル
         *	          TO
         *	       A   B   C
         *	F| A   0 144 172
         *	R| B  90   0 105
         *	O| C 118 105   0
         *	M|
         */
         public int[][] DistTable;

         /**
         *	計算時メモ領域
         *	地点A〜Cがある場合、出発地点A以外について記録する
         *	訪問済み  現在位置
         *	C B:    B   C
         *	0 0:    -   -		<= 全地点未訪問
         *	0 1:  277   -		<= A->C->Bに移動した場合
         *	1 0:    - 249		<= A->B->Cに移動した場合
         *	1 1:  144 172		<= 初期値はAから他の地点までの距離
         */
         private int[][] Memo;

         /** 地点の最大要素番号(要素数-1) */
         private int N;

         /**
         *	コンストラクタ
         *	@param n 地点数
         */
         public Tsp(int n){
             DistTable = new int[n][n];
             N = n-1;
             Memo = new int[1<<(N)][];
         }

         /**
         *	到達地点に達した時のコスト配列を初期化して返す。
         *	要素は出発地点を除くので要素数-1
         */
        private int[] createMemoRow(){
            int[] ret = new int[N];
            Arrays.fill(ret, Short.MAX_VALUE);
            return ret;
        }

        /**
        *	動的計画法で巡回セールスマン問題を解く
        */
        public long tsp(){
            // 初期値を設定
            Memo[(1<<N)-1] = createMemoRow();
            for(int i=0; i<N; i++){
                Memo[(1<<N)-1][i]=DistTable[0][i+1];
            }

            for(int v=(1<<N)-2; v>0; v--){
                int[] tmp = createMemoRow();
                for(int i=0; i<N; i++){			// 現在地点 i
                    if( ((1 << i) & v) != 0){	// 現在地点を通過済み

                        for(int j=0; j<N; j++){		// 次の地点 j
                            if(((1<<j)&v) == 0){	// 次の地点は未到達
                                int n = DistTable[i+1][j+1] + Memo[v | (1<<j)][j];
                                tmp[i] = Math.min(tmp[i], n);
                            }
                        }
                    }
                }
                Memo[v] = tmp;
            }

            // 最後の到達点からスタート地点までのコストを計算
            long ret = Short.MAX_VALUE;
            for(int i=0; i<N; i++){
                int x = DistTable[i+1][0] + Memo[1<<i][i];
                ret = Math.min(ret, x);
            }
            return ret;
        }

        public void printTable(int[][] table){
            for(int[] r:table){
                if(r == null){
                    System.out.println("null");
                    continue;
                }
                for(int i:r){
                    System.out.printf("%-3d ", i);
                }
                System.out.println("");
            }
        }

        public void printDistTable(){
            printTable(DistTable);
        }
    }

    /** 地図 */
    int[][] QMap;

    /** 最高地点間の最短経路を求めるためのメモ */
    int[][] QMemo;

    /** 最高地点間の最短経路を求める時の探索方向リスト */
    Point[] Direction;

    /** 最大高さ  */
    int MaxH = 0;

    /** 最高地点(スタート視点を含む)のリスト */
    ArrayList<Point> HPoints;

    Main(){
        HPoints = new ArrayList<>();
        HPoints.add(new Point(1,1));    // スタート地点を設定しておく
        Direction = new Point[4];
        Direction[0] = new Point(-1,0); //上
        Direction[1] = new Point(+1,0); //下
        Direction[2] = new Point(0,-1); //左
        Direction[3] = new Point(0,+1); //右
    }

    /**
     * 地図を初期化すると最高地点間のメモを作成する
     * 計算を楽にするため一回り大きく作る
     * @param s サイズをスペース区切りした文字列
     */
    void initMap(String s){
        String[] xy = s.split(" ");
        int y = Integer.parseInt(xy[0]) + 2;
        int x = Integer.parseInt(xy[1]) + 2;

        QMap = new int[y][];
        QMemo = new int[y][];

        for(int i=0; i<y; i++){
            QMap[i] = new int[x];
            QMemo[i] = new int[x];

            for(int j=0; j<x; j++){
                QMap[i][j] = -2;
            }
        }
    }

    /**
     * 地図を設定する
     * @param l 地図の行番号(0始まり)
     * @param s 地図の値
     */
    void setMap(int l, String s){
        for(int i=0; i<s.length(); i++){
            char c = s.charAt(i);
            int y = l+1;
            int x = i+1;
            QMap[y][x] = c-0x30;

            if(QMap[y][x] > MaxH){
                MaxH = QMap[y][x];
            }
        }
    }

    /**
     * 最高地点のリストを作成する
     * 地図は外周が設定されているので1〜要素数-1までを処理する
     */
    void setHighPoints(){
        for(int y=1; y<QMap.length-1; y++){
            for(int x=1; x<QMap[y].length-1; x++){
                if(QMap[y][x] == MaxH){
                    Point p = new Point(x,y);
                    HPoints.add(p);
                }
            }
        }
    }

    /**
     * QMemoを初期化する
     */
    void clearQMemo(){
        for(int i=0; i<QMemo.length; i++){
            for(int j=0; j<QMemo[i].length; j++){
                QMemo[i][j] = Integer.MAX_VALUE;
            }
        }
    }

    int getPointCost(Point now, Point next){
        int d = QMap[next.y][next.x] - QMap[now.y][now.x];
        switch(d){
            case -1:
                return 2;
            case 1:
                return 11;
            case 0:
                return 3;
            default:
                return -1;
        }

    }

    /**
     * 2点間の最短経路のコストを求める
     * @param st スタート地点座標
     * @param hps 終了地点座標
     * @return コストのリスト。順番はhpsと同じ。
     */
    ArrayList<Integer> calcPathH(Point st, ArrayList<Point> hps){
        Deque<Point> queue = new ArrayDeque<>();
        clearQMemo();

        // スタート地点の訪問履歴をセット
        QMemo[st.y][st.x] = 0;
        queue.addLast(st);

        while(!queue.isEmpty()){
            Point cur = queue.pollFirst();

            for(Point d: Direction){
                Point next = new Point(cur.x + d.x, cur.y + d.y);
                int c = getPointCost(cur, next) ;
                // 移動できない地点は無視
                if(c < 0){
                    continue;
                }

                // 次の地点までの累積移動コスト
               int nc = c + QMemo[cur.y][cur.x];
                if(nc < QMemo[next.y][next.x]){
                    QMemo[next.y][next.x] = nc;
                    queue.addLast(next);
                }
            }
        }
        ArrayList<Integer> ret = new ArrayList<>();
        for(Point p: hps){
            ret.add(QMemo[p.y][p.x]);
        }
        return ret;
    }

    /**
     * 最高地点(とスタート地点)間の移動にかかるコストを求める
     */
    Tsp getHighPointCosts(){
        /** 巡回セールスマン問題を解くためのオブジェクト */
        Tsp tsp = new Tsp(HPoints.size());

        for(int i=0; i<HPoints.size(); i++){
            Point st = HPoints.get(i);
            ArrayList<Integer> costs = calcPathH(st, HPoints);

            for(int j=0; j<HPoints.size(); j++){
                tsp.DistTable[i][j] = costs.get(j);
            }
        }
        return tsp;
    }

// DEBUG -->
    /**
     * 地図を表示する。
     * デバッグ用。
     */
    void printMap(){
        for (int[] q : QMap) {
            for (int j = 0; j < q.length; j++) {
                System.out.print(q[j]);
            }
            System.out.println();
        }
    }

    void printQMemo(){
        for(int[] q: QMemo){
            for(int i=0; i<q.length; i++){
                if(q[i] != Integer.MAX_VALUE){
                    System.out.printf("%03d ", q[i]);
                }
                else{
                    System.out.print("xxx ");
                }
            }
            System.out.println();
        }
    }
// <-- DEBUG


    public static void main(String[] args) throws IOException {
        Main qrage = new Main();

        try (BufferedReader stdReader = new BufferedReader(new InputStreamReader(System.in))) {
            String line;

            for(int i=0; (line = stdReader.readLine()) != null; i++) {
                if(i==0){
                    qrage.initMap(line.trim());
                }
                else{
                    qrage.setMap(i-1, line.trim());
                }
                if(i >= qrage.QMap.length-2){
                    break;
                }
            }
        }
        qrage.setHighPoints();
        Tsp tsp = qrage.getHighPointCosts();
        System.out.println(tsp.tsp());
        return;
    }
}

解説

私がCodeIQの問題にチャレンジするようになってからこれまで(2016/2/28現在)最高の難易度(★★★★)の問題です。想定時間30分となっていますが、私の能力ではとてもそんな時間ではできません。それどころか、何度もギブアップしようと思ったほどです。

基本方針

非常に難しい問題ですが、問題自体は理解しやすいです(良問と言えます)。
実際のところ、基本方針は結構容易に考えついたのです。
私が考えた基本方針は次の通りです。
  • 探索を2段階に分ける
  • 1段階目で経由地点(最高地点とスタート地点)間のコストを算出する
  • 2段階目でスタート地点から開始して全地点を経由し、スタート地点に戻る最短経路を求める
この考え方自体は正しいと思います(実際、この通りにやってパスしています)。
当初、計算量は別として1段階目の探索は幅優先探索で、2段階目はダイクストラ法で解けるのではないか、と思っていました……。

巡回セールスマン問題

しかし、よくよく考えると2段階目の探索はダイクストラ法では解けそうにないことに気づきました。ダイクストラ法はスタートからゴールまで任意の経路のうち最短経路を求めるアルゴリズムです。この問題はすべての地点を経由して戻って来る最短経路を求めなければなりません。
私の大学での専攻は生物学でアルゴリズムについて専門的に学んだことはありませんし、会社に入ってからもCodeIQで要求されるような仕事はほとんどありません。ですが、「すべての地点を経由して戻って来る最短経路を求める」問題はあまりにも有名なので記憶にヒットしました。

これは「巡回セールスマン問題」じゃないか?

Wikipediaで調べてみるとやはり巡回セールスマン問題のようです。
ここで途方にくれました。「巡回セールスマン問題」は効率的に解を求めることのできない問題として有名なものです。この段階ではとりあえず有効な解法を思いつけませんでした。

とりあえずやってみる

「巡回セールスマン問題」が効率的に解けないなら入力パターンはそれを考慮しているかも、と甘い期待を元にとりあえず2段階目の経路を貪欲法でやってみました。当然そんなに甘くはなくNG。
ここでもう1点1段階目の探索にも間違いがあることがわかりました。
テストケースの1つ目が(記憶が正しければ)最高地点が1つしかないにもかかわらずNGになっています。テストケースの2つ目と例題は正しい結果を得られていたので1段階目のロジックに何らかのミスがあることは明らかでした。

1段階目の探索を見直す

私の回答は時間切れで打ち切られているので正しい答えを教えてくれません。仕方ないので1つ目のテストケースを紙に書いて最短経路を手計算で求めて、自分のプログラムの計算結果と照らし合わせます。
元のプログラムでは初めてゴール(最高地点)にたどり着いた時点で探索を打ち切っていたのがまずいことがわかりました。ArrayList<Integer> calcPathH(Point st, ArrayList<Point> hps)を修正し、queueが空になるまで探索を継続するように修正しました。
結果、1段階目の探索は正しい結果を得られるようになりました。

「巡回セールスマン問題」を動的計画法で解く

巡回セールスマン問題を考えたり、調べたりしているうちに動的計画法かメモ化再帰でできないか? と思いつきます。要素数があまりにも多いなら時間切れでしょうが、ある程度の数ならまともな時間で解ける可能性があります。
とりあえず、動的計画法で考えてみることにしました。ですが、履歴をどう記録すれば良いのかがわかりません(苦笑)。こういう時は先人の知恵に頼るしかないのでGoogleで「巡回セールスマン問題 動的計画法」とやって調べます。
Ptyhonで記述されたわかりやすいサイトを見つけたのでそれを参考にしました。class Tspがそれです。元のコードはクラスにはしていませんでしたが見通しを良くするためクラスにしました。また、元のコードは無向グラフでしたが、この問題では有向グラフなのでそれに対応できるようにTsp#DistTableを設定します。

これで、マップが広くなって(記憶が正しければ200×200くらい)時間切れになるまでは正しい結果が得られるようになりました。ロジックは正しそうですが高速化が必要です。

1段階目の探索の高速化

おそらく1回目の探索回数が多すぎるのでここを高速化する必要があることはわかっていましたが、一応プロファイリングして確かめます。思った通り、1回目の探索にほとんどの時間がかかっています。
考えている間にあることに気づきました。私は1回目の探索の結果が間違っているのを修正するためにゴールに達した時点で探索を打ち切る処理を削除し、queueが空になるまで探索を続けるように修正しました。これは次のように言い換えることができます。

すべての地点を探索し、目的地点までの最短経路を算出する」

重要なポイントは下線部分です。探索を途中で打ち切らなくしたので、1回の探索で1つの開始点からその他の目的地までの最短経路が全てわかっているはずです。実際に検証した結果その通りの結果でした。これで計算量はO(n2)からO(n)に激減します。
修正したコードでテストケース中最大のマップをテストしてみるとローカルの開発環境で0.7秒余りになりました。

テストーケースをクリア

これまでが時間切れのため(途中で時間切れやエラーを生じた場合は正答を教えてくれないため)本当に正しい結果になっているかはローカルではわかりませんが、投稿した結果全テストケースをパスしました。

雑感

さすがに★★★★だけあって非常に難しかったというのが感想です。
解法の基本方針はイメージできているのにその細部がわからないという問題は余り経験がありません。
ですが、諦めずにチャレンジし続けた結果、達成できたというのは嬉しかったですし、やればできるものだという自信にもなりました。