Source code

Revision control

Copy as Markdown

Other Tools

use crate::postgres::common::*;
use crate::Decimal;
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use std::io::Cursor;
impl<'a> FromSql<'a> for Decimal {
// Decimals are represented as follows:
// Header:
// u16 numGroups
// i16 weightFirstGroup (10000^weight)
// u16 sign (0x0000 = positive, 0x4000 = negative, 0xC000 = NaN)
// i16 dscale. Number of digits (in base 10) to print after decimal separator
//
// Pseudo code :
// const Decimals [
// 0.0000000000000000000000000001,
// 0.000000000000000000000001,
// 0.00000000000000000001,
// 0.0000000000000001,
// 0.000000000001,
// 0.00000001,
// 0.0001,
// 1,
// 10000,
// 100000000,
// 1000000000000,
// 10000000000000000,
// 100000000000000000000,
// 1000000000000000000000000,
// 10000000000000000000000000000
// ]
// overflow = false
// result = 0
// for i = 0, weight = weightFirstGroup + 7; i < numGroups; i++, weight--
// group = read.u16
// if weight < 0 or weight > MaxNum
// overflow = true
// else
// result += Decimals[weight] * group
// sign == 0x4000 ? -result : result
// So if we were to take the number: 3950.123456
//
// Stored on Disk:
// 00 03 00 00 00 00 00 06 0F 6E 04 D2 15 E0
//
// Number of groups: 00 03
// Weight of first group: 00 00
// Sign: 00 00
// DScale: 00 06
//
// 0F 6E = 3950
// result = result + 3950 * 1;
// 04 D2 = 1234
// result = result + 1234 * 0.0001;
// 15 E0 = 5600
// result = result + 5600 * 0.00000001;
//
fn from_sql(_: &Type, raw: &[u8]) -> Result<Decimal, Box<dyn std::error::Error + 'static + Sync + Send>> {
let mut raw = Cursor::new(raw);
let num_groups = raw.read_u16::<BigEndian>()?;
let weight = raw.read_i16::<BigEndian>()?; // 10000^weight
// Sign: 0x0000 = positive, 0x4000 = negative, 0xC000 = NaN
let sign = raw.read_u16::<BigEndian>()?;
// Number of digits (in base 10) to print after decimal separator
let scale = raw.read_u16::<BigEndian>()?;
// Read all of the groups
let mut groups = Vec::new();
for _ in 0..num_groups as usize {
groups.push(raw.read_u16::<BigEndian>()?);
}
Ok(Self::from_postgres(PostgresDecimal {
neg: sign == 0x4000,
weight,
scale,
digits: groups.into_iter(),
}))
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
}
impl ToSql for Decimal {
fn to_sql(
&self,
_: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> {
let PostgresDecimal {
neg,
weight,
scale,
digits,
} = self.to_postgres();
let num_digits = digits.len();
// Reserve bytes
out.reserve(8 + num_digits * 2);
// Number of groups
out.put_u16(num_digits.try_into().unwrap());
// Weight of first group
out.put_i16(weight);
// Sign
out.put_u16(if neg { 0x4000 } else { 0x0000 });
// DScale
out.put_u16(scale);
// Now process the number
for digit in digits[0..num_digits].iter() {
out.put_i16(*digit);
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
matches!(*ty, Type::NUMERIC)
}
to_sql_checked!();
}
#[cfg(test)]
mod test {
use super::*;
use ::postgres::{Client, NoTls};
use core::str::FromStr;
/// Gets the URL for connecting to PostgreSQL for testing. Set the POSTGRES_URL
/// environment variable to change from the default of "postgres://postgres@localhost".
fn get_postgres_url() -> String {
if let Ok(url) = std::env::var("POSTGRES_URL") {
return url;
}
}
pub static TEST_DECIMALS: &[(u32, u32, &str, &str)] = &[
// precision, scale, sent, expected
(35, 6, "3950.123456", "3950.123456"),
(35, 2, "3950.123456", "3950.12"),
(35, 2, "3950.1256", "3950.13"),
(10, 2, "3950.123456", "3950.12"),
(35, 6, "3950", "3950.000000"),
(4, 0, "3950", "3950"),
(35, 6, "0.1", "0.100000"),
(35, 6, "0.01", "0.010000"),
(35, 6, "0.001", "0.001000"),
(35, 6, "0.0001", "0.000100"),
(35, 6, "0.00001", "0.000010"),
(35, 6, "0.000001", "0.000001"),
(35, 6, "1", "1.000000"),
(35, 6, "-100", "-100.000000"),
(35, 6, "-123.456", "-123.456000"),
(35, 6, "119996.25", "119996.250000"),
(35, 6, "1000000", "1000000.000000"),
(35, 6, "9999999.99999", "9999999.999990"),
(35, 6, "12340.56789", "12340.567890"),
// Scale is only 28 since that is the maximum we can represent.
(65, 30, "1.2", "1.2000000000000000000000000000"),
// Pi - rounded at scale 28
(
65,
30,
"3.141592653589793238462643383279",
"3.1415926535897932384626433833",
),
(
65,
34,
"3.1415926535897932384626433832795028",
"3.1415926535897932384626433833",
),
// Unrounded number
(
65,
34,
"1.234567890123456789012345678950000",
"1.2345678901234567890123456790",
),
(
65,
34, // No rounding due to 49999 after significant digits
"1.234567890123456789012345678949999",
"1.2345678901234567890123456789",
),
// 0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF (96 bit)
(35, 0, "79228162514264337593543950335", "79228162514264337593543950335"),
// 0x0FFF_FFFF_FFFF_FFFF_FFFF_FFFF (95 bit)
(35, 1, "4951760157141521099596496895", "4951760157141521099596496895.0"),
// 0x1000_0000_0000_0000_0000_0000
(35, 1, "4951760157141521099596496896", "4951760157141521099596496896.0"),
(35, 6, "18446744073709551615", "18446744073709551615.000000"),
(35, 6, "-18446744073709551615", "-18446744073709551615.000000"),
(35, 6, "0.10001", "0.100010"),
(35, 6, "0.12345", "0.123450"),
];
#[test]
fn test_null() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
// Test NULL
let result: Option<Decimal> = match client.query("SELECT NULL::numeric", &[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(None, result);
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_test_null() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
let statement = client.prepare(&"SELECT NULL::numeric").await.unwrap();
let rows = client.query(&statement, &[]).await.unwrap();
let result: Option<Decimal> = rows.iter().next().unwrap().get(0);
assert_eq!(None, result);
}
#[test]
fn read_very_small_numeric_type() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
let result: Decimal = match client.query("SELECT 1e-130::NUMERIC(130, 0)", &[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("error - {:#?}", err),
};
// We compare this to zero since it is so small that it is effectively zero
assert_eq!(Decimal::ZERO, result);
}
#[test]
fn read_numeric_type() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let result: Decimal =
match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("SELECT {}::NUMERIC({}, {}), error - {:#?}", sent, precision, scale, err),
};
assert_eq!(
expected,
result.to_string(),
"NUMERIC({}, {}) sent: {}",
precision,
scale,
sent
);
}
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_read_numeric_type() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let statement = client
.prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
.await
.unwrap();
let rows = client.query(&statement, &[]).await.unwrap();
let result: Decimal = rows.iter().next().unwrap().get(0);
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[test]
fn write_numeric_type() {
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let number = Decimal::from_str(sent).unwrap();
let result: Decimal =
match client.query(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale), &[&number]) {
Ok(x) => x.iter().next().unwrap().get(0),
Err(err) => panic!("{:#?}", err),
};
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_write_numeric_type() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
for &(precision, scale, sent, expected) in TEST_DECIMALS.iter() {
let statement = client
.prepare(&*format!("SELECT $1::NUMERIC({}, {})", precision, scale))
.await
.unwrap();
let number = Decimal::from_str(sent).unwrap();
let rows = client.query(&statement, &[&number]).await.unwrap();
let result: Decimal = rows.iter().next().unwrap().get(0);
assert_eq!(expected, result.to_string(), "NUMERIC({}, {})", precision, scale);
}
}
#[test]
fn numeric_overflow() {
let tests = [(4, 4, "3950.1234")];
let mut client = match Client::connect(&get_postgres_url(), NoTls) {
Ok(x) => x,
Err(err) => panic!("{:#?}", err),
};
for &(precision, scale, sent) in tests.iter() {
match client.query(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale), &[]) {
Ok(_) => panic!(
"Expected numeric overflow for {}::NUMERIC({}, {})",
sent, precision, scale
),
Err(err) => {
assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code");
}
};
}
}
#[tokio::test]
#[cfg(feature = "tokio-pg")]
async fn async_numeric_overflow() {
use futures::future::FutureExt;
use tokio_postgres::connect;
let tests = [(4, 4, "3950.1234")];
let (client, connection) = connect(&get_postgres_url(), NoTls).await.unwrap();
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
for &(precision, scale, sent) in tests.iter() {
let statement = client
.prepare(&*format!("SELECT {}::NUMERIC({}, {})", sent, precision, scale))
.await
.unwrap();
match client.query(&statement, &[]).await {
Ok(_) => panic!(
"Expected numeric overflow for {}::NUMERIC({}, {})",
sent, precision, scale
),
Err(err) => assert_eq!("22003", err.code().unwrap().code(), "Unexpected error code"),
}
}
}
}