
torch.matmul
torch.matmul은 가장 일반적인 행렬 곱셈 함수이다.
다양한 차원의 텐서 간 곱셈을 지원한다.
브로드캐스팅으로 인해 의도치 않은 결과가 나올 수 있다.
입력 텐서의 차원에 따라 다르게 동작한다.
1D x 1D: 내적(dot product)
2D x 2D: 일반적인 행렬 곱셈
1D x 2D 또는 2D x 1D: 벡터-행렬 곱셈
ND x ND (N > 2): 배치 행렬 곱셈
import torch
# 2D x 2D
a = torch.randn(3, 4)
b = torch.randn(4, 2)
c = torch.matmul(a, b)
print("2D x 2D:")
print("a shape:", a.shape) # torch.Size([3, 4])
print("b shape:", b.shape) # torch.Size([4, 2])
print("c shape:", c.shape) # torch.Size([3, 2])
# 1D x 2D
d = torch.randn(4)
e = torch.matmul(d, b)
print("\n1D x 2D:")
print("d shape:", d.shape) # torch.Size([4])
print("b shape:", b.shape) # torch.Size([4, 2])
print("e shape:", e.shape) # torch.Size([2])
# 3D x 3D (batch matrix multiplication)
f = torch.randn(2, 3, 4)
g = torch.randn(2, 4, 2)
h = torch.matmul(f, g)
print("\n3D x 3D:")
print("f shape:", f.shape) # torch.Size([2, 3, 4])
print("g shape:", g.shape) # torch.Size([2, 4, 2])
print("h shape:", h.shape) # torch.Size([2, 3, 2])
torch.mm
torch.mm은 2차원 텐서(행렬) 간의 곱셈에 특화된 함수이다.
오직 2D 텐서 간의 곱셈만 수행한다.
브로드캐스팅을 지원하지 않는다.
따라서 입력 크기가 정확히 맞아야 한다: (n x m) x (m x p) -> (n x p)
의도한 대로 정확한 행렬 곱셈을 수행하고자 할 때 유용하다.
import torch
a = torch.randn(3, 4)
b = torch.randn(4, 2)
c = torch.mm(a, b)
print("a shape:", a.shape) # torch.Size([3, 4])
print("b shape:", b.shape) # torch.Size([4, 2])
print("c shape:", c.shape) # torch.Size([3, 2])
# 아래 코드는 에러를 발생 (3D 텐서 사용 불가)
# d = torch.randn(2, 3, 4)
# e = torch.randn(2, 4, 2)
# f = torch.mm(d, e) # 에러
torch.bmm
torch.bmm은 배치 행렬 곱셈(batch matrix multiplication)을 위한 함수이다.
3차원 텐서 간의 곱셈만 수행한다.
첫 번째 차원(배치 크기)이 동일해야 한다.
브로드캐스팅을 지원하지 않는다.
입력 크기: (b x n x m) x (b x m x p) -> (b x n x p)
여러 개의 행렬 곱셈을 동시에 수행할 때 효율적이다.
import torch
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
c = torch.bmm(a, b)
print("a shape:", a.shape) # torch.Size([10, 3, 4])
print("b shape:", b.shape) # torch.Size([10, 4, 5])
print("c shape:", c.shape) # torch.Size([10, 3, 5])
# 아래 코드는 에러를 발생 (배치 크기가 다름)
# d = torch.randn(10, 3, 4)
# e = torch.randn(9, 4, 5)
# f = torch.bmm(d, e) # 에러
torch.bmm 예시 in Attention
scores = torch.bmm(query, keys.transpose(1, 2))
attention_weights = F.softmax(scores, dim=-1)
torch.bmm은 배치 행렬 곱셈(batch matrix multiplication)을 수행한다.
여기서는 배치 처리 및 효율성, 차원유지를 위해 사용되었다. 행렬의 변화 과정은 아래와 같다.
query shape: (batch_size, seq_len_q, d_k)
keys shape: (batch_size, seq_len_k, d_k)
라고 가정할 때,
keys.transpose(1, 2)
결과 shape: (batch_size, d_k, seq_len_k)
torch.bmm(query, keys.transpose(1, 2))
연산: (batch_size, seq_len_q, d_k) x (batch_size, d_k, seq_len_k)
결과 shape (scores): (batch_size, seq_len_q, seq_len_k)
F.softmax(scores, dim=-1)
softmax를 마지막 차원에 적용
attention_weights shape: (batch_size, seq_len_q, seq_len_k)
keys를 전치하는 이유는 query와의 내적을 위해서이다.
bmm 연산은 각 query 벡터와 모든 key 벡터 간의 유사도를 계산한다.
결과적으로 각 query에 대한 모든 key의 관련성 점수를 얻게 된다.
softmax를 적용하여 이 점수를 확률 분포로 변환한다.
'DATA, AI' 카테고리의 다른 글
GPU의 제한된 vram 환경에서 효율적으로 모델을 학습하는 방법 (1) | 2024.11.20 |
---|---|
huggingface로 협업하기 (2) | 2024.10.29 |
입출력 형태에 따른 자연어 처리 Task의 이해 (2) | 2024.10.02 |
NLP 논문 리스트 (3) | 2024.09.13 |
데이터의 분석 : seaborn plot의 활용 (0) | 2024.08.26 |
개발새발라이프
hi there🙌