From db60c8841d63c660516bbd3b409804bca18cbb81 Mon Sep 17 00:00:00 2001 From: dominik martinez Date: Sat, 17 May 2025 18:12:10 -0700 Subject: [PATCH] Add i-amax --- src/cl-blas.lisp | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/cl-blas.lisp b/src/cl-blas.lisp index ca4c65c..1daae83 100644 --- a/src/cl-blas.lisp +++ b/src/cl-blas.lisp @@ -10,6 +10,7 @@ (simd+ f32.8+) (simd* f32.8*) (simd-horizontal+ f32.8-horizontal+) + (simd-horizontal-max f32.8-horizontal-max) (stride 8) (float single-float) (0f 0.0))) @@ -19,6 +20,7 @@ (simd+ f64.4+) (simd* f64.4*) (simd-horizontal+ f64.4-horizontal+) + (simd-horizontal-max f64.4-horizontal-max) (stride 4) (float double-float) (0f 0.0d0))) @@ -84,6 +86,19 @@ (,(parse-cond-clauses double-float-type-declarations) (,double-float-name ,@lambda-list))))))) +(declaim (inline simd-abs-f32.8 simd-abs-f64.4)) + +(defun simd-abs-f32.8 (pack) + (declare (f32.8 pack)) + (f32.8-sqrt (f32.8* pack pack))) + +(defun simd-abs-f64.4 (pack) + (declare (f64.4 pack)) + (f64.4-sqrt (f64.4* pack pack))) + +(push '(simd-abs simd-abs-f32.8) *single-float-fn-generic*) +(push '(simd-abs simd-abs-f64.4) *double-float-fn-generic*) + ;;;; Level 1 BLAS: vector, O(n) operations ;;; axpy @@ -179,7 +194,7 @@ ;;; asum (defblas asum ((x 1)) - (let* ((n (array-dimension x 0)) + (let* ((n (array-dimension x 0)) (n-block (* (floor n stride) stride)) (norm 0f)) (declare (fixnum n n-block) @@ -187,11 +202,25 @@ (loop with acc of-type simd = (simd 0) for i fixnum from 0 below n-block by stride for sel of-type simd = (simd-aref x i) - do (setf acc (simd+ acc (simd* sel sel))) + do (setf acc (simd+ acc (simd-abs sel))) finally (setf norm (simd-horizontal+ acc))) (loop for i fixnum from n-block below n do (setf norm (+ norm (expt (aref x i) 2)))) - norm)) + norm)) + +;;; i-amax + +(defblas i-amax ((x 1)) + (let* ((n (array-dimension x 0)) + (n-block (* (floor n stride) stride)) + (m 0f)) + (declare (fixnum n n-block) + (float m)) + (loop for i fixnum from 0 below n-block by stride + do (setf m (max m (simd-horizontal-max (simd-abs (simd-aref x i)))))) + (loop for i fixnum from n-block below n + do (setf m (max m (abs (aref x i))))) + m)) ;;;; Level 3 BLAS: matrix-matrix, O(n^3) operations