// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/external_protocol/auto_launch_protocols_policy_handler.h"

#include <memory>
#include <string>
#include <utility>

#include "base/strings/string_util.h"
#include "base/values.h"
#include "chrome/common/pref_names.h"
#include "components/policy/core/browser/policy_error_map.h"
#include "components/policy/core/common/policy_map.h"
#include "components/policy/core/common/policy_pref_names.h"
#include "components/policy/policy_constants.h"
#include "components/prefs/pref_value_map.h"
#include "components/strings/grit/components_strings.h"
#include "url/gurl.h"

namespace policy {

const char AutoLaunchProtocolsPolicyHandler::kProtocolNameKey[] = "protocol";
const char AutoLaunchProtocolsPolicyHandler::kOriginListKey[] =
    "allowed_origins";

namespace {
const char kValidProtocolChars[] =
    "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ+-.";

bool IsValidProtocol(const base::StringPiece protocol) {
  // RFC3986: ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
  if (protocol.empty())
    return false;
  if (!base::IsAsciiAlpha(protocol.front()))
    return false;
  if (protocol.length() > 1 &&
      !base::ContainsOnlyChars(protocol, kValidProtocolChars)) {
    return false;
  }
  return true;
}

// Catches obvious errors like including a [/path] or [@query] element in the
// pattern.
bool IsValidOriginMatchingPattern(const base::StringPiece origin_pattern) {
  GURL gurl(origin_pattern);
  if (gurl.has_path() && gurl.path_piece() != "/")
    return false;
  if (gurl.has_query())
    return false;
  return true;
}

}  // namespace

AutoLaunchProtocolsPolicyHandler::AutoLaunchProtocolsPolicyHandler(
    const policy::Schema& chrome_schema)
    : SchemaValidatingPolicyHandler(
          policy::key::kAutoLaunchProtocolsFromOrigins,
          chrome_schema.GetKnownProperty(
              policy::key::kAutoLaunchProtocolsFromOrigins),
          policy::SCHEMA_ALLOW_UNKNOWN) {}

AutoLaunchProtocolsPolicyHandler::~AutoLaunchProtocolsPolicyHandler() = default;

bool AutoLaunchProtocolsPolicyHandler::CheckPolicySettings(
    const PolicyMap& policies,
    PolicyErrorMap* errors) {
  std::unique_ptr<base::Value> policy_value;
  if (!CheckAndGetValue(policies, nullptr, &policy_value) || !policy_value)
    return false;

  base::Value::ConstListView policy_list = policy_value->GetList();
  for (size_t i = 0; i < policy_list.size(); ++i) {
    const base::DictionaryValue& protocol_origins_map =
        base::Value::AsDictionaryValue(policy_list[i]);

    // If the protocol is invalid mark it as an error.
    const std::string* protocol = protocol_origins_map.FindStringKey(
        AutoLaunchProtocolsPolicyHandler::kProtocolNameKey);
    DCHECK(protocol);
    if (!IsValidProtocol(*protocol)) {
      errors->AddError(policy::key::kAutoLaunchProtocolsFromOrigins, i,
                       IDS_POLICY_VALUE_FORMAT_ERROR);
    }

    const base::Value* origins_list = protocol_origins_map.FindListKey(
        AutoLaunchProtocolsPolicyHandler::kOriginListKey);
    for (const auto& entry : origins_list->GetList()) {
      const std::string pattern = entry.GetString();
      // If it's not a valid origin pattern mark it as an error.
      if (!IsValidOriginMatchingPattern(pattern)) {
        errors->AddError(policy::key::kAutoLaunchProtocolsFromOrigins, i,
                         IDS_POLICY_VALUE_FORMAT_ERROR);
      }
    }
    // If the origin list is empty mark it as an error.
    if (origins_list->GetList().empty()) {
      errors->AddError(policy::key::kAutoLaunchProtocolsFromOrigins, i,
                       IDS_POLICY_VALUE_FORMAT_ERROR);
    }
  }

  // Always continue to ApplyPolicySettings which can remove invalid values and
  // apply the valid ones.
  return true;
}

void AutoLaunchProtocolsPolicyHandler::ApplyPolicySettings(
    const PolicyMap& policies,
    PrefValueMap* prefs) {
  std::unique_ptr<base::Value> policy_value;
  CheckAndGetValue(policies, nullptr, &policy_value);

  base::ListValue validated_pref_values;
  for (auto& protocol_origins_map : policy_value->GetList()) {
    // If the protocol is invalid skip the entry.
    const std::string* protocol = protocol_origins_map.FindStringKey(
        AutoLaunchProtocolsPolicyHandler::kProtocolNameKey);
    DCHECK(protocol);
    if (!IsValidProtocol(*protocol))
      continue;

    // Remove invalid patterns from the list.
    base::Value* origin_patterns_list = protocol_origins_map.FindListKey(
        AutoLaunchProtocolsPolicyHandler::kOriginListKey);
    origin_patterns_list->EraseListValueIf([](const base::Value& pattern) {
      return !IsValidOriginMatchingPattern(pattern.GetString());
    });
    // If the origin list is empty skip the entry.
    if (origin_patterns_list->GetList().size() == 0)
      continue;

    validated_pref_values.Append(protocol_origins_map.Clone());
  }
  prefs->SetValue(prefs::kAutoLaunchProtocolsFromOrigins,
                  std::move(validated_pref_values));
}

}  // namespace policy
