import heapq
import sys
import xml.etree.cElementTree as ET
from collections import Counter
from os import path as OSPath
from PIL import Image, ImageDraw
MODE_RECTANGLE = 1
MODE_ELLIPSE = 2
MODE_ROUNDED_RECTANGLE = 3
MODE = MODE_RECTANGLE
LEAF_SIZE = 4
PADDING = 1
FILL_COLOR = (0, 0, 0)
SAVE_FRAMES = False
ERROR_RATE = 0.5
AREA_POWER = 0.25
OUTPUT_SCALE = 1
def weighted_average(hist):
total = sum(hist)
value = sum(i * x for i, x in enumerate(hist)) / total
error = sum(x * (value - i) ** 2 for i, x in enumerate(hist)) / total
error = error ** 0.5
return value, error
def color_from_histogram(hist):
r, re = weighted_average(hist[:256])
g, ge = weighted_average(hist[256:512])
b, be = weighted_average(hist[512:768])
e = re * 0.2989 + ge * 0.5870 + be * 0.1140
return (r, g, b), e
def rounded_rectangle(draw, box, radius, color):
l, t, r, b = box
d = radius * 2
draw.ellipse((l, t, l + d, t + d), color)
draw.ellipse((r - d, t, r, t + d), color)
draw.ellipse((l, b - d, l + d, b), color)
draw.ellipse((r - d, b - d, r, b), color)
d = radius
draw.rectangle((l, t + d, r, b - d), color)
draw.rectangle((l + d, t, r - d, b), color)
class Quad(object):
def __init__(self, model, box, depth):
self.model = model
self.box = box
self.depth = depth
hist = self.model.im.crop(self.box).histogram()
self.color, self.error = color_from_histogram(hist)
self.leaf = self.is_leaf()
self.area = self.compute_area()
self.children = []
def is_leaf(self):
l, t, r, b = self.box
return int(r - l <= LEAF_SIZE or b - t <= LEAF_SIZE)
def compute_area(self):
l, t, r, b = self.box
return (r - l) * (b - t)
def split(self):
l, t, r, b = self.box
lr = l + (r - l) / 2
tb = t + (b - t) / 2
depth = self.depth + 1
tl = Quad(self.model, (l, t, lr, tb), depth)
tr = Quad(self.model, (lr, t, r, tb), depth)
bl = Quad(self.model, (l, tb, lr, b), depth)
br = Quad(self.model, (lr, tb, r, b), depth)
self.children = (tl, tr, bl, br)
return self.children
def get_leaf_nodes(self, max_depth=None):
if not self.children:
return [self]
if max_depth is not None and self.depth >= max_depth:
return [self]
result = []
for child in self.children:
result.extend(child.get_leaf_nodes(max_depth))
return result
class Model(object):
def __init__(self, path):
self.im = Image.open(path).convert('RGB')
self.width, self.height = self.im.size
self.heap = []
self.root = Quad(self, (0, 0, self.width, self.height), 0)
self.error_sum = self.root.error * self.root.area
self.push(self.root)
self.filename = OSPath.basename(path)
@property
def quads(self):
return [x[-1] for x in self.heap]
def average_error(self):
return self.error_sum / (self.width * self.height)
def push(self, quad):
score = -quad.error * (quad.area ** AREA_POWER)
heapq.heappush(self.heap, (quad.leaf, score, quad))
def pop(self):
return heapq.heappop(self.heap)[-1]
def split(self):
quad = self.pop()
self.error_sum -= quad.error * quad.area
children = quad.split()
for child in children:
self.push(child)
self.error_sum += child.error * child.area
def render(self, path, max_depth=None):
m = OUTPUT_SCALE
dx, dy = (PADDING, PADDING)
im = Image.new('RGB', (self.width * m + dx, self.height * m + dy))
draw = ImageDraw.Draw(im)
draw.rectangle((0, 0, self.width * m, self.height * m), FILL_COLOR)
svg_root = ET.Element("svg")
svg_root.set("xmlns", "http://www.w3.org/2000/svg")
svg_root.set("viewBox", "0 0 " + str(self.width +
PADDING) + " " + str(self.height + PADDING))
svg_root.set("width", str((self.width + PADDING) * m))
svg_root.set("height", str((self.height + PADDING) * m))
svg_title = ET.SubElement(svg_root, "title")
svg_title.text = self.filename
svg_rect = ET.SubElement(svg_root, "rect")
svg_rect.set("x", "0")
svg_rect.set("y", "0")
svg_rect.set("width", "100%") svg_rect.set("height", "100%") svg_rect.set("fill", "rgb(" + str(FILL_COLOR[0]) + ", " + str(
FILL_COLOR[1]) + ", " + str(FILL_COLOR[2]) + ")")
svg_contents = ET.SubElement(svg_root, "g")
for quad in self.root.get_leaf_nodes(max_depth):
l, t, r, b = quad.box
box = (l * m + dx, t * m + dy, r * m - 1, b * m - 1)
if MODE == MODE_ELLIPSE:
draw.ellipse(box, quad.color)
svg_rect = ET.SubElement(svg_contents, "rect")
svg_rect.set("x", str(quad.box[0] + PADDING))
svg_rect.set("y", str(quad.box[1] + PADDING))
svg_rect.set("rx", "100%")
svg_rect.set("ry", "100%")
svg_rect.set("width", str(
quad.box[2] - (quad.box[0] + PADDING)))
svg_rect.set("height", str(
quad.box[3] - (quad.box[1] + PADDING)))
svg_rect.set("fill", "rgb(" + str(quad.color[0]) + ", " + str(
quad.color[1]) + ", " + str(quad.color[2]) + ")")
elif MODE == MODE_ROUNDED_RECTANGLE:
radius = m * min((r - l), (b - t)) / 4
rounded_rectangle(draw, box, radius, quad.color)
svg_rect = ET.SubElement(svg_contents, "rect")
svg_rect.set("x", str(quad.box[0] + PADDING))
svg_rect.set("y", str(quad.box[1] + PADDING))
svg_rect.set("rx", str(radius))
svg_rect.set("ry", str(radius))
svg_rect.set("width", str(
quad.box[2] - (quad.box[0] + PADDING)))
svg_rect.set("height", str(
quad.box[3] - (quad.box[1] + PADDING)))
svg_rect.set("fill", "rgb(" + str(quad.color[0]) + ", " + str(
quad.color[1]) + ", " + str(quad.color[2]) + ")")
else:
draw.rectangle(box, quad.color)
svg_rect = ET.SubElement(svg_contents, "rect")
svg_rect.set("x", str(quad.box[0] + PADDING))
svg_rect.set("y", str(quad.box[1] + PADDING))
svg_rect.set("width", str(
quad.box[2] - (quad.box[0] + PADDING)))
svg_rect.set("height", str(
quad.box[3] - (quad.box[1] + PADDING)))
svg_rect.set("fill", "rgb(" + str(quad.color[0]) + ", " + str(
quad.color[1]) + ", " + str(quad.color[2]) + ")")
del draw
im.save(path, 'PNG')
tree = ET.ElementTree(svg_root)
tree.write("output.svg")
def main():
args = sys.argv[1:]
if len(args) != 2:
print('Usage: python main.py input_image iterations')
return
model = Model(args[0])
ITERATIONS = int(args[1])
previous = None
for i in range(ITERATIONS):
error = model.average_error()
if previous is None or previous - error > ERROR_RATE:
print(i, error)
if SAVE_FRAMES:
model.render('frames/%06d.png' % i)
previous = error
model.split()
model.render('output.png')
print('-' * 32)
depth = Counter(x.depth for x in model.quads)
for key in sorted(depth):
value = depth[key]
n = 4 ** key
pct = 100.0 * value / n
print('%3d %8d %8d %8.2f%%' % (key, n, value, pct))
print('-' * 32)
print(' %8d %8.2f%%' % (len(model.quads), 100))
if __name__ == '__main__':
main()