CodeIQ:巡回サンタクロース問題

私自身が表題の問題を解いた時のプログラムについて解説します。
問題の詳細は「巡回サンタクロース問題」(CodeIQ)を参照してください。

問題の概要

次のようにパラメータが入力されます。

4
2
3
1
4

1行目はパラメータの入力数です。
2行目以降がパラメータです。
この例だとサンタクロースはA~Dの家にプレゼントを配る必要があり、それぞれの家に配る荷物の重さはA:2、B:3、C:1、D:4となります。 また各家の距離は2つの家に配るプレゼントの重さの積に等しいとなっています。例えばA-Bは距離6、A-Cは距離2、B-Dは距離12です。
サンタクロースは(まだ配っていない荷物の重さの合計)×(家と家の間の距離)だけエネルギーを消費します。
エネルギーを最小にする経路を求め、その時に消費するエネルギーを出力せよ、と言う問題です。

私のプログラム

Pythonで解答しています。

import fileinput

#==============================================================================
#	探索中の経路を管理するクラス
#------------------------------------------------------------------------------
class Candidate:
	def __init__(self, path, less):
		self.Path = path		# 通過済みの場所
		self.Less = less		# まだ行ってない場所

	def __str__(self):
		return str(self.Path) + str(self.Less)

	#	可能性のある経路のリストを返す
	#	閾値を超えてしまう経路は切り捨てる
	def getNextCandidates(self, th):
		ret = []
		for i in self.Less:
			p = self.Path[:]
			l = self.Less[:]

			p.append(i)
			l.remove(i)

			if th >= calcCost(p):
				c = Candidate(p, l)
				ret.append(c)

		return ret

	# 通過済みの場所にかかるコストを計算する
	def calcPathCost(self):
		return calcCost(self.Path)

#==============================================================================
#	枝刈りの基準となる並びを生成する
#	要素を大きい方からMax1,Max2,Max3,...,Min3,Min2,Min1とした場合、
#		Max1, Min1,Max2,Min2,...
#	の並びは結構いい線を行くのでこれを基準にする
#------------------------------------------------------------------------------
def getThresholdList(lst):
	l = lst[:]
	a = []
	even = True

	while len(l):
		if even:
			n = max(l)
		else:
			n = min(l)

		a.append(n)
		l.remove(n)
		even = not even

	return a

#==============================================================================
#	リストの経路の移動にかかるコストを計算する
#------------------------------------------------------------------------------
def calcCost(l):
	total = 0
	pos = 1
	w = Weight

	for i in l[1:]:
		w -= l[pos-1]
		total += w * l[pos-1] * l[pos]
		pos += 1

	return total

#==============================================================================
#	次の候補リストを作成する
#	Thresholdを超えたコストのかかる経路は無視する
#------------------------------------------------------------------------------
def getNewCandidates(lst, th):
	new_lst = []

	for i in lst:
		cost = i.calcPathCost()
		new_lst += i.getNextCandidates(th)

	return new_lst

#==============================================================================
#	最小コストの経路を返す
#------------------------------------------------------------------------------
def getMin(lst, th):
	mi = th	# 最小コスト
	pos = -1

	for i, c in enumerate(lst):
		cost = c.calcPathCost()
		if cost <= mi:
			mi = cost
			pos = i

	return lst[pos]

def printResults(lst):
	for i in lst:
		c = i.calcPathCost()
		print(str(i) + " -> " +str(c))

#==============================================================================
# main
#------------------------------------------------------------------------------
LineNo = 0		# 読み込んだ行数
Presents = []	# 読み込んだ値のリスト
Threshold = 0	# 閾値
Weight = 0		# 全重量

for line in fileinput.input():
	if not line.strip():
		continue

	if LineNo:
		i = int(line.strip())
		Presents.append(i)
		Weight += i
	LineNo += 1

MaxMin = getThresholdList(Presents)
Threshold = calcCost(MaxMin)

# 少なくとも要素の半分まではMax1, Min1, Max2, Min2, ...の順で並ぶらしい
Mid = len(MaxMin) // 2
new_cand = [Candidate(MaxMin[:Mid], MaxMin[Mid:])]

for i in range(len(MaxMin) - Mid):
	new_cand = getNewCandidates(new_cand, Threshold)

mi = getMin(new_cand, Threshold)
print(str(mi.calcPathCost()))

解説

この問題は一度実行時間オーバーで失敗しています。
最大の入力数の場合で3秒を切っていたので大丈夫か? と思っていましたがダメでした。

計算量の問題

この問題に挑戦したあたりではCodeIQの問題の傾向もかなりわかってきていて、問題を読んだだけで計算量が大変になることはわかりました。
それでどうするかという話になるのですがアプローチは2つです。

  • 漸化式を求める
  • 効率的な枝刈りを行う
最初の方法はわかりそうになかったのでとりあえず単純にシミュレートしてパターンを探ってみることにしました。
(注記:実際にはパターンから漸化式にはできたのですが、なぜそうなるのかを自力で証明できなかったのでプログラムには採用しませんでした)

効率的な枝刈りの基準

シミュレートして入力値が7くらいまでやった時に次のことに気づきました。
コメントにも書いてありますが、要素を大きい方からMax1,Max2,Max3,...,Min3,Min2,Min1とした場合、Max1, Min1,Max2,Min2,...の並びは最善ではありませんが常にかなり効率的な経路になります。きちんと計算していないのでわかりませんが上位1割くらいの候補には入っていそうでした。

これを求めているのがgetThresholdList()です。

それでは具体的にプログラムの説明をします。

Candidateクラス

このクラスは経路情報を記録し、次の行き先候補を返す機能と通過した経路にかかったコストを計算する機能を提供します。
すべての経路を探索するまでの間はこのクラスのインスタンスのリストが候補リストになります。

getNextCandidates()は次の候補となる経路を返します。同時に枝刈りをしていて、経路の途中でコストが基準を超えた経路を無視します。

calcPathCost()で確定した経路を通るのにかかったコストを計算しています。
実際に計算しているのはメンバ関数ではないcalcCost()です。なぜメンバではない関数で計算させているかというと、わずかなりともこちらの方がメモリ効率が良さそうに思えたからです。実際にはどうかわかりませんし、実務で書くプログラムならメンバにすると思います。

経路探索を行う

経路探索を行うのはgetNewCandidates()です。
第一引数には経路の候補リストが、第二引数には枝刈り基準のコストが与えられます。
getNewCandidates()は1回呼び出されるごとに未探査の経路を1つ追加して (Candidate#getNextCandidates()で枝を刈った上で新たな候補を追加して)新たなリストを返します。これを入力されたパラメータの回数だけ繰り返せば全経路探索ができます(131〜132行目)

ここに掲載しているのはテストケースを通過したコードなので最初のものとは違います。
最初のコードは128〜129行目の処理が違っていて、最大の値を持つ経路だけを通過済みのCandidateインスタンスをnew_candの要素としていました(最も大きな値を最初に回るべきなのは明らかです。移動距離0で最大の荷物を下ろせるのですから)。その最初のコードをローカルで動かしたら入力値が11経路の場合で3秒を少し下回る(2.7秒)程度の性能でした。しかし、テストパターンは1秒以内で処理が完了しないとNGだったのでプログラムを見直します。

プログラムの見直し

テストケースは時間切れでダメでしたが、シミュレーション自体は正しいはずです(時間切れのパターンは正解出力が示されないので本当に合っているかはわかりませんが、シミュレーションのロジックは正しいという自信があったので自分のコードを信じることにしました)
シミュレーション結果の経路は次のようになります。

( 4経路の場合) 4, 1, 3, 2
( 5経路の場合) 5, 1, 4, 2, 3
( 6経路の場合) 6, 1, 5, 2, 3, 4
( 7経路の場合) 7, 1, 6, 2, 4, 3, 5
( 8経路の場合) 8, 1, 7, 2, 5, 3, 4, 6
( 9経路の場合) 9, 1, 8, 2, 7, 3, 5, 4, 6
(10経路の場合) 10, 1, 9, 2, 8, 3, 6, 4, 5, 7

これを見るとパターンがあります。
経路が5まではMax1,Min1,Max2,Min2,...の並びが最後までですが、経路6以上になると途中でその時点で最も大きい値ではなく次に大きい値になった後、再び大小を繰り返し、最後にパターンが変わった時点で最も大きな値が登場するという並びになります。

パターンがあそうなことはわかってそれをプログラムに落とし込むこともできそうですが、私にはなぜそうなるのかを証明できませんでした。なので、そちらの方向は諦めて計算量を減らすだけでお茶を濁すことにしました(^^;。
前述のパターンから少なくとも経路の半分までは大小を繰り返すのは確実に見えるので、そのロジックを実装しました(128〜129行目と131行目)。
128〜129行目で全経路の半分までの経路を大小の並びで決め打ちし、131行目ではそれ以降の経路だけをシミュレートしています。結果的に11経路のパターンでも元の6経路分しか計算しなくて良くなったので大幅に計算量が減少します。この問題はパターンが階乗で増えてゆくので大幅に計算量が減少します。

雑感

結局これでテストケースをパスしました。
多分、このプログラムを使って経路数20くらいまでシミュレートすると法則からもっと高速なプログラムを作れると思いますが、ちょっとそこまでする根性がなくここで終了としました。