import numpy as np
def top_k_sampling(logits, k=40, temperature=0.7):
# 1. 温度缩放
logits = logits / temperature
# 2. 找到Top-K的值,其余位置设为负无穷
top_k_val = np.sort(logits)[-k]
logits[logits < top_k_val] = -float('inf')
# 3. Softmax 归一化
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / np.sum(exp_logits)
# 4. 多项式分布采样
return np.random.choice(len(logits), p=probs)