diff --git a/grive/src/main.cc b/grive/src/main.cc index 54d5560..8009d12 100644 --- a/grive/src/main.cc +++ b/grive/src/main.cc @@ -111,9 +111,9 @@ int Main( int argc, char **argv ) ( "help,h", "Produce help message" ) ( "version,v", "Display Grive version" ) ( "auth,a", "Request authorization token" ) - ( "id,i", po::value(), "Authentication ID") - ( "secret,e", po::value(), "Authentication secret") - ( "print-url", "Only print url for request") + ( "id,i", po::value(), "Authentication ID") + ( "secret,e", po::value(), "Authentication secret") + ( "print-url", "Only print url for request") ( "path,p", po::value(), "Path to working copy root") ( "dir,s", po::value(), "Single subdirectory to sync") ( "verbose,V", "Verbose mode. Enable more messages than normal.") @@ -185,34 +185,32 @@ int Main( int argc, char **argv ) : default_secret ; OAuth2 token( http.get(), id, secret ) ; - + if ( vm.count("print-url") ) { - std::cout << token.MakeAuthURL() << std::endl ; + std::cout << token.MakeAuthURL() << std::endl ; return 0 ; } - + std::cout << "-----------------------\n" - << "Please go to this URL and get an authentication code:\n\n" + << "Please open this URL in your browser to authenticate Grive2:\n\n" << token.MakeAuthURL() << std::endl ; - - std::cout - << "\n-----------------------\n" - << "Please input the authentication code here: " << std::endl ; - std::string code ; - std::cin >> code ; - - token.Auth( code ) ; - + + if ( !token.GetCode() ) + { + std::cout << "Authentication failed\n"; + return -1; + } + // save to config config.Set( "id", Val( id ) ) ; config.Set( "secret", Val( secret ) ) ; config.Set( "refresh_token", Val( token.RefreshToken() ) ) ; config.Save() ; } - + std::string refresh_token ; std::string id ; std::string secret ; @@ -231,7 +229,7 @@ int Main( int argc, char **argv ) return -1; } - + OAuth2 token( http.get(), refresh_token, id, secret ) ; AuthAgent agent( token, http.get() ) ; v2::Syncer2 syncer( &agent ); diff --git a/libgrive/src/protocol/OAuth2.cc b/libgrive/src/protocol/OAuth2.cc index db1858a..478f108 100644 --- a/libgrive/src/protocol/OAuth2.cc +++ b/libgrive/src/protocol/OAuth2.cc @@ -25,6 +25,13 @@ #include "http/Header.hh" #include "util/log/Log.hh" +#include +#include +#include +#include +#include +#include + // for debugging #include @@ -50,18 +57,29 @@ OAuth2::OAuth2( const std::string& client_id, const std::string& client_secret ) : m_agent( agent ), + m_port( 0 ), + m_socket( -1 ), m_client_id( client_id ), m_client_secret( client_secret ) { } -void OAuth2::Auth( const std::string& auth_code ) +OAuth2::~OAuth2() +{ + if ( m_socket >= 0 ) + { + close( m_socket ); + m_socket = -1; + } +} + +bool OAuth2::Auth( const std::string& auth_code ) { std::string post = "code=" + auth_code + "&client_id=" + m_client_id + "&client_secret=" + m_client_secret + - "&redirect_uri=" + "urn:ietf:wg:oauth:2.0:oob" + + "&redirect_uri=http%3A%2F%2Flocalhost:" + std::to_string( m_port ) + "%2Fauth" + "&grant_type=authorization_code" ; http::ValResponse resp ; @@ -77,19 +95,120 @@ void OAuth2::Auth( const std::string& auth_code ) { Log( "Failed to obtain auth token: HTTP %1%, body: %2%", code, m_agent->LastError(), log::error ) ; - BOOST_THROW_EXCEPTION( AuthFailed() ); + return false; } + + return true; } std::string OAuth2::MakeAuthURL() { + if ( !m_port ) + { + sockaddr_storage addr = { 0 }; + addr.ss_family = AF_INET; + m_socket = socket( AF_INET, SOCK_STREAM, 0 ); + if ( m_socket < 0 ) + throw std::runtime_error( std::string("socket: ") + strerror(errno) ); + if ( bind( m_socket, (sockaddr*)&addr, sizeof( addr ) ) < 0 ) + { + close( m_socket ); + m_socket = -1; + throw std::runtime_error( std::string("bind: ") + strerror(errno) ); + } + socklen_t len = sizeof( addr ); + if ( getsockname( m_socket, (sockaddr *)&addr, &len ) == -1 ) + { + close( m_socket ); + m_socket = -1; + throw std::runtime_error( std::string("getsockname: ") + strerror(errno) ); + } + m_port = ntohs(((sockaddr_in*)&addr)->sin_port); + if ( listen( m_socket, 128 ) < 0 ) + { + close( m_socket ); + m_socket = -1; + m_port = 0; + throw std::runtime_error( std::string("listen: ") + strerror(errno) ); + } + } return "https://accounts.google.com/o/oauth2/auth" "?scope=" + m_agent->Escape( "https://www.googleapis.com/auth/drive" ) + - "&redirect_uri=urn:ietf:wg:oauth:2.0:oob" + "&redirect_uri=http%3A%2F%2Flocalhost:" + std::to_string( m_port ) + "%2Fauth" + "&response_type=code" "&client_id=" + m_client_id ; } +bool OAuth2::GetCode( ) +{ + sockaddr_storage addr = { 0 }; + int peer_fd = -1; + while ( peer_fd < 0 ) + { + socklen_t peer_addr_size = sizeof( addr ); + peer_fd = accept( m_socket, (sockaddr*)&addr, &peer_addr_size ); + if ( peer_fd == -1 && errno != EAGAIN && errno != EINTR ) + throw std::runtime_error( std::string("accept: ") + strerror(errno) ); + } + fcntl( peer_fd, F_SETFL, fcntl( peer_fd, F_GETFL, 0 ) | O_NONBLOCK ); + struct pollfd pfd = (struct pollfd){ + .fd = peer_fd, + .events = POLLIN|POLLRDHUP, + }; + char buf[4096]; + std::string request; + while ( true ) + { + pfd.revents = 0; + poll( &pfd, 1, -1 ); + if ( pfd.revents & POLLRDHUP ) + break; + int r = 1; + while ( r > 0 ) + { + r = read( peer_fd, buf, sizeof( buf ) ); + if ( r > 0 ) + request += std::string( buf, r ); + else if ( r == 0 ) + break; + else if ( errno != EAGAIN && errno != EINTR ) + throw std::runtime_error( std::string("read: ") + strerror(errno) ); + } + if ( r == 0 || ( r < 0 && request.find( "\n" ) > 0 ) ) // GET ... HTTP/1.1\r\n + break; + } + bool ok = false; + if ( request.substr( 0, 10 ) == "GET /auth?" ) + { + std::string line = request; + int p = line.find( "\n" ); + if ( p > 0 ) + line = line.substr( 0, p ); + p = line.rfind( " " ); + if ( p > 0 ) + line = line.substr( 0, p ); + p = line.find( "code=" ); + if ( p > 0 ) + line = line.substr( p+5 ); + p = line.find( "&" ); + if ( p > 0 ) + line = line.substr( 0, p ); + ok = Auth( line ); + } + std::string response = ( ok + ? "Authenticated successfully. Please close the page" + : "Authentication error. Please try again" ); + response = "HTTP/1.1 200 OK\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "Connection: close\r\n" + "\r\n"+ + response+ + "\r\n"; + write( peer_fd, response.c_str(), response.size() ); + close( peer_fd ); + return ok; +} + void OAuth2::Refresh( ) { std::string post = diff --git a/libgrive/src/protocol/OAuth2.hh b/libgrive/src/protocol/OAuth2.hh index e9a23da..ae40db7 100644 --- a/libgrive/src/protocol/OAuth2.hh +++ b/libgrive/src/protocol/OAuth2.hh @@ -41,13 +41,15 @@ public : const std::string& refresh_code, const std::string& client_id, const std::string& client_secret ) ; + ~OAuth2( ) ; std::string Str() const ; std::string MakeAuthURL() ; - void Auth( const std::string& auth_code ) ; + bool Auth( const std::string& auth_code ) ; void Refresh( ) ; + bool GetCode( ) ; std::string RefreshToken( ) const ; std::string AccessToken( ) const ; @@ -59,7 +61,9 @@ private : std::string m_access ; std::string m_refresh ; http::Agent* m_agent ; - + int m_port ; + int m_socket ; + const std::string m_client_id ; const std::string m_client_secret ; } ; diff --git a/libgrive/src/util/Config.cc b/libgrive/src/util/Config.cc index c08972f..7d17a55 100644 --- a/libgrive/src/util/Config.cc +++ b/libgrive/src/util/Config.cc @@ -84,7 +84,7 @@ void Config::Save( ) void Config::Set( const std::string& key, const Val& value ) { - m_file.Add( key, value ) ; + m_file.Set( key, value ) ; } Val Config::Get( const std::string& key ) const