Pythonでナップサック問題

ナップサック問題とかメジャーなアルゴリズムすら綺麗さっぱり忘れてて困ります。リハビリにwikipediaを見ながらPythonで書いてみました。

ナップサック問題はn個の商品(それぞれ重さwと価値v)がある時に、キャパシティC以内の制約条件の元で最良の組合せを見つけるというものです。

それぞれの商品を1回しか選べない場合は、0-1ナップサック問題、複数回選択可能な時は123ナップサック問題と呼ばれていてアルゴリズムも違います。それぞれ書いてみたのが以下になります。

#!/usr/bin/python
# -*- coding: utf-8 -*-
# reference: http://en.wikipedia.org/wiki/Knapsack_problem

# 0-1 ナップサック問題 (2次元動的計画法)
# items = [{'w':weight, 'v':value}, {...}, {...}]
def knapsack_01(items, capacity):
  work = [[0 for row in range(capacity+1)] for col in range(len(items)+1)]
  for c in range(1, capacity+1):
    for n in range(1, len(items)+1):
      cur = items[n-1]
      if cur['w'] < = c:
        left = work[n-1][c - cur['w']] + cur['v']
        top  = work[n-1][c]
        work[n][c] = left if left > top else top
      else:
        work[n][c] = work[n][c-1]
  return work[-1][-1]

# unbound ナップサック問題 (1次元動的計画法)
# items = [{'w':weight, 'v':value}, {...}, {...}]
def knapsack_123(items, capacity):
  work = [0 for i in range(capacity+1)]
  for c in range(1, capacity+1):
    candidates = [item['v'] + work[c-item['w']]
                  for item in items if item['w'] < = c]
    work[c] = max(candidates) if candidates else 0
  return work[-1]

if __name__ == '__main__':
  sample_data = [{'w':2, 'v':4},
                 {'w':2, 'v':5},
                 {'w':1, 'v':2},
                 {'w':3, 'v':8}]
  
  print knapsack_01(sample_data,  7)
  print knapsack_123(sample_data, 7)

割とすっきり書けました。実際は価値を示すパラメータが複数ある方が普通な気がするので、vが複数ある時にも対応して、クラス化してみました。メインルーチンも最適値だけではなく、どのような組合せかも、同時に返却するようにしてみました。

class Knapsack:
  def __init__(self, items, cap, ccol='c', vcols=['v'], norm=None):
    self.items        = items
    self.capacity     = cap
    self.ccol = ccol
    self.vcols   = vcols
    self.norm         = norm if norm else Knapsack._norm
  
  def calc(self, unbound=False):
    if unbound:
      return self._calc_123()
    else:
      return self._calc_01()

  # 0-1ナップサック問題
  def _calc_01(self):
    work = [[[0,[]] for row in range(self.capacity+1)] for col in range(len(self.items)+1)]
    for n in range(1, len(self.items)+1):
      for c in range(1, self.capacity+1):
        cur = self.items[n-1]
        if cur[self.ccol] < = c:
          ref  = work[n-1][c - cur[self.ccol]]
          left = [ref[0] + self.norm(cur, self.vcols), copy.deepcopy(ref[1])]
          left[1].append(n-1)
          top  = work[n-1][c]
          work[n][c] = left if left[0] > top[0] else top
        else:
          work[n][c] = work[n][c-1]
    return work[-1][-1]

  # unboundナップサック問題
  def _calc_123(self):
    items = self.items
    work  = [[0,[]] for i in range(self.capacity + 1)]
    for c in range(1, self.capacity+1):
      norms = [[self.norm(items[i], self.vcols) + work[c - items[i][self.ccol]][0], i]
               for i in range(len(items)) if items[i][self.ccol] < = c]
      res = max(norms, key=lambda(x):x[0])
      ref = work[c - items[res[1]][self.ccol]]
      work[c] = copy.deepcopy([res[0], ref[1]])
      if len(norms):
        work[c][1].append(res[1])
    return work[-1]

  @staticmethod
  def _norm(item, vcols):
    import math
    return math.sqrt(sum([pow(item[col],2) for col in vcols]))

これで割と汎用的にナップサック問題を扱う事が出来るようになりました。使い方は以下の様な形になります。

sample_data = [{'w':2, 'v':4},
               {'w':2, 'v':5},
               {'w':1, 'v':2},
               {'w':3, 'v':8}]

sample_data2 = [{'w':2, 'v':4, 'v2':3},
                {'w':2, 'v':5, 'v2':4},
                {'w':1, 'v':2, 'v2':3},
                {'w':3, 'v':8, 'v2':7},
                {'w':5, 'v':11,'v2':9},
                {'w':4, 'v':4, 'v2':10},
                {'w':6, 'v':15,'v2':10},
                ]
p = Knapsack(sample_data, 10, ccol='w', vcols=['v'])
print p.calc()
print p.calc(unbound=True)
p2 = Knapsack(sample_data2, 10, ccol='w', vcols=['v', 'v2'])
print p2.calc()
print p2.calc(unbound=True)

上記プログラムを走らせてみると、出力は以下のようになります。ベンチマークはまたの機会で。

[19.0, [0, 1, 2, 3]]
[26.0, [3, 3, 1, 1]]
[31.409150939900499, [1, 2, 3, 5]]
[36.055512754639885, [2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]

Leave a Reply

Your email address will not be published. Required fields are marked *