
DATA, AI2024. 9. 7. 02:09PyTorch: torch.bmm,matmul,mm 그리고 Attention 가중치의 계산
torch.matmultorch.matmul은 가장 일반적인 행렬 곱셈 함수이다.다양한 차원의 텐서 간 곱셈을 지원한다.브로드캐스팅으로 인해 의도치 않은 결과가 나올 수 있다. 입력 텐서의 차원에 따라 다르게 동작한다.1D x 1D: 내적(dot product)2D x 2D: 일반적인 행렬 곱셈1D x 2D 또는 2D x 1D: 벡터-행렬 곱셈ND x ND (N > 2): 배치 행렬 곱셈import torch# 2D x 2Da = torch.randn(3, 4)b = torch.randn(4, 2)c = torch.matmul(a, b)print("2D x 2D:")print("a shape:", a.shape) # torch.Size([3, 4])print("b shape:", b.shape) # ..