mas_storage_pg/personal/
access_token.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use mas_data_model::{
9    Clock,
10    personal::{PersonalAccessToken, session::PersonalSession},
11};
12use mas_storage::personal::PersonalAccessTokenRepository;
13use rand::RngCore;
14use sha2::{Digest, Sha256};
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, tracing::ExecuteExt as _};
20
21/// An implementation of [`PersonalAccessTokenRepository`] for a PostgreSQL
22/// connection
23pub struct PgPersonalAccessTokenRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgPersonalAccessTokenRepository<'c> {
28    /// Create a new [`PgPersonalAccessTokenRepository`] from an active
29    /// PostgreSQL connection
30    pub fn new(conn: &'c mut PgConnection) -> Self {
31        Self { conn }
32    }
33}
34
35struct PersonalAccessTokenLookup {
36    personal_access_token_id: Uuid,
37    personal_session_id: Uuid,
38    created_at: DateTime<Utc>,
39    expires_at: Option<DateTime<Utc>>,
40    revoked_at: Option<DateTime<Utc>>,
41}
42
43impl From<PersonalAccessTokenLookup> for PersonalAccessToken {
44    fn from(value: PersonalAccessTokenLookup) -> Self {
45        Self {
46            id: Ulid::from(value.personal_access_token_id),
47            session_id: Ulid::from(value.personal_session_id),
48            created_at: value.created_at,
49            expires_at: value.expires_at,
50            revoked_at: value.revoked_at,
51        }
52    }
53}
54
55#[async_trait]
56impl PersonalAccessTokenRepository for PgPersonalAccessTokenRepository<'_> {
57    type Error = DatabaseError;
58
59    #[tracing::instrument(
60        name = "db.personal_access_token.lookup",
61        skip_all,
62        fields(
63            db.query.text,
64            personal_access_token.id = %id,
65        ),
66        err,
67    )]
68    async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalAccessToken>, Self::Error> {
69        let res = sqlx::query_as!(
70            PersonalAccessTokenLookup,
71            r#"
72                SELECT personal_access_token_id
73                     , personal_session_id
74                     , created_at
75                     , expires_at
76                     , revoked_at
77
78                FROM personal_access_tokens
79
80                WHERE personal_access_token_id = $1
81            "#,
82            Uuid::from(id),
83        )
84        .traced()
85        .fetch_optional(&mut *self.conn)
86        .await?;
87
88        let Some(res) = res else { return Ok(None) };
89
90        Ok(Some(res.into()))
91    }
92
93    #[tracing::instrument(
94        name = "db.personal_access_token.find_by_token",
95        skip_all,
96        fields(
97            db.query.text,
98        ),
99        err,
100    )]
101    async fn find_by_token(
102        &mut self,
103        access_token: &str,
104    ) -> Result<Option<PersonalAccessToken>, Self::Error> {
105        let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
106
107        let res = sqlx::query_as!(
108            PersonalAccessTokenLookup,
109            r#"
110                SELECT personal_access_token_id
111                     , personal_session_id
112                     , created_at
113                     , expires_at
114                     , revoked_at
115
116                FROM personal_access_tokens
117
118                WHERE access_token_sha256 = $1
119            "#,
120            &token_sha256,
121        )
122        .traced()
123        .fetch_optional(&mut *self.conn)
124        .await?;
125
126        let Some(res) = res else { return Ok(None) };
127
128        Ok(Some(res.into()))
129    }
130
131    #[tracing::instrument(
132        name = "db.personal_access_token.find_active_for_session",
133        skip_all,
134        fields(
135            db.query.text,
136        ),
137        err,
138    )]
139    async fn find_active_for_session(
140        &mut self,
141        session_id: Ulid,
142    ) -> Result<Option<PersonalAccessToken>, Self::Error> {
143        let res: Option<PersonalAccessTokenLookup> = sqlx::query_as!(
144            PersonalAccessTokenLookup,
145            r#"
146                SELECT personal_access_token_id
147                     , personal_session_id
148                     , created_at
149                     , expires_at
150                     , revoked_at
151
152                FROM personal_access_tokens
153
154                WHERE personal_session_id = $1
155                AND revoked_at IS NULL
156            "#,
157            Uuid::from(session_id),
158        )
159        .traced()
160        .fetch_optional(&mut *self.conn)
161        .await?;
162
163        let Some(res) = res else { return Ok(None) };
164
165        Ok(Some(res.into()))
166    }
167
168    #[tracing::instrument(
169        name = "db.personal_access_token.add",
170        skip_all,
171        fields(
172            db.query.text,
173            personal_access_token.id,
174            %session.id,
175        ),
176        err,
177    )]
178    async fn add(
179        &mut self,
180        rng: &mut (dyn RngCore + Send),
181        clock: &dyn Clock,
182        session: &PersonalSession,
183        access_token: &str,
184        expires_after: Option<chrono::Duration>,
185    ) -> Result<PersonalAccessToken, Self::Error> {
186        let created_at = clock.now();
187        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
188        tracing::Span::current().record("personal_access_token.id", tracing::field::display(id));
189
190        let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
191
192        let expires_at = expires_after.map(|expires_after| created_at + expires_after);
193
194        sqlx::query!(
195            r#"
196                INSERT INTO personal_access_tokens
197                    (personal_access_token_id, personal_session_id, access_token_sha256, created_at, expires_at)
198                VALUES ($1, $2, $3, $4, $5)
199            "#,
200            Uuid::from(id),
201            Uuid::from(session.id),
202            &token_sha256,
203            created_at,
204            expires_at,
205        )
206        .traced()
207        .execute(&mut *self.conn)
208        .await?;
209
210        Ok(PersonalAccessToken {
211            id,
212            session_id: session.id,
213            created_at,
214            expires_at,
215            revoked_at: None,
216        })
217    }
218
219    #[tracing::instrument(
220        name = "db.personal_access_token.revoke",
221        skip_all,
222        fields(
223            db.query.text,
224            %access_token.id,
225            personal_session.id = %access_token.session_id,
226        ),
227        err,
228    )]
229    async fn revoke(
230        &mut self,
231        clock: &dyn Clock,
232        mut access_token: PersonalAccessToken,
233    ) -> Result<PersonalAccessToken, Self::Error> {
234        let revoked_at = clock.now();
235        let res = sqlx::query!(
236            r#"
237                UPDATE personal_access_tokens
238                SET revoked_at = $2
239                WHERE personal_access_token_id = $1
240            "#,
241            Uuid::from(access_token.id),
242            revoked_at,
243        )
244        .traced()
245        .execute(&mut *self.conn)
246        .await?;
247
248        DatabaseError::ensure_affected_rows(&res, 1)?;
249
250        access_token.revoked_at = Some(revoked_at);
251        Ok(access_token)
252    }
253}