Add i-amax

This commit is contained in:
dominik martinez 2025-05-17 18:12:10 -07:00
parent a65cb5b040
commit db60c8841d

View file

@ -10,6 +10,7 @@
(simd+ f32.8+) (simd+ f32.8+)
(simd* f32.8*) (simd* f32.8*)
(simd-horizontal+ f32.8-horizontal+) (simd-horizontal+ f32.8-horizontal+)
(simd-horizontal-max f32.8-horizontal-max)
(stride 8) (stride 8)
(float single-float) (float single-float)
(0f 0.0))) (0f 0.0)))
@ -19,6 +20,7 @@
(simd+ f64.4+) (simd+ f64.4+)
(simd* f64.4*) (simd* f64.4*)
(simd-horizontal+ f64.4-horizontal+) (simd-horizontal+ f64.4-horizontal+)
(simd-horizontal-max f64.4-horizontal-max)
(stride 4) (stride 4)
(float double-float) (float double-float)
(0f 0.0d0))) (0f 0.0d0)))
@ -84,6 +86,19 @@
(,(parse-cond-clauses double-float-type-declarations) (,(parse-cond-clauses double-float-type-declarations)
(,double-float-name ,@lambda-list))))))) (,double-float-name ,@lambda-list)))))))
(declaim (inline simd-abs-f32.8 simd-abs-f64.4))
(defun simd-abs-f32.8 (pack)
(declare (f32.8 pack))
(f32.8-sqrt (f32.8* pack pack)))
(defun simd-abs-f64.4 (pack)
(declare (f64.4 pack))
(f64.4-sqrt (f64.4* pack pack)))
(push '(simd-abs simd-abs-f32.8) *single-float-fn-generic*)
(push '(simd-abs simd-abs-f64.4) *double-float-fn-generic*)
;;;; Level 1 BLAS: vector, O(n) operations ;;;; Level 1 BLAS: vector, O(n) operations
;;; axpy ;;; axpy
@ -179,7 +194,7 @@
;;; asum ;;; asum
(defblas asum ((x 1)) (defblas asum ((x 1))
(let* ((n (array-dimension x 0)) (let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride)) (n-block (* (floor n stride) stride))
(norm 0f)) (norm 0f))
(declare (fixnum n n-block) (declare (fixnum n n-block)
@ -187,11 +202,25 @@
(loop with acc of-type simd = (simd 0) (loop with acc of-type simd = (simd 0)
for i fixnum from 0 below n-block by stride for i fixnum from 0 below n-block by stride
for sel of-type simd = (simd-aref x i) for sel of-type simd = (simd-aref x i)
do (setf acc (simd+ acc (simd* sel sel))) do (setf acc (simd+ acc (simd-abs sel)))
finally (setf norm (simd-horizontal+ acc))) finally (setf norm (simd-horizontal+ acc)))
(loop for i fixnum from n-block below n (loop for i fixnum from n-block below n
do (setf norm (+ norm (expt (aref x i) 2)))) do (setf norm (+ norm (expt (aref x i) 2))))
norm)) norm))
;;; i-amax
(defblas i-amax ((x 1))
(let* ((n (array-dimension x 0))
(n-block (* (floor n stride) stride))
(m 0f))
(declare (fixnum n n-block)
(float m))
(loop for i fixnum from 0 below n-block by stride
do (setf m (max m (simd-horizontal-max (simd-abs (simd-aref x i))))))
(loop for i fixnum from n-block below n
do (setf m (max m (abs (aref x i)))))
m))
;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations ;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations