使用快取(Caching)加快你的程式執行速度

Part I

老師講解的順序

▌Introduction

1. 函數是不變的

給定相同的輸入,就會回傳相同的輸出。

因此,對時間敏感的函數或使用亂數的函數將不起作用。你應該也不想使用有副作用的快取函數,因此你應該堅持使用純函數。

2. 函數參數必須是可哈希的(hashable)

因為 Python 的快取實作時,參數會作為字典中的鍵(keys)

3. 適用時機

只有在應用程式的生命週期中,多次使用相同的值重複呼叫函數時才有用;例如在 CPU、資源利用率、延遲等方面,運行該函數的成本很高。

老師是用費式數列(Fibonacci sequence)示範 Cache。

費式數列曾在 Python Deep Dive I 的 Section7: Scopes, Closures and Decorators 示範過,大家可以參考 Chris 兄當時的分享:

Python Deep Dive II 的 Section6: Generators 也示範過,大家可以參考 Andy 兄當時的分享:

▌First Principles

和老師習慣的介紹方法一樣,先逐步手刻一個簡單的範例,然後告訴你其實有更好的作法。

所以這裡看看就好,了解一下運作原理,無需執著。

from timeit import timeit

def fib(n):
    if n <= 1:
        return 1
    # print(f"calculating fib({n})") # 檢查到底執行了幾次
    return fib(n-1) + fib(n-2)

for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed:.2f} s")
# 輸出:中間省略部分資料
25 - 0.14 s
26 - 0.14 s
27 - 0.26 s
...
37 - 10.33 s
38 - 19.38 s
39 - 29.01 s

因為每次都要從頭算起,這個演算法是指數級的:O(2^n)。

觀察 print(f"calculating fib({n})") 指令會發現,時間都浪費在重複運算了。

這正是 Cache 發揮所長的地方。

建立一個全域字典,保存 n 的 key(n)和 value(fib(n))。

程式執行時,首先檢查 n 是否在字典中。

  • 如果在:直接傳回該值。(O(1) 的速度)

  • 如果不在:執行計算式,儲存到全域字典中,並傳回結果。

cache = {}

def fib(n):
    if n <= 1:
        return 1
    if n in cache:
        return cache[n]
    print(f"calculating fib({n})")
    result = fib(n-1) + fib(n-2)
    cache[n] = result
    return result

for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")
# 輸出:
25 - 4.1896999846358085e-05 s
26 - 5.4420002015831415e-06 s
27 - 4.383000032248674e-06 s
...
37 - 2.183999640692491e-06 s
38 - 2.148000021406915e-06 s
39 - 1.8249997992825229e-06 s

▌Closure/Decorator

手刻完成後,開始告訴你有哪些缺點(意指如果繼續手刻,還要做哪些事):

  1. 我們自己手刻的 Cache 是全域變數(全域字典)。

  2. 每次新的函式,我們都必須重新撰寫 Cache。

  3. 還有防呆措施,例如有人不小心改了全域字典的值…

我們在 Python Deep Dive I 學到的閉包(Closure),可以用裝飾器(Decorator)來解決這個問題。

def cache(fn):
    data_cache = {}

    def inner(*args):
        key = tuple(args)
        if key in data_cache:
            return data_cache[key]
        result = fn(*args)
        data_cache[key] = result
        return result
        
    return inner

@cache
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)     

for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")
# 輸出:
25 - 0.00019559500015020603 s
26 - 5.757000508310739e-06 s
27 - 5.309000698616728e-06 s
...
37 - 4.342000465840101e-06 s
38 - 3.821999598585535e-06 s
39 - 5.619000148726627e-06 s

▌@lru_cache Decorator

參考:快取替換演算法(Cache replacement policies)

使用閉包,雖然解決了部分問題,但還有這些問題(大家很熟悉吧 :rofl:):

  1. 未處理 keyword-only arguments

  2. 快取大小是無限的

所以,我們還是交給 Python 內建的 LRU Cache 函式 @lru_cache 吧。

LRU: Least Recently Used(最近最少使用)

from functools import lru_cache

@lru_cache  # 預設最大 128
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)   

for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")
# 輸出:
25 - 3.0309999601740856e-05 s
26 - 4.104000254301354e-06 s
27 - 3.4929998946608976e-06 s
...
37 - 1.6680005501257256e-06 s
38 - 2.1699997887481004e-06 s
39 - 1.485000211687293e-06 s

關於前面提到的問題:

  1. 設定快取大小
@lru_cache(maxsize=None) # 預設大小為 128
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

@lru_cache(maxsize=20)
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

@lru_cache(maxsize=2)
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)  

# 當然,maxsize 設為 1 就會有問題。

簡化語法

就是把 lru_cache 改為 cache

- from functools import lru_cache
+ from functools import cache

- @lru_cache
+ @cache
def fib(n):
    if n <= 1:
        return 1
    return fib(n-1) + fib(n-2)   

for i in range(25, 40):
    elapsed = timeit(f"fib({i})", globals=globals(), number=1)
    print(f"{i} - {elapsed} s")

▌探討LRU Cache

def func(a, b, c):
    print(f"computing result for {a=}, {b=}, {c=}")
    return a + b + c

func(1, 2, 3)
# 輸出:computing result for a=1, b=2, c=3
# 輸出:6
func(1, 2, 3)
# 輸出:computing result for a=1, b=2, c=3
# 輸出:6

# =======================================
@cache
def func(a, b, c):
    print(f"computing result for {a=}, {b=}, {c=}")
    return a + b + c

func(1, 2, 3)
# 輸出:computing result for a=1, b=2, c=3
# 輸出:6
func(1, 2, 3)
# 這裡少了
# 輸出:6

Keyword arguments

Keyword arguments (named arguments)

我這裡執行會有問題,以下資料是用老師的

TypeError: cache.<locals>.inner() got an unexpected keyword argument 'a'

# 續
func(a=1, b=2, c=3)
# 輸出:computing result for a=1, b=2, c=3
# 輸出:6

func(a=1, b=2, c=3)
# 輸出:6

然後更改參數順序:

# 續
func(b=2, a=1, c=3)
# 輸出:computing result for a=1, b=2, c=3
# 輸出:6

func(a=1, b=2, c=3)
# 輸出:6

func(b=2, a=1, c=3)
# 輸出:6

@cached_property

物件屬性也可以使用 Cache(特別是計算屬性)

Caching Class Properties(@cached_property

大家熟悉的圓面積又來了。

from math import pi
from timeit import timeit

class Circle():
    def __init__(self, radius):
        self._radius = radius

    @property
    def area(self):
        return pi * (self._radius ** 2)

c = Circle(3)

timeit("c.area", globals=globals(), number=1_000_000)
# 輸出:0.3926595550001366

一樣先示範手刻:

class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None ## <=====

    @property
    def area(self):
        if self._area is None: ## <=====
            self._area = pi * (self._radius ** 2)
        return self._area

c = Circle(3)

timeit("c.area", globals=globals(), number=1_000_000)
# 輸出:0.14328316899991478 ## 比較上面的數字:0.3926595550001366

當然,前面提的手刻問題,這裡一樣存在。

也當然,Python 透過 @cached_property 解決了這個問題。

from functools import cached_property

class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None

    @cached_property
    def area(self):
        return pi * (self._radius ** 2)

c = Circle(3)

timeit("c.area", globals=globals(), number=1_000_000)
# 輸出 3:0.06766027399862651 ## 本次
# 輸出 2:0.14328316899991478 ## 上次
# 輸出 1:0.3926595550001366  ## 上上次

▌探討快取屬性與不可變性

Cached Properties and Mutability Caveats

快取屬性和可變性的注意事項

前面 Cache 屬性的前提,當然是屬性不變。

萬一真的變了呢?

以下說明如何防呆。

# 續
c = Circle(3)  ## <=====
c.area
# 輸出:28.274333882308138

c._radius = 1  ## <=====
c.area
# 輸出:28.274333882308138  ## <=====

方法一:手刻

class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None

    @property
    def area(self):
        if self._area is None:
            self._area = pi * (self._radius ** 2)
        return self._area

getter & setter

我們要透過實作 getter 和 setter,來避免預期不會改變的屬性被改變。

在 setter 中,當我們發現屬性被改變時,就(透過不同的方法)使 Cache 值失效。

方法一

透過 _area 這個屬性來判斷

class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None  ## <=====

    @property  ## <=====
    def radius(self):
        return self._radius

    @radius.setter  ## <=====
    def radius(self, value):
        if self._radius != value:
            self._area = None
        self._radius = value
    
    @property
    def area(self):
        if self._area is None:  ## <=====
            self._area = pi * (self._radius ** 2)
        return self._area
c = Circle(1)
c.area
# 輸出:3.141592653589793

c.radius = 2
c.area
# 輸出:12.566370614359172

方法二

@cached_property 刪除 property

class Circle():
    def __init__(self, radius):
        self._radius = radius
        self._area = None  ## same

    @property  ## same
    def radius(self):
        return self._radius

    @radius.setter  ## same
    def radius(self, value):
        if self._radius != value:
            del self.area  ## <===== Different
        self._radius = value
        
    @cached_property
    def area(self):
        ## if self._area is None:  ## <===== Different
        return pi * (self._radius ** 2)
c = Circle(1)
c.area
# 輸出:3.141592653589793

c.radius = 2
c.area
# 輸出:12.566370614359172

▌Caching Class Methods

一樣使用 cachelru_cache。但因為 cache 所在位置不同,處理方式也就不同。

  • cached_property(Caching Class Properties):instance level cache

  • cachelru_cache(Caching Class Methods):class level cache

class Circle:
    def __init__(self, r):
        self.r = r

    @cache
    def area(self):
        print(f"calculating area for {self.r=}")
        return pi * (self.r ** 2)

c = Circle(1)
c.area()
# 輸出:calculating area for self.r=1
# 輸出:3.141592653589793

c.r = 3
c.area()
# 輸出:3.141592653589793 
## 錯誤
c1.area()
# 輸出:calculating area for self.r=1
# 輸出:3.141592653589793

c2.area()
# 輸出:calculating area for self.r=1
# 輸出:3.141592653589793

## 沒有 cache 到

__eq__ & __hash__

透過 __eq__ 把 instance 指定到第二個

class Circle:
    def __init__(self, r):
        self.r = r

    @cache
    def area(self):
        print(f"calculating area for {self.r=}")
        return pi * (self.r ** 2)

    def __eq__(self, other):  ## <=====
        return self.r == other.r

    def __hash__(self):  ## <=====
        return hash(self.r)
c1 = Circle(1)
c2 = Circle(1)

c1 == c2
# 輸出:True
c1.area()
# 輸出:calculating area for self.r=1
# 輸出:3.141592653589793

c2.area()
# 輸出:3.141592653589793  ## 這次對了
c1.r = 2

c1 == c2
# 輸出:False

c1.area()
# 輸出:calculating area for self.r=2  # 重新計算無誤
# 輸出:12.566370614359172

c2.area()
# 輸出:calculating area for self.r=1  # 重新計算無誤
## 因為原本 c1, c2 是同一 instance,直到 c1 改變了半徑,使 instance 失效。
# 輸出:3.141592653589793

▌Conclusion and 3rd Party Library for more options


Part II

一些延伸探討

Cache vs. Global Valuables

Cache

Global Valuables

何時使用


Cache 的類別

Cache 置換策略

Cache 置換策略是指當 Cache 中資料已滿時,如何選擇要淘汰的資料。常見的 Cache 置換策略包括:

  • LRU (Least Recently Used):最近最少使用的資料優先淘汰。

  • LFU (Least Frequently Used):最少使用的資料優先淘汰。

  • TTL (Time To Live):超過指定生存時間的資料會被淘汰。

  • FIFO (First In, First Out):先進先出的資料優先淘汰。

特性 LRU 快取 LFU 快取 TTL 快取 FIFO 快取
置換策略 最近最少使用 最少使用 時間到期 先進先出
適用場景 熱點資料 冷門資料 資料有有效期 資料順序有意義
優點 命中率高 節省記憶體 資料不會過期 資料順序不變
缺點 可能淘汰重要資料

實作:內建函式庫


實作:第三方函式庫

以下一些第三方快取函式庫,最常使用的是第一個:cachetools

  1. cachetools 提供多種快取算法,包括 LRU(Least Recently Used)、LFU(Least Frequently Used)和其他一些變種。它也支援 TTL(Time-To-Live)過期機制。GitHub Repository

  2. dogpile.cache 提供一個通用的 API,用於在應用程式中使用各種不同的快取後端(如內存、Redis、Memcached等)。支援多種快取策略,包括 LRU、MRU(Most Recently Used)、隨機替換等。GitHub Repository

  3. diskcache 將快取數據保存在本地磁碟上,可用於需要持久性的快取需求。它支援多進程和多執行緒環境。GitHub Repository

  4. cachy 提供簡單且具有彈性的快取解決方案,支援多種後端,包括內存、文件、Redis等。GitHub Repository

  5. redis-py 如果你的應用程式使用了 Redis,你可以使用 redis-py 來實現基於 Redis 的快取。GitHub Repository

安裝

pip install cachetools

LRU (Least Recently Used)

cachetools 提供了 LRUCache 類別,用來實現 LRU 快取。

from cachetools import LRUCache
import time

# 創建一個 LRU 快取,最大容量為2
lru_cache = LRUCache(maxsize=2)

def expensive_function(x):
    # 模擬昂貴的計算
    time.sleep(1)
    return x * x

def get_cached_result(x):
    # 檢查快取中是否有結果
    if x in lru_cache:
        print("Using cached result.")
        return lru_cache[x]
    else:
        print("Cache miss. Calculating and caching.")
        result = expensive_function(x)
        # 將結果存入快取
        lru_cache[x] = result
        return result

# 第一次呼叫需要計算,結果被快取
result1 = get_cached_result(5)
print(result1)  # 輸出: Cache miss. Calculating and caching.

# 再次呼叫相同的輸入,由於結果被快取,無需重新計算
result2 = get_cached_result(5)
print(result2)  # 輸出: Using cached result.

# 增加不同的輸入,觸發 LRU 策略,淘汰最久未被使用的項目
result3 = get_cached_result(10)
print(result3)  # 輸出: Cache miss. Calculating and caching.

# 由於快取大小為2,再次呼叫之前的輸入,將觸發淘汰最久未被使用的項目
result4 = get_cached_result(5)
print(result4)  # 輸出: Cache miss. Calculating and caching.

LFU (Least Frequently Used)

cachetools 中,並沒有內建的 LFUCache 類別,必須自己實作。(範例略過)


TTL (Time-To-Live)

cachetools 提供了 TTLCache 類別,用來實現 TTL(Time-To-Live)快取。

from cachetools import TTLCache
import time

# 創建一個 TTL 快取,最大容量為2,每個項目的生存時間為5秒
ttl_cache = TTLCache(maxsize=2, ttl=5)

def expensive_function(x):
    # 模擬昂貴的計算
    time.sleep(1)
    return x * x

def get_cached_result(x):
    # 檢查快取中是否有結果
    if x in ttl_cache:
        print("Using cached result.")
        return ttl_cache[x]
    else:
        print("Cache miss. Calculating and caching.")
        result = expensive_function(x)
        # 將結果存入快取,並設定生存時間
        ttl_cache[x] = result
        return result

# 第一次呼叫需要計算,結果被快取
result1 = get_cached_result(5)
print(result1)  # 輸出: Cache miss. Calculating and caching.

# 再次呼叫相同的輸入,由於結果被快取,無需重新計算
result2 = get_cached_result(5)
print(result2)  # 輸出: Using cached result.

# 等待5秒,使得快取的項目過期
time.sleep(5)

# 快取過期後,再次呼叫需要重新計算
result3 = get_cached_result(5)
print(result3)  # 輸出: Cache miss. Calculating and caching.

FIFO(First-In-First-Out)

cachetools 中,並沒有內建的 FIFOCache 類別,必須自己實作。(範例略過)



費式數列

老師在舉例的時候,相信大家也有似曾相識的感覺。所以我去找了之前 Python Deep Dive 中,提及費式數列的章節。

1個讚