Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ env:
min_key: 1
max_key: 10000
socks_dir: /home/runner/work/breeze/socks
# TODO rust版本升级后,突然大量warnings,暂时允许dead_code警告
RUSTFLAGS: "-D warnings --allow dead_code"
RUSTFLAGS: "-D warnings"

jobs:
build:
Expand Down
7 changes: 7 additions & 0 deletions context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ pub struct ContextOption {
#[clap(long, help("private key path"), default_value("/var/private_key.pem"))]
pub key_path: String,

#[clap(
long,
help("redis private key path"),
default_value("/var/redis_private_key.pem")
)]
pub redis_key_path: String,

#[clap(long, help("region"), default_value(""))]
pub region: String,

Expand Down
31 changes: 25 additions & 6 deletions endpoint/src/redisservice/config.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
//use ds::time::Duration;

use std::{collections::HashSet, fmt::Debug};

use base64::{Engine as _, engine::general_purpose};
use serde::{Deserialize, Serialize};
//use sharding::distribution::{DIST_ABS_MODULA, DIST_MODULA};
use std::{collections::HashSet, fmt::Debug, fs};

use crate::{Timeout, TO_REDIS_M, TO_REDIS_S};
use crate::{TO_REDIS_M, TO_REDIS_S, Timeout};

// range/modrange 对应的distribution配置项如果有此后缀,不进行后端数量的校验
const NO_CHECK_SUFFIX: &str = "-nocheck";
Expand Down Expand Up @@ -42,6 +39,8 @@ pub struct Basic {
// master是否参与读
#[serde(default)]
pub(crate) master_read: bool,
#[serde(default)]
pub(crate) password: String,
}

impl RedisNamespace {
Expand Down Expand Up @@ -78,6 +77,17 @@ impl RedisNamespace {
return None;
}

// 解密密码
if !ns.basic.password.is_empty() {
match ns.decrypt_password() {
Ok(password) => ns.basic.password = password,
Err(e) => {
log::warn!("failed to decrypt password, e:{}", e);
return None;
}
}
}

log::debug!("parsed redis config:{}/{}", ns.basic.distribution, cfg);
return Some(ns);
}
Expand Down Expand Up @@ -139,4 +149,13 @@ impl RedisNamespace {

true
}

#[inline]
fn decrypt_password(&self) -> Result<String, Box<dyn std::error::Error>> {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decrypt_password 算法相同,把散落在各处的进行复用?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

路径不一样 先合了吧 验证一下

let key_pem = fs::read_to_string(&context::get().redis_key_path)?;
let encrypted_data = general_purpose::STANDARD.decode(self.basic.password.as_bytes())?;
let decrypted_data = ds::decrypt::decrypt_password(&key_pem, &encrypted_data)?;
let decrypted_string = String::from_utf8(decrypted_data)?;
Ok(decrypted_string)
}
}
31 changes: 27 additions & 4 deletions endpoint/src/redisservice/topo.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::{
Endpoint, Endpoints, PerformanceTuning, Topology,
dns::{DnsConfig, DnsLookup},
shards::Shard,
Endpoint, Endpoints, PerformanceTuning, Topology,
};
use discovery::TopologyWrite;
use protocol::{Protocol, RedisFlager, Request, Resource::Redis};
use protocol::{Protocol, RedisFlager, Request, ResOption, Resource::Redis};
use sharding::distribution::Distribute;
use sharding::hash::{Hash, HashKey, Hasher};

Expand All @@ -18,6 +18,7 @@ pub struct RedisService<E, P> {
distribute: Distribute,
parser: P,
cfg: Box<DnsConfig<RedisNamespace>>,
password: String,
}
impl<E, P> From<P> for RedisService<E, P> {
#[inline]
Expand All @@ -28,6 +29,7 @@ impl<E, P> From<P> for RedisService<E, P> {
hasher: Default::default(),
distribute: Default::default(),
cfg: Default::default(),
password: Default::default(),
}
}
}
Expand Down Expand Up @@ -188,6 +190,18 @@ where
assert_eq!(addrs.len(), self.cfg.shards_url.len());
// 到这之后,所有的shard都能解析出ip

// 如果密码不一致,则清空所有现有的shard
if self.password != self.cfg.basic.password {
self.shards.clear();
self.password = self.cfg.basic.password.clone();
}

// Redis认证只需要密码,无需用户名
let res_option = ResOption {
token: self.cfg.basic.password.clone(),
username: String::new(), // Redis不需要用户名
};

// 把所有的endpoints cache下来
let mut endpoints: Endpoints<'_, P, E> =
Endpoints::new(&self.cfg.service, &self.parser, Redis);
Expand All @@ -199,10 +213,18 @@ where
// 遍历所有的shards_url
addrs.iter().for_each(|ips| {
assert!(ips.len() >= 2);
let master = endpoints.take_or_build_one(&ips[0], self.cfg.timeout_master());
let master = endpoints.take_or_build_one_with_res(
&ips[0],
self.cfg.timeout_master(),
res_option.clone(),
);
// 第0个是master,如果master提供读,则从第0个开始。
let oft = if self.cfg.basic.master_read { 0 } else { 1 };
let slaves = endpoints.take_or_build(&ips[oft..], self.cfg.timeout_slave());
let slaves = endpoints.take_or_build_with_res(
&ips[oft..],
self.cfg.timeout_slave(),
res_option.clone(),
);
let shard = Shard::selector(
self.cfg.basic.selector.tuning_mode(),
master,
Expand All @@ -213,6 +235,7 @@ where
shard.check_region_len(ty, &self.cfg.service);
self.shards.push(shard);
});

Some(())
}
}
Expand Down
28 changes: 28 additions & 0 deletions endpoint/src/topo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ impl<'a, P: Protocol, E: Endpoint> Endpoints<'a, P, E> {
.pop()
.expect("take")
}

pub fn take_or_build_one_with_res(&mut self, addr: &str, to: Timeout, res: ResOption) -> E {
self.take_or_build_with_res(&[addr.to_owned()], to, res)
.pop()
.expect("take")
}

pub fn take_or_build(&mut self, addrs: &[String], to: Timeout) -> Vec<E> {
addrs
.iter()
Expand All @@ -137,6 +144,27 @@ impl<'a, P: Protocol, E: Endpoint> Endpoints<'a, P, E> {
.collect()
}

pub fn take_or_build_with_res(
&mut self,
addrs: &[String],
to: Timeout,
res: ResOption,
) -> Vec<E> {
addrs
.iter()
.map(|addr| {
self.cache
.get_mut(addr)
.map(|endpoints| endpoints.pop())
.flatten()
.unwrap_or_else(|| {
let p = self.parser.clone();
E::build_o(&addr, p, self.resource, self.service, to, res.clone())
})
})
.collect()
}

#[inline]
pub fn take_all(&mut self) -> Vec<E> {
self.cache
Expand Down
10 changes: 5 additions & 5 deletions protocol/src/kv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod reqpacket;
mod rsppacket;

mod mc2mysql;
pub use mc2mysql::{escape_mysql_and_push, MysqlBuilder, Strategy, VectorSqlBuilder};
pub use mc2mysql::{MysqlBuilder, Strategy, VectorSqlBuilder, escape_mysql_and_push};
use std::ops::Deref;

use self::common::proto::Text;
Expand All @@ -23,24 +23,24 @@ use self::rsppacket::ResponsePacket;

use super::Flag;
use super::Protocol;
use crate::kv::client::Client;
use crate::kv::error::Error;
use crate::HandShake;
use crate::HashedCommand;
use crate::RequestProcessor;
use crate::Stream;
use crate::kv::client::Client;
use crate::kv::error::Error;
use crate::{Command, Operation};
use ds::RingSlice;

use sharding::hash::Hash;

pub mod prelude {

#[doc(inline)]
pub use crate::kv::common::row::convert::FromRow;
#[doc(inline)]
pub use crate::kv::common::row::ColumnIndex;
#[doc(inline)]
pub use crate::kv::common::row::convert::FromRow;
#[doc(inline)]
pub use crate::kv::common::value::convert::{ConvIr, FromValue, ToValue};

// Trait for protocol markers [`crate::Binary`] and [`crate::Text`].
Expand Down
1 change: 1 addition & 0 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub use redis::RedisFlager;
pub use redis::packet::Packet;
pub mod req;
//pub mod resp;
#[allow(dead_code)]
pub mod kv;
pub mod metrics;
pub mod msgque;
Expand Down
59 changes: 56 additions & 3 deletions protocol/src/redis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,31 @@ pub use flag::RedisFlager;
pub(crate) mod packet;

use crate::{
Command, Commander, Error, HandShake, HashedCommand, Metric, MetricItem, MetricName, Protocol,
RequestProcessor, ResOption, Result, Stream, Writer,
redis::command::CommandType,
redis::{error::RedisError, packet::RequestPacket},
Command, Commander, Error, HashedCommand, Metric, MetricItem, MetricName, Protocol,
RequestProcessor, Result, Stream, Writer,
};
pub use packet::{transmute, Packet, ResponseContext};
pub use packet::{Packet, ResponseContext, transmute};
use sharding::hash::Hash;

#[derive(Clone, Default)]
pub struct Redis;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum HandShakeStatus {
Init = 0,
Sent = 1,
Success = 2,
}

impl Default for HandShakeStatus {
fn default() -> Self {
Self::Init
}
}

impl Redis {
#[inline]
fn parse_request_inner<S: Stream, H: Hash, P: RequestProcessor>(
Expand Down Expand Up @@ -91,9 +105,48 @@ impl Redis {
}

impl Protocol for Redis {
fn handshake(&self, stream: &mut impl Stream, option: &mut ResOption) -> Result<HandShake> {
let status = transmute(stream.context()).status;

match status {
HandShakeStatus::Init => {
// a two-bulk "AUTH" command.
let pass = &option.token;
let mut auth_cmd = Vec::with_capacity(32);
auth_cmd.extend_from_slice(b"*2\r\n$4\r\nAUTH\r\n$");
auth_cmd.extend_from_slice(pass.len().to_string().as_bytes());
auth_cmd.extend_from_slice(b"\r\n");
auth_cmd.extend_from_slice(pass.as_bytes());
auth_cmd.extend_from_slice(b"\r\n");

stream.write_all(&auth_cmd)?;
transmute(stream.context()).status = HandShakeStatus::Sent;
Ok(HandShake::Continue)
}

HandShakeStatus::Sent => {
let data = stream.slice();
if let Some(idx) = data.find_lf_cr(0) {
// response should be +OK\r\n
if data.start_with(0, b"+OK\r\n") {
stream.ignore(idx + 2);
transmute(stream.context()).status = HandShakeStatus::Success;
return Ok(HandShake::Success);
}
log::warn!("redis auth failed response:{:?}", data);
Err(Error::AuthFailed)
} else {
stream.reserve(8);
Ok(HandShake::Continue)
}
}
HandShakeStatus::Success => Ok(HandShake::Success),
}
}
#[inline]
fn config(&self) -> crate::Config {
crate::Config {
need_auth: true,
pipeline: true,
..Default::default()
}
Expand Down
17 changes: 11 additions & 6 deletions protocol/src/redis/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ use super::{
command::{CommandHasher, CommandProperties, CommandType},
error::RedisError,
};
use crate::{error::Error, redis::command, Flag, Result, StreamContext};
use crate::{
Flag, Result, StreamContext,
error::Error,
redis::{HandShakeStatus, command},
};
use ds::RingSlice;
use sharding::hash::Hash;

Expand Down Expand Up @@ -39,7 +43,8 @@ impl From<RequestContext> for StreamContext {
#[derive(Debug, Default, Clone, Copy)]
pub struct ResponseContext {
pub oft: usize,
pub bulk: usize,
pub bulk: u32,
pub status: HandShakeStatus,
}
#[inline]
pub fn transmute(ctx: &mut StreamContext) -> &mut ResponseContext {
Expand Down Expand Up @@ -607,7 +612,7 @@ impl Packet {
self.skip_multibulks_inner(&mut ctx.oft, &mut ctx.bulk)
.map_err(|e| {
if let Error::ProtocolIncomplete(_) = e {
Error::ProtocolIncomplete(ctx.bulk * 64)
Error::ProtocolIncomplete(ctx.bulk as usize * 64)
} else {
e
}
Expand Down Expand Up @@ -641,20 +646,20 @@ impl Packet {
// }
//协议完整才跳过,否则不做改动
#[inline]
pub fn full_skip_multibulks(&self, oft: &mut usize, bulks: &mut usize) -> Result<()> {
pub fn full_skip_multibulks(&self, oft: &mut usize, bulks: &mut u32) -> Result<()> {
let (mut oft_tmp, mut bulks_tmp) = (*oft, *bulks);
self.skip_multibulks_inner(&mut oft_tmp, &mut bulks_tmp)?;
*oft = oft_tmp;
*bulks = bulks_tmp;
Ok(())
}
#[inline]
pub fn skip_multibulks_inner(&self, oft: &mut usize, bulks: &mut usize) -> Result<()> {
pub fn skip_multibulks_inner(&self, oft: &mut usize, bulks: &mut u32) -> Result<()> {
while *bulks > 0 {
self.check_onetoken(*oft)?;
// 下面每种情况都确保了不会越界
match self.at(*oft) {
b'*' => *bulks = *bulks + self.num_of_bulks(oft)?,
b'*' => *bulks = *bulks + self.num_of_bulks(oft)? as u32,
// 能完整解析才跳过当前字符串:num个字节 + "\r\n" 2个字节
b'$' => self.skip_string_inner(oft).map(|_| {})?,
b'+' | b':' => self.line(oft)?,
Expand Down
Loading
Loading