//! Non-zero scalar type. use crate::{ ops::{Invert, Reduce, ReduceNonZero}, scalar::IsHigh, CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarPrimitive, SecretKey, }; use base16ct::HexDisplay; use core::{ fmt, ops::{Deref, Mul, Neg}, str, }; use crypto_bigint::{ArrayEncoding, Integer}; use ff::{Field, PrimeField}; use generic_array::{typenum::Unsigned, GenericArray}; use rand_core::CryptoRngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use zeroize::Zeroize; #[cfg(feature = "serde")] use serdect::serde::{de, ser, Deserialize, Serialize}; /// Non-zero scalar type. /// /// This type ensures that its value is not zero, ala `core::num::NonZero*`. /// To do this, the generic `S` type must impl both `Default` and /// `ConstantTimeEq`, with the requirement that `S::default()` returns 0. /// /// In the context of ECC, it's useful for ensuring that scalar multiplication /// cannot result in the point at infinity. #[derive(Clone)] pub struct NonZeroScalar where C: CurveArithmetic, { scalar: Scalar, } impl NonZeroScalar where C: CurveArithmetic, { /// Generate a random `NonZeroScalar`. pub fn random(mut rng: &mut impl CryptoRngCore) -> Self { // Use rejection sampling to eliminate zero values. // While this method isn't constant-time, the attacker shouldn't learn // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`. loop { if let Some(result) = Self::new(Field::random(&mut rng)).into() { break result; } } } /// Create a [`NonZeroScalar`] from a scalar. pub fn new(scalar: Scalar) -> CtOption { CtOption::new(Self { scalar }, !scalar.is_zero()) } /// Decode a [`NonZeroScalar`] from a big endian-serialized field element. pub fn from_repr(repr: FieldBytes) -> CtOption { Scalar::::from_repr(repr).and_then(Self::new) } /// Create a [`NonZeroScalar`] from a `C::Uint`. pub fn from_uint(uint: C::Uint) -> CtOption { ScalarPrimitive::new(uint).and_then(|scalar| Self::new(scalar.into())) } } impl AsRef> for NonZeroScalar where C: CurveArithmetic, { fn as_ref(&self) -> &Scalar { &self.scalar } } impl ConditionallySelectable for NonZeroScalar where C: CurveArithmetic, { fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { Self { scalar: Scalar::::conditional_select(&a.scalar, &b.scalar, choice), } } } impl ConstantTimeEq for NonZeroScalar where C: CurveArithmetic, { fn ct_eq(&self, other: &Self) -> Choice { self.scalar.ct_eq(&other.scalar) } } impl Copy for NonZeroScalar where C: CurveArithmetic {} impl Deref for NonZeroScalar where C: CurveArithmetic, { type Target = Scalar; fn deref(&self) -> &Scalar { &self.scalar } } impl From> for FieldBytes where C: CurveArithmetic, { fn from(scalar: NonZeroScalar) -> FieldBytes { Self::from(&scalar) } } impl From<&NonZeroScalar> for FieldBytes where C: CurveArithmetic, { fn from(scalar: &NonZeroScalar) -> FieldBytes { scalar.to_repr() } } impl From> for ScalarPrimitive where C: CurveArithmetic, { #[inline] fn from(scalar: NonZeroScalar) -> ScalarPrimitive { Self::from(&scalar) } } impl From<&NonZeroScalar> for ScalarPrimitive where C: CurveArithmetic, { fn from(scalar: &NonZeroScalar) -> ScalarPrimitive { ScalarPrimitive::from_bytes(&scalar.to_repr()).unwrap() } } impl From> for NonZeroScalar where C: CurveArithmetic, { fn from(sk: SecretKey) -> NonZeroScalar { Self::from(&sk) } } impl From<&SecretKey> for NonZeroScalar where C: CurveArithmetic, { fn from(sk: &SecretKey) -> NonZeroScalar { let scalar = sk.as_scalar_primitive().to_scalar(); debug_assert!(!bool::from(scalar.is_zero())); Self { scalar } } } impl Invert for NonZeroScalar where C: CurveArithmetic, Scalar: Invert>>, { type Output = Self; fn invert(&self) -> Self { Self { // This will always succeed since `scalar` will never be 0 scalar: Invert::invert(&self.scalar).unwrap(), } } fn invert_vartime(&self) -> Self::Output { Self { // This will always succeed since `scalar` will never be 0 scalar: Invert::invert_vartime(&self.scalar).unwrap(), } } } impl IsHigh for NonZeroScalar where C: CurveArithmetic, { fn is_high(&self) -> Choice { self.scalar.is_high() } } impl Neg for NonZeroScalar where C: CurveArithmetic, { type Output = NonZeroScalar; fn neg(self) -> NonZeroScalar { let scalar = -self.scalar; debug_assert!(!bool::from(scalar.is_zero())); NonZeroScalar { scalar } } } impl Mul> for NonZeroScalar where C: PrimeCurve + CurveArithmetic, { type Output = Self; #[inline] fn mul(self, other: Self) -> Self { Self::mul(self, &other) } } impl Mul<&NonZeroScalar> for NonZeroScalar where C: PrimeCurve + CurveArithmetic, { type Output = Self; fn mul(self, other: &Self) -> Self { // Multiplication is modulo a prime, so the product of two non-zero // scalars is also non-zero. let scalar = self.scalar * other.scalar; debug_assert!(!bool::from(scalar.is_zero())); NonZeroScalar { scalar } } } /// Note: this is a non-zero reduction, as it's impl'd for [`NonZeroScalar`]. impl Reduce for NonZeroScalar where C: CurveArithmetic, I: Integer + ArrayEncoding, Scalar: Reduce + ReduceNonZero, { type Bytes = as Reduce>::Bytes; fn reduce(n: I) -> Self { let scalar = Scalar::::reduce_nonzero(n); debug_assert!(!bool::from(scalar.is_zero())); Self { scalar } } fn reduce_bytes(bytes: &Self::Bytes) -> Self { let scalar = Scalar::::reduce_nonzero_bytes(bytes); debug_assert!(!bool::from(scalar.is_zero())); Self { scalar } } } /// Note: forwards to the [`Reduce`] impl. impl ReduceNonZero for NonZeroScalar where Self: Reduce, C: CurveArithmetic, I: Integer + ArrayEncoding, Scalar: Reduce + ReduceNonZero, { fn reduce_nonzero(n: I) -> Self { Self::reduce(n) } fn reduce_nonzero_bytes(bytes: &Self::Bytes) -> Self { Self::reduce_bytes(bytes) } } impl TryFrom<&[u8]> for NonZeroScalar where C: CurveArithmetic, { type Error = Error; fn try_from(bytes: &[u8]) -> Result { if bytes.len() == C::FieldBytesSize::USIZE { Option::from(NonZeroScalar::from_repr(GenericArray::clone_from_slice( bytes, ))) .ok_or(Error) } else { Err(Error) } } } impl Zeroize for NonZeroScalar where C: CurveArithmetic, { fn zeroize(&mut self) { // Use zeroize's volatile writes to ensure value is cleared. self.scalar.zeroize(); // Write a 1 instead of a 0 to ensure this type's non-zero invariant // is upheld. self.scalar = Scalar::::ONE; } } impl fmt::Display for NonZeroScalar where C: CurveArithmetic, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self:X}") } } impl fmt::LowerHex for NonZeroScalar where C: CurveArithmetic, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:x}", HexDisplay(&self.to_repr())) } } impl fmt::UpperHex for NonZeroScalar where C: CurveArithmetic, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:}", HexDisplay(&self.to_repr())) } } impl str::FromStr for NonZeroScalar where C: CurveArithmetic, { type Err = Error; fn from_str(hex: &str) -> Result { let mut bytes = FieldBytes::::default(); if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() { Option::from(Self::from_repr(bytes)).ok_or(Error) } else { Err(Error) } } } #[cfg(feature = "serde")] impl Serialize for NonZeroScalar where C: CurveArithmetic, { fn serialize(&self, serializer: S) -> Result where S: ser::Serializer, { ScalarPrimitive::from(self).serialize(serializer) } } #[cfg(feature = "serde")] impl<'de, C> Deserialize<'de> for NonZeroScalar where C: CurveArithmetic, { fn deserialize(deserializer: D) -> Result where D: de::Deserializer<'de>, { let scalar = ScalarPrimitive::deserialize(deserializer)?; Option::from(Self::new(scalar.into())) .ok_or_else(|| de::Error::custom("expected non-zero scalar")) } } #[cfg(all(test, feature = "dev"))] mod tests { use crate::dev::{NonZeroScalar, Scalar}; use ff::{Field, PrimeField}; use hex_literal::hex; use zeroize::Zeroize; #[test] fn round_trip() { let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721"); let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap(); assert_eq!(&bytes, scalar.to_repr().as_slice()); } #[test] fn zeroize() { let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap(); scalar.zeroize(); assert_eq!(*scalar, Scalar::ONE); } }