1use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use mas_data_model::{
9 Clock,
10 personal::{PersonalAccessToken, session::PersonalSession},
11};
12use mas_storage::personal::PersonalAccessTokenRepository;
13use rand::RngCore;
14use sha2::{Digest, Sha256};
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, tracing::ExecuteExt as _};
20
21pub struct PgPersonalAccessTokenRepository<'c> {
24 conn: &'c mut PgConnection,
25}
26
27impl<'c> PgPersonalAccessTokenRepository<'c> {
28 pub fn new(conn: &'c mut PgConnection) -> Self {
31 Self { conn }
32 }
33}
34
35struct PersonalAccessTokenLookup {
36 personal_access_token_id: Uuid,
37 personal_session_id: Uuid,
38 created_at: DateTime<Utc>,
39 expires_at: Option<DateTime<Utc>>,
40 revoked_at: Option<DateTime<Utc>>,
41}
42
43impl From<PersonalAccessTokenLookup> for PersonalAccessToken {
44 fn from(value: PersonalAccessTokenLookup) -> Self {
45 Self {
46 id: Ulid::from(value.personal_access_token_id),
47 session_id: Ulid::from(value.personal_session_id),
48 created_at: value.created_at,
49 expires_at: value.expires_at,
50 revoked_at: value.revoked_at,
51 }
52 }
53}
54
55#[async_trait]
56impl PersonalAccessTokenRepository for PgPersonalAccessTokenRepository<'_> {
57 type Error = DatabaseError;
58
59 #[tracing::instrument(
60 name = "db.personal_access_token.lookup",
61 skip_all,
62 fields(
63 db.query.text,
64 personal_access_token.id = %id,
65 ),
66 err,
67 )]
68 async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalAccessToken>, Self::Error> {
69 let res = sqlx::query_as!(
70 PersonalAccessTokenLookup,
71 r#"
72 SELECT personal_access_token_id
73 , personal_session_id
74 , created_at
75 , expires_at
76 , revoked_at
77
78 FROM personal_access_tokens
79
80 WHERE personal_access_token_id = $1
81 "#,
82 Uuid::from(id),
83 )
84 .traced()
85 .fetch_optional(&mut *self.conn)
86 .await?;
87
88 let Some(res) = res else { return Ok(None) };
89
90 Ok(Some(res.into()))
91 }
92
93 #[tracing::instrument(
94 name = "db.personal_access_token.find_by_token",
95 skip_all,
96 fields(
97 db.query.text,
98 ),
99 err,
100 )]
101 async fn find_by_token(
102 &mut self,
103 access_token: &str,
104 ) -> Result<Option<PersonalAccessToken>, Self::Error> {
105 let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
106
107 let res = sqlx::query_as!(
108 PersonalAccessTokenLookup,
109 r#"
110 SELECT personal_access_token_id
111 , personal_session_id
112 , created_at
113 , expires_at
114 , revoked_at
115
116 FROM personal_access_tokens
117
118 WHERE access_token_sha256 = $1
119 "#,
120 &token_sha256,
121 )
122 .traced()
123 .fetch_optional(&mut *self.conn)
124 .await?;
125
126 let Some(res) = res else { return Ok(None) };
127
128 Ok(Some(res.into()))
129 }
130
131 #[tracing::instrument(
132 name = "db.personal_access_token.find_active_for_session",
133 skip_all,
134 fields(
135 db.query.text,
136 ),
137 err,
138 )]
139 async fn find_active_for_session(
140 &mut self,
141 session_id: Ulid,
142 ) -> Result<Option<PersonalAccessToken>, Self::Error> {
143 let res: Option<PersonalAccessTokenLookup> = sqlx::query_as!(
144 PersonalAccessTokenLookup,
145 r#"
146 SELECT personal_access_token_id
147 , personal_session_id
148 , created_at
149 , expires_at
150 , revoked_at
151
152 FROM personal_access_tokens
153
154 WHERE personal_session_id = $1
155 AND revoked_at IS NULL
156 "#,
157 Uuid::from(session_id),
158 )
159 .traced()
160 .fetch_optional(&mut *self.conn)
161 .await?;
162
163 let Some(res) = res else { return Ok(None) };
164
165 Ok(Some(res.into()))
166 }
167
168 #[tracing::instrument(
169 name = "db.personal_access_token.add",
170 skip_all,
171 fields(
172 db.query.text,
173 personal_access_token.id,
174 %session.id,
175 ),
176 err,
177 )]
178 async fn add(
179 &mut self,
180 rng: &mut (dyn RngCore + Send),
181 clock: &dyn Clock,
182 session: &PersonalSession,
183 access_token: &str,
184 expires_after: Option<chrono::Duration>,
185 ) -> Result<PersonalAccessToken, Self::Error> {
186 let created_at = clock.now();
187 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
188 tracing::Span::current().record("personal_access_token.id", tracing::field::display(id));
189
190 let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
191
192 let expires_at = expires_after.map(|expires_after| created_at + expires_after);
193
194 sqlx::query!(
195 r#"
196 INSERT INTO personal_access_tokens
197 (personal_access_token_id, personal_session_id, access_token_sha256, created_at, expires_at)
198 VALUES ($1, $2, $3, $4, $5)
199 "#,
200 Uuid::from(id),
201 Uuid::from(session.id),
202 &token_sha256,
203 created_at,
204 expires_at,
205 )
206 .traced()
207 .execute(&mut *self.conn)
208 .await?;
209
210 Ok(PersonalAccessToken {
211 id,
212 session_id: session.id,
213 created_at,
214 expires_at,
215 revoked_at: None,
216 })
217 }
218
219 #[tracing::instrument(
220 name = "db.personal_access_token.revoke",
221 skip_all,
222 fields(
223 db.query.text,
224 %access_token.id,
225 personal_session.id = %access_token.session_id,
226 ),
227 err,
228 )]
229 async fn revoke(
230 &mut self,
231 clock: &dyn Clock,
232 mut access_token: PersonalAccessToken,
233 ) -> Result<PersonalAccessToken, Self::Error> {
234 let revoked_at = clock.now();
235 let res = sqlx::query!(
236 r#"
237 UPDATE personal_access_tokens
238 SET revoked_at = $2
239 WHERE personal_access_token_id = $1
240 "#,
241 Uuid::from(access_token.id),
242 revoked_at,
243 )
244 .traced()
245 .execute(&mut *self.conn)
246 .await?;
247
248 DatabaseError::ensure_affected_rows(&res, 1)?;
249
250 access_token.revoked_at = Some(revoked_at);
251 Ok(access_token)
252 }
253}