This commit is contained in:
dominik martinez 2025-05-21 18:21:28 -07:00
parent 225accdb63
commit e4e449f562

View file

@ -62,6 +62,7 @@
(defmacro defblas (name args &body body) (defmacro defblas (name args &body body)
;; ARGS should be of the form ((NAME RANK)+). ;; ARGS should be of the form ((NAME RANK)+).
;; If RANK is 0, then this NAME is a scalar ;; If RANK is 0, then this NAME is a scalar
;; Also it can just be NAME, which in that case the type is fixnum
(let* ((single-float-type-declarations (parse-type-declarations args 'single-float)) (let* ((single-float-type-declarations (parse-type-declarations args 'single-float))
(double-float-type-declarations (parse-type-declarations args 'double-float)) (double-float-type-declarations (parse-type-declarations args 'double-float))
(single-float-name (alexandria:symbolicate name '-single-float)) (single-float-name (alexandria:symbolicate name '-single-float))
@ -224,6 +225,15 @@
;;;; Level 2 BLAS: matrix-vector, O(n^2) operations ;;;; Level 2 BLAS: matrix-vector, O(n^2) operations
(defblas transpose ((A 2))
(loop with m fixnum = (array-dimension A 0)
with n fixnum = (array-dimension A 1)
with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float)
for i fixnum from 0 below m
do (loop for j fixnum from 0 below n
do (setf (aref B j i) (aref A i j)))
finally (return B)))
;;; gemv ;;; gemv
(defblas gemv ((A 2) (x 1) (y 1) (alpha 0) (beta 0)) (defblas gemv ((A 2) (x 1) (y 1) (alpha 0) (beta 0))
@ -241,16 +251,27 @@
(loop for j fixnum from n-block below n (loop for j fixnum from n-block below n
sum (* alpha (aref A i j) (aref x j)) of-type float)))))) sum (* alpha (aref A i j) (aref x j)) of-type float))))))
;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations ;;; ger
(defblas transpose ((A 2)) (defblas ger ((A 2) (x 1) (y 1) (alpha 0))
(loop with m fixnum = (array-dimension A 0) (let* ((m (array-dimension A 0))
with n fixnum = (array-dimension A 1) (n (array-dimension A 1))
with B of-type (simple-array float 2) = (make-array `(,n ,m) :element-type 'float) (n-block (* (floor n stride) stride))
for i fixnum from 0 below m (alpha-vec (simd alpha)))
do (loop for j fixnum from 0 below n (declare (fixnum m n n-block)
do (setf (aref B j i) (aref A i j))) (simd alpha-vec))
finally (return B))) (loop for i fixnum from 0 below m
for x-i float = (aref x i)
for x-i-vec of-type simd = (simd x-i)
for alpha-x-i-vec of-type simd = (simd* alpha-vec x-i-vec)
do (loop for j fixnum from 0 below n-block by stride
for y-j-vec of-type simd = (simd-aref y j)
for product of-type simd = (simd* alpha-x-i-vec y-j-vec)
do (setf (simd-aref A i j) (simd+ (simd-aref A i j) product)))
(loop for j fixnum from n-block below n
do (setf (aref A i j) (+ (aref A i j) (* alpha x-i (aref y j))))))))
;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations
;;; gemm ;;; gemm